1
0
mirror of https://github.com/juanfont/headscale.git synced 2026-02-07 20:04:00 +01:00

policy/v2: compile localpart rules into per-user SSHRules

Implement compilation of localpart:*@<domain> entries into per-user
tailcfg.SSHRule instances. Since tailcfg.SSHRule.SSHUsers is a static
map[string]string, localpart must be resolved at compile time into
individual rules where each matching user gets their own rule with
their email local-part in the SSHUsers map and only their node IPs
as principals.

- Add resolveLocalpartRules helper that generates per-user SSH rules
- Modify compileSSHPolicy to separate common userMap from localpart
  rules and emit both through autogroup:self and other destination paths
- Handle autogroup:self filtering for localpart rules
- Skip users with empty or invalid emails
- Add compilation unit tests and end-to-end policy manager tests

Updates #3049
This commit is contained in:
Kristoffer Dalby 2026-02-18 09:47:04 +00:00
parent 537f3b1fa1
commit a71bbd8d79
3 changed files with 593 additions and 29 deletions

View File

@ -1487,6 +1487,146 @@ func TestSSHPolicyRules(t *testing.T) {
}
}
func TestSSHPolicyRules_Localpart(t *testing.T) {
users := []types.User{
{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}},
{Name: "bob", Email: "bob@example.com", Model: gorm.Model{ID: 2}},
}
nodeAlice := types.Node{
Hostname: "alice-device",
IPv4: ap("100.64.0.1"),
UserID: new(uint(1)),
User: new(users[0]),
}
nodeBob := types.Node{
Hostname: "bob-device",
IPv4: ap("100.64.0.2"),
UserID: new(uint(2)),
User: new(users[1]),
}
nodeTaggedServer := types.Node{
Hostname: "tagged-server",
IPv4: ap("100.64.0.5"),
UserID: new(uint(1)),
User: new(users[0]),
Tags: []string{"tag:server"},
}
tests := []struct {
name string
targetNode types.Node
peers types.Nodes
policy string
validate func(t *testing.T, got *tailcfg.SSHPolicy)
}{
{
name: "localpart-maps-email-to-os-user",
targetNode: nodeTaggedServer,
peers: types.Nodes{&nodeAlice, &nodeBob},
policy: `{
"tagOwners": {
"tag:server": ["alice@example.com"]
},
"ssh": [
{
"action": "accept",
"src": ["autogroup:member"],
"dst": ["tag:server"],
"users": ["localpart:*@example.com"]
}
]
}`,
validate: func(t *testing.T, got *tailcfg.SSHPolicy) {
t.Helper()
require.NotNil(t, got)
require.Len(t, got.Rules, 2, "Should have per-user rules for alice and bob")
foundAlice := false
foundBob := false
for _, rule := range got.Rules {
if _, ok := rule.SSHUsers["alice"]; ok {
foundAlice = true
assert.Equal(t, map[string]string{"alice": "alice"}, rule.SSHUsers)
require.Len(t, rule.Principals, 1)
assert.Equal(t, "100.64.0.1", rule.Principals[0].NodeIP)
assert.True(t, rule.Action.Accept)
}
if _, ok := rule.SSHUsers["bob"]; ok {
foundBob = true
assert.Equal(t, map[string]string{"bob": "bob"}, rule.SSHUsers)
require.Len(t, rule.Principals, 1)
assert.Equal(t, "100.64.0.2", rule.Principals[0].NodeIP)
}
}
assert.True(t, foundAlice, "Should have alice's localpart rule")
assert.True(t, foundBob, "Should have bob's localpart rule")
},
},
{
name: "localpart-combined-with-root",
targetNode: nodeTaggedServer,
peers: types.Nodes{&nodeAlice},
policy: `{
"tagOwners": {
"tag:server": ["alice@example.com"]
},
"ssh": [
{
"action": "accept",
"src": ["autogroup:member"],
"dst": ["tag:server"],
"users": ["localpart:*@example.com", "root"]
}
]
}`,
validate: func(t *testing.T, got *tailcfg.SSHPolicy) {
t.Helper()
require.NotNil(t, got)
// 1 common rule for root + per-user localpart rules
require.GreaterOrEqual(t, len(got.Rules), 2,
"Should have common root rule and per-user localpart rules")
foundRoot := false
foundAliceWithRoot := false
for _, rule := range got.Rules {
if v, ok := rule.SSHUsers["root"]; ok && v == "root" {
foundRoot = true
}
if _, hasAlice := rule.SSHUsers["alice"]; hasAlice {
if rule.SSHUsers["root"] == "root" {
foundAliceWithRoot = true
}
}
}
assert.True(t, foundRoot, "Should have root mapping")
assert.True(t, foundAliceWithRoot, "Alice's localpart rule should merge root")
},
},
}
for _, tt := range tests {
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
pm, err := pmf(users, append(tt.peers, &tt.targetNode).ViewSlice())
require.NoError(t, err)
got, err := pm.SSHPolicy(tt.targetNode.View())
require.NoError(t, err)
tt.validate(t, got)
})
}
}
}
func TestReduceRoutes(t *testing.T) {
type args struct {
node *types.Node

View File

@ -3,6 +3,7 @@ package v2
import (
"errors"
"fmt"
"maps"
"net/netip"
"slices"
"strconv"
@ -384,21 +385,40 @@ func (pol *Policy) compileSSHPolicy(
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
}
userMap := make(map[string]string, len(rule.Users))
// Build the "common" userMap for non-localpart entries (root, autogroup:nonroot, specific users).
const rootUser = "root"
commonUserMap := make(map[string]string, len(rule.Users))
if rule.Users.ContainsNonRoot() {
userMap["*"] = "="
commonUserMap["*"] = "="
// by default, we do not allow root unless explicitly stated
userMap["root"] = ""
commonUserMap[rootUser] = ""
}
if rule.Users.ContainsRoot() {
userMap["root"] = "root"
commonUserMap[rootUser] = rootUser
}
for _, u := range rule.Users.NormalUsers() {
userMap[u.String()] = u.String()
commonUserMap[u.String()] = u.String()
}
// Resolve localpart entries into per-user rules.
// Each localpart:*@<domain> entry maps users in that domain to their email local-part.
// Because SSHUsers is a static map per rule, we need a separate rule per user
// to constrain each user to only their own local-part.
localpartRules := resolveLocalpartRules(
rule.Users.LocalpartEntries(),
users,
nodes,
srcIPs,
commonUserMap,
&action,
)
// Determine whether the common userMap has any entries worth emitting.
hasCommonUsers := len(commonUserMap) > 0
// Handle autogroup:self destinations (if any)
// Note: Tagged nodes can't match autogroup:self, so skip this block for tagged nodes
if len(autogroupSelfDests) > 0 && !node.IsTagged() {
@ -443,19 +463,46 @@ func (pol *Policy) compileSSHPolicy(
}
if filteredSrcSet != nil && len(filteredSrcSet.Prefixes()) > 0 {
var principals []*tailcfg.SSHPrincipal
for addr := range util.IPSetAddrIter(filteredSrcSet) {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: addr.String(),
})
// Emit common rule if there are non-localpart users
if hasCommonUsers {
var principals []*tailcfg.SSHPrincipal
for addr := range util.IPSetAddrIter(filteredSrcSet) {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: addr.String(),
})
}
if len(principals) > 0 {
rules = append(rules, &tailcfg.SSHRule{
Principals: principals,
SSHUsers: commonUserMap,
Action: &action,
})
}
}
if len(principals) > 0 {
rules = append(rules, &tailcfg.SSHRule{
Principals: principals,
SSHUsers: userMap,
Action: &action,
})
// Emit per-user localpart rules, filtered to autogroup:self sources
for _, lpRule := range localpartRules {
var filteredPrincipals []*tailcfg.SSHPrincipal
for _, p := range lpRule.Principals {
addr, err := netip.ParseAddr(p.NodeIP)
if err != nil {
continue
}
if filteredSrcSet.Contains(addr) {
filteredPrincipals = append(filteredPrincipals, p)
}
}
if len(filteredPrincipals) > 0 {
rules = append(rules, &tailcfg.SSHRule{
Principals: filteredPrincipals,
SSHUsers: lpRule.SSHUsers,
Action: lpRule.Action,
})
}
}
}
}
@ -484,21 +531,27 @@ func (pol *Policy) compileSSHPolicy(
// Only create rule if this node is in the destination set
if node.InIPSet(destSet) {
// For non-autogroup:self destinations, use all resolved sources (no filtering)
var principals []*tailcfg.SSHPrincipal
for addr := range util.IPSetAddrIter(srcIPs) {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: addr.String(),
})
// Emit common rule if there are non-localpart users
if hasCommonUsers {
// For non-autogroup:self destinations, use all resolved sources (no filtering)
var principals []*tailcfg.SSHPrincipal
for addr := range util.IPSetAddrIter(srcIPs) {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: addr.String(),
})
}
if len(principals) > 0 {
rules = append(rules, &tailcfg.SSHRule{
Principals: principals,
SSHUsers: commonUserMap,
Action: &action,
})
}
}
if len(principals) > 0 {
rules = append(rules, &tailcfg.SSHRule{
Principals: principals,
SSHUsers: userMap,
Action: &action,
})
}
// Emit per-user localpart rules
rules = append(rules, localpartRules...)
}
}
}
@ -508,6 +561,103 @@ func (pol *Policy) compileSSHPolicy(
}, nil
}
// resolveLocalpartRules generates per-user SSH rules for localpart:*@<domain> entries.
// For each localpart entry, it finds all users whose email is in the specified domain,
// extracts their email local-part, and creates a tailcfg.SSHRule scoped to that user's
// node IPs with an SSHUsers map that only allows their local-part.
// The commonUserMap entries (root, autogroup:nonroot, specific users) are merged into
// each per-user rule so that localpart rules compose with other user entries.
func resolveLocalpartRules(
localpartEntries []SSHUser,
users types.Users,
nodes views.Slice[types.NodeView],
srcIPs *netipx.IPSet,
commonUserMap map[string]string,
action *tailcfg.SSHAction,
) []*tailcfg.SSHRule {
if len(localpartEntries) == 0 {
return nil
}
var rules []*tailcfg.SSHRule
for _, entry := range localpartEntries {
domain, err := entry.ParseLocalpart()
if err != nil {
// Should not happen if validation passed, but skip gracefully.
log.Warn().Err(err).Msgf("skipping invalid localpart entry %q during SSH compilation", entry)
continue
}
// Find users whose email matches *@<domain> and build per-user rules.
for _, user := range users {
if user.Email == "" {
continue
}
atIdx := strings.LastIndex(user.Email, "@")
if atIdx < 0 {
continue
}
emailDomain := user.Email[atIdx+1:]
if !strings.EqualFold(emailDomain, domain) {
continue
}
localPart := user.Email[:atIdx]
// Find this user's non-tagged nodes that are in the source IP set.
var userSrcIPs netipx.IPSetBuilder
for _, n := range nodes.All() {
if n.IsTagged() {
continue
}
if !n.User().Valid() || n.User().ID() != user.ID {
continue
}
if slices.ContainsFunc(n.IPs(), srcIPs.Contains) {
n.AppendToIPSet(&userSrcIPs)
}
}
userSrcSet, err := userSrcIPs.IPSet()
if err != nil || userSrcSet == nil || len(userSrcSet.Prefixes()) == 0 {
continue
}
var principals []*tailcfg.SSHPrincipal
for addr := range util.IPSetAddrIter(userSrcSet) {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: addr.String(),
})
}
if len(principals) == 0 {
continue
}
// Build per-user SSHUsers map: start with the common entries, then add the localpart.
userMap := make(map[string]string, len(commonUserMap)+1)
maps.Copy(userMap, commonUserMap)
userMap[localPart] = localPart
rules = append(rules, &tailcfg.SSHRule{
Principals: principals,
SSHUsers: userMap,
Action: action,
})
}
}
return rules
}
func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
var out []string

