diff --git a/hscontrol/app.go b/hscontrol/app.go index 02b1ece8..1b9e5f14 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -162,6 +162,14 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { return nil, fmt.Errorf("loading ACL policy: %w", err) } + // TODO(kradalby): There is an circular dependency here, maybe we should + // look at some sort of dependency injection? + // https://github.com/uber-go/dig + // or + // https://github.com/uber-go/fx + // Maybe overkill? + app.db.SetPolicyManager(app.polMan) + var authProvider AuthProvider authProvider = NewAuthProviderWeb(cfg.ServerURL) if cfg.OIDC.Issuer != "" { diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 941b51b2..fc58f2ca 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -153,27 +153,6 @@ func (h *Headscale) waitForFollowup( return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil) } -// canUsePreAuthKey checks if a pre auth key can be used. -func canUsePreAuthKey(pak *types.PreAuthKey) error { - if pak == nil { - return NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil) - } - if pak.Expiration != nil && pak.Expiration.Before(time.Now()) { - return NewHTTPError(http.StatusUnauthorized, "authkey expired", nil) - } - - // we don't need to check if has been used before - if pak.Reusable { - return nil - } - - if pak.Used { - return NewHTTPError(http.StatusUnauthorized, "authkey already used", nil) - } - - return nil -} - func (h *Headscale) handleRegisterWithAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, @@ -183,30 +162,26 @@ func (h *Headscale) handleRegisterWithAuthKey( if errors.Is(err, gorm.ErrRecordNotFound) { return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil) } - return nil, err - } - - err = canUsePreAuthKey(pak) - if err != nil { - return nil, err + return nil, NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil) } nodeToRegister := types.Node{ Hostname: regReq.Hostinfo.Hostname, - UserID: pak.User.ID, - User: pak.User, MachineKey: machineKey, NodeKey: regReq.NodeKey, Hostinfo: regReq.Hostinfo, LastSeen: ptr.To(time.Now()), RegisterMethod: util.RegisterMethodAuthKey, - // TODO(kradalby): This should not be set on the node, - // they should be looked up through the key, which is - // attached to the node. - ForcedTags: pak.Proto().GetAclTags(), - AuthKey: pak, - AuthKeyID: &pak.ID, + AuthKey: pak, + AuthKeyID: &pak.ID, + } + + if pak.IsTagged() { + nodeToRegister.Tags = pak.Tags + } else { + nodeToRegister.UserID = pak.UserID + nodeToRegister.User = pak.User } if !regReq.Expiry.IsZero() { @@ -257,7 +232,7 @@ func (h *Headscale) handleRegisterWithAuthKey( // This works, but might be another good candidate for doing some sort of // eventbus. routesChanged := policy.AutoApproveRoutes(h.polMan, node) - if err := h.db.DB.Save(node).Error; err != nil { + if err := h.db.SaveNode(node); err != nil { return nil, fmt.Errorf("saving auto approved routes to node: %w", err) } diff --git a/hscontrol/db/api_key.go b/hscontrol/db/api_key.go index 51083145..de803d32 100644 --- a/hscontrol/db/api_key.go +++ b/hscontrol/db/api_key.go @@ -65,7 +65,7 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { // GetAPIKey returns a ApiKey for a given key. func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { - key := types.APIKey{} + var key types.APIKey if result := hsdb.DB.First(&key, "prefix = ?", prefix); result.Error != nil { return nil, result.Error } @@ -75,7 +75,7 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { // GetAPIKeyByID returns a ApiKey for a given id. func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { - key := types.APIKey{} + var key types.APIKey if result := hsdb.DB.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { return nil, result.Error } diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index bab0061e..69bef3e6 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -15,6 +15,7 @@ import ( "github.com/glebarez/sqlite" "github.com/go-gormigrate/gormigrate/v2" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -44,6 +45,7 @@ type HSDatabase struct { DB *gorm.DB cfg *types.DatabaseConfig regCache *zcache.Cache[types.RegistrationID, types.RegisterNode] + polMan policy.PolicyManager baseDomain string } @@ -718,6 +720,36 @@ AND auth_key_id NOT IN ( }, Rollback: func(db *gorm.DB) error { return nil }, }, + // Migrate node table to make users optional. + // Rename forced_tags to tags + { + ID: "202505211519-node-user-optional-tags", + Migrate: func(tx *gorm.DB) error { + _ = tx.Migrator().RenameColumn(&types.Node{}, "forced_tags", "tags") + + err = tx.AutoMigrate(&types.Node{}) + if err != nil { + return fmt.Errorf("automigrating types.Node: %w", err) + } + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, + // Migrate preauthkey table to make users and tags optional. + // Use prefix+hash for keys. + { + ID: "202505231615-preauthkey-user-optional-tags-user", + Migrate: func(tx *gorm.DB) error { + err = tx.AutoMigrate(&types.PreAuthKey{}) + if err != nil { + return fmt.Errorf("automigrating types.PreAuthKey: %w", err) + } + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, }, ) @@ -736,6 +768,10 @@ AND auth_key_id NOT IN ( return &db, err } +func (db *HSDatabase) SetPolicyManager(pol policy.PolicyManager) { + db.polMan = pol +} + func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { // TODO(kradalby): Integrate this with zerolog var dbLogger logger.Interface diff --git a/hscontrol/db/ip_test.go b/hscontrol/db/ip_test.go index f558cdf7..73115876 100644 --- a/hscontrol/db/ip_test.go +++ b/hscontrol/db/ip_test.go @@ -96,7 +96,7 @@ func TestIPAllocatorSequential(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.1"), IPv6: nap("fd7a:115c:a1e0::1"), }) @@ -124,7 +124,7 @@ func TestIPAllocatorSequential(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.2"), IPv6: nap("fd7a:115c:a1e0::2"), }) @@ -314,7 +314,7 @@ func TestBackfillIPAddresses(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.1"), }) @@ -339,7 +339,7 @@ func TestBackfillIPAddresses(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv6: nap("fd7a:115c:a1e0::1"), }) @@ -364,7 +364,7 @@ func TestBackfillIPAddresses(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.1"), IPv6: nap("fd7a:115c:a1e0::1"), }) @@ -388,7 +388,7 @@ func TestBackfillIPAddresses(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.1"), IPv6: nap("fd7a:115c:a1e0::1"), }) @@ -412,19 +412,19 @@ func TestBackfillIPAddresses(t *testing.T) { db.DB.Save(&user) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.1"), }) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.2"), }) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.3"), }) db.DB.Save(&types.Node{ - User: user, + User: &user, IPv4: nap("100.64.0.4"), }) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index c91687da..57e98d00 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -97,6 +98,23 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) { }) } +// SaveNode saves a node to the database. +// It performs checks to validate if the conforms to certain restrictions: +// - A node must be either tagged or owned by a user, not both. +func (hsdb *HSDatabase) SaveNode(node *types.Node) error { + if node.IsTagged() && node.UserID != nil { + return fmt.Errorf("node %q is tagged and has a user ID, has to be either tagged or owned by user", node.Hostname) + } + + if !node.IsTagged() && node.UserID == nil { + return fmt.Errorf("node %q is not tagged and has no user ID, has to be either tagged or owned by user", node.Hostname) + } + + slices.Sort(node.Tags) + node.Tags = slices.Compact(node.Tags) + return hsdb.DB.Save(node).Error +} + func (hsdb *HSDatabase) getNode(uid types.UserID, name string) (*types.Node, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { return getNode(rx, uid, name) @@ -196,18 +214,26 @@ func (hsdb *HSDatabase) SetTags( // SetTags takes a NodeID and update the forced tags. // It will overwrite any tags with the new list. +// If the node has a UserID, it will be unset as a node +// can only have a UserID or tags, not both. func SetTags( tx *gorm.DB, nodeID types.NodeID, tags []string, ) error { + // If no tags are provided, return an error. + // Tailscale does not support removing all tags from a node. + // A node needs to have either a User owner, or be tagged, and + // it is not supported to remove all tags and "return it to a user". if len(tags) == 0 { - // if no tags are provided, we remove all forced tags - if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", "[]").Error; err != nil { - return fmt.Errorf("removing tags: %w", err) - } + return types.ErrCannotRemoveAllTags + } - return nil + // If the node has a UserID, we need to remove it. + // This is because a node can only have a UserID or tags, not both. + // We need to set the UserID to nil. + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("user_id", nil).Error; err != nil { + return fmt.Errorf("removing user from tagged node: %w", err) } slices.Sort(tags) @@ -217,14 +243,15 @@ func SetTags( return err } - if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", string(b)).Error; err != nil { + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("tags", string(b)).Error; err != nil { return fmt.Errorf("updating tags: %w", err) } return nil } -// SetTags takes a Node struct pointer and update the forced tags. +// SetApprovedRoutes takes a NodeID and a list of routes and updates the +// approved routes for the node. func SetApprovedRoutes( tx *gorm.DB, nodeID types.NodeID, @@ -339,6 +366,30 @@ func (hsdb *HSDatabase) DeleteEphemeralNode( }) } +func checkTags(polMan policy.PolicyManager, node *types.Node, reqTags []string) ([]string, error) { + if len(reqTags) == 0 { + return nil, nil + } + + var tags []string + var invalidTags []string + for _, tag := range reqTags { + if polMan.NodeCanHaveTag(node, tag) { + tags = append(tags, tag) + } else { + invalidTags = append(invalidTags, tag) + } + } + + if len(invalidTags) > 0 { + return nil, fmt.Errorf(`requested tags %v are invalid or not defined in policy`, invalidTags) + } + + slices.Sort(tags) + tags = slices.Compact(tags) + return tags, nil +} + // HandleNodeFromAuthPath is called from the OIDC or CLI auth path // with a registrationID to register or reauthenticate a node. // If the node found in the registration cache is not already registered, @@ -352,8 +403,9 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( registrationMethod string, ipv4 *netip.Addr, ipv6 *netip.Addr, -) (*types.Node, bool, error) { +) (*types.Node, types.ChangeSet, error) { var newNode bool + cs := types.ChangeSet{} node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { if reg, ok := hsdb.regCache.Get(registrationID); ok { if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil { @@ -376,12 +428,22 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( // Why not always? // Registration of expired node with different user if reg.Node.ID != 0 && - reg.Node.UserID != user.ID { + reg.Node.UserID != nil && + *reg.Node.UserID != user.ID { return nil, ErrDifferentRegisteredUser } - reg.Node.UserID = user.ID - reg.Node.User = *user + if reqTags := reg.Node.RequestTags(); len(reqTags) > 0 { + tags, err := checkTags(hsdb.polMan, ®.Node, reqTags) + if err != nil { + return nil, err + } + reg.Node.Tags = tags + } else { + reg.Node.UserID = &user.ID + reg.Node.User = user + } + reg.Node.RegisterMethod = registrationMethod if nodeExpiry != nil { @@ -405,14 +467,27 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( } close(reg.Registered) - newNode = true + cs.New = true return node, err } else { + if reqTags := reg.Node.RequestTags(); len(reqTags) > 0 { + tags, err := checkTags(hsdb.polMan, ®.Node, reqTags) + if err != nil { + return nil, err + } + err = SetTags(tx, node.ID, tags) + if err != nil { + return nil, err + } + cs.Tags = true + } + // If the node is already registered, this is a refresh. err := NodeSetExpiry(tx, node.ID, *nodeExpiry) if err != nil { return nil, err } + cs.Expiry = true return node, nil } } @@ -420,7 +495,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath( return nil, ErrNodeNotFoundRegistrationCache }) - return node, newNode, err + return node, cs, err } func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { @@ -435,7 +510,6 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad Str("node", node.Hostname). Str("machine_key", node.MachineKey.ShortString()). Str("node_key", node.NodeKey.ShortString()). - Str("user", node.User.Username()). Msg("Registering node") // If the a new node is registered with the same machine key, to the same user, @@ -463,7 +537,6 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad Str("node", node.Hostname). Str("machine_key", node.MachineKey.ShortString()). Str("node_key", node.NodeKey.ShortString()). - Str("user", node.User.Username()). Msg("Node authorized again") return &node, nil diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 56c967f1..4f4871d0 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -43,7 +43,7 @@ func (s *Suite) TestGetNode(c *check.C) { MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), } @@ -72,7 +72,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) { MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), } @@ -95,7 +95,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) { MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "testnode3", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, } trx := db.DB.Save(&node) @@ -127,7 +127,7 @@ func (s *Suite) TestListPeers(c *check.C) { MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "testnode" + strconv.Itoa(index), - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), } @@ -165,7 +165,7 @@ func (s *Suite) TestExpireNode(c *check.C) { MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), Expiry: &time.Time{}, @@ -206,7 +206,7 @@ func (s *Suite) TestSetTags(c *check.C) { MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), } @@ -220,7 +220,7 @@ func (s *Suite) TestSetTags(c *check.C) { c.Assert(err, check.IsNil) node, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) - c.Assert(node.ForcedTags, check.DeepEquals, sTags) + c.Assert(node.Tags, check.DeepEquals, sTags) // assign duplicate tags, expect no errors but no doubles in DB eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} @@ -229,7 +229,7 @@ func (s *Suite) TestSetTags(c *check.C) { node, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert( - node.ForcedTags, + node.Tags, check.DeepEquals, []string{"tag:bar", "tag:test", "tag:unknown"}, ) @@ -239,7 +239,7 @@ func (s *Suite) TestSetTags(c *check.C) { c.Assert(err, check.IsNil) node, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) - c.Assert(node.ForcedTags, check.DeepEquals, []string{}) + c.Assert(node.Tags, check.DeepEquals, []string{}) } func TestHeadscale_generateGivenName(t *testing.T) { @@ -451,7 +451,7 @@ func TestAutoApproveRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tt.routes, @@ -467,13 +467,13 @@ func TestAutoApproveRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "taggednode", - UserID: taggedUser.ID, + UserID: &taggedUser.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tt.routes, }, - ForcedTags: []string{"tag:exit"}, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), + Tags: []string{"tag:exit"}, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), } err = adb.DB.Save(&nodeTagged).Error @@ -612,7 +612,7 @@ func TestListEphemeralNodes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), } @@ -622,7 +622,7 @@ func TestListEphemeralNodes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "ephemeral", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pakEph.ID), } @@ -665,7 +665,7 @@ func TestRenameNode(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -675,7 +675,7 @@ func TestRenameNode(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test", - UserID: user2.ID, + UserID: &user2.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -765,7 +765,7 @@ func TestListPeers(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test1", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -775,7 +775,7 @@ func TestListPeers(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test2", - UserID: user2.ID, + UserID: &user2.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -849,7 +849,7 @@ func TestListNodes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test1", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } @@ -859,7 +859,7 @@ func TestListNodes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test2", - UserID: user2.ID, + UserID: &user2.ID, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{}, } diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index ee977ae3..58548a66 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -1,16 +1,17 @@ package db import ( - "crypto/rand" - "encoding/hex" "errors" "fmt" + "slices" "strings" "time" + v2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "golang.org/x/crypto/bcrypt" "gorm.io/gorm" - "tailscale.com/util/set" ) var ( @@ -19,72 +20,105 @@ var ( ErrSingleUseAuthKeyHasBeenUsed = errors.New("AuthKey has already been used") ErrUserMismatch = errors.New("user mismatch") ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid") + ErrPreAuthKeyFailedToParse = errors.New("failed to parse AuthKey") ) +const authKeyPrefix = "hskey-auth-" +const authKeyPrefixLength = 12 +const authKeyLength = 64 + func (hsdb *HSDatabase) CreatePreAuthKey( - uid types.UserID, + uid *types.UserID, reusable bool, ephemeral bool, expiration *time.Time, - aclTags []string, -) (*types.PreAuthKey, error) { - return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) { - return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags) + tags []string, +) (string, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (string, error) { + return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, tags) }) } // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. +// A PreAuthKey can be tagged or owned by a user, but not both. func CreatePreAuthKey( tx *gorm.DB, - uid types.UserID, + uid *types.UserID, reusable bool, ephemeral bool, expiration *time.Time, - aclTags []string, -) (*types.PreAuthKey, error) { - user, err := GetUserByID(tx, uid) - if err != nil { - return nil, err + tags []string, +) (string, error) { + var err error + var user *types.User + var userID *uint + + if uid == nil && len(tags) == 0 { + return "", errors.New("preauthkey must be either tagged or owned by user") } - // Remove duplicates - aclTags = set.SetOf(aclTags).Slice() + if uid != nil && len(tags) > 0 { + return "", errors.New("preauthkey cannot be both tagged and owned by user") + } - // TODO(kradalby): factor out and create a reusable tag validation, - // check if there is one in Tailscale's lib. - for _, tag := range aclTags { - if !strings.HasPrefix(tag, "tag:") { - return nil, fmt.Errorf( - "%w: '%s' did not begin with 'tag:'", - ErrPreAuthKeyACLTagInvalid, - tag, - ) + if uid != nil { + user, err = GetUserByID(tx, *uid) + if err != nil { + return "", err + } + + userID = &user.ID + } + + if len(tags) > 0 { + slices.Sort(tags) + tags = slices.Compact(tags) + + for _, tag := range tags { + t := v2.Tag(tag) + if err := t.Validate(); err != nil { + return "", fmt.Errorf("invalid tag: %w", tag, err) + } } } now := time.Now().UTC() - // TODO(kradalby): unify the key generations spread all over the code. - kstr, err := generateKey() + + prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) if err != nil { - return nil, err + return "", err + } + + toBeHashed, err := util.GenerateRandomStringURLSafe(apiKeyLength) + if err != nil { + return "", err + } + + // Key to return to user, this will only be visible _once_ + keyStr := authKeyPrefix + "-" + prefix + "-" + toBeHashed + + hash, err := bcrypt.GenerateFromPassword([]byte(toBeHashed), bcrypt.DefaultCost) + if err != nil { + return "", err } key := types.PreAuthKey{ - Key: kstr, - UserID: user.ID, - User: *user, Reusable: reusable, Ephemeral: ephemeral, CreatedAt: &now, Expiration: expiration, - Tags: aclTags, + UserID: userID, + User: user, + Tags: tags, + Prefix: prefix, + Hash: hash, } if err := tx.Save(&key).Error; err != nil { - return nil, fmt.Errorf("failed to create key in the database: %w", err) + return "", fmt.Errorf("failed to create key in the database: %w", err) } - return &key, nil + return keyStr, nil } func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) { @@ -101,28 +135,68 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e } keys := []types.PreAuthKey{} - if err := tx.Preload("User").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { + if err := tx.Preload("User").Where(&types.PreAuthKey{UserID: &user.ID}).Find(&keys).Error; err != nil { return nil, err } return keys, nil } -func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) { - return GetPreAuthKey(rx, key) +// GetPreAuthKey returns a PreAuthKey by its key string. +// It will return an error if the key is not found, or if it is expired, used or invalid. +func (hsdb *HSDatabase) GetPreAuthKey(keyStr string) (*types.PreAuthKey, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) { + return GetPreAuthKey(tx, keyStr) }) } -// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible -// for checking if the key is usable (expired or used). -func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) { - pak := types.PreAuthKey{} - if err := tx.Preload("User").First(&pak, "key = ?", key).Error; err != nil { +// GetPreAuthKey returns a PreAuthKey by its key string. +// It will return an error if the key is not found, or if it is expired, used or invalid. +func GetPreAuthKey(tx *gorm.DB, keyStr string) (*types.PreAuthKey, error) { + pak, err := findAuthKey(tx, keyStr) + if err != nil { + return nil, err + } + + if pak.Expiration != nil && pak.Expiration.Before(time.Now()) { + return nil, ErrPreAuthKeyExpired + } + + if pak.Used { + return nil, ErrSingleUseAuthKeyHasBeenUsed + } + + return pak, nil +} + +func findAuthKey(tx *gorm.DB, keyStr string) (*types.PreAuthKey, error) { + var pak *types.PreAuthKey + _, prefixAndHash, found := strings.Cut(keyStr, authKeyPrefix) + + if !found { + if err := tx.Preload("User").First(pak, "key = ?", keyStr).Error; err != nil { + return nil, ErrPreAuthKeyNotFound + } + } else { + prefix, hash, found := strings.Cut(prefixAndHash, "-") + if !found { + return nil, ErrPreAuthKeyFailedToParse + } + + if err := tx.Preload("User").First(pak, "prefix = ?", prefix).Error; err != nil { + return nil, ErrPreAuthKeyNotFound + } + + if err := bcrypt.CompareHashAndPassword(pak.Hash, []byte(hash)); err != nil { + return nil, err + } + } + + if pak == nil { return nil, ErrPreAuthKeyNotFound } - return &pak, nil + return pak, nil } // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey @@ -161,13 +235,3 @@ func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { return nil } - -func generateKey() (string, error) { - size := 24 - bytes := make([]byte, size) - if _, err := rand.Read(bytes); err != nil { - return "", err - } - - return hex.EncodeToString(bytes), nil -} diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index 5ace968a..a90cf41f 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -74,7 +74,7 @@ func TestCannotDeleteAssignedPreAuthKey(t *testing.T) { node := types.Node{ ID: 0, Hostname: "testest", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(key.ID), } diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index d7f31e5b..cdc6fc32 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -7,6 +7,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "gorm.io/gorm" + "tailscale.com/types/ptr" ) var ( @@ -192,7 +193,7 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { // ListNodesByUser gets all the nodes in a given user. func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) { nodes := types.Nodes{} - if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: uint(uid)}).Find(&nodes).Error; err != nil { + if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: ptr.To(uint(uid))}).Find(&nodes).Error; err != nil { return nil, err } @@ -211,7 +212,7 @@ func AssignNodeToUser(tx *gorm.DB, node *types.Node, uid types.UserID) error { if err != nil { return err } - node.User = *user + node.User = user if result := tx.Save(&node); result.Error != nil { return result.Error } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 6cec2d5a..6bbb77c0 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -52,7 +52,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { node := types.Node{ ID: 0, Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), } @@ -110,17 +110,17 @@ func (s *Suite) TestSetMachineUser(c *check.C) { node := types.Node{ ID: 0, Hostname: "testnode", - UserID: oldUser.ID, + UserID: &oldUser.ID, RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: ptr.To(pak.ID), } trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) - c.Assert(node.UserID, check.Equals, oldUser.ID) + c.Assert(*node.UserID, check.Equals, oldUser.ID) err = db.AssignNodeToUser(&node, types.UserID(newUser.ID)) c.Assert(err, check.IsNil) - c.Assert(node.UserID, check.Equals, newUser.ID) + c.Assert(*node.UserID, check.Equals, newUser.ID) c.Assert(node.User.Name, check.Equals, newUser.Name) err = db.AssignNodeToUser(&node, 9584849) @@ -128,6 +128,6 @@ func (s *Suite) TestSetMachineUser(c *check.C) { err = db.AssignNodeToUser(&node, types.UserID(newUser.ID)) c.Assert(err, check.IsNil) - c.Assert(node.UserID, check.Equals, newUser.ID) + c.Assert(*node.UserID, check.Equals, newUser.ID) c.Assert(node.User.Name, check.Equals, newUser.Name) } diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 7d31e2bb..f24ee8a2 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -10,12 +10,10 @@ import ( "os" "slices" "sort" - "strings" "time" "github.com/puzpuzpuz/xsync/v4" "github.com/rs/zerolog/log" - "github.com/samber/lo" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" @@ -281,7 +279,7 @@ func (api headscaleV1APIServer) RegisterNode( // This works, but might be another good candidate for doing some sort of // eventbus. routesChanged := policy.AutoApproveRoutes(api.h.polMan, node) - if err := api.h.db.DB.Save(node).Error; err != nil { + if err := api.h.db.SaveNode(node); err != nil { return nil, fmt.Errorf("saving auto approved routes to node: %w", err) } @@ -315,15 +313,27 @@ func (api headscaleV1APIServer) SetTags( ctx context.Context, request *v1.SetTagsRequest, ) (*v1.SetTagsResponse, error) { + node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId())) + if err != nil { + return nil, err + } + + var tags []string + var invalidTags []string for _, tag := range request.GetTags() { - err := validateTag(tag) - if err != nil { - return nil, err + if api.h.polMan.NodeCanHaveTag(node, tag) { + tags = append(tags, tag) + } else { + invalidTags = append(invalidTags, tag) } } - node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { - err := db.SetTags(tx, types.NodeID(request.GetNodeId()), request.GetTags()) + if len(invalidTags) > 0 { + return nil, fmt.Errorf(`requested tags %v are invalid or not defined in policy`, invalidTags) + } + + node, err = db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { + err := db.SetTags(tx, types.NodeID(request.GetNodeId()), tags) if err != nil { return nil, err } @@ -395,19 +405,6 @@ func (api headscaleV1APIServer) SetApprovedRoutes( return &v1.SetApprovedRoutesResponse{Node: proto}, nil } -func validateTag(tag string) error { - if strings.Index(tag, "tag:") != 0 { - return errors.New("tag must start with the string 'tag:'") - } - if strings.ToLower(tag) != tag { - return errors.New("tag should be lowercase") - } - if len(strings.Fields(tag)) > 1 { - return errors.New("tag should not contains space") - } - return nil -} - func (api headscaleV1APIServer) DeleteNode( ctx context.Context, request *v1.DeleteNodeRequest, @@ -546,13 +543,8 @@ func nodesToProto(polMan policy.PolicyManager, isLikelyConnected *xsync.MapOf[ty resp.Online = true } - var tags []string - for _, tag := range node.RequestTags() { - if polMan.NodeCanHaveTag(node, tag) { - tags = append(tags, tag) - } - } - resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags...)) + // TODO(kradalby): Rename ValidTags, there is only Tags + resp.ValidTags = node.Tags resp.SubnetRoutes = util.PrefixesToString(append(pr.PrimaryRoutes(node.ID), node.ExitRoutes()...)) response[index] = resp } @@ -819,7 +811,7 @@ func (api headscaleV1APIServer) DebugCreateNode( NodeKey: key.NewNode().Public(), MachineKey: key.NewMachine().Public(), Hostname: request.GetName(), - User: *user, + User: user, Expiry: &time.Time{}, LastSeen: &time.Time{}, diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index d7deb0a5..e84b7397 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -104,11 +104,16 @@ func generateUserProfiles( ) []tailcfg.UserProfile { userMap := make(map[uint]*types.User) ids := make([]uint, 0, len(userMap)) - userMap[node.User.ID] = &node.User - ids = append(ids, node.User.ID) - for _, peer := range peers { - userMap[peer.User.ID] = &peer.User - ids = append(ids, peer.User.ID) + var tagged bool + if node.IsUserOwned() { + userMap[node.User.ID] = node.User + ids = append(ids, node.User.ID) + for _, peer := range peers { + userMap[peer.User.ID] = peer.User + ids = append(ids, peer.User.ID) + } + } else { + tagged = true } slices.Sort(ids) @@ -120,6 +125,10 @@ func generateUserProfiles( } } + if tagged { + profiles = append(profiles, types.TaggedDevices.TailscaleUserProfile()) + } + return profiles } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 8d2c60bb..eb76c03a 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -53,8 +53,8 @@ func TestDNSConfigMapResponse(t *testing.T) { mach := func(hostname, username string, userid uint) *types.Node { return &types.Node{ Hostname: hostname, - UserID: userid, - User: types.User{ + UserID: &userid, + User: &types.User{ Name: username, }, } @@ -128,15 +128,15 @@ func Test_fullMapResponse(t *testing.T) { DiscoKey: mustDK( "discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", ), - IPv4: iap("100.64.0.1"), - Hostname: "mini", - GivenName: "mini", - UserID: user1.ID, - User: user1, - ForcedTags: []string{}, - AuthKey: &types.PreAuthKey{}, - LastSeen: &lastSeen, - Expiry: &expire, + IPv4: iap("100.64.0.1"), + Hostname: "mini", + GivenName: "mini", + UserID: &user1.ID, + User: &user1, + Tags: []string{}, + AuthKey: &types.PreAuthKey{}, + LastSeen: &lastSeen, + Expiry: &expire, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ tsaddr.AllIPv4(), @@ -205,16 +205,16 @@ func Test_fullMapResponse(t *testing.T) { DiscoKey: mustDK( "discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", ), - IPv4: iap("100.64.0.2"), - Hostname: "peer1", - GivenName: "peer1", - UserID: user2.ID, - User: user2, - ForcedTags: []string{}, - LastSeen: &lastSeen, - Expiry: &expire, - Hostinfo: &tailcfg.Hostinfo{}, - CreatedAt: created, + IPv4: iap("100.64.0.2"), + Hostname: "peer1", + GivenName: "peer1", + UserID: &user2.ID, + User: &user2, + Tags: []string{}, + LastSeen: &lastSeen, + Expiry: &expire, + Hostinfo: &tailcfg.Hostinfo{}, + CreatedAt: created, } tailPeer1 := &tailcfg.Node{ diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index eae70e96..60e26ee1 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -6,7 +6,6 @@ import ( "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" - "github.com/samber/lo" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" ) @@ -72,14 +71,6 @@ func tailNode( return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err) } - var tags []string - for _, tag := range node.RequestTags() { - if polMan.NodeCanHaveTag(node, tag) { - tags = append(tags, tag) - } - } - tags = lo.Uniq(append(tags, node.ForcedTags...)) - routes := primaryRouteFunc(node.ID) allowed := append(node.Prefixes(), routes...) allowed = append(allowed, node.ExitRoutes()...) @@ -91,8 +82,6 @@ func tailNode( Name: hostname, Cap: capVer, - User: tailcfg.UserID(node.UserID), - Key: node.NodeKey, KeyExpiry: keyExpiry.UTC(), @@ -109,12 +98,20 @@ func tailNode( Online: node.IsOnline, - Tags: tags, + Tags: node.Tags, MachineAuthorized: !node.IsExpired(), Expired: node.IsExpired(), } + if node.IsUserOwned() { + tNode.User = tailcfg.UserID(*node.UserID) + } + + if node.IsTagged() { + tNode.User = tailcfg.UserID(types.TaggedDevices.ID) + } + tNode.CapMap = tailcfg.NodeCapMap{ tailcfg.CapabilityFileSharing: []tailcfg.RawMessage{}, tailcfg.CapabilityAdmin: []tailcfg.RawMessage{}, diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index cacc4930..9fc19c14 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -15,6 +15,7 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/ptr" ) func TestTailNode(t *testing.T) { @@ -70,7 +71,6 @@ func TestTailNode(t *testing.T) { HomeDERP: 0, LegacyDERPString: "127.3.3.40:0", Hostinfo: hiview(tailcfg.Hostinfo{}), - Tags: []string{}, MachineAuthorized: true, CapMap: tailcfg.NodeCapMap{ @@ -97,14 +97,13 @@ func TestTailNode(t *testing.T) { IPv4: iap("100.64.0.1"), Hostname: "mini", GivenName: "mini", - UserID: 0, - User: types.User{ + UserID: ptr.To(uint(0)), + User: &types.User{ Name: "mini", }, - ForcedTags: []string{}, - AuthKey: &types.PreAuthKey{}, - LastSeen: &lastSeen, - Expiry: &expire, + AuthKey: &types.PreAuthKey{}, + LastSeen: &lastSeen, + Expiry: &expire, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ tsaddr.AllIPv4(), @@ -156,8 +155,6 @@ func TestTailNode(t *testing.T) { }), Created: created, - Tags: []string{}, - LastSeen: &lastSeen, MachineAuthorized: true, @@ -184,7 +181,6 @@ func TestTailNode(t *testing.T) { HomeDERP: 0, LegacyDERPString: "127.3.3.40:0", Hostinfo: hiview(tailcfg.Hostinfo{}), - Tags: []string{}, MachineAuthorized: true, CapMap: tailcfg.NodeCapMap{ diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index ad2b0fba..978b9d98 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -509,7 +509,7 @@ func (a *AuthProviderOIDC) handleRegistration( return false, err } - node, newNode, err := a.db.HandleNodeFromAuthPath( + node, cs, err := a.db.HandleNodeFromAuthPath( registrationID, types.UserID(user.ID), &expiry, @@ -540,7 +540,7 @@ func (a *AuthProviderOIDC) handleRegistration( // This works, but might be another good candidate for doing some sort of // eventbus. routesChanged := policy.AutoApproveRoutes(a.polMan, node) - if err := a.db.DB.Save(node).Error; err != nil { + if err := a.db.SaveNode(node); err != nil { return false, fmt.Errorf("saving auto approved routes to node: %w", err) } @@ -556,7 +556,7 @@ func (a *AuthProviderOIDC) handleRegistration( a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID) } - return newNode, nil + return cs.New, nil } // TODO(kradalby): diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index 83d69eb8..b904572e 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -17,6 +17,7 @@ import ( "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" + "tailscale.com/types/ptr" "tailscale.com/util/must" ) @@ -142,15 +143,17 @@ func TestReduceFilterRules(t *testing.T) { } `, node: &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - User: users[0], + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), + User: &users[0], + UserID: &users[0].ID, }, peers: types.Nodes{ &types.Node{ - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - User: users[0], + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), + User: &users[0], + UserID: &users[0].ID, }, }, want: []tailcfg.FilterRule{}, @@ -189,9 +192,10 @@ func TestReduceFilterRules(t *testing.T) { } `, node: &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: users[1], + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: &users[1], + UserID: &users[1].ID, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ netip.MustParsePrefix("10.33.0.0/16"), @@ -200,9 +204,10 @@ func TestReduceFilterRules(t *testing.T) { }, peers: types.Nodes{ &types.Node{ - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0::2"), - User: users[1], + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: &users[1], + UserID: &users[1].ID, }, }, want: []tailcfg.FilterRule{ @@ -279,21 +284,24 @@ func TestReduceFilterRules(t *testing.T) { } `, node: &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: users[1], + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: &users[1], + UserID: &users[1].ID, }, peers: types.Nodes{ &types.Node{ - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0::2"), - User: users[2], + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: &users[2], + UserID: &users[2].ID, }, // "internal" exit node &types.Node{ - IPv4: ap("100.64.0.100"), - IPv6: ap("fd7a:115c:a1e0::100"), - User: users[3], + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: &users[3], + UserID: &users[3].ID, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -340,23 +348,26 @@ func TestReduceFilterRules(t *testing.T) { } `, node: &types.Node{ - IPv4: ap("100.64.0.100"), - IPv6: ap("fd7a:115c:a1e0::100"), - User: users[3], + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: &users[3], + UserID: &users[3].ID, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, }, peers: types.Nodes{ &types.Node{ - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0::2"), - User: users[2], + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: &users[2], + UserID: &users[2].ID, }, &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: users[1], + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: &users[1], + UserID: &users[1].ID, }, }, want: []tailcfg.FilterRule{ @@ -447,23 +458,26 @@ func TestReduceFilterRules(t *testing.T) { } `, node: &types.Node{ - IPv4: ap("100.64.0.100"), - IPv6: ap("fd7a:115c:a1e0::100"), - User: users[3], + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: &users[3], + UserID: &users[3].ID, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, }, peers: types.Nodes{ &types.Node{ - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0::2"), - User: users[2], + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: &users[2], + UserID: &users[2].ID, }, &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: users[1], + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: &users[1], + UserID: &users[1].ID, }, }, want: []tailcfg.FilterRule{ @@ -557,23 +571,26 @@ func TestReduceFilterRules(t *testing.T) { } `, node: &types.Node{ - IPv4: ap("100.64.0.100"), - IPv6: ap("fd7a:115c:a1e0::100"), - User: users[3], + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: &users[3], + UserID: &users[3].ID, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, }, }, peers: types.Nodes{ &types.Node{ - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0::2"), - User: users[2], + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: &users[2], + UserID: &users[2].ID, }, &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: users[1], + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: &users[1], + UserID: &users[1].ID, }, }, want: []tailcfg.FilterRule{ @@ -645,23 +662,26 @@ func TestReduceFilterRules(t *testing.T) { } `, node: &types.Node{ - IPv4: ap("100.64.0.100"), - IPv6: ap("fd7a:115c:a1e0::100"), - User: users[3], + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: &users[3], + UserID: &users[3].ID, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, }, }, peers: types.Nodes{ &types.Node{ - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0::2"), - User: users[2], + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: &users[2], + UserID: &users[2].ID, }, &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: users[1], + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: &users[1], + UserID: &users[1].ID, }, }, want: []tailcfg.FilterRule{ @@ -725,19 +745,21 @@ func TestReduceFilterRules(t *testing.T) { } `, node: &types.Node{ - IPv4: ap("100.64.0.100"), - IPv6: ap("fd7a:115c:a1e0::100"), - User: users[3], + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: &users[3], + UserID: &users[3].ID, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, }, - ForcedTags: []string{"tag:access-servers"}, + Tags: []string{"tag:access-servers"}, }, peers: types.Nodes{ &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: users[1], + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: &users[1], + UserID: &users[1].ID, }, }, want: []tailcfg.FilterRule{ @@ -791,15 +813,17 @@ func TestReduceFilterRules(t *testing.T) { } `, node: &types.Node{ - IPv4: ap("100.64.0.2"), - IPv6: ap("fd7a:115c:a1e0::2"), - User: users[3], + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: &users[3], + UserID: &users[3].ID, }, peers: types.Nodes{ &types.Node{ - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: users[1], + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: &users[1], + UserID: &users[1].ID, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")}, }, @@ -846,19 +870,22 @@ func TestReduceNodes(t *testing.T) { args: args{ nodes: types.Nodes{ // list of all nodes in the database &types.Node{ - ID: 1, - IPv4: ap("100.64.0.1"), - User: types.User{Name: "joe"}, + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, &types.Node{ - ID: 2, - IPv4: ap("100.64.0.2"), - User: types.User{Name: "marc"}, + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, &types.Node{ - ID: 3, - IPv4: ap("100.64.0.3"), - User: types.User{Name: "mickael"}, + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael", Model: gorm.Model{ID: 3}}, + UserID: ptr.To(uint(3)), }, }, rules: []tailcfg.FilterRule{ @@ -870,21 +897,24 @@ func TestReduceNodes(t *testing.T) { }, }, node: &types.Node{ // current nodes - ID: 1, - IPv4: ap("100.64.0.1"), - User: types.User{Name: "joe"}, + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, }, want: types.Nodes{ &types.Node{ - ID: 2, - IPv4: ap("100.64.0.2"), - User: types.User{Name: "marc"}, + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, &types.Node{ - ID: 3, - IPv4: ap("100.64.0.3"), - User: types.User{Name: "mickael"}, + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael", Model: gorm.Model{ID: 3}}, + UserID: ptr.To(uint(3)), }, }, }, @@ -893,19 +923,22 @@ func TestReduceNodes(t *testing.T) { args: args{ nodes: types.Nodes{ // list of all nodes in the database &types.Node{ - ID: 1, - IPv4: ap("100.64.0.1"), - User: types.User{Name: "joe"}, + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, &types.Node{ - ID: 2, - IPv4: ap("100.64.0.2"), - User: types.User{Name: "marc"}, + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, &types.Node{ - ID: 3, - IPv4: ap("100.64.0.3"), - User: types.User{Name: "mickael"}, + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael", Model: gorm.Model{ID: 3}}, + UserID: ptr.To(uint(3)), }, }, rules: []tailcfg.FilterRule{ // list of all ACLRules registered @@ -917,16 +950,18 @@ func TestReduceNodes(t *testing.T) { }, }, node: &types.Node{ // current nodes - ID: 1, - IPv4: ap("100.64.0.1"), - User: types.User{Name: "joe"}, + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, }, want: types.Nodes{ &types.Node{ - ID: 2, - IPv4: ap("100.64.0.2"), - User: types.User{Name: "marc"}, + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, }, @@ -935,19 +970,22 @@ func TestReduceNodes(t *testing.T) { args: args{ nodes: types.Nodes{ // list of all nodes in the database &types.Node{ - ID: 1, - IPv4: ap("100.64.0.1"), - User: types.User{Name: "joe"}, + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, &types.Node{ - ID: 2, - IPv4: ap("100.64.0.2"), - User: types.User{Name: "marc"}, + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, &types.Node{ - ID: 3, - IPv4: ap("100.64.0.3"), - User: types.User{Name: "mickael"}, + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael", Model: gorm.Model{ID: 3}}, + UserID: ptr.To(uint(3)), }, }, rules: []tailcfg.FilterRule{ // list of all ACLRules registered @@ -959,16 +997,18 @@ func TestReduceNodes(t *testing.T) { }, }, node: &types.Node{ // current nodes - ID: 2, - IPv4: ap("100.64.0.2"), - User: types.User{Name: "marc"}, + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, want: types.Nodes{ &types.Node{ - ID: 3, - IPv4: ap("100.64.0.3"), - User: types.User{Name: "mickael"}, + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael", Model: gorm.Model{ID: 3}}, + UserID: ptr.To(uint(3)), }, }, }, @@ -977,19 +1017,22 @@ func TestReduceNodes(t *testing.T) { args: args{ nodes: types.Nodes{ // list of all nodes in the database &types.Node{ - ID: 1, - IPv4: ap("100.64.0.1"), - User: types.User{Name: "joe"}, + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, &types.Node{ - ID: 2, - IPv4: ap("100.64.0.2"), - User: types.User{Name: "marc"}, + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, &types.Node{ - ID: 3, - IPv4: ap("100.64.0.3"), - User: types.User{Name: "mickael"}, + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael", Model: gorm.Model{ID: 3}}, + UserID: ptr.To(uint(3)), }, }, rules: []tailcfg.FilterRule{ // list of all ACLRules registered @@ -1001,16 +1044,18 @@ func TestReduceNodes(t *testing.T) { }, }, node: &types.Node{ // current nodes - ID: 1, - IPv4: ap("100.64.0.1"), - User: types.User{Name: "joe"}, + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, }, want: types.Nodes{ &types.Node{ - ID: 2, - IPv4: ap("100.64.0.2"), - User: types.User{Name: "marc"}, + ID: 2, + IPv4: ap("100.64.0.2"), + UserID: ptr.To(uint(2)), + User: &types.User{Name: "marc", Model: gorm.Model{ID: 2}}, }, }, }, @@ -1019,19 +1064,22 @@ func TestReduceNodes(t *testing.T) { args: args{ nodes: types.Nodes{ // list of all nodes in the database &types.Node{ - ID: 1, - IPv4: ap("100.64.0.1"), - User: types.User{Name: "joe"}, + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, &types.Node{ - ID: 2, - IPv4: ap("100.64.0.2"), - User: types.User{Name: "marc"}, + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, &types.Node{ - ID: 3, - IPv4: ap("100.64.0.3"), - User: types.User{Name: "mickael"}, + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael", Model: gorm.Model{ID: 3}}, + UserID: ptr.To(uint(3)), }, }, rules: []tailcfg.FilterRule{ // list of all ACLRules registered @@ -1043,21 +1091,24 @@ func TestReduceNodes(t *testing.T) { }, }, node: &types.Node{ // current nodes - ID: 2, - IPv4: ap("100.64.0.2"), - User: types.User{Name: "marc"}, + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, want: types.Nodes{ &types.Node{ - ID: 1, - IPv4: ap("100.64.0.1"), - User: types.User{Name: "joe"}, + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, &types.Node{ - ID: 3, - IPv4: ap("100.64.0.3"), - User: types.User{Name: "mickael"}, + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael", Model: gorm.Model{ID: 3}}, + UserID: ptr.To(uint(3)), }, }, }, @@ -1066,19 +1117,22 @@ func TestReduceNodes(t *testing.T) { args: args{ nodes: types.Nodes{ // list of all nodes in the database &types.Node{ - ID: 1, - IPv4: ap("100.64.0.1"), - User: types.User{Name: "joe"}, + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, &types.Node{ - ID: 2, - IPv4: ap("100.64.0.2"), - User: types.User{Name: "marc"}, + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, &types.Node{ - ID: 3, - IPv4: ap("100.64.0.3"), - User: types.User{Name: "mickael"}, + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael", Model: gorm.Model{ID: 3}}, + UserID: ptr.To(uint(3)), }, }, rules: []tailcfg.FilterRule{ // list of all ACLRules registered @@ -1090,21 +1144,24 @@ func TestReduceNodes(t *testing.T) { }, }, node: &types.Node{ // current nodes - ID: 2, - IPv4: ap("100.64.0.2"), - User: types.User{Name: "marc"}, + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, want: types.Nodes{ &types.Node{ - ID: 1, - IPv4: ap("100.64.0.1"), - User: types.User{Name: "joe"}, + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, &types.Node{ - ID: 3, - IPv4: ap("100.64.0.3"), - User: types.User{Name: "mickael"}, + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael", Model: gorm.Model{ID: 3}}, + UserID: ptr.To(uint(3)), }, }, }, @@ -1113,27 +1170,31 @@ func TestReduceNodes(t *testing.T) { args: args{ nodes: types.Nodes{ // list of all nodes in the database &types.Node{ - ID: 1, - IPv4: ap("100.64.0.1"), - User: types.User{Name: "joe"}, + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, &types.Node{ - ID: 2, - IPv4: ap("100.64.0.2"), - User: types.User{Name: "marc"}, + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, &types.Node{ - ID: 3, - IPv4: ap("100.64.0.3"), - User: types.User{Name: "mickael"}, + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael", Model: gorm.Model{ID: 3}}, + UserID: ptr.To(uint(3)), }, }, rules: []tailcfg.FilterRule{ // list of all ACLRules registered }, node: &types.Node{ // current nodes - ID: 2, - IPv4: ap("100.64.0.2"), - User: types.User{Name: "marc"}, + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, want: nil, @@ -1151,28 +1212,32 @@ func TestReduceNodes(t *testing.T) { Hostname: "ts-head-upcrmb", IPv4: ap("100.64.0.3"), IPv6: ap("fd7a:115c:a1e0::3"), - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, &types.Node{ ID: 2, Hostname: "ts-unstable-rlwpvr", IPv4: ap("100.64.0.4"), IPv6: ap("fd7a:115c:a1e0::4"), - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, &types.Node{ ID: 3, Hostname: "ts-head-8w6paa", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user2"}, + User: &types.User{Name: "user2", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, &types.Node{ ID: 4, Hostname: "ts-unstable-lys2ib", IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: &types.User{Name: "user2", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, rules: []tailcfg.FilterRule{ // list of all ACLRules registered @@ -1194,7 +1259,8 @@ func TestReduceNodes(t *testing.T) { Hostname: "ts-head-8w6paa", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user2"}, + User: &types.User{Name: "user2", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, want: types.Nodes{ @@ -1203,14 +1269,16 @@ func TestReduceNodes(t *testing.T) { Hostname: "ts-head-upcrmb", IPv4: ap("100.64.0.3"), IPv6: ap("fd7a:115c:a1e0::3"), - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, &types.Node{ ID: 2, Hostname: "ts-unstable-rlwpvr", IPv4: ap("100.64.0.4"), IPv6: ap("fd7a:115c:a1e0::4"), - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, }, }, @@ -1222,13 +1290,15 @@ func TestReduceNodes(t *testing.T) { ID: 1, IPv4: ap("100.64.0.2"), Hostname: "peer1", - User: types.User{Name: "mini"}, + User: &types.User{Name: "mini", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, { ID: 2, IPv4: ap("100.64.0.3"), Hostname: "peer2", - User: types.User{Name: "peer2"}, + User: &types.User{Name: "peer2", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, rules: []tailcfg.FilterRule{ @@ -1244,7 +1314,8 @@ func TestReduceNodes(t *testing.T) { ID: 0, IPv4: ap("100.64.0.1"), Hostname: "mini", - User: types.User{Name: "mini"}, + User: &types.User{Name: "mini", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, }, want: []*types.Node{ @@ -1252,7 +1323,8 @@ func TestReduceNodes(t *testing.T) { ID: 2, IPv4: ap("100.64.0.3"), Hostname: "peer2", - User: types.User{Name: "peer2"}, + User: &types.User{Name: "peer2", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, }, @@ -1264,19 +1336,22 @@ func TestReduceNodes(t *testing.T) { ID: 1, IPv4: ap("100.64.0.2"), Hostname: "user1-2", - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, { ID: 0, IPv4: ap("100.64.0.1"), Hostname: "user1-1", - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, { ID: 3, IPv4: ap("100.64.0.4"), Hostname: "user2-2", - User: types.User{Name: "user2"}, + User: &types.User{Name: "user2", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, rules: []tailcfg.FilterRule{ @@ -1313,7 +1388,8 @@ func TestReduceNodes(t *testing.T) { ID: 2, IPv4: ap("100.64.0.3"), Hostname: "user-2-1", - User: types.User{Name: "user2"}, + User: &types.User{Name: "user2", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, want: []*types.Node{ @@ -1321,19 +1397,22 @@ func TestReduceNodes(t *testing.T) { ID: 1, IPv4: ap("100.64.0.2"), Hostname: "user1-2", - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, { ID: 0, IPv4: ap("100.64.0.1"), Hostname: "user1-1", - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, { ID: 3, IPv4: ap("100.64.0.4"), Hostname: "user2-2", - User: types.User{Name: "user2"}, + User: &types.User{Name: "user2", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, }, @@ -1345,19 +1424,22 @@ func TestReduceNodes(t *testing.T) { ID: 1, IPv4: ap("100.64.0.2"), Hostname: "user1-2", - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, { ID: 2, IPv4: ap("100.64.0.3"), Hostname: "user-2-1", - User: types.User{Name: "user2"}, + User: &types.User{Name: "user2", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, { ID: 3, IPv4: ap("100.64.0.4"), Hostname: "user2-2", - User: types.User{Name: "user2"}, + User: &types.User{Name: "user2", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, rules: []tailcfg.FilterRule{ @@ -1394,7 +1476,8 @@ func TestReduceNodes(t *testing.T) { ID: 0, IPv4: ap("100.64.0.1"), Hostname: "user1-1", - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, }, want: []*types.Node{ @@ -1402,19 +1485,22 @@ func TestReduceNodes(t *testing.T) { ID: 1, IPv4: ap("100.64.0.2"), Hostname: "user1-2", - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, { ID: 2, IPv4: ap("100.64.0.3"), - Hostname: "user-2-1", - User: types.User{Name: "user2"}, + Hostname: "user-2-1", + User: &types.User{Name: "user2", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, { ID: 3, IPv4: ap("100.64.0.4"), Hostname: "user2-2", - User: types.User{Name: "user2"}, + User: &types.User{Name: "user2", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, }, @@ -1426,13 +1512,15 @@ func TestReduceNodes(t *testing.T) { ID: 1, IPv4: ap("100.64.0.1"), Hostname: "user1", - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, { ID: 2, IPv4: ap("100.64.0.2"), Hostname: "router", - User: types.User{Name: "router"}, + User: &types.User{Name: "router", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")}, }, @@ -1453,7 +1541,8 @@ func TestReduceNodes(t *testing.T) { ID: 1, IPv4: ap("100.64.0.1"), Hostname: "user1", - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, }, want: []*types.Node{ @@ -1461,7 +1550,8 @@ func TestReduceNodes(t *testing.T) { ID: 2, IPv4: ap("100.64.0.2"), Hostname: "router", - User: types.User{Name: "router"}, + User: &types.User{Name: "router", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")}, }, @@ -1477,7 +1567,8 @@ func TestReduceNodes(t *testing.T) { ID: 1, IPv4: ap("100.64.0.1"), Hostname: "router", - User: types.User{Name: "router"}, + User: &types.User{Name: "router", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")}, }, @@ -1487,7 +1578,8 @@ func TestReduceNodes(t *testing.T) { ID: 2, IPv4: ap("100.64.0.2"), Hostname: "node", - User: types.User{Name: "node"}, + User: &types.User{Name: "node", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, rules: []tailcfg.FilterRule{ @@ -1504,7 +1596,8 @@ func TestReduceNodes(t *testing.T) { ID: 1, IPv4: ap("100.64.0.1"), Hostname: "router", - User: types.User{Name: "router"}, + User: &types.User{Name: "router", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")}, }, @@ -1516,7 +1609,8 @@ func TestReduceNodes(t *testing.T) { ID: 2, IPv4: ap("100.64.0.2"), Hostname: "node", - User: types.User{Name: "node"}, + User: &types.User{Name: "node", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, }, @@ -1528,7 +1622,8 @@ func TestReduceNodes(t *testing.T) { ID: 1, IPv4: ap("100.64.0.1"), Hostname: "router", - User: types.User{Name: "router"}, + User: &types.User{Name: "router", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")}, }, @@ -1538,7 +1633,8 @@ func TestReduceNodes(t *testing.T) { ID: 2, IPv4: ap("100.64.0.2"), Hostname: "node", - User: types.User{Name: "node"}, + User: &types.User{Name: "node", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, rules: []tailcfg.FilterRule{ @@ -1555,7 +1651,8 @@ func TestReduceNodes(t *testing.T) { ID: 2, IPv4: ap("100.64.0.2"), Hostname: "node", - User: types.User{Name: "node"}, + User: &types.User{Name: "node", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, }, want: []*types.Node{ @@ -1563,7 +1660,8 @@ func TestReduceNodes(t *testing.T) { ID: 1, IPv4: ap("100.64.0.1"), Hostname: "router", - User: types.User{Name: "router"}, + User: &types.User{Name: "router", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")}, }, @@ -1599,28 +1697,28 @@ func TestSSHPolicyRules(t *testing.T) { nodeUser1 := types.Node{ Hostname: "user1-device", IPv4: ap("100.64.0.1"), - UserID: 1, - User: users[0], + User: &users[0], + UserID: ptr.To(users[0].ID), } nodeUser2 := types.Node{ Hostname: "user2-device", IPv4: ap("100.64.0.2"), - UserID: 2, - User: users[1], + User: &users[1], + UserID: ptr.To(users[1].ID), } taggedServer := types.Node{ - Hostname: "tagged-server", - IPv4: ap("100.64.0.3"), - UserID: 3, - User: users[2], - ForcedTags: []string{"tag:server"}, + Hostname: "tagged-server", + IPv4: ap("100.64.0.3"), + User: &users[2], + UserID: ptr.To(users[2].ID), + Tags: []string{"tag:server"}, } taggedClient := types.Node{ - Hostname: "tagged-client", - IPv4: ap("100.64.0.4"), - UserID: 2, - User: users[1], - ForcedTags: []string{"tag:client"}, + Hostname: "tagged-client", + IPv4: ap("100.64.0.4"), + User: &users[1], + UserID: ptr.To(users[1].ID), + Tags: []string{"tag:client"}, } tests := []struct { @@ -1984,9 +2082,10 @@ func TestReduceRoutes(t *testing.T) { name: "node-can-access-all-routes", args: args{ node: &types.Node{ - ID: 1, - IPv4: ap("100.64.0.1"), - User: types.User{Name: "user1"}, + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "user1", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, routes: []netip.Prefix{ netip.MustParsePrefix("10.0.0.0/24"), @@ -2014,7 +2113,7 @@ func TestReduceRoutes(t *testing.T) { node: &types.Node{ ID: 1, IPv4: ap("100.64.0.1"), - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1"}, }, routes: []netip.Prefix{ netip.MustParsePrefix("10.0.0.0/24"), @@ -2040,7 +2139,7 @@ func TestReduceRoutes(t *testing.T) { node: &types.Node{ ID: 1, IPv4: ap("100.64.0.1"), - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1"}, }, routes: []netip.Prefix{ netip.MustParsePrefix("10.0.0.0/24"), @@ -2068,7 +2167,7 @@ func TestReduceRoutes(t *testing.T) { node: &types.Node{ ID: 1, IPv4: ap("100.64.0.1"), - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1"}, }, routes: []netip.Prefix{ netip.MustParsePrefix("10.0.0.0/24"), @@ -2095,7 +2194,7 @@ func TestReduceRoutes(t *testing.T) { node: &types.Node{ ID: 1, IPv4: ap("100.64.0.1"), - User: types.User{Name: "user1"}, + User: &types.User{Name: "user1"}, }, routes: []netip.Prefix{ netip.MustParsePrefix("10.0.0.0/24"), @@ -2117,10 +2216,11 @@ func TestReduceRoutes(t *testing.T) { name: "node-with-both-ipv4-and-ipv6", args: args{ node: &types.Node{ - ID: 1, - IPv4: ap("100.64.0.1"), - IPv6: ap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + ID: 1, + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: &types.User{Name: "user1", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, routes: []netip.Prefix{ netip.MustParsePrefix("10.0.0.0/24"), @@ -2151,9 +2251,10 @@ func TestReduceRoutes(t *testing.T) { name: "router-with-multiple-routes-and-node-with-specific-access", args: args{ node: &types.Node{ - ID: 2, - IPv4: ap("100.64.0.2"), // Node IP - User: types.User{Name: "node"}, + ID: 2, + IPv4: ap("100.64.0.2"), // Node IP + User: &types.User{Name: "node", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, routes: []netip.Prefix{ netip.MustParsePrefix("10.10.10.0/24"), @@ -2183,9 +2284,10 @@ func TestReduceRoutes(t *testing.T) { name: "node-with-access-to-one-subnet-and-partial-overlap", args: args{ node: &types.Node{ - ID: 2, - IPv4: ap("100.64.0.2"), - User: types.User{Name: "node"}, + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "node", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, routes: []netip.Prefix{ netip.MustParsePrefix("10.10.10.0/24"), @@ -2212,7 +2314,7 @@ func TestReduceRoutes(t *testing.T) { node: &types.Node{ ID: 2, IPv4: ap("100.64.0.2"), - User: types.User{Name: "node"}, + User: &types.User{Name: "node"}, }, routes: []netip.Prefix{ netip.MustParsePrefix("10.10.10.0/24"), @@ -2240,7 +2342,7 @@ func TestReduceRoutes(t *testing.T) { node: &types.Node{ ID: 2, IPv4: ap("100.64.0.2"), - User: types.User{Name: "node"}, + User: &types.User{Name: "node"}, }, routes: []netip.Prefix{ netip.MustParsePrefix("10.10.10.0/24"), @@ -2278,7 +2380,7 @@ func TestReduceRoutes(t *testing.T) { node: &types.Node{ ID: 2, IPv4: ap("100.64.0.2"), // node with IP 100.64.0.2 - User: types.User{Name: "node"}, + User: &types.User{Name: "node"}, }, routes: []netip.Prefix{ netip.MustParsePrefix("10.10.10.0/24"), @@ -2311,9 +2413,10 @@ func TestReduceRoutes(t *testing.T) { args: args{ // When testing from router node's perspective node: &types.Node{ - ID: 1, - IPv4: ap("100.64.0.1"), // router with IP 100.64.0.1 - User: types.User{Name: "router"}, + ID: 1, + IPv4: ap("100.64.0.1"), // router with IP 100.64.0.1 + User: &types.User{Name: "router", Model: gorm.Model{ID: 1}}, + UserID: ptr.To(uint(1)), }, routes: []netip.Prefix{ netip.MustParsePrefix("10.10.10.0/24"), @@ -2353,9 +2456,10 @@ func TestReduceRoutes(t *testing.T) { name: "acl-specific-port-ranges-for-subnets", args: args{ node: &types.Node{ - ID: 2, - IPv4: ap("100.64.0.2"), // node - User: types.User{Name: "node"}, + ID: 2, + IPv4: ap("100.64.0.2"), // node + User: &types.User{Name: "node", Model: gorm.Model{ID: 2}}, + UserID: ptr.To(uint(2)), }, routes: []netip.Prefix{ netip.MustParsePrefix("10.10.10.0/24"), @@ -2389,7 +2493,7 @@ func TestReduceRoutes(t *testing.T) { node: &types.Node{ ID: 2, IPv4: ap("100.64.0.2"), // node - User: types.User{Name: "node"}, + User: &types.User{Name: "node"}, }, routes: []netip.Prefix{ netip.MustParsePrefix("10.10.10.0/24"), diff --git a/hscontrol/policy/route_approval_test.go b/hscontrol/policy/route_approval_test.go index 19d61d82..d671e5bd 100644 --- a/hscontrol/policy/route_approval_test.go +++ b/hscontrol/policy/route_approval_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" + "tailscale.com/types/ptr" ) func TestNodeCanApproveRoute(t *testing.T) { @@ -24,34 +25,34 @@ func TestNodeCanApproveRoute(t *testing.T) { ID: 1, Hostname: "user1-device", IPv4: ap("100.64.0.1"), - UserID: 1, - User: users[0], + UserID: ptr.To(uint(1)), + User: &users[0], } exitNode := types.Node{ ID: 2, Hostname: "user2-device", IPv4: ap("100.64.0.2"), - UserID: 2, - User: users[1], + UserID: ptr.To(uint(2)), + User: &users[1], } taggedNode := types.Node{ - ID: 3, - Hostname: "tagged-server", - IPv4: ap("100.64.0.3"), - UserID: 3, - User: users[2], - ForcedTags: []string{"tag:router"}, + ID: 3, + Hostname: "tagged-server", + IPv4: ap("100.64.0.3"), + UserID: ptr.To(uint(3)), + User: &users[2], + Tags: []string{"tag:router"}, } multiTagNode := types.Node{ - ID: 4, - Hostname: "multi-tag-node", - IPv4: ap("100.64.0.4"), - UserID: 2, - User: users[1], - ForcedTags: []string{"tag:router", "tag:server"}, + ID: 4, + Hostname: "multi-tag-node", + IPv4: ap("100.64.0.4"), + UserID: ptr.To(uint(2)), + User: &users[1], + Tags: []string{"tag:router", "tag:server"}, } tests := []struct { diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index b5f08164..7095f666 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -5,14 +5,11 @@ import ( "github.com/google/go-cmp/cmp" "github.com/juanfont/headscale/hscontrol/types" - "gorm.io/gorm" "tailscale.com/tailcfg" ) func TestParsing(t *testing.T) { - users := types.Users{ - {Model: gorm.Model{ID: 1}, Name: "testuser"}, - } + tu := GetTestUsers() tests := []struct { name string format string @@ -352,17 +349,19 @@ func TestParsing(t *testing.T) { } rules, err := pol.compileFilterRules( - users, + tu.FilteredSlice("testuser"), types.Nodes{ &types.Node{ IPv4: ap("100.100.100.100"), }, &types.Node{ IPv4: ap("200.200.200.200"), - User: users[0], + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{}, }, - }) + }, + ) if (err != nil) != tt.wantErr { t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index b61c5758..26735ae3 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -1,13 +1,13 @@ package v2 import ( - "github.com/juanfont/headscale/hscontrol/policy/matcher" "testing" + "github.com/juanfont/headscale/hscontrol/policy/matcher" + "github.com/google/go-cmp/cmp" "github.com/juanfont/headscale/hscontrol/types" "github.com/stretchr/testify/require" - "gorm.io/gorm" "tailscale.com/tailcfg" ) @@ -17,17 +17,14 @@ func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) Hostname: name, IPv4: ap(ipv4), IPv6: ap(ipv6), - User: user, - UserID: user.ID, + User: &user, + UserID: &user.ID, Hostinfo: hostinfo, } } func TestPolicyManager(t *testing.T) { - users := types.Users{ - {Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"}, - {Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"}, - } + tu := GetTestUsers() tests := []struct { name string @@ -47,7 +44,7 @@ func TestPolicyManager(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes) + pm, err := NewPolicyManager([]byte(tt.pol), tu.FilteredSlice("testuser", "otheruser"), tt.nodes) require.NoError(t, err) filter, matchers := pm.Filter() diff --git a/hscontrol/policy/v2/testusers_helper.go b/hscontrol/policy/v2/testusers_helper.go new file mode 100644 index 00000000..63d692f6 --- /dev/null +++ b/hscontrol/policy/v2/testusers_helper.go @@ -0,0 +1,96 @@ +package v2 + +import ( + "sync" + + "github.com/juanfont/headscale/hscontrol/types" + "gorm.io/gorm" +) + +// TestUsers provides a convenient way to manage test users across tests +type TestUsers struct { + users map[string]*types.User + once sync.Once +} + +var defaultTestUsers TestUsers + +// GetTestUsers returns a singleton instance of TestUsers with predefined test users +func GetTestUsers() *TestUsers { + defaultTestUsers.once.Do(func() { + defaultTestUsers.users = map[string]*types.User{ + "testuser": {Model: gorm.Model{ID: 1}, Name: "testuser"}, + "groupuser": {Model: gorm.Model{ID: 2}, Name: "groupuser"}, + "groupuser1": {Model: gorm.Model{ID: 3}, Name: "groupuser1"}, + "groupuser2": {Model: gorm.Model{ID: 4}, Name: "groupuser2"}, + "notme": {Model: gorm.Model{ID: 5}, Name: "notme"}, + "user1": {Model: gorm.Model{ID: 6}, Name: "user1"}, + "user2": {Model: gorm.Model{ID: 7}, Name: "user2"}, + "user3": {Model: gorm.Model{ID: 8}, Name: "user3"}, + "otheruser": {Model: gorm.Model{ID: 9}, Name: "otheruser", Email: "otheruser@headscale.net"}, + "mickael": {Model: gorm.Model{ID: 10}, Name: "mickael"}, + "user100": {Model: gorm.Model{ID: 11}, Name: "user100"}, + } + }) + return &defaultTestUsers +} + +// User returns a copy of the User with the given name +func (tu *TestUsers) User(name string) types.User { + if user, ok := tu.users[name]; ok { + return *user + } + // Return empty user if not found + return types.User{} +} + +// UserPtr returns a pointer to the User with the given name +func (tu *TestUsers) UserPtr(name string) *types.User { + return tu.users[name] +} + +// ID returns the ID for the given user name +func (tu *TestUsers) ID(name string) uint { + if user, ok := tu.users[name]; ok { + return user.ID + } + return 0 +} + +// IDPtr returns a pointer to the ID for the given user name +func (tu *TestUsers) IDPtr(name string) *uint { + if user, ok := tu.users[name]; ok { + id := user.ID + return &id + } + return nil +} + +// AsMap returns all users as a map +func (tu *TestUsers) AsMap() map[string]types.User { + result := make(map[string]types.User, len(tu.users)) + for name, user := range tu.users { + result[name] = *user + } + return result +} + +// AsSlice returns all users as a slice +func (tu *TestUsers) AsSlice() types.Users { + result := make(types.Users, 0, len(tu.users)) + for _, user := range tu.users { + result = append(result, *user) + } + return result +} + +// FilteredSlice returns a slice with only the specified user names +func (tu *TestUsers) FilteredSlice(names ...string) types.Users { + result := make(types.Users, 0, len(names)) + for _, name := range names { + if user, ok := tu.users[name]; ok { + result = append(result, *user) + } + } + return result +} \ No newline at end of file diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 941a645b..e7bd7738 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -189,7 +189,7 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*net } for _, node := range nodes { - if node.IsTagged() { + if !node.IsUserOwned() { continue } @@ -205,10 +205,14 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*net type Group string func (g Group) Validate() error { - if isGroup(string(g)) { - return nil + if !isGroup(string(g)) { + return fmt.Errorf(`Group has to start with "group:", got: %q`, g) } - return fmt.Errorf(`Group has to start with "group:", got: %q`, g) + + // Group name is everything after "group:" + groupName := string(g)[len("group:"):] + + return validateNameFormat(groupName, "Group", g) } func (g *Group) UnmarshalJSON(b []byte) error { @@ -266,10 +270,14 @@ func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx type Tag string func (t Tag) Validate() error { - if isTag(string(t)) { - return nil + if !isTag(string(t)) { + return fmt.Errorf(`tag has to start with "tag:", got: %q`, t) } - return fmt.Errorf(`tag has to start with "tag:", got: %q`, t) + + // Tag name is everything after "tag:" + tagName := string(t)[len("tag:"):] + + return validateNameFormat(tagName, "Tag", t) } func (t *Tag) UnmarshalJSON(b []byte) error { @@ -515,7 +523,7 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*n for _, node := range nodes { // Skip if node has forced tags - if len(node.ForcedTags) != 0 { + if len(node.Tags) != 0 { continue } @@ -548,7 +556,7 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*n for _, node := range nodes { // Include if node has forced tags - if len(node.ForcedTags) != 0 { + if len(node.Tags) != 0 { node.AppendToIPSet(&build) continue } @@ -654,6 +662,36 @@ func isTag(str string) bool { return strings.HasPrefix(str, "tag:") } +// validateNameFormat checks if a name follows the required format: +// - Must start with an ASCII letter (a-z, A-Z) +// - Can only contain ASCII letters, numbers, or dashes +func validateNameFormat(name string, typeLabel string, original interface{}) error { + // Check if empty + if len(name) == 0 { + return fmt.Errorf(`%s names cannot be empty, got: %q`, typeLabel, original) + } + + // Check if first character is an ASCII letter + firstChar := name[0] + if !((firstChar >= 'a' && firstChar <= 'z') || (firstChar >= 'A' && firstChar <= 'Z')) { + return fmt.Errorf(`%s names must start with a letter, got: %q`, typeLabel, original) + } + + // Check if all characters are ASCII letters, numbers, or dashes + for i := 0; i < len(name); i++ { + char := name[i] + isAsciiLetter := (char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') + isDigit := char >= '0' && char <= '9' + isDash := char == '-' + + if !isAsciiLetter && !isDigit && !isDash { + return fmt.Errorf(`%s names can only contain ASCII letters, numbers, or dashes, got: %q`, typeLabel, original) + } + } + + return nil +} + func isAutoGroup(str string) bool { return strings.HasPrefix(str, "autogroup:") } @@ -1077,6 +1115,39 @@ func (to TagOwners) MarshalJSON() ([]byte, error) { // TagOwners are a map of Tag to a list of the UserEntities that own the tag. type TagOwners map[Tag]Owners +// UnmarshalJSON overrides the default JSON unmarshalling for TagOwners to ensure +// that each tag name is validated using the isTag function and character validation rules. +// This ensures that all tag names conform to the expected format and character rules. +func (to *TagOwners) UnmarshalJSON(b []byte) error { + var rawTagOwners map[string][]string + if err := json.Unmarshal(b, &rawTagOwners); err != nil { + return err + } + + *to = make(TagOwners) + for key, value := range rawTagOwners { + tag := Tag(key) + if err := tag.Validate(); err != nil { + return err + } + + var owners Owners + + for _, o := range value { + owner, err := parseOwner(o) + if err != nil { + return err + } + + owners = append(owners, owner) + } + + (*to)[tag] = owners + } + + return nil +} + func (to TagOwners) Contains(tagOwner *Tag) error { if tagOwner == nil { return nil diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index ac2fc3b1..6ac0d7e8 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -6,17 +6,16 @@ import ( "strings" "testing" + "time" + "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/prometheus/common/model" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go4.org/netipx" - xmaps "golang.org/x/exp/maps" - "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/ptr" @@ -68,7 +67,7 @@ func TestMarshalJSON(t *testing.T) { // Marshal the policy to JSON marshalled, err := json.MarshalIndent(policy, "", " ") require.NoError(t, err) - + // Make sure all expected fields are present in the JSON jsonString := string(marshalled) assert.Contains(t, jsonString, "group:example") @@ -79,21 +78,21 @@ func TestMarshalJSON(t *testing.T) { assert.Contains(t, jsonString, "accept") assert.Contains(t, jsonString, "tcp") assert.Contains(t, jsonString, "80") - + // Unmarshal back to verify round trip var roundTripped Policy err = json.Unmarshal(marshalled, &roundTripped) require.NoError(t, err) - + // Compare the original and round-tripped policies - cmps := append(util.Comparers, + cmps := append(util.Comparers, cmp.Comparer(func(x, y Prefix) bool { return x == y }), cmpopts.IgnoreUnexported(Policy{}), cmpopts.EquateEmpty(), ) - + if diff := cmp.Diff(policy, &roundTripped, cmps...); diff != "" { t.Fatalf("round trip policy (-original +roundtripped):\n%s", diff) } @@ -390,6 +389,150 @@ func TestUnmarshalPolicy(t *testing.T) { // wantErr: `Username has to contain @, got: "group:inner"`, wantErr: `Nested groups are not allowed, found "group:inner" inside "group:example"`, }, + { + name: "invalid-group-name-special-chars", + input: ` +{ + "groups": { + "group:example@invalid": [ + "valid@example.com", + ], + }, +} +`, + wantErr: `Group names can only contain ASCII letters, numbers, or dashes, got: "group:example@invalid"`, + }, + { + name: "invalid-group-name-starting-with-number", + input: ` +{ + "groups": { + "group:123example": [ + "valid@example.com", + ], + }, +} +`, + wantErr: `Group names must start with a letter, got: "group:123example"`, + }, + { + name: "invalid-group-name-scandinavian-characters", + input: ` +{ + "groups": { + "group:æøå-example": [ + "valid@example.com", + ], + }, +} +`, + wantErr: `Group names must start with a letter, got: "group:æøå-example"`, + }, + { + name: "invalid-group-name-cyrillic-characters", + input: ` +{ + "groups": { + "group:группа": [ + "valid@example.com", + ], + }, +} +`, + wantErr: `Group names must start with a letter, got: "group:группа"`, + }, + { + name: "invalid-group-name-emoji", + input: ` +{ + "groups": { + "group:dev-😊": [ + "valid@example.com", + ], + }, +} +`, + wantErr: `Group names can only contain ASCII letters, numbers, or dashes, got: "group:dev-😊"`, + }, + { + name: "invalid-group-name-other-special-chars", + input: ` +{ + "groups": { + "group:dev_team": [ + "valid@example.com", + ], + }, +} +`, + wantErr: `Group names can only contain ASCII letters, numbers, or dashes, got: "group:dev_team"`, + }, + { + name: "invalid-tag-name-special-chars", + input: ` +{ + "tagOwners": { + "tag:test@invalid": ["valid@example.com"], + }, +} +`, + wantErr: `Tag names can only contain ASCII letters, numbers, or dashes, got: "tag:test@invalid"`, + }, + { + name: "invalid-tag-name-starting-with-number", + input: ` +{ + "tagOwners": { + "tag:123test": ["valid@example.com"], + }, +} +`, + wantErr: `Tag names must start with a letter, got: "tag:123test"`, + }, + { + name: "invalid-tag-name-scandinavian-characters", + input: ` +{ + "tagOwners": { + "tag:æøå-test": ["valid@example.com"], + }, +} +`, + wantErr: `Tag names must start with a letter, got: "tag:æøå-test"`, + }, + { + name: "invalid-tag-name-cyrillic-characters", + input: ` +{ + "tagOwners": { + "tag:тест": ["valid@example.com"], + }, +} +`, + wantErr: `Tag names must start with a letter, got: "tag:тест"`, + }, + { + name: "invalid-tag-name-emoji", + input: ` +{ + "tagOwners": { + "tag:test-😊": ["valid@example.com"], + }, +} +`, + wantErr: `Tag names can only contain ASCII letters, numbers, or dashes, got: "tag:test-😊"`, + }, + { + name: "invalid-tag-name-other-special-chars", + input: ` +{ + "tagOwners": { + "tag:test_underscore": ["valid@example.com"], + }, +} +`, + wantErr: `Tag names can only contain ASCII letters, numbers, or dashes, got: "tag:test_underscore"`, + }, { name: "invalid-addr", input: ` @@ -958,13 +1101,13 @@ func TestUnmarshalPolicy(t *testing.T) { }, } - cmps := append(util.Comparers, + cmps := append(util.Comparers, cmp.Comparer(func(x, y Prefix) bool { return x == y }), cmpopts.IgnoreUnexported(Policy{}), ) - + // For round-trip testing, we'll normalize the policies before comparing for _, tt := range tests { @@ -1001,9 +1144,9 @@ func TestUnmarshalPolicy(t *testing.T) { if err != nil { t.Fatalf("round-trip unmarshalling: %v", err) } - + // Add EquateEmpty to handle nil vs empty maps/slices - roundTripCmps := append(cmps, + roundTripCmps := append(cmps, cmpopts.EquateEmpty(), cmpopts.IgnoreUnexported(Policy{}), ) @@ -1028,13 +1171,7 @@ func pp(pref string) *Prefix { return ptr.To(Prefix(mp(pref))) } func p(pref string) Prefix { return Prefix(mp(pref)) } func TestResolvePolicy(t *testing.T) { - users := map[string]types.User{ - "testuser": {Model: gorm.Model{ID: 1}, Name: "testuser"}, - "groupuser": {Model: gorm.Model{ID: 2}, Name: "groupuser"}, - "groupuser1": {Model: gorm.Model{ID: 3}, Name: "groupuser1"}, - "groupuser2": {Model: gorm.Model{ID: 4}, Name: "groupuser2"}, - "notme": {Model: gorm.Model{ID: 5}, Name: "notme"}, - } + tu := GetTestUsers() tests := []struct { name string nodes types.Nodes @@ -1064,33 +1201,40 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Not matching other user { - User: users["notme"], - IPv4: ap("100.100.101.1"), + User: tu.UserPtr("notme"), + UserID: tu.IDPtr("notme"), + IPv4: ap("100.100.101.1"), }, - // Not matching forced tags + // Not matching tags, usernames are ignored if a node is tagged { - User: users["testuser"], - ForcedTags: []string{"tag:anything"}, - IPv4: ap("100.100.101.2"), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), + Tags: []string{"tag:anything"}, + IPv4: ap("100.100.101.2"), }, // not matchin pak tag + // since 0.27.0, tags are only considered when + // set directly on the node, not via pak. { - User: users["testuser"], + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), AuthKey: &types.PreAuthKey{ Tags: []string{"alsotagged"}, }, IPv4: ap("100.100.101.3"), }, { - User: users["testuser"], - IPv4: ap("100.100.101.103"), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), + IPv4: ap("100.100.101.103"), }, { - User: users["testuser"], - IPv4: ap("100.100.101.104"), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), + IPv4: ap("100.100.101.104"), }, }, - want: []netip.Prefix{mp("100.100.101.103/32"), mp("100.100.101.104/32")}, + want: []netip.Prefix{mp("100.100.101.3/32"), mp("100.100.101.103/32"), mp("100.100.101.104/32")}, }, { name: "group", @@ -1098,30 +1242,46 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Not matching other user { - User: users["notme"], - IPv4: ap("100.100.101.4"), + User: tu.UserPtr("notme"), + UserID: tu.IDPtr("notme"), + IPv4: ap("100.100.101.4"), }, // Not matching forced tags { - User: users["groupuser"], - ForcedTags: []string{"tag:anything"}, - IPv4: ap("100.100.101.5"), + User: tu.UserPtr("groupuser"), + UserID: tu.IDPtr("groupuser"), + Tags: []string{"tag:anything"}, + IPv4: ap("100.100.101.5"), }, // not matchin pak tag + // since 0.27.0, tags are only considered when + // set directly on the node, not via pak. { - User: users["groupuser"], + User: tu.UserPtr("groupuser"), + UserID: tu.IDPtr("groupuser"), AuthKey: &types.PreAuthKey{ Tags: []string{"tag:alsotagged"}, }, IPv4: ap("100.100.101.6"), }, { - User: users["groupuser"], - IPv4: ap("100.100.101.203"), + User: tu.UserPtr("groupuser"), + UserID: tu.IDPtr("groupuser"), + IPv4: ap("100.100.101.203"), + }, + // not matchin username because tagged + // since 0.27.0, tags are only considered when + // set directly on the node, not via pak. + { + User: tu.UserPtr("groupuser"), + UserID: tu.IDPtr("groupuser"), + Tags: []string{"tag:taggg"}, + IPv4: ap("100.100.101.209"), }, { - User: users["groupuser"], - IPv4: ap("100.100.101.204"), + User: tu.UserPtr("groupuser1"), + UserID: tu.IDPtr("groupuser1"), + IPv4: ap("100.100.101.204"), }, }, pol: &Policy{ @@ -1130,7 +1290,7 @@ func TestResolvePolicy(t *testing.T) { "group:othergroup": Usernames{"notmetoo"}, }, }, - want: []netip.Prefix{mp("100.100.101.203/32"), mp("100.100.101.204/32")}, + want: []netip.Prefix{mp("100.100.101.6/32"), mp("100.100.101.203/32")}, }, { name: "tag", @@ -1138,13 +1298,14 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Not matching other user { - User: users["notme"], - IPv4: ap("100.100.101.9"), + User: tu.UserPtr("notme"), + UserID: tu.IDPtr("notme"), + IPv4: ap("100.100.101.9"), }, // Not matching forced tags { - ForcedTags: []string{"tag:anything"}, - IPv4: ap("100.100.101.10"), + Tags: []string{"tag:anything"}, + IPv4: ap("100.100.101.10"), }, // not matchin pak tag { @@ -1153,12 +1314,21 @@ func TestResolvePolicy(t *testing.T) { }, IPv4: ap("100.100.101.11"), }, - // Not matching forced tags + // matching forced tags { - ForcedTags: []string{"tag:test"}, - IPv4: ap("100.100.101.234"), + Tags: []string{"tag:test"}, + IPv4: ap("100.100.101.234"), + }, + // matching tag with user (user is ignored) + { + User: tu.UserPtr("notme"), + UserID: tu.IDPtr("notme"), + Tags: []string{"tag:test"}, + IPv4: ap("100.100.101.109"), }, // not matchin pak tag + // since 0.27.0, tags are only considered when + // set directly on the node, not via pak. { AuthKey: &types.PreAuthKey{ Tags: []string{"tag:test"}, @@ -1168,7 +1338,7 @@ func TestResolvePolicy(t *testing.T) { }, // TODO(kradalby): tests handling TagOwners + hostinfo pol: &Policy{}, - want: []netip.Prefix{mp("100.100.101.234/32"), mp("100.100.101.239/32")}, + want: []netip.Prefix{mp("100.100.101.109/32"), mp("100.100.101.234/32")}, }, { name: "empty-policy", @@ -1191,12 +1361,14 @@ func TestResolvePolicy(t *testing.T) { toResolve: ptr.To(Group("group:testgroup")), nodes: types.Nodes{ { - User: users["groupuser1"], - IPv4: ap("100.100.101.203"), + User: tu.UserPtr("groupuser1"), + UserID: tu.IDPtr("groupuser1"), + IPv4: ap("100.100.101.203"), }, { - User: users["groupuser2"], - IPv4: ap("100.100.101.204"), + User: tu.UserPtr("groupuser2"), + UserID: tu.IDPtr("groupuser2"), + IPv4: ap("100.100.101.204"), }, }, pol: &Policy{ @@ -1216,8 +1388,9 @@ func TestResolvePolicy(t *testing.T) { toResolve: ptr.To(Username("invaliduser@")), nodes: types.Nodes{ { - User: users["testuser"], - IPv4: ap("100.100.101.103"), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), + IPv4: ap("100.100.101.103"), }, }, wantErr: `user with token "invaliduser@" not found`, @@ -1227,8 +1400,8 @@ func TestResolvePolicy(t *testing.T) { toResolve: tp("tag:invalid"), nodes: types.Nodes{ { - ForcedTags: []string{"tag:test"}, - IPv4: ap("100.100.101.234"), + Tags: []string{"tag:test"}, + IPv4: ap("100.100.101.234"), }, }, }, @@ -1248,18 +1421,21 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Node with no tags (should be included) { - User: users["testuser"], - IPv4: ap("100.100.101.1"), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), + IPv4: ap("100.100.101.1"), }, // Node with forced tags (should be excluded) { - User: users["testuser"], - ForcedTags: []string{"tag:test"}, - IPv4: ap("100.100.101.2"), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), + Tags: []string{"tag:test"}, + IPv4: ap("100.100.101.2"), }, // Node with allowed requested tag (should be excluded) { - User: users["testuser"], + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test"}, }, @@ -1267,7 +1443,8 @@ func TestResolvePolicy(t *testing.T) { }, // Node with non-allowed requested tag (should be included) { - User: users["testuser"], + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:notallowed"}, }, @@ -1275,7 +1452,8 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple requested tags, one allowed (should be excluded) { - User: users["testuser"], + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test", "tag:notallowed"}, }, @@ -1283,7 +1461,8 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple requested tags, none allowed (should be included) { - User: users["testuser"], + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:notallowed1", "tag:notallowed2"}, }, @@ -1307,18 +1486,21 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Node with no tags (should be excluded) { - User: users["testuser"], - IPv4: ap("100.100.101.1"), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), + IPv4: ap("100.100.101.1"), }, // Node with forced tag (should be included) { - User: users["testuser"], - ForcedTags: []string{"tag:test"}, - IPv4: ap("100.100.101.2"), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), + Tags: []string{"tag:test"}, + IPv4: ap("100.100.101.2"), }, // Node with allowed requested tag (should be included) { - User: users["testuser"], + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test"}, }, @@ -1326,7 +1508,8 @@ func TestResolvePolicy(t *testing.T) { }, // Node with non-allowed requested tag (should be excluded) { - User: users["testuser"], + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:notallowed"}, }, @@ -1334,7 +1517,8 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple requested tags, one allowed (should be included) { - User: users["testuser"], + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:test", "tag:notallowed"}, }, @@ -1342,7 +1526,8 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple requested tags, none allowed (should be excluded) { - User: users["testuser"], + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), Hostinfo: &tailcfg.Hostinfo{ RequestTags: []string{"tag:notallowed1", "tag:notallowed2"}, }, @@ -1350,9 +1535,10 @@ func TestResolvePolicy(t *testing.T) { }, // Node with multiple forced tags (should be included) { - User: users["testuser"], - ForcedTags: []string{"tag:test", "tag:other"}, - IPv4: ap("100.100.101.7"), + User: tu.UserPtr("testuser"), + UserID: tu.IDPtr("testuser"), + Tags: []string{"tag:test", "tag:other"}, + IPv4: ap("100.100.101.7"), }, }, pol: &Policy{ @@ -1376,7 +1562,7 @@ func TestResolvePolicy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ips, err := tt.toResolve.Resolve(tt.pol, - xmaps.Values(users), + tu.AsSlice(), tt.nodes) if tt.wantErr == "" { if err != nil { @@ -1405,32 +1591,31 @@ func TestResolvePolicy(t *testing.T) { } func TestResolveAutoApprovers(t *testing.T) { - users := types.Users{ - {Model: gorm.Model{ID: 1}, Name: "user1"}, - {Model: gorm.Model{ID: 2}, Name: "user2"}, - {Model: gorm.Model{ID: 3}, Name: "user3"}, - } + tu := GetTestUsers() nodes := types.Nodes{ { - IPv4: ap("100.64.0.1"), - User: users[0], + IPv4: ap("100.64.0.1"), + User: tu.UserPtr("user1"), + UserID: tu.IDPtr("user1"), }, { - IPv4: ap("100.64.0.2"), - User: users[1], + IPv4: ap("100.64.0.2"), + User: tu.UserPtr("user2"), + UserID: tu.IDPtr("user2"), }, { - IPv4: ap("100.64.0.3"), - User: users[2], + IPv4: ap("100.64.0.3"), + User: tu.UserPtr("user3"), + UserID: tu.IDPtr("user3"), }, { - IPv4: ap("100.64.0.4"), - ForcedTags: []string{"tag:testtag"}, + IPv4: ap("100.64.0.4"), + Tags: []string{"tag:testtag"}, }, { - IPv4: ap("100.64.0.5"), - ForcedTags: []string{"tag:exittest"}, + IPv4: ap("100.64.0.5"), + Tags: []string{"tag:exittest"}, }, } @@ -1557,7 +1742,7 @@ func TestResolveAutoApprovers(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, gotAllIPRoutes, err := resolveAutoApprovers(tt.policy, users, nodes) + got, gotAllIPRoutes, err := resolveAutoApprovers(tt.policy, tu.FilteredSlice("user1", "user2", "user3"), nodes) if (err != nil) != tt.wantErr { t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr) return @@ -1595,24 +1780,27 @@ func ipSetComparer(x, y *netipx.IPSet) bool { } func TestNodeCanApproveRoute(t *testing.T) { - users := types.Users{ - {Model: gorm.Model{ID: 1}, Name: "user1"}, - {Model: gorm.Model{ID: 2}, Name: "user2"}, - {Model: gorm.Model{ID: 3}, Name: "user3"}, - } + tu := GetTestUsers() nodes := types.Nodes{ { - IPv4: ap("100.64.0.1"), - User: users[0], + IPv4: ap("100.64.0.1"), + User: tu.UserPtr("user1"), + UserID: tu.IDPtr("user1"), }, { - IPv4: ap("100.64.0.2"), - User: users[1], + IPv4: ap("100.64.0.2"), + User: tu.UserPtr("user2"), + UserID: tu.IDPtr("user2"), }, { - IPv4: ap("100.64.0.3"), - User: users[2], + IPv4: ap("100.64.0.3"), + User: tu.UserPtr("user3"), + UserID: tu.IDPtr("user3"), + }, + { + IPv4: ap("100.64.0.4"), + Tags: []string{"tag:testtag"}, }, } @@ -1716,7 +1904,7 @@ func TestNodeCanApproveRoute(t *testing.T) { b, err := json.Marshal(tt.policy) require.NoError(t, err) - pm, err := NewPolicyManager(b, users, nodes) + pm, err := NewPolicyManager(b, tu.FilteredSlice("user1", "user2", "user3"), nodes) require.NoErrorf(t, err, "NewPolicyManager() error = %v", err) got := pm.NodeCanApproveRoute(tt.node, tt.route) @@ -1728,24 +1916,27 @@ func TestNodeCanApproveRoute(t *testing.T) { } func TestResolveTagOwners(t *testing.T) { - users := types.Users{ - {Model: gorm.Model{ID: 1}, Name: "user1"}, - {Model: gorm.Model{ID: 2}, Name: "user2"}, - {Model: gorm.Model{ID: 3}, Name: "user3"}, - } + tu := GetTestUsers() nodes := types.Nodes{ { - IPv4: ap("100.64.0.1"), - User: users[0], + IPv4: ap("100.64.0.1"), + User: tu.UserPtr("user1"), + UserID: tu.IDPtr("user1"), }, { - IPv4: ap("100.64.0.2"), - User: users[1], + IPv4: ap("100.64.0.2"), + User: tu.UserPtr("user2"), + UserID: tu.IDPtr("user2"), }, { - IPv4: ap("100.64.0.3"), - User: users[2], + IPv4: ap("100.64.0.3"), + User: tu.UserPtr("user3"), + UserID: tu.IDPtr("user3"), + }, + { + IPv4: ap("100.64.0.4"), + Tags: []string{"tag:testtag"}, }, } @@ -1800,7 +1991,7 @@ func TestResolveTagOwners(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := resolveTagOwners(tt.policy, users, nodes) + got, err := resolveTagOwners(tt.policy, tu.FilteredSlice("user1", "user2", "user3"), nodes) if (err != nil) != tt.wantErr { t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr) return @@ -1813,24 +2004,27 @@ func TestResolveTagOwners(t *testing.T) { } func TestNodeCanHaveTag(t *testing.T) { - users := types.Users{ - {Model: gorm.Model{ID: 1}, Name: "user1"}, - {Model: gorm.Model{ID: 2}, Name: "user2"}, - {Model: gorm.Model{ID: 3}, Name: "user3"}, - } + tu := GetTestUsers() nodes := types.Nodes{ { - IPv4: ap("100.64.0.1"), - User: users[0], + IPv4: ap("100.64.0.1"), + User: tu.UserPtr("user1"), + UserID: tu.IDPtr("user1"), }, { - IPv4: ap("100.64.0.2"), - User: users[1], + IPv4: ap("100.64.0.2"), + User: tu.UserPtr("user2"), + UserID: tu.IDPtr("user2"), }, { - IPv4: ap("100.64.0.3"), - User: users[2], + IPv4: ap("100.64.0.3"), + User: tu.UserPtr("user3"), + UserID: tu.IDPtr("user3"), + }, + { + IPv4: ap("100.64.0.4"), + Tags: []string{"tag:testtag"}, }, } @@ -1911,7 +2105,7 @@ func TestNodeCanHaveTag(t *testing.T) { b, err := json.Marshal(tt.policy) require.NoError(t, err) - pm, err := NewPolicyManager(b, users, nodes) + pm, err := NewPolicyManager(b, tu.FilteredSlice("user1", "user2", "user3"), nodes) if tt.wantErr != "" { require.ErrorContains(t, err, tt.wantErr) return diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 763ab85b..31eeca37 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -495,7 +495,7 @@ func (m *mapSession) handleEndpointUpdate() { // the hostname change. m.node.ApplyHostnameFromHostInfo(m.req.Hostinfo) - if err := m.h.db.DB.Save(m.node).Error; err != nil { + if err := m.h.db.SaveNode(m.node); err != nil { m.errf(err, "Failed to persist/update node in the database") http.Error(m.w, "", http.StatusInternalServerError) mapResponseEndpointUpdates.WithLabelValues("error").Inc() diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index c4cc8a2e..da39e695 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -196,3 +196,14 @@ type RegisterNode struct { Node Node Registered chan *Node } + +// TODO(kradalby): Not sure if this is a good idea, +// but ran into this problem in HandleNodeFromAuthPath +// describing what has changed in the node... +// ChangeSet described changes that has happend to a node +type ChangeSet struct { + NodeID NodeID + New bool + Tags bool + Expiry bool +} diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index da185563..0833afde 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -5,7 +5,6 @@ import ( "fmt" "net/netip" "slices" - "sort" "strconv" "strings" "time" @@ -51,6 +50,7 @@ func (id NodeID) String() string { } // Node is a Headscale client. +// A node is "owned" by either a user or a tag. type Node struct { ID NodeID `gorm:"primary_key"` @@ -76,21 +76,22 @@ type Node struct { // GivenName is the name used in all DNS related // parts of headscale. GivenName string `gorm:"type:varchar(63);unique_index"` - UserID uint - User User `gorm:"constraint:OnDelete:CASCADE;"` RegisterMethod string - // ForcedTags are tags set by CLI/API. It is not considered - // the source of truth, but is one of the sources from - // which a tag might originate. - // ForcedTags are _always_ applied to the node. - ForcedTags []string `gorm:"column:forced_tags;serializer:json"` + // UserID defines the user that owns the node. + // It is a foreign key to the User table. + // It is not set if the node is not owned by a user and is consider a tagged node. + UserID *uint `sql:"DEFAULT:NULL"` + User *User + + // Tags is a list of tags associated with the node. + // If not non-empty, the node is tagged. + // For historic reason, if the node is owned by a user and the tags + // are defined, then the node is considered a tagged node and the + // user is ignored. + Tags []string `gorm:"column:tags;serializer:json"` - // When a node has been created with a PreAuthKey, we need to - // prevent the preauthkey from being deleted before the node. - // The preauthkey can define "tags" of the node so we need it - // around. AuthKeyID *uint64 `sql:"DEFAULT:NULL"` AuthKey *PreAuthKey @@ -162,27 +163,39 @@ func (node *Node) HasIP(i netip.Addr) bool { return false } -// IsTagged reports if a device is tagged -// and therefore should not be treated as a -// user owned device. -// Currently, this function only handles tags set -// via CLI ("forced tags" and preauthkeys) -func (node *Node) IsTagged() bool { - if len(node.ForcedTags) > 0 { - return true - } - - if node.AuthKey != nil && len(node.AuthKey.Tags) > 0 { - return true - } - - if node.Hostinfo == nil { +// IsUserOwned reports if a node is owned by a user. +func (node *Node) IsUserOwned() bool { + // For historic reason, if the node is owned by a user and the tags + // are defined, then the node is considered a tagged node and the + // user is ignored. + if node.IsTagged() { return false } - // TODO(kradalby): Figure out how tagging should work - // and hostinfo.requestedtags. - // Do this in other work. + if node.UserID == nil { + return false + } + + if node.User == nil { + return false + } + + return true +} + +var ErrCannotRemoveAllTags = errors.New("cannot remove all tags from node") + +// IsTagged reports if a device is tagged +// and therefore should not be treated as a +// user owned device. +func (node *Node) IsTagged() bool { + if node.Tags == nil { + return false + } + + if len(node.Tags) > 0 { + return true + } return false } @@ -191,31 +204,12 @@ func (node *Node) IsTagged() bool { // Currently, this function only handles tags set // via CLI ("forced tags" and preauthkeys) func (node *Node) HasTag(tag string) bool { - return slices.Contains(node.Tags(), tag) -} - -func (node *Node) Tags() []string { - var tags []string - - if node.AuthKey != nil { - tags = append(tags, node.AuthKey.Tags...) - } - - // TODO(kradalby): Figure out how tagging should work - // and hostinfo.requestedtags. - // Do this in other work. - // #2417 - - tags = append(tags, node.ForcedTags...) - sort.Strings(tags) - tags = slices.Compact(tags) - - return tags + return slices.Contains(node.Tags, tag) } func (node *Node) RequestTags() []string { if node.Hostinfo == nil { - return []string{} + return nil } return node.Hostinfo.RequestTags @@ -341,7 +335,7 @@ func (node *Node) Proto() *v1.Node { Name: node.Hostname, GivenName: node.GivenName, User: node.User.Proto(), - ForcedTags: node.ForcedTags, + ForcedTags: node.Tags, // Only ApprovedRoutes and AvailableRoutes is set here. SubnetRoutes has // to be populated manually with PrimaryRoute, to ensure it includes the @@ -573,8 +567,10 @@ func (nodes Nodes) DebugString() string { func (node Node) DebugString() string { var sb strings.Builder fmt.Fprintf(&sb, "%s(%s):\n", node.Hostname, node.ID) - fmt.Fprintf(&sb, "\tUser: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username()) - fmt.Fprintf(&sb, "\tTags: %v\n", node.Tags()) + if node.IsUserOwned() { + fmt.Fprintf(&sb, "\tUser: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username()) + } + fmt.Fprintf(&sb, "\tTags: %v\n", node.Tags) fmt.Fprintf(&sb, "\tIPs: %v\n", node.IPs()) fmt.Fprintf(&sb, "\tApprovedRoutes: %v\n", node.ApprovedRoutes) fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes()) diff --git a/hscontrol/types/node_test.go b/hscontrol/types/node_test.go index c7261587..12c3bdbb 100644 --- a/hscontrol/types/node_test.go +++ b/hscontrol/types/node_test.go @@ -2,11 +2,12 @@ package types import ( "fmt" - "github.com/juanfont/headscale/hscontrol/policy/matcher" "net/netip" "strings" "testing" + "github.com/juanfont/headscale/hscontrol/policy/matcher" + "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -139,7 +140,7 @@ func TestNodeFQDN(t *testing.T) { name: "no-dnsconfig-with-username", node: Node{ GivenName: "test", - User: User{ + User: &User{ Name: "user", }, }, @@ -150,7 +151,7 @@ func TestNodeFQDN(t *testing.T) { name: "all-set", node: Node{ GivenName: "test", - User: User{ + User: &User{ Name: "user", }, }, @@ -160,7 +161,7 @@ func TestNodeFQDN(t *testing.T) { { name: "no-given-name", node: Node{ - User: User{ + User: &User{ Name: "user", }, }, @@ -179,7 +180,7 @@ func TestNodeFQDN(t *testing.T) { name: "no-dnsconfig", node: Node{ GivenName: "test", - User: User{ + User: &User{ Name: "user", }, }, diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index 3e4441dd..3c252d74 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -9,19 +9,30 @@ import ( // PreAuthKey describes a pre-authorization key usable in a particular user. type PreAuthKey struct { - ID uint64 `gorm:"primary_key"` - Key string - UserID uint - User User `gorm:"constraint:OnDelete:SET NULL;"` + ID uint64 `gorm:"primary_key"` + + // Old Key, for backwards compatibility + Key string + + // Encrypted key + Prefix string + Hash []byte + Reusable bool Ephemeral bool `gorm:"default:false"` Used bool `gorm:"default:false"` + // UserID if set, is the owner of the key. + // If a node is authenticated with this key, the node + // is assigned to this user. + UserID *uint `sql:"DEFAULT:NULL"` + User *User + // Tags are always applied to the node and is one of // the sources of tags a node might have. They are copied // from the PreAuthKey when the node logs in the first time, // and ignored after. - Tags []string `gorm:"serializer:json"` + Tags []string `gorm:"column:tags;serializer:json"` CreatedAt *time.Time Expiration *time.Time @@ -48,3 +59,16 @@ func (key *PreAuthKey) Proto() *v1.PreAuthKey { return &protoKey } + +// IsTagged reports if a key is tagged. +func (key *PreAuthKey) IsTagged() bool { + if key.Tags == nil { + return false + } + + if len(key.Tags) > 0 { + return true + } + + return false +} diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 6cd2c41a..c3982cbe 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -18,6 +18,16 @@ import ( "tailscale.com/tailcfg" ) +// TaggedDevices is a special user that is used to +// populate the tagged devices in the Tailscale MapResponse. +var TaggedDevices = User{ + // This ID is arbitrarily chosen, it is naively high to avoid + // and conflicts with other IDs. + Model: gorm.Model{ID: 2147455555}, + Name: "tagged-devices", + DisplayName: "Tagged Devices", +} + type UserID uint64 type Users []User @@ -273,7 +283,7 @@ func CleanIdentifier(identifier string) string { cleanParts = append(cleanParts, part) } } - + if len(cleanParts) == 0 { u.Path = "" } else {