diff --git a/hscontrol/app.go b/hscontrol/app.go index d62acb34..3ba25d89 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -163,6 +163,14 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { return nil, fmt.Errorf("loading ACL policy: %w", err) } + // TODO(kradalby): There is an circular dependency here, maybe we should + // look at some sort of dependency injection? + // https://github.com/uber-go/dig + // or + // https://github.com/uber-go/fx + // Maybe overkill? + app.db.SetPolicyManager(app.polMan) + var authProvider AuthProvider authProvider = NewAuthProviderWeb(cfg.ServerURL) if cfg.OIDC.Issuer != "" { diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 49964173..fc58f2ca 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -153,27 +153,6 @@ func (h *Headscale) waitForFollowup( return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil) } -// canUsePreAuthKey checks if a pre auth key can be used. -func canUsePreAuthKey(pak *types.PreAuthKey) error { - if pak == nil { - return NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil) - } - if pak.Expiration != nil && pak.Expiration.Before(time.Now()) { - return NewHTTPError(http.StatusUnauthorized, "authkey expired", nil) - } - - // we don't need to check if has been used before - if pak.Reusable { - return nil - } - - if pak.Used { - return NewHTTPError(http.StatusUnauthorized, "authkey already used", nil) - } - - return nil -} - func (h *Headscale) handleRegisterWithAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, @@ -183,32 +162,28 @@ func (h *Headscale) handleRegisterWithAuthKey( if errors.Is(err, gorm.ErrRecordNotFound) { return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil) } - return nil, err - } - - err = canUsePreAuthKey(pak) - if err != nil { - return nil, err + return nil, NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil) } nodeToRegister := types.Node{ Hostname: regReq.Hostinfo.Hostname, - UserID: ptr.To(pak.User.ID), - User: ptr.To(pak.User), MachineKey: machineKey, NodeKey: regReq.NodeKey, Hostinfo: regReq.Hostinfo, LastSeen: ptr.To(time.Now()), RegisterMethod: util.RegisterMethodAuthKey, - // TODO(kradalby): This should not be set on the node, - // they should be looked up through the key, which is - // attached to the node. - Tags: pak.Proto().GetAclTags(), AuthKey: pak, AuthKeyID: &pak.ID, } + if pak.IsTagged() { + nodeToRegister.Tags = pak.Tags + } else { + nodeToRegister.UserID = pak.UserID + nodeToRegister.User = pak.User + } + if !regReq.Expiry.IsZero() { nodeToRegister.Expiry = ®Req.Expiry } @@ -257,7 +232,7 @@ func (h *Headscale) handleRegisterWithAuthKey( // This works, but might be another good candidate for doing some sort of // eventbus. routesChanged := policy.AutoApproveRoutes(h.polMan, node) - if err := h.db.DB.Save(node).Error; err != nil { + if err := h.db.SaveNode(node); err != nil { return nil, fmt.Errorf("saving auto approved routes to node: %w", err) } diff --git a/hscontrol/db/api_key.go b/hscontrol/db/api_key.go index 51083145..de803d32 100644 --- a/hscontrol/db/api_key.go +++ b/hscontrol/db/api_key.go @@ -65,7 +65,7 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { // GetAPIKey returns a ApiKey for a given key. func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { - key := types.APIKey{} + var key types.APIKey if result := hsdb.DB.First(&key, "prefix = ?", prefix); result.Error != nil { return nil, result.Error } @@ -75,7 +75,7 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { // GetAPIKeyByID returns a ApiKey for a given id. func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { - key := types.APIKey{} + var key types.APIKey if result := hsdb.DB.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { return nil, result.Error } diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 14b72767..69bef3e6 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -15,6 +15,7 @@ import ( "github.com/glebarez/sqlite" "github.com/go-gormigrate/gormigrate/v2" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -44,6 +45,7 @@ type HSDatabase struct { DB *gorm.DB cfg *types.DatabaseConfig regCache *zcache.Cache[types.RegistrationID, types.RegisterNode] + polMan policy.PolicyManager baseDomain string } @@ -766,6 +768,10 @@ AND auth_key_id NOT IN ( return &db, err } +func (db *HSDatabase) SetPolicyManager(pol policy.PolicyManager) { + db.polMan = pol +} + func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { // TODO(kradalby): Integrate this with zerolog var dbLogger logger.Interface diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 9b372f50..57e98d00 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -97,6 +98,23 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) { }) } +// SaveNode saves a node to the database. +// It performs checks to validate if the conforms to certain restrictions: +// - A node must be either tagged or owned by a user, not both. +func (hsdb *HSDatabase) SaveNode(node *types.Node) error { + if node.IsTagged() && node.UserID != nil { + return fmt.Errorf("node %q is tagged and has a user ID, has to be either tagged or owned by user", node.Hostname) + } + + if !node.IsTagged() && node.UserID == nil { + return fmt.Errorf("node %q is not tagged and has no user ID, has to be either tagged or owned by user", node.Hostname) + } + + slices.Sort(node.Tags) + node.Tags = slices.Compact(node.Tags) + return hsdb.DB.Save(node).Error +} + func (hsdb *HSDatabase) getNode(uid types.UserID, name string) (*types.Node, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { return getNode(rx, uid, name) @@ -196,18 +214,26 @@ func (hsdb *HSDatabase) SetTags( // SetTags takes a NodeID and update the forced tags. // It will overwrite any tags with the new list. +// If the node has a UserID, it will be unset as a node +// can only have a UserID or tags, not both. func SetTags( tx *gorm.DB, nodeID types.NodeID, tags []string, ) error { + // If no tags are provided, return an error. + // Tailscale does not support removing all tags from a node. + // A node needs to have either a User owner, or be tagged, and + // it is not supported to remove all tags and "return it to a user". if len(tags) == 0 { - // if no tags are provided, we remove all forced tags - if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("tags", "[]").Error; err != nil { - return fmt.Errorf("removing tags: %w", err) - } + return types.ErrCannotRemoveAllTags + } - return nil + // If the node has a UserID, we need to remove it. + // This is because a node can only have a UserID or tags, not both. + // We need to set the UserID to nil. + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("user_id", nil).Error; err != nil { + return fmt.Errorf("removing user from tagged node: %w", err) } slices.Sort(tags) @@ -224,7 +250,8 @@ func SetTags( return nil } -// SetTags takes a Node struct pointer and update the forced tags. +// SetApprovedRoutes takes a NodeID and a list of routes and updates the +// approved routes for the node. func SetApprovedRoutes( tx *gorm.DB, nodeID types.NodeID, @@ -339,6 +366,30 @@ func (hsdb *HSDatabase) DeleteEphemeralNode( }) } +func checkTags(polMan policy.PolicyManager, node *types.Node, reqTags []string) ([]string, error) { + if len(reqTags) == 0 { + return nil, nil + } + + var tags []string + var invalidTags []string + for _, tag := range reqTags { + if polMan.NodeCanHaveTag(node, tag) { + tags = append(tags, tag) + } else { + invalidTags = append(invalidTags, tag) + } + } + + if len(invalidTags) > 0 { + return nil, fmt.Errorf(`requested tags %v are invalid or not defined in policy`, invalidTags) + } + + slices.Sort(tags) + tags = slices.Compact(tags) + return tags, nil +} + // HandleNodeFromAuthPath is called from the OIDC or CLI auth path // with a registrationID to register or reauthenticate a node. // If the node found in the registration cache is not already registered, @@ -352,8 +403,9 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( registrationMethod string, ipv4 *netip.Addr, ipv6 *netip.Addr, -) (*types.Node, bool, error) { +) (*types.Node, types.ChangeSet, error) { var newNode bool + cs := types.ChangeSet{} node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { if reg, ok := hsdb.regCache.Get(registrationID); ok { if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil { @@ -381,8 +433,17 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( return nil, ErrDifferentRegisteredUser } - reg.Node.UserID = &user.ID - reg.Node.User = user + if reqTags := reg.Node.RequestTags(); len(reqTags) > 0 { + tags, err := checkTags(hsdb.polMan, ®.Node, reqTags) + if err != nil { + return nil, err + } + reg.Node.Tags = tags + } else { + reg.Node.UserID = &user.ID + reg.Node.User = user + } + reg.Node.RegisterMethod = registrationMethod if nodeExpiry != nil { @@ -406,14 +467,27 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( } close(reg.Registered) - newNode = true + cs.New = true return node, err } else { + if reqTags := reg.Node.RequestTags(); len(reqTags) > 0 { + tags, err := checkTags(hsdb.polMan, ®.Node, reqTags) + if err != nil { + return nil, err + } + err = SetTags(tx, node.ID, tags) + if err != nil { + return nil, err + } + cs.Tags = true + } + // If the node is already registered, this is a refresh. err := NodeSetExpiry(tx, node.ID, *nodeExpiry) if err != nil { return nil, err } + cs.Expiry = true return node, nil } } @@ -421,7 +495,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( return nil, ErrNodeNotFoundRegistrationCache }) - return node, newNode, err + return node, cs, err } func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index fce992ba..79e699fe 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -10,12 +10,10 @@ import ( "os" "slices" "sort" - "strings" "time" "github.com/puzpuzpuz/xsync/v3" "github.com/rs/zerolog/log" - "github.com/samber/lo" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" @@ -281,7 +279,7 @@ func (api headscaleV1APIServer) RegisterNode( // This works, but might be another good candidate for doing some sort of // eventbus. routesChanged := policy.AutoApproveRoutes(api.h.polMan, node) - if err := api.h.db.DB.Save(node).Error; err != nil { + if err := api.h.db.SaveNode(node); err != nil { return nil, fmt.Errorf("saving auto approved routes to node: %w", err) } @@ -315,15 +313,27 @@ func (api headscaleV1APIServer) SetTags( ctx context.Context, request *v1.SetTagsRequest, ) (*v1.SetTagsResponse, error) { + node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId())) + if err != nil { + return nil, err + } + + var tags []string + var invalidTags []string for _, tag := range request.GetTags() { - err := validateTag(tag) - if err != nil { - return nil, err + if api.h.polMan.NodeCanHaveTag(node, tag) { + tags = append(tags, tag) + } else { + invalidTags = append(invalidTags, tag) } } - node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { - err := db.SetTags(tx, types.NodeID(request.GetNodeId()), request.GetTags()) + if len(invalidTags) > 0 { + return nil, fmt.Errorf(`requested tags %v are invalid or not defined in policy`, invalidTags) + } + + node, err = db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { + err := db.SetTags(tx, types.NodeID(request.GetNodeId()), tags) if err != nil { return nil, err } @@ -395,19 +405,6 @@ func (api headscaleV1APIServer) SetApprovedRoutes( return &v1.SetApprovedRoutesResponse{Node: proto}, nil } -func validateTag(tag string) error { - if strings.Index(tag, "tag:") != 0 { - return errors.New("tag must start with the string 'tag:'") - } - if strings.ToLower(tag) != tag { - return errors.New("tag should be lowercase") - } - if len(strings.Fields(tag)) > 1 { - return errors.New("tag should not contains space") - } - return nil -} - func (api headscaleV1APIServer) DeleteNode( ctx context.Context, request *v1.DeleteNodeRequest, @@ -546,13 +543,8 @@ func nodesToProto(polMan policy.PolicyManager, isLikelyConnected *xsync.MapOf[ty resp.Online = true } - var tags []string - for _, tag := range node.RequestTags() { - if polMan.NodeCanHaveTag(node, tag) { - tags = append(tags, tag) - } - } - resp.ValidTags = lo.Uniq(append(tags, node.Tags...)) + // TODO(kradalby): Rename ValidTags, there is only Tags + resp.ValidTags = node.Tags resp.SubnetRoutes = util.PrefixesToString(append(pr.PrimaryRoutes(node.ID), node.ExitRoutes()...)) response[index] = resp } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index ad2b0fba..978b9d98 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -509,7 +509,7 @@ func (a *AuthProviderOIDC) handleRegistration( return false, err } - node, newNode, err := a.db.HandleNodeFromAuthPath( + node, cs, err := a.db.HandleNodeFromAuthPath( registrationID, types.UserID(user.ID), &expiry, @@ -540,7 +540,7 @@ func (a *AuthProviderOIDC) handleRegistration( // This works, but might be another good candidate for doing some sort of // eventbus. routesChanged := policy.AutoApproveRoutes(a.polMan, node) - if err := a.db.DB.Save(node).Error; err != nil { + if err := a.db.SaveNode(node); err != nil { return false, fmt.Errorf("saving auto approved routes to node: %w", err) } @@ -556,7 +556,7 @@ func (a *AuthProviderOIDC) handleRegistration( a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID) } - return newNode, nil + return cs.New, nil } // TODO(kradalby): diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 763ab85b..31eeca37 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -495,7 +495,7 @@ func (m *mapSession) handleEndpointUpdate() { // the hostname change. m.node.ApplyHostnameFromHostInfo(m.req.Hostinfo) - if err := m.h.db.DB.Save(m.node).Error; err != nil { + if err := m.h.db.SaveNode(m.node); err != nil { m.errf(err, "Failed to persist/update node in the database") http.Error(m.w, "", http.StatusInternalServerError) mapResponseEndpointUpdates.WithLabelValues("error").Inc() diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index c4cc8a2e..da39e695 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -196,3 +196,14 @@ type RegisterNode struct { Node Node Registered chan *Node } + +// TODO(kradalby): Not sure if this is a good idea, +// but ran into this problem in HandleNodeFromAuthPath +// describing what has changed in the node... +// ChangeSet described changes that has happend to a node +type ChangeSet struct { + NodeID NodeID + New bool + Tags bool + Expiry bool +} diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 91162f94..0833afde 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -183,11 +183,11 @@ func (node *Node) IsUserOwned() bool { return true } +var ErrCannotRemoveAllTags = errors.New("cannot remove all tags from node") + // IsTagged reports if a device is tagged // and therefore should not be treated as a // user owned device. -// Currently, this function only handles tags set -// via CLI ("forced tags" and preauthkeys) func (node *Node) IsTagged() bool { if node.Tags == nil { return false