View File

@ -647,6 +647,280 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
}
}
func TestCompileSSHPolicy_LocalpartMapping(t *testing.T) {
users := types.Users{
{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}},
{Name: "bob", Email: "bob@example.com", Model: gorm.Model{ID: 2}},
{Name: "charlie", Email: "charlie@other.com", Model: gorm.Model{ID: 3}},
{Name: "dave", Model: gorm.Model{ID: 4}}, // CLI user, no email
}
nodeTaggedServer := types.Node{
Hostname: "tagged-server",
IPv4: createAddr("100.64.0.1"),
UserID: new(users[0].ID),
User: new(users[0]),
Tags: []string{"tag:server"},
}
nodeAlice := types.Node{
Hostname: "alice-device",
IPv4: createAddr("100.64.0.2"),
UserID: new(users[0].ID),
User: new(users[0]),
}
nodeBob := types.Node{
Hostname: "bob-device",
IPv4: createAddr("100.64.0.3"),
UserID: new(users[1].ID),
User: new(users[1]),
}
nodeCharlie := types.Node{
Hostname: "charlie-device",
IPv4: createAddr("100.64.0.4"),
UserID: new(users[2].ID),
User: new(users[2]),
}
nodeDave := types.Node{
Hostname: "dave-device",
IPv4: createAddr("100.64.0.5"),
UserID: new(users[3].ID),
User: new(users[3]),
}
nodes := types.Nodes{&nodeTaggedServer, &nodeAlice, &nodeBob, &nodeCharlie, &nodeDave}
t.Run("localpart only", func(t *testing.T) {
policy := &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{up("alice@example.com")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:server")},
Users: []SSHUser{SSHUser("localpart:*@example.com")},
},
},
}
require.NoError(t, policy.validate())
sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
// Should get 2 rules: one for alice, one for bob (both @example.com)
// charlie@other.com and dave (no email) should not match
require.Len(t, sshPolicy.Rules, 2, "Should have per-user rules for alice and bob")
// Collect all rules' SSHUsers and principals
foundAlice := false
foundBob := false
for _, rule := range sshPolicy.Rules {
if _, ok := rule.SSHUsers["alice"]; ok {
foundAlice = true
assert.Equal(t, map[string]string{"alice": "alice"}, rule.SSHUsers)
require.Len(t, rule.Principals, 1)
assert.Equal(t, "100.64.0.2", rule.Principals[0].NodeIP)
}
if _, ok := rule.SSHUsers["bob"]; ok {
foundBob = true
assert.Equal(t, map[string]string{"bob": "bob"}, rule.SSHUsers)
require.Len(t, rule.Principals, 1)
assert.Equal(t, "100.64.0.3", rule.Principals[0].NodeIP)
}
}
assert.True(t, foundAlice, "Should have a rule for alice's localpart")
assert.True(t, foundBob, "Should have a rule for bob's localpart")
})
t.Run("localpart with root", func(t *testing.T) {
policy := &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{up("alice@example.com")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:server")},
Users: []SSHUser{SSHUser("localpart:*@example.com"), "root"},
},
},
}
require.NoError(t, policy.validate())
sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
// Should get 3 rules: common rule for root + per-user for alice + per-user for bob
require.Len(t, sshPolicy.Rules, 3, "Should have common root rule + per-user localpart rules")
// Find the common rule (has root but no per-user localpart)
foundCommon := false
for _, rule := range sshPolicy.Rules {
if v, ok := rule.SSHUsers["root"]; ok && v == "root" && len(rule.SSHUsers) == 1 {
foundCommon = true
}
}
assert.True(t, foundCommon, "Should have a common rule with root mapping")
// Per-user rules should also include root
for _, rule := range sshPolicy.Rules {
if _, hasAlice := rule.SSHUsers["alice"]; hasAlice {
assert.Equal(t, "root", rule.SSHUsers["root"], "Alice's localpart rule should merge root")
}
if _, hasBob := rule.SSHUsers["bob"]; hasBob {
assert.Equal(t, "root", rule.SSHUsers["root"], "Bob's localpart rule should merge root")
}
}
})
t.Run("localpart no matching users in domain", func(t *testing.T) {
policy := &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{up("alice@example.com")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:server")},
Users: []SSHUser{SSHUser("localpart:*@nonexistent.com")},
},
},
}
require.NoError(t, policy.validate())
sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
assert.Empty(t, sshPolicy.Rules, "Should have no rules when no users match domain")
})
t.Run("localpart with special chars in email", func(t *testing.T) {
specialUsers := types.Users{
{Name: "dave+sshuser", Email: "dave+sshuser@example.com", Model: gorm.Model{ID: 10}},
}
nodeSpecial := types.Node{
Hostname: "special-device",
IPv4: createAddr("100.64.0.10"),
UserID: new(specialUsers[0].ID),
User: new(specialUsers[0]),
}
specialNodes := types.Nodes{&nodeTaggedServer, &nodeSpecial}
policy := &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{up("dave+sshuser@example.com")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:server")},
Users: []SSHUser{SSHUser("localpart:*@example.com")},
},
},
}
require.NoError(t, policy.validate())
sshPolicy, err := policy.compileSSHPolicy(specialUsers, nodeTaggedServer.View(), specialNodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
rule := sshPolicy.Rules[0]
// Per Tailscale docs: "if the login is dave+sshuser@example.com,
// Tailscale will map this to the SSH user dave+sshuser"
assert.Equal(t, map[string]string{"dave+sshuser": "dave+sshuser"}, rule.SSHUsers)
})
t.Run("localpart excludes CLI users without email", func(t *testing.T) {
// dave has no email, should be excluded from localpart matching
cliOnlyUsers := types.Users{
{Name: "dave", Model: gorm.Model{ID: 4}},
}
nodeDaveCli := types.Node{
Hostname: "dave-cli-device",
IPv4: createAddr("100.64.0.5"),
UserID: new(cliOnlyUsers[0].ID),
User: new(cliOnlyUsers[0]),
}
cliNodes := types.Nodes{&nodeTaggedServer, &nodeDaveCli}
policy := &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{up("dave@")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:server")},
Users: []SSHUser{SSHUser("localpart:*@example.com")},
},
},
}
require.NoError(t, policy.validate())
sshPolicy, err := policy.compileSSHPolicy(cliOnlyUsers, nodeTaggedServer.View(), cliNodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
assert.Empty(t, sshPolicy.Rules, "CLI users without email should not match localpart rules")
})
t.Run("localpart with multiple domains", func(t *testing.T) {
policy := &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{up("alice@example.com")},
},
SSHs: []SSH{
{
Action: "accept",
Sources: SSHSrcAliases{agp("autogroup:member")},
Destinations: SSHDstAliases{tp("tag:server")},
Users: []SSHUser{
SSHUser("localpart:*@example.com"),
SSHUser("localpart:*@other.com"),
},
},
},
}
require.NoError(t, policy.validate())
sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
// alice@example.com, bob@example.com, charlie@other.com should all match
require.Len(t, sshPolicy.Rules, 3, "Should have rules for alice, bob, and charlie")
localparts := make(map[string]bool)
for _, rule := range sshPolicy.Rules {
for k := range rule.SSHUsers {
localparts[k] = true
}
}
assert.True(t, localparts["alice"], "Should have alice's localpart rule")
assert.True(t, localparts["bob"], "Should have bob's localpart rule")
assert.True(t, localparts["charlie"], "Should have charlie's localpart rule")
})
}
func TestCompileSSHPolicy_CheckAction(t *testing.T) {
users := types.Users{
{Name: "user1", Model: gorm.Model{ID: 1}},