diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 1dfd10ee..d5aa84d3 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -253,6 +253,7 @@ jobs: - TestSSHIsBlockedInACL - TestSSHUserOnlyIsolation - TestSSHAutogroupSelf + - TestSSHLocalpart - TestTagsAuthKeyWithTagRequestDifferentTag - TestTagsAuthKeyWithTagNoAdvertiseFlag - TestTagsAuthKeyWithTagCannotAddViaCLI diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index 9c97e39c..af8e6fbd 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -1487,6 +1487,146 @@ func TestSSHPolicyRules(t *testing.T) { } } +func TestSSHPolicyRules_Localpart(t *testing.T) { + users := []types.User{ + {Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}}, + {Name: "bob", Email: "bob@example.com", Model: gorm.Model{ID: 2}}, + } + + nodeAlice := types.Node{ + Hostname: "alice-device", + IPv4: ap("100.64.0.1"), + UserID: new(uint(1)), + User: new(users[0]), + } + nodeBob := types.Node{ + Hostname: "bob-device", + IPv4: ap("100.64.0.2"), + UserID: new(uint(2)), + User: new(users[1]), + } + nodeTaggedServer := types.Node{ + Hostname: "tagged-server", + IPv4: ap("100.64.0.5"), + UserID: new(uint(1)), + User: new(users[0]), + Tags: []string{"tag:server"}, + } + + tests := []struct { + name string + targetNode types.Node + peers types.Nodes + policy string + validate func(t *testing.T, got *tailcfg.SSHPolicy) + }{ + { + name: "localpart-maps-email-to-os-user", + targetNode: nodeTaggedServer, + peers: types.Nodes{&nodeAlice, &nodeBob}, + policy: `{ + "tagOwners": { + "tag:server": ["alice@example.com"] + }, + "ssh": [ + { + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:server"], + "users": ["localpart:*@example.com"] + } + ] + }`, + validate: func(t *testing.T, got *tailcfg.SSHPolicy) { + t.Helper() + require.NotNil(t, got) + require.Len(t, got.Rules, 2, "Should have per-user rules for alice and bob") + + foundAlice := false + foundBob := false + + for _, rule := range got.Rules { + if _, ok := rule.SSHUsers["alice"]; ok { + foundAlice = true + + assert.Equal(t, map[string]string{"alice": "alice"}, rule.SSHUsers) + require.Len(t, rule.Principals, 1) + assert.Equal(t, "100.64.0.1", rule.Principals[0].NodeIP) + assert.True(t, rule.Action.Accept) + } + + if _, ok := rule.SSHUsers["bob"]; ok { + foundBob = true + + assert.Equal(t, map[string]string{"bob": "bob"}, rule.SSHUsers) + require.Len(t, rule.Principals, 1) + assert.Equal(t, "100.64.0.2", rule.Principals[0].NodeIP) + } + } + + assert.True(t, foundAlice, "Should have alice's localpart rule") + assert.True(t, foundBob, "Should have bob's localpart rule") + }, + }, + { + name: "localpart-combined-with-root", + targetNode: nodeTaggedServer, + peers: types.Nodes{&nodeAlice}, + policy: `{ + "tagOwners": { + "tag:server": ["alice@example.com"] + }, + "ssh": [ + { + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:server"], + "users": ["localpart:*@example.com", "root"] + } + ] + }`, + validate: func(t *testing.T, got *tailcfg.SSHPolicy) { + t.Helper() + require.NotNil(t, got) + // 1 common rule for root + per-user localpart rules + require.GreaterOrEqual(t, len(got.Rules), 2, + "Should have common root rule and per-user localpart rules") + + foundRoot := false + foundAliceWithRoot := false + + for _, rule := range got.Rules { + if v, ok := rule.SSHUsers["root"]; ok && v == "root" { + foundRoot = true + } + + if _, hasAlice := rule.SSHUsers["alice"]; hasAlice { + if rule.SSHUsers["root"] == "root" { + foundAliceWithRoot = true + } + } + } + + assert.True(t, foundRoot, "Should have root mapping") + assert.True(t, foundAliceWithRoot, "Alice's localpart rule should merge root") + }, + }, + } + + for _, tt := range tests { + for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) { + t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { + pm, err := pmf(users, append(tt.peers, &tt.targetNode).ViewSlice()) + require.NoError(t, err) + + got, err := pm.SSHPolicy(tt.targetNode.View()) + require.NoError(t, err) + tt.validate(t, got) + }) + } + } +} + func TestReduceRoutes(t *testing.T) { type args struct { node *types.Node diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 9c2c5f17..08646d0c 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -3,6 +3,7 @@ package v2 import ( "errors" "fmt" + "maps" "net/netip" "slices" "strconv" @@ -384,21 +385,40 @@ func (pol *Policy) compileSSHPolicy( return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) } - userMap := make(map[string]string, len(rule.Users)) + // Build the "common" userMap for non-localpart entries (root, autogroup:nonroot, specific users). + const rootUser = "root" + + commonUserMap := make(map[string]string, len(rule.Users)) if rule.Users.ContainsNonRoot() { - userMap["*"] = "=" + commonUserMap["*"] = "=" // by default, we do not allow root unless explicitly stated - userMap["root"] = "" + commonUserMap[rootUser] = "" } if rule.Users.ContainsRoot() { - userMap["root"] = "root" + commonUserMap[rootUser] = rootUser } for _, u := range rule.Users.NormalUsers() { - userMap[u.String()] = u.String() + commonUserMap[u.String()] = u.String() } + // Resolve localpart entries into per-user rules. + // Each localpart:*@ entry maps users in that domain to their email local-part. + // Because SSHUsers is a static map per rule, we need a separate rule per user + // to constrain each user to only their own local-part. + localpartRules := resolveLocalpartRules( + rule.Users.LocalpartEntries(), + users, + nodes, + srcIPs, + commonUserMap, + &action, + ) + + // Determine whether the common userMap has any entries worth emitting. + hasCommonUsers := len(commonUserMap) > 0 + // Handle autogroup:self destinations (if any) // Note: Tagged nodes can't match autogroup:self, so skip this block for tagged nodes if len(autogroupSelfDests) > 0 && !node.IsTagged() { @@ -443,19 +463,46 @@ func (pol *Policy) compileSSHPolicy( } if filteredSrcSet != nil && len(filteredSrcSet.Prefixes()) > 0 { - var principals []*tailcfg.SSHPrincipal - for addr := range util.IPSetAddrIter(filteredSrcSet) { - principals = append(principals, &tailcfg.SSHPrincipal{ - NodeIP: addr.String(), - }) + // Emit common rule if there are non-localpart users + if hasCommonUsers { + var principals []*tailcfg.SSHPrincipal + for addr := range util.IPSetAddrIter(filteredSrcSet) { + principals = append(principals, &tailcfg.SSHPrincipal{ + NodeIP: addr.String(), + }) + } + + if len(principals) > 0 { + rules = append(rules, &tailcfg.SSHRule{ + Principals: principals, + SSHUsers: commonUserMap, + Action: &action, + }) + } } - if len(principals) > 0 { - rules = append(rules, &tailcfg.SSHRule{ - Principals: principals, - SSHUsers: userMap, - Action: &action, - }) + // Emit per-user localpart rules, filtered to autogroup:self sources + for _, lpRule := range localpartRules { + var filteredPrincipals []*tailcfg.SSHPrincipal + + for _, p := range lpRule.Principals { + addr, err := netip.ParseAddr(p.NodeIP) + if err != nil { + continue + } + + if filteredSrcSet.Contains(addr) { + filteredPrincipals = append(filteredPrincipals, p) + } + } + + if len(filteredPrincipals) > 0 { + rules = append(rules, &tailcfg.SSHRule{ + Principals: filteredPrincipals, + SSHUsers: lpRule.SSHUsers, + Action: lpRule.Action, + }) + } } } } @@ -484,21 +531,27 @@ func (pol *Policy) compileSSHPolicy( // Only create rule if this node is in the destination set if node.InIPSet(destSet) { - // For non-autogroup:self destinations, use all resolved sources (no filtering) - var principals []*tailcfg.SSHPrincipal - for addr := range util.IPSetAddrIter(srcIPs) { - principals = append(principals, &tailcfg.SSHPrincipal{ - NodeIP: addr.String(), - }) + // Emit common rule if there are non-localpart users + if hasCommonUsers { + // For non-autogroup:self destinations, use all resolved sources (no filtering) + var principals []*tailcfg.SSHPrincipal + for addr := range util.IPSetAddrIter(srcIPs) { + principals = append(principals, &tailcfg.SSHPrincipal{ + NodeIP: addr.String(), + }) + } + + if len(principals) > 0 { + rules = append(rules, &tailcfg.SSHRule{ + Principals: principals, + SSHUsers: commonUserMap, + Action: &action, + }) + } } - if len(principals) > 0 { - rules = append(rules, &tailcfg.SSHRule{ - Principals: principals, - SSHUsers: userMap, - Action: &action, - }) - } + // Emit per-user localpart rules + rules = append(rules, localpartRules...) } } } @@ -508,6 +561,103 @@ func (pol *Policy) compileSSHPolicy( }, nil } +// resolveLocalpartRules generates per-user SSH rules for localpart:*@ entries. +// For each localpart entry, it finds all users whose email is in the specified domain, +// extracts their email local-part, and creates a tailcfg.SSHRule scoped to that user's +// node IPs with an SSHUsers map that only allows their local-part. +// The commonUserMap entries (root, autogroup:nonroot, specific users) are merged into +// each per-user rule so that localpart rules compose with other user entries. +func resolveLocalpartRules( + localpartEntries []SSHUser, + users types.Users, + nodes views.Slice[types.NodeView], + srcIPs *netipx.IPSet, + commonUserMap map[string]string, + action *tailcfg.SSHAction, +) []*tailcfg.SSHRule { + if len(localpartEntries) == 0 { + return nil + } + + var rules []*tailcfg.SSHRule + + for _, entry := range localpartEntries { + domain, err := entry.ParseLocalpart() + if err != nil { + // Should not happen if validation passed, but skip gracefully. + log.Warn().Err(err).Msgf("skipping invalid localpart entry %q during SSH compilation", entry) + + continue + } + + // Find users whose email matches *@ and build per-user rules. + for _, user := range users { + if user.Email == "" { + continue + } + + atIdx := strings.LastIndex(user.Email, "@") + if atIdx < 0 { + continue + } + + emailDomain := user.Email[atIdx+1:] + if !strings.EqualFold(emailDomain, domain) { + continue + } + + localPart := user.Email[:atIdx] + + // Find this user's non-tagged nodes that are in the source IP set. + var userSrcIPs netipx.IPSetBuilder + + for _, n := range nodes.All() { + if n.IsTagged() { + continue + } + + if !n.User().Valid() || n.User().ID() != user.ID { + continue + } + + if slices.ContainsFunc(n.IPs(), srcIPs.Contains) { + n.AppendToIPSet(&userSrcIPs) + } + } + + userSrcSet, err := userSrcIPs.IPSet() + if err != nil || userSrcSet == nil || len(userSrcSet.Prefixes()) == 0 { + continue + } + + var principals []*tailcfg.SSHPrincipal + for addr := range util.IPSetAddrIter(userSrcSet) { + principals = append(principals, &tailcfg.SSHPrincipal{ + NodeIP: addr.String(), + }) + } + + if len(principals) == 0 { + continue + } + + // Build per-user SSHUsers map: start with the common entries, then add the localpart. + userMap := make(map[string]string, len(commonUserMap)+1) + maps.Copy(userMap, commonUserMap) + + userMap[localPart] = localPart + + rules = append(rules, &tailcfg.SSHRule{ + Principals: principals, + SSHUsers: userMap, + Action: action, + }) + } + } + + return rules +} + func ipSetToPrefixStringList(ips *netipx.IPSet) []string { var out []string diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index cdf7c131..519d8b5c 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -647,6 +647,280 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { } } +func TestCompileSSHPolicy_LocalpartMapping(t *testing.T) { + users := types.Users{ + {Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}}, + {Name: "bob", Email: "bob@example.com", Model: gorm.Model{ID: 2}}, + {Name: "charlie", Email: "charlie@other.com", Model: gorm.Model{ID: 3}}, + {Name: "dave", Model: gorm.Model{ID: 4}}, // CLI user, no email + } + + nodeTaggedServer := types.Node{ + Hostname: "tagged-server", + IPv4: createAddr("100.64.0.1"), + UserID: new(users[0].ID), + User: new(users[0]), + Tags: []string{"tag:server"}, + } + nodeAlice := types.Node{ + Hostname: "alice-device", + IPv4: createAddr("100.64.0.2"), + UserID: new(users[0].ID), + User: new(users[0]), + } + nodeBob := types.Node{ + Hostname: "bob-device", + IPv4: createAddr("100.64.0.3"), + UserID: new(users[1].ID), + User: new(users[1]), + } + nodeCharlie := types.Node{ + Hostname: "charlie-device", + IPv4: createAddr("100.64.0.4"), + UserID: new(users[2].ID), + User: new(users[2]), + } + nodeDave := types.Node{ + Hostname: "dave-device", + IPv4: createAddr("100.64.0.5"), + UserID: new(users[3].ID), + User: new(users[3]), + } + + nodes := types.Nodes{&nodeTaggedServer, &nodeAlice, &nodeBob, &nodeCharlie, &nodeDave} + + t.Run("localpart only", func(t *testing.T) { + policy := &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("alice@example.com")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{SSHUser("localpart:*@example.com")}, + }, + }, + } + require.NoError(t, policy.validate()) + + sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicy) + + // Should get 2 rules: one for alice, one for bob (both @example.com) + // charlie@other.com and dave (no email) should not match + require.Len(t, sshPolicy.Rules, 2, "Should have per-user rules for alice and bob") + + // Collect all rules' SSHUsers and principals + foundAlice := false + foundBob := false + + for _, rule := range sshPolicy.Rules { + if _, ok := rule.SSHUsers["alice"]; ok { + foundAlice = true + + assert.Equal(t, map[string]string{"alice": "alice"}, rule.SSHUsers) + require.Len(t, rule.Principals, 1) + assert.Equal(t, "100.64.0.2", rule.Principals[0].NodeIP) + } + + if _, ok := rule.SSHUsers["bob"]; ok { + foundBob = true + + assert.Equal(t, map[string]string{"bob": "bob"}, rule.SSHUsers) + require.Len(t, rule.Principals, 1) + assert.Equal(t, "100.64.0.3", rule.Principals[0].NodeIP) + } + } + + assert.True(t, foundAlice, "Should have a rule for alice's localpart") + assert.True(t, foundBob, "Should have a rule for bob's localpart") + }) + + t.Run("localpart with root", func(t *testing.T) { + policy := &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("alice@example.com")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{SSHUser("localpart:*@example.com"), "root"}, + }, + }, + } + require.NoError(t, policy.validate()) + + sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicy) + + // Should get 3 rules: common rule for root + per-user for alice + per-user for bob + require.Len(t, sshPolicy.Rules, 3, "Should have common root rule + per-user localpart rules") + + // Find the common rule (has root but no per-user localpart) + foundCommon := false + + for _, rule := range sshPolicy.Rules { + if v, ok := rule.SSHUsers["root"]; ok && v == "root" && len(rule.SSHUsers) == 1 { + foundCommon = true + } + } + + assert.True(t, foundCommon, "Should have a common rule with root mapping") + + // Per-user rules should also include root + for _, rule := range sshPolicy.Rules { + if _, hasAlice := rule.SSHUsers["alice"]; hasAlice { + assert.Equal(t, "root", rule.SSHUsers["root"], "Alice's localpart rule should merge root") + } + + if _, hasBob := rule.SSHUsers["bob"]; hasBob { + assert.Equal(t, "root", rule.SSHUsers["root"], "Bob's localpart rule should merge root") + } + } + }) + + t.Run("localpart no matching users in domain", func(t *testing.T) { + policy := &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("alice@example.com")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{SSHUser("localpart:*@nonexistent.com")}, + }, + }, + } + require.NoError(t, policy.validate()) + + sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicy) + assert.Empty(t, sshPolicy.Rules, "Should have no rules when no users match domain") + }) + + t.Run("localpart with special chars in email", func(t *testing.T) { + specialUsers := types.Users{ + {Name: "dave+sshuser", Email: "dave+sshuser@example.com", Model: gorm.Model{ID: 10}}, + } + nodeSpecial := types.Node{ + Hostname: "special-device", + IPv4: createAddr("100.64.0.10"), + UserID: new(specialUsers[0].ID), + User: new(specialUsers[0]), + } + + specialNodes := types.Nodes{&nodeTaggedServer, &nodeSpecial} + + policy := &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("dave+sshuser@example.com")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{SSHUser("localpart:*@example.com")}, + }, + }, + } + require.NoError(t, policy.validate()) + + sshPolicy, err := policy.compileSSHPolicy(specialUsers, nodeTaggedServer.View(), specialNodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicy) + require.Len(t, sshPolicy.Rules, 1) + + rule := sshPolicy.Rules[0] + // Per Tailscale docs: "if the login is dave+sshuser@example.com, + // Tailscale will map this to the SSH user dave+sshuser" + assert.Equal(t, map[string]string{"dave+sshuser": "dave+sshuser"}, rule.SSHUsers) + }) + + t.Run("localpart excludes CLI users without email", func(t *testing.T) { + // dave has no email, should be excluded from localpart matching + cliOnlyUsers := types.Users{ + {Name: "dave", Model: gorm.Model{ID: 4}}, + } + nodeDaveCli := types.Node{ + Hostname: "dave-cli-device", + IPv4: createAddr("100.64.0.5"), + UserID: new(cliOnlyUsers[0].ID), + User: new(cliOnlyUsers[0]), + } + + cliNodes := types.Nodes{&nodeTaggedServer, &nodeDaveCli} + + policy := &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("dave@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{SSHUser("localpart:*@example.com")}, + }, + }, + } + require.NoError(t, policy.validate()) + + sshPolicy, err := policy.compileSSHPolicy(cliOnlyUsers, nodeTaggedServer.View(), cliNodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicy) + assert.Empty(t, sshPolicy.Rules, "CLI users without email should not match localpart rules") + }) + + t.Run("localpart with multiple domains", func(t *testing.T) { + policy := &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("alice@example.com")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{ + SSHUser("localpart:*@example.com"), + SSHUser("localpart:*@other.com"), + }, + }, + }, + } + require.NoError(t, policy.validate()) + + sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicy) + + // alice@example.com, bob@example.com, charlie@other.com should all match + require.Len(t, sshPolicy.Rules, 3, "Should have rules for alice, bob, and charlie") + + localparts := make(map[string]bool) + + for _, rule := range sshPolicy.Rules { + for k := range rule.SSHUsers { + localparts[k] = true + } + } + + assert.True(t, localparts["alice"], "Should have alice's localpart rule") + assert.True(t, localparts["bob"], "Should have bob's localpart rule") + assert.True(t, localparts["charlie"], "Should have charlie's localpart rule") + }) +} + func TestCompileSSHPolicy_CheckAction(t *testing.T) { users := types.Users{ {Name: "user1", Model: gorm.Model{ID: 1}}, diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 8785bed0..b876dd69 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -43,6 +43,7 @@ var ( ErrSSHAutogroupSelfRequiresUserSource = errors.New("autogroup:self destination requires source to contain only users or groups, not tags or autogroup:tagged") ErrSSHTagSourceToAutogroupMember = errors.New("tags in SSH source cannot access autogroup:member (user-owned devices)") ErrSSHWildcardDestination = errors.New("wildcard (*) is not supported as SSH destination") + ErrInvalidLocalpart = errors.New("invalid localpart format, must be localpart:*@") ) // ACL validation errors. @@ -1953,6 +1954,14 @@ func (p *Policy) validate() error { continue } } + + if user.IsLocalpart() { + _, err := user.ParseLocalpart() + if err != nil { + errs = append(errs, err) + continue + } + } } for _, src := range ssh.Sources { @@ -2255,6 +2264,11 @@ type SSHDstAliases []Alias type SSHUsers []SSHUser +// SSHUserLocalpartPrefix is the prefix for localpart SSH user entries. +// Format: localpart:*@ +// See: https://tailscale.com/docs/features/tailscale-ssh#users +const SSHUserLocalpartPrefix = "localpart:" + func (u SSHUsers) ContainsRoot() bool { return slices.Contains(u, "root") } @@ -2263,9 +2277,25 @@ func (u SSHUsers) ContainsNonRoot() bool { return slices.Contains(u, SSHUser(AutoGroupNonRoot)) } +// ContainsLocalpart returns true if any entry has the localpart: prefix. +func (u SSHUsers) ContainsLocalpart() bool { + return slices.ContainsFunc(u, func(user SSHUser) bool { + return user.IsLocalpart() + }) +} + +// NormalUsers returns all SSH users that are not root, autogroup:nonroot, +// or localpart: entries. func (u SSHUsers) NormalUsers() []SSHUser { return slicesx.Filter(nil, u, func(user SSHUser) bool { - return user != "root" && user != SSHUser(AutoGroupNonRoot) + return user != "root" && user != SSHUser(AutoGroupNonRoot) && !user.IsLocalpart() + }) +} + +// LocalpartEntries returns only the localpart: prefixed entries. +func (u SSHUsers) LocalpartEntries() []SSHUser { + return slicesx.Filter(nil, u, func(user SSHUser) bool { + return user.IsLocalpart() }) } @@ -2275,6 +2305,41 @@ func (u SSHUser) String() string { return string(u) } +// IsLocalpart returns true if the SSHUser has the localpart: prefix. +func (u SSHUser) IsLocalpart() bool { + return strings.HasPrefix(string(u), SSHUserLocalpartPrefix) +} + +// ParseLocalpart validates and extracts the domain from a localpart: entry. +// The expected format is localpart:*@. +// Returns the domain part or an error if the format is invalid. +func (u SSHUser) ParseLocalpart() (string, error) { + if !u.IsLocalpart() { + return "", fmt.Errorf("%w: missing prefix %q in %q", ErrInvalidLocalpart, SSHUserLocalpartPrefix, u) + } + + pattern := strings.TrimPrefix(string(u), SSHUserLocalpartPrefix) + + // Must be *@ + atIdx := strings.LastIndex(pattern, "@") + if atIdx < 0 { + return "", fmt.Errorf("%w: missing @ in %q", ErrInvalidLocalpart, u) + } + + localPart := pattern[:atIdx] + domain := pattern[atIdx+1:] + + if localPart != "*" { + return "", fmt.Errorf("%w: local part must be *, got %q in %q", ErrInvalidLocalpart, localPart, u) + } + + if domain == "" { + return "", fmt.Errorf("%w: empty domain in %q", ErrInvalidLocalpart, u) + } + + return domain, nil +} + // MarshalJSON marshals the SSHUser to JSON. func (u SSHUser) MarshalJSON() ([]byte, error) { return json.Marshal(string(u)) diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index acea9c28..95899b17 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -1766,6 +1766,105 @@ func TestUnmarshalPolicy(t *testing.T) { }, }, }, + { + name: "ssh-localpart-valid", + input: ` +{ + "tagOwners": {"tag:prod": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:prod"], + "users": ["localpart:*@example.com"] + }] +} +`, + want: &Policy{ + TagOwners: TagOwners{ + Tag("tag:prod"): Owners{up("admin@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:prod")}, + Users: []SSHUser{SSHUser("localpart:*@example.com")}, + }, + }, + }, + }, + { + name: "ssh-localpart-with-other-users", + input: ` +{ + "tagOwners": {"tag:prod": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:prod"], + "users": ["localpart:*@example.com", "root", "autogroup:nonroot"] + }] +} +`, + want: &Policy{ + TagOwners: TagOwners{ + Tag("tag:prod"): Owners{up("admin@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:prod")}, + Users: []SSHUser{SSHUser("localpart:*@example.com"), "root", SSHUser(AutoGroupNonRoot)}, + }, + }, + }, + }, + { + name: "ssh-localpart-invalid-no-at-sign", + input: ` +{ + "tagOwners": {"tag:prod": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:prod"], + "users": ["localpart:foo"] + }] +} +`, + wantErr: "invalid localpart format", + }, + { + name: "ssh-localpart-invalid-non-wildcard", + input: ` +{ + "tagOwners": {"tag:prod": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:prod"], + "users": ["localpart:alice@example.com"] + }] +} +`, + wantErr: "invalid localpart format", + }, + { + name: "ssh-localpart-invalid-empty-domain", + input: ` +{ + "tagOwners": {"tag:prod": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:prod"], + "users": ["localpart:*@"] + }] +} +`, + wantErr: "invalid localpart format", + }, } cmps := append(util.Comparers, @@ -2576,6 +2675,154 @@ func TestSSHUsers_ContainsNonRoot(t *testing.T) { } } +func TestSSHUsers_ContainsLocalpart(t *testing.T) { + tests := []struct { + name string + users SSHUsers + expected bool + }{ + { + name: "empty users", + users: SSHUsers{}, + expected: false, + }, + { + name: "contains localpart", + users: SSHUsers{SSHUser("localpart:*@example.com")}, + expected: true, + }, + { + name: "does not contain localpart", + users: SSHUsers{"ubuntu", "admin", "root"}, + expected: false, + }, + { + name: "contains localpart among others", + users: SSHUsers{"ubuntu", SSHUser("localpart:*@example.com"), "admin"}, + expected: true, + }, + { + name: "multiple localpart entries", + users: SSHUsers{SSHUser("localpart:*@a.com"), SSHUser("localpart:*@b.com")}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.users.ContainsLocalpart() + assert.Equal(t, tt.expected, result, "ContainsLocalpart() should return expected result") + }) + } +} + +func TestSSHUsers_LocalpartEntries(t *testing.T) { + tests := []struct { + name string + users SSHUsers + expected []SSHUser + }{ + { + name: "empty users", + users: SSHUsers{}, + expected: []SSHUser{}, + }, + { + name: "no localpart entries", + users: SSHUsers{"root", "ubuntu", SSHUser(AutoGroupNonRoot)}, + expected: []SSHUser{}, + }, + { + name: "single localpart entry", + users: SSHUsers{"root", SSHUser("localpart:*@example.com"), "ubuntu"}, + expected: []SSHUser{SSHUser("localpart:*@example.com")}, + }, + { + name: "multiple localpart entries", + users: SSHUsers{SSHUser("localpart:*@a.com"), "root", SSHUser("localpart:*@b.com")}, + expected: []SSHUser{SSHUser("localpart:*@a.com"), SSHUser("localpart:*@b.com")}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.users.LocalpartEntries() + assert.ElementsMatch(t, tt.expected, result, "LocalpartEntries() should return expected entries") + }) + } +} + +func TestSSHUsers_NormalUsers_ExcludesLocalpart(t *testing.T) { + users := SSHUsers{ + "ubuntu", + "root", + SSHUser(AutoGroupNonRoot), + SSHUser("localpart:*@example.com"), + "admin", + } + + result := users.NormalUsers() + assert.ElementsMatch(t, []SSHUser{"ubuntu", "admin"}, result, + "NormalUsers() should exclude root, autogroup:nonroot, and localpart entries") +} + +func TestSSHUser_ParseLocalpart(t *testing.T) { + tests := []struct { + name string + user SSHUser + expectedDomain string + expectErr bool + }{ + { + name: "valid localpart", + user: SSHUser("localpart:*@example.com"), + expectedDomain: "example.com", + }, + { + name: "valid localpart with subdomain", + user: SSHUser("localpart:*@corp.example.com"), + expectedDomain: "corp.example.com", + }, + { + name: "missing prefix", + user: SSHUser("ubuntu"), + expectErr: true, + }, + { + name: "missing @ sign", + user: SSHUser("localpart:foo"), + expectErr: true, + }, + { + name: "non-wildcard local part", + user: SSHUser("localpart:alice@example.com"), + expectErr: true, + }, + { + name: "empty domain", + user: SSHUser("localpart:*@"), + expectErr: true, + }, + { + name: "just prefix", + user: SSHUser("localpart:"), + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + domain, err := tt.user.ParseLocalpart() + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedDomain, domain) + } + }) + } +} + func mustIPSet(prefixes ...string) *netipx.IPSet { var builder netipx.IPSetBuilder for _, p := range prefixes { diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 45bc2dc7..2cb468d7 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -10,6 +10,7 @@ import ( policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" + "github.com/oauth2-proxy/mockoidc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" @@ -20,7 +21,8 @@ func isSSHNoAccessStdError(stderr string) bool { // Since https://github.com/tailscale/tailscale/pull/14853 strings.Contains(stderr, "failed to evaluate SSH policy") || // Since https://github.com/tailscale/tailscale/pull/16127 - strings.Contains(stderr, "tailnet policy does not permit you to SSH to this node") + // Covers both "to this node" and "as user " variants. + strings.Contains(stderr, "tailnet policy does not permit you to SSH") } func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Scenario { @@ -420,15 +422,27 @@ func doSSHWithoutRetry(t *testing.T, client TailscaleClient, peer TailscaleClien func doSSHWithRetry(t *testing.T, client TailscaleClient, peer TailscaleClient, retry bool) (string, string, error) { t.Helper() + return doSSHWithRetryAsUser(t, client, peer, "ssh-it-user", retry) +} + +func doSSHWithRetryAsUser( + t *testing.T, + client TailscaleClient, + peer TailscaleClient, + sshUser string, + retry bool, +) (string, string, error) { + t.Helper() + peerFQDN, _ := peer.FQDN() command := []string{ "/usr/bin/ssh", "-o StrictHostKeyChecking=no", "-o ConnectTimeout=1", - fmt.Sprintf("%s@%s", "ssh-it-user", peerFQDN), + fmt.Sprintf("%s@%s", sshUser, peerFQDN), "'hostname'", } - log.Printf("Running from %s to %s", client.Hostname(), peer.Hostname()) + log.Printf("Running from %s to %s as %s", client.Hostname(), peer.Hostname(), sshUser) log.Printf("Command: %s", strings.Join(command, " ")) var ( @@ -499,6 +513,31 @@ func assertSSHNoAccessStdError(t *testing.T, err error, stderr string) { } } +func doSSHAsUser(t *testing.T, client TailscaleClient, peer TailscaleClient, sshUser string) (string, string, error) { + t.Helper() + + return doSSHWithRetryAsUser(t, client, peer, sshUser, true) +} + +func assertSSHHostnameAsUser(t *testing.T, client TailscaleClient, peer TailscaleClient, sshUser string) { + t.Helper() + + result, _, err := doSSHAsUser(t, client, peer, sshUser) + require.NoError(t, err) + + require.Contains(t, peer.ContainerID(), strings.ReplaceAll(result, "\n", "")) +} + +func assertSSHPermissionDeniedAsUser(t *testing.T, client TailscaleClient, peer TailscaleClient, sshUser string) { + t.Helper() + + result, stderr, err := doSSHWithRetryAsUser(t, client, peer, sshUser, false) + + assert.Empty(t, result) + + assertSSHNoAccessStdError(t, err, stderr) +} + // TestSSHAutogroupSelf tests that SSH with autogroup:self works correctly: // - Users can SSH to their own devices // - Users cannot SSH to other users' devices. @@ -579,3 +618,234 @@ func TestSSHAutogroupSelf(t *testing.T) { } } } + +// TestSSHLocalpart tests that SSH with localpart:*@ works correctly. +// localpart maps the local-part of each user's OIDC email to an OS user, +// so user1@headscale.net can SSH as local user "user1". +// This requires OIDC login so that users have real email addresses. +func TestSSHLocalpart(t *testing.T) { + IntegrationSkip(t) + + baseACLs := []policyv2.ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + } + + tests := []struct { + name string + policy *policyv2.Policy + testFn func(t *testing.T, scenario *Scenario) + }{ + { + name: "MemberAndTagged", + policy: &policyv2.Policy{ + ACLs: baseACLs, + SSHs: []policyv2.SSH{ + { + Action: "accept", + Sources: policyv2.SSHSrcAliases{new(policyv2.AutoGroupMember)}, + Destinations: policyv2.SSHDstAliases{ + new(policyv2.AutoGroupMember), + new(policyv2.AutoGroupTagged), + }, + Users: []policyv2.SSHUser{"localpart:*@headscale.net"}, + }, + }, + }, + testFn: func(t *testing.T, scenario *Scenario) { + t.Helper() + + user1Clients, err := scenario.ListTailscaleClients("user1") + requireNoErrListClients(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + requireNoErrListClients(t, err) + + // user1 can SSH to user2's nodes as "user1" (localpart of user1@headscale.net) + for _, client := range user1Clients { + for _, peer := range user2Clients { + assertSSHHostnameAsUser(t, client, peer, "user1") + } + } + + // user2 can SSH to user1's nodes as "user2" (localpart of user2@headscale.net) + for _, client := range user2Clients { + for _, peer := range user1Clients { + assertSSHHostnameAsUser(t, client, peer, "user2") + } + } + + // user1 CANNOT SSH as "user2" — no rule maps user1's IPs to user2 + for _, client := range user1Clients { + for _, peer := range user2Clients { + assertSSHPermissionDeniedAsUser(t, client, peer, "user2") + } + } + + // user2 CANNOT SSH as "user1" — no rule maps user2's IPs to user1 + for _, client := range user2Clients { + for _, peer := range user1Clients { + assertSSHPermissionDeniedAsUser(t, client, peer, "user1") + } + } + }, + }, + { + name: "AutogroupSelf", + policy: &policyv2.Policy{ + ACLs: baseACLs, + SSHs: []policyv2.SSH{ + { + Action: "accept", + Sources: policyv2.SSHSrcAliases{new(policyv2.AutoGroupMember)}, + Destinations: policyv2.SSHDstAliases{new(policyv2.AutoGroupSelf)}, + Users: []policyv2.SSHUser{"localpart:*@headscale.net"}, + }, + }, + }, + testFn: func(t *testing.T, scenario *Scenario) { + t.Helper() + + user1Clients, err := scenario.ListTailscaleClients("user1") + requireNoErrListClients(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + requireNoErrListClients(t, err) + + // With autogroup:self, cross-user SSH should be denied regardless of localpart. + // user1 cannot SSH to user2's nodes as "user1" + for _, client := range user1Clients { + for _, peer := range user2Clients { + assertSSHPermissionDeniedAsUser(t, client, peer, "user1") + } + } + + // user2 cannot SSH to user1's nodes as "user2" + for _, client := range user2Clients { + for _, peer := range user1Clients { + assertSSHPermissionDeniedAsUser(t, client, peer, "user2") + } + } + + // user1 also cannot SSH to user2's nodes as "user2" + for _, client := range user1Clients { + for _, peer := range user2Clients { + assertSSHPermissionDeniedAsUser(t, client, peer, "user2") + } + } + }, + }, + { + name: "LocalpartPlusRoot", + policy: &policyv2.Policy{ + ACLs: baseACLs, + SSHs: []policyv2.SSH{ + { + Action: "accept", + Sources: policyv2.SSHSrcAliases{new(policyv2.AutoGroupMember)}, + Destinations: policyv2.SSHDstAliases{ + new(policyv2.AutoGroupMember), + new(policyv2.AutoGroupTagged), + }, + Users: []policyv2.SSHUser{ + "localpart:*@headscale.net", + "root", + }, + }, + }, + }, + testFn: func(t *testing.T, scenario *Scenario) { + t.Helper() + + user1Clients, err := scenario.ListTailscaleClients("user1") + requireNoErrListClients(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + requireNoErrListClients(t, err) + + // localpart works: user1 can SSH to user2's nodes as "user1" + for _, client := range user1Clients { + for _, peer := range user2Clients { + assertSSHHostnameAsUser(t, client, peer, "user1") + } + } + + // root also works: user1 can SSH to user2's nodes as "root" + for _, client := range user1Clients { + for _, peer := range user2Clients { + assertSSHHostnameAsUser(t, client, peer, "root") + } + } + + // user2 can SSH as "user2" (localpart) + for _, client := range user2Clients { + for _, peer := range user1Clients { + assertSSHHostnameAsUser(t, client, peer, "user2") + } + } + + // user2 can SSH as "root" + for _, client := range user2Clients { + for _, peer := range user1Clients { + assertSSHHostnameAsUser(t, client, peer, "root") + } + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + oidcMockUser("user2", true), + }, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + } + + err = scenario.CreateHeadscaleEnvWithLoginURL( + []tsic.Option{ + tsic.WithSSH(), + tsic.WithNetfilter("off"), + tsic.WithPackages("openssh"), + tsic.WithExtraCommands("adduser user1", "adduser user2"), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithTestName("sshlocalpart"), + hsic.WithACLPolicy(tt.policy), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), + ) + requireNoErrHeadscaleEnv(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + requireNoErrListFQDN(t, err) + + tt.testFn(t, scenario) + }) + } +}