mirror of
https://github.com/juanfont/headscale.git
synced 2026-02-07 20:04:00 +01:00
Merge b235da08be into 13ebea192c
This commit is contained in:
commit
841cf72cc0
1
.github/workflows/test-integration.yaml
vendored
1
.github/workflows/test-integration.yaml
vendored
@ -253,6 +253,7 @@ jobs:
|
||||
- TestSSHIsBlockedInACL
|
||||
- TestSSHUserOnlyIsolation
|
||||
- TestSSHAutogroupSelf
|
||||
- TestSSHLocalpart
|
||||
- TestTagsAuthKeyWithTagRequestDifferentTag
|
||||
- TestTagsAuthKeyWithTagNoAdvertiseFlag
|
||||
- TestTagsAuthKeyWithTagCannotAddViaCLI
|
||||
|
||||
@ -1487,6 +1487,146 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHPolicyRules_Localpart(t *testing.T) {
|
||||
users := []types.User{
|
||||
{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}},
|
||||
{Name: "bob", Email: "bob@example.com", Model: gorm.Model{ID: 2}},
|
||||
}
|
||||
|
||||
nodeAlice := types.Node{
|
||||
Hostname: "alice-device",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
UserID: new(uint(1)),
|
||||
User: new(users[0]),
|
||||
}
|
||||
nodeBob := types.Node{
|
||||
Hostname: "bob-device",
|
||||
IPv4: ap("100.64.0.2"),
|
||||
UserID: new(uint(2)),
|
||||
User: new(users[1]),
|
||||
}
|
||||
nodeTaggedServer := types.Node{
|
||||
Hostname: "tagged-server",
|
||||
IPv4: ap("100.64.0.5"),
|
||||
UserID: new(uint(1)),
|
||||
User: new(users[0]),
|
||||
Tags: []string{"tag:server"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
targetNode types.Node
|
||||
peers types.Nodes
|
||||
policy string
|
||||
validate func(t *testing.T, got *tailcfg.SSHPolicy)
|
||||
}{
|
||||
{
|
||||
name: "localpart-maps-email-to-os-user",
|
||||
targetNode: nodeTaggedServer,
|
||||
peers: types.Nodes{&nodeAlice, &nodeBob},
|
||||
policy: `{
|
||||
"tagOwners": {
|
||||
"tag:server": ["alice@example.com"]
|
||||
},
|
||||
"ssh": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["autogroup:member"],
|
||||
"dst": ["tag:server"],
|
||||
"users": ["localpart:*@example.com"]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
validate: func(t *testing.T, got *tailcfg.SSHPolicy) {
|
||||
t.Helper()
|
||||
require.NotNil(t, got)
|
||||
require.Len(t, got.Rules, 2, "Should have per-user rules for alice and bob")
|
||||
|
||||
foundAlice := false
|
||||
foundBob := false
|
||||
|
||||
for _, rule := range got.Rules {
|
||||
if _, ok := rule.SSHUsers["alice"]; ok {
|
||||
foundAlice = true
|
||||
|
||||
assert.Equal(t, map[string]string{"alice": "alice"}, rule.SSHUsers)
|
||||
require.Len(t, rule.Principals, 1)
|
||||
assert.Equal(t, "100.64.0.1", rule.Principals[0].NodeIP)
|
||||
assert.True(t, rule.Action.Accept)
|
||||
}
|
||||
|
||||
if _, ok := rule.SSHUsers["bob"]; ok {
|
||||
foundBob = true
|
||||
|
||||
assert.Equal(t, map[string]string{"bob": "bob"}, rule.SSHUsers)
|
||||
require.Len(t, rule.Principals, 1)
|
||||
assert.Equal(t, "100.64.0.2", rule.Principals[0].NodeIP)
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, foundAlice, "Should have alice's localpart rule")
|
||||
assert.True(t, foundBob, "Should have bob's localpart rule")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "localpart-combined-with-root",
|
||||
targetNode: nodeTaggedServer,
|
||||
peers: types.Nodes{&nodeAlice},
|
||||
policy: `{
|
||||
"tagOwners": {
|
||||
"tag:server": ["alice@example.com"]
|
||||
},
|
||||
"ssh": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["autogroup:member"],
|
||||
"dst": ["tag:server"],
|
||||
"users": ["localpart:*@example.com", "root"]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
validate: func(t *testing.T, got *tailcfg.SSHPolicy) {
|
||||
t.Helper()
|
||||
require.NotNil(t, got)
|
||||
// 1 common rule for root + per-user localpart rules
|
||||
require.GreaterOrEqual(t, len(got.Rules), 2,
|
||||
"Should have common root rule and per-user localpart rules")
|
||||
|
||||
foundRoot := false
|
||||
foundAliceWithRoot := false
|
||||
|
||||
for _, rule := range got.Rules {
|
||||
if v, ok := rule.SSHUsers["root"]; ok && v == "root" {
|
||||
foundRoot = true
|
||||
}
|
||||
|
||||
if _, hasAlice := rule.SSHUsers["alice"]; hasAlice {
|
||||
if rule.SSHUsers["root"] == "root" {
|
||||
foundAliceWithRoot = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, foundRoot, "Should have root mapping")
|
||||
assert.True(t, foundAliceWithRoot, "Alice's localpart rule should merge root")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
|
||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||
pm, err := pmf(users, append(tt.peers, &tt.targetNode).ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := pm.SSHPolicy(tt.targetNode.View())
|
||||
require.NoError(t, err)
|
||||
tt.validate(t, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReduceRoutes(t *testing.T) {
|
||||
type args struct {
|
||||
node *types.Node
|
||||
|
||||
@ -3,6 +3,7 @@ package v2
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
@ -384,21 +385,40 @@ func (pol *Policy) compileSSHPolicy(
|
||||
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
|
||||
}
|
||||
|
||||
userMap := make(map[string]string, len(rule.Users))
|
||||
// Build the "common" userMap for non-localpart entries (root, autogroup:nonroot, specific users).
|
||||
const rootUser = "root"
|
||||
|
||||
commonUserMap := make(map[string]string, len(rule.Users))
|
||||
if rule.Users.ContainsNonRoot() {
|
||||
userMap["*"] = "="
|
||||
commonUserMap["*"] = "="
|
||||
// by default, we do not allow root unless explicitly stated
|
||||
userMap["root"] = ""
|
||||
commonUserMap[rootUser] = ""
|
||||
}
|
||||
|
||||
if rule.Users.ContainsRoot() {
|
||||
userMap["root"] = "root"
|
||||
commonUserMap[rootUser] = rootUser
|
||||
}
|
||||
|
||||
for _, u := range rule.Users.NormalUsers() {
|
||||
userMap[u.String()] = u.String()
|
||||
commonUserMap[u.String()] = u.String()
|
||||
}
|
||||
|
||||
// Resolve localpart entries into per-user rules.
|
||||
// Each localpart:*@<domain> entry maps users in that domain to their email local-part.
|
||||
// Because SSHUsers is a static map per rule, we need a separate rule per user
|
||||
// to constrain each user to only their own local-part.
|
||||
localpartRules := resolveLocalpartRules(
|
||||
rule.Users.LocalpartEntries(),
|
||||
users,
|
||||
nodes,
|
||||
srcIPs,
|
||||
commonUserMap,
|
||||
&action,
|
||||
)
|
||||
|
||||
// Determine whether the common userMap has any entries worth emitting.
|
||||
hasCommonUsers := len(commonUserMap) > 0
|
||||
|
||||
// Handle autogroup:self destinations (if any)
|
||||
// Note: Tagged nodes can't match autogroup:self, so skip this block for tagged nodes
|
||||
if len(autogroupSelfDests) > 0 && !node.IsTagged() {
|
||||
@ -443,19 +463,46 @@ func (pol *Policy) compileSSHPolicy(
|
||||
}
|
||||
|
||||
if filteredSrcSet != nil && len(filteredSrcSet.Prefixes()) > 0 {
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
for addr := range util.IPSetAddrIter(filteredSrcSet) {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
NodeIP: addr.String(),
|
||||
})
|
||||
// Emit common rule if there are non-localpart users
|
||||
if hasCommonUsers {
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
for addr := range util.IPSetAddrIter(filteredSrcSet) {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
NodeIP: addr.String(),
|
||||
})
|
||||
}
|
||||
|
||||
if len(principals) > 0 {
|
||||
rules = append(rules, &tailcfg.SSHRule{
|
||||
Principals: principals,
|
||||
SSHUsers: commonUserMap,
|
||||
Action: &action,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(principals) > 0 {
|
||||
rules = append(rules, &tailcfg.SSHRule{
|
||||
Principals: principals,
|
||||
SSHUsers: userMap,
|
||||
Action: &action,
|
||||
})
|
||||
// Emit per-user localpart rules, filtered to autogroup:self sources
|
||||
for _, lpRule := range localpartRules {
|
||||
var filteredPrincipals []*tailcfg.SSHPrincipal
|
||||
|
||||
for _, p := range lpRule.Principals {
|
||||
addr, err := netip.ParseAddr(p.NodeIP)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if filteredSrcSet.Contains(addr) {
|
||||
filteredPrincipals = append(filteredPrincipals, p)
|
||||
}
|
||||
}
|
||||
|
||||
if len(filteredPrincipals) > 0 {
|
||||
rules = append(rules, &tailcfg.SSHRule{
|
||||
Principals: filteredPrincipals,
|
||||
SSHUsers: lpRule.SSHUsers,
|
||||
Action: lpRule.Action,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -484,21 +531,27 @@ func (pol *Policy) compileSSHPolicy(
|
||||
|
||||
// Only create rule if this node is in the destination set
|
||||
if node.InIPSet(destSet) {
|
||||
// For non-autogroup:self destinations, use all resolved sources (no filtering)
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
for addr := range util.IPSetAddrIter(srcIPs) {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
NodeIP: addr.String(),
|
||||
})
|
||||
// Emit common rule if there are non-localpart users
|
||||
if hasCommonUsers {
|
||||
// For non-autogroup:self destinations, use all resolved sources (no filtering)
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
for addr := range util.IPSetAddrIter(srcIPs) {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
NodeIP: addr.String(),
|
||||
})
|
||||
}
|
||||
|
||||
if len(principals) > 0 {
|
||||
rules = append(rules, &tailcfg.SSHRule{
|
||||
Principals: principals,
|
||||
SSHUsers: commonUserMap,
|
||||
Action: &action,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(principals) > 0 {
|
||||
rules = append(rules, &tailcfg.SSHRule{
|
||||
Principals: principals,
|
||||
SSHUsers: userMap,
|
||||
Action: &action,
|
||||
})
|
||||
}
|
||||
// Emit per-user localpart rules
|
||||
rules = append(rules, localpartRules...)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -508,6 +561,103 @@ func (pol *Policy) compileSSHPolicy(
|
||||
}, nil
|
||||
}
|
||||
|
||||
// resolveLocalpartRules generates per-user SSH rules for localpart:*@<domain> entries.
|
||||
// For each localpart entry, it finds all users whose email is in the specified domain,
|
||||
// extracts their email local-part, and creates a tailcfg.SSHRule scoped to that user's
|
||||
// node IPs with an SSHUsers map that only allows their local-part.
|
||||
// The commonUserMap entries (root, autogroup:nonroot, specific users) are merged into
|
||||
// each per-user rule so that localpart rules compose with other user entries.
|
||||
func resolveLocalpartRules(
|
||||
localpartEntries []SSHUser,
|
||||
users types.Users,
|
||||
nodes views.Slice[types.NodeView],
|
||||
srcIPs *netipx.IPSet,
|
||||
commonUserMap map[string]string,
|
||||
action *tailcfg.SSHAction,
|
||||
) []*tailcfg.SSHRule {
|
||||
if len(localpartEntries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var rules []*tailcfg.SSHRule
|
||||
|
||||
for _, entry := range localpartEntries {
|
||||
domain, err := entry.ParseLocalpart()
|
||||
if err != nil {
|
||||
// Should not happen if validation passed, but skip gracefully.
|
||||
log.Warn().Err(err).Msgf("skipping invalid localpart entry %q during SSH compilation", entry)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Find users whose email matches *@<domain> and build per-user rules.
|
||||
for _, user := range users {
|
||||
if user.Email == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
atIdx := strings.LastIndex(user.Email, "@")
|
||||
if atIdx < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
emailDomain := user.Email[atIdx+1:]
|
||||
if !strings.EqualFold(emailDomain, domain) {
|
||||
continue
|
||||
}
|
||||
|
||||
localPart := user.Email[:atIdx]
|
||||
|
||||
// Find this user's non-tagged nodes that are in the source IP set.
|
||||
var userSrcIPs netipx.IPSetBuilder
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if n.IsTagged() {
|
||||
continue
|
||||
}
|
||||
|
||||
if !n.User().Valid() || n.User().ID() != user.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
if slices.ContainsFunc(n.IPs(), srcIPs.Contains) {
|
||||
n.AppendToIPSet(&userSrcIPs)
|
||||
}
|
||||
}
|
||||
|
||||
userSrcSet, err := userSrcIPs.IPSet()
|
||||
if err != nil || userSrcSet == nil || len(userSrcSet.Prefixes()) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
for addr := range util.IPSetAddrIter(userSrcSet) {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
NodeIP: addr.String(),
|
||||
})
|
||||
}
|
||||
|
||||
if len(principals) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Build per-user SSHUsers map: start with the common entries, then add the localpart.
|
||||
userMap := make(map[string]string, len(commonUserMap)+1)
|
||||
maps.Copy(userMap, commonUserMap)
|
||||
|
||||
userMap[localPart] = localPart
|
||||
|
||||
rules = append(rules, &tailcfg.SSHRule{
|
||||
Principals: principals,
|
||||
SSHUsers: userMap,
|
||||
Action: action,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
|
||||
var out []string
|
||||
|
||||
|
||||
@ -647,6 +647,280 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileSSHPolicy_LocalpartMapping(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Name: "alice", Email: "alice@example.com", Model: gorm.Model{ID: 1}},
|
||||
{Name: "bob", Email: "bob@example.com", Model: gorm.Model{ID: 2}},
|
||||
{Name: "charlie", Email: "charlie@other.com", Model: gorm.Model{ID: 3}},
|
||||
{Name: "dave", Model: gorm.Model{ID: 4}}, // CLI user, no email
|
||||
}
|
||||
|
||||
nodeTaggedServer := types.Node{
|
||||
Hostname: "tagged-server",
|
||||
IPv4: createAddr("100.64.0.1"),
|
||||
UserID: new(users[0].ID),
|
||||
User: new(users[0]),
|
||||
Tags: []string{"tag:server"},
|
||||
}
|
||||
nodeAlice := types.Node{
|
||||
Hostname: "alice-device",
|
||||
IPv4: createAddr("100.64.0.2"),
|
||||
UserID: new(users[0].ID),
|
||||
User: new(users[0]),
|
||||
}
|
||||
nodeBob := types.Node{
|
||||
Hostname: "bob-device",
|
||||
IPv4: createAddr("100.64.0.3"),
|
||||
UserID: new(users[1].ID),
|
||||
User: new(users[1]),
|
||||
}
|
||||
nodeCharlie := types.Node{
|
||||
Hostname: "charlie-device",
|
||||
IPv4: createAddr("100.64.0.4"),
|
||||
UserID: new(users[2].ID),
|
||||
User: new(users[2]),
|
||||
}
|
||||
nodeDave := types.Node{
|
||||
Hostname: "dave-device",
|
||||
IPv4: createAddr("100.64.0.5"),
|
||||
UserID: new(users[3].ID),
|
||||
User: new(users[3]),
|
||||
}
|
||||
|
||||
nodes := types.Nodes{&nodeTaggedServer, &nodeAlice, &nodeBob, &nodeCharlie, &nodeDave}
|
||||
|
||||
t.Run("localpart only", func(t *testing.T) {
|
||||
policy := &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{up("alice@example.com")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{agp("autogroup:member")},
|
||||
Destinations: SSHDstAliases{tp("tag:server")},
|
||||
Users: []SSHUser{SSHUser("localpart:*@example.com")},
|
||||
},
|
||||
},
|
||||
}
|
||||
require.NoError(t, policy.validate())
|
||||
|
||||
sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy)
|
||||
|
||||
// Should get 2 rules: one for alice, one for bob (both @example.com)
|
||||
// charlie@other.com and dave (no email) should not match
|
||||
require.Len(t, sshPolicy.Rules, 2, "Should have per-user rules for alice and bob")
|
||||
|
||||
// Collect all rules' SSHUsers and principals
|
||||
foundAlice := false
|
||||
foundBob := false
|
||||
|
||||
for _, rule := range sshPolicy.Rules {
|
||||
if _, ok := rule.SSHUsers["alice"]; ok {
|
||||
foundAlice = true
|
||||
|
||||
assert.Equal(t, map[string]string{"alice": "alice"}, rule.SSHUsers)
|
||||
require.Len(t, rule.Principals, 1)
|
||||
assert.Equal(t, "100.64.0.2", rule.Principals[0].NodeIP)
|
||||
}
|
||||
|
||||
if _, ok := rule.SSHUsers["bob"]; ok {
|
||||
foundBob = true
|
||||
|
||||
assert.Equal(t, map[string]string{"bob": "bob"}, rule.SSHUsers)
|
||||
require.Len(t, rule.Principals, 1)
|
||||
assert.Equal(t, "100.64.0.3", rule.Principals[0].NodeIP)
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, foundAlice, "Should have a rule for alice's localpart")
|
||||
assert.True(t, foundBob, "Should have a rule for bob's localpart")
|
||||
})
|
||||
|
||||
t.Run("localpart with root", func(t *testing.T) {
|
||||
policy := &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{up("alice@example.com")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{agp("autogroup:member")},
|
||||
Destinations: SSHDstAliases{tp("tag:server")},
|
||||
Users: []SSHUser{SSHUser("localpart:*@example.com"), "root"},
|
||||
},
|
||||
},
|
||||
}
|
||||
require.NoError(t, policy.validate())
|
||||
|
||||
sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy)
|
||||
|
||||
// Should get 3 rules: common rule for root + per-user for alice + per-user for bob
|
||||
require.Len(t, sshPolicy.Rules, 3, "Should have common root rule + per-user localpart rules")
|
||||
|
||||
// Find the common rule (has root but no per-user localpart)
|
||||
foundCommon := false
|
||||
|
||||
for _, rule := range sshPolicy.Rules {
|
||||
if v, ok := rule.SSHUsers["root"]; ok && v == "root" && len(rule.SSHUsers) == 1 {
|
||||
foundCommon = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, foundCommon, "Should have a common rule with root mapping")
|
||||
|
||||
// Per-user rules should also include root
|
||||
for _, rule := range sshPolicy.Rules {
|
||||
if _, hasAlice := rule.SSHUsers["alice"]; hasAlice {
|
||||
assert.Equal(t, "root", rule.SSHUsers["root"], "Alice's localpart rule should merge root")
|
||||
}
|
||||
|
||||
if _, hasBob := rule.SSHUsers["bob"]; hasBob {
|
||||
assert.Equal(t, "root", rule.SSHUsers["root"], "Bob's localpart rule should merge root")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("localpart no matching users in domain", func(t *testing.T) {
|
||||
policy := &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{up("alice@example.com")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{agp("autogroup:member")},
|
||||
Destinations: SSHDstAliases{tp("tag:server")},
|
||||
Users: []SSHUser{SSHUser("localpart:*@nonexistent.com")},
|
||||
},
|
||||
},
|
||||
}
|
||||
require.NoError(t, policy.validate())
|
||||
|
||||
sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy)
|
||||
assert.Empty(t, sshPolicy.Rules, "Should have no rules when no users match domain")
|
||||
})
|
||||
|
||||
t.Run("localpart with special chars in email", func(t *testing.T) {
|
||||
specialUsers := types.Users{
|
||||
{Name: "dave+sshuser", Email: "dave+sshuser@example.com", Model: gorm.Model{ID: 10}},
|
||||
}
|
||||
nodeSpecial := types.Node{
|
||||
Hostname: "special-device",
|
||||
IPv4: createAddr("100.64.0.10"),
|
||||
UserID: new(specialUsers[0].ID),
|
||||
User: new(specialUsers[0]),
|
||||
}
|
||||
|
||||
specialNodes := types.Nodes{&nodeTaggedServer, &nodeSpecial}
|
||||
|
||||
policy := &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{up("dave+sshuser@example.com")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{agp("autogroup:member")},
|
||||
Destinations: SSHDstAliases{tp("tag:server")},
|
||||
Users: []SSHUser{SSHUser("localpart:*@example.com")},
|
||||
},
|
||||
},
|
||||
}
|
||||
require.NoError(t, policy.validate())
|
||||
|
||||
sshPolicy, err := policy.compileSSHPolicy(specialUsers, nodeTaggedServer.View(), specialNodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy)
|
||||
require.Len(t, sshPolicy.Rules, 1)
|
||||
|
||||
rule := sshPolicy.Rules[0]
|
||||
// Per Tailscale docs: "if the login is dave+sshuser@example.com,
|
||||
// Tailscale will map this to the SSH user dave+sshuser"
|
||||
assert.Equal(t, map[string]string{"dave+sshuser": "dave+sshuser"}, rule.SSHUsers)
|
||||
})
|
||||
|
||||
t.Run("localpart excludes CLI users without email", func(t *testing.T) {
|
||||
// dave has no email, should be excluded from localpart matching
|
||||
cliOnlyUsers := types.Users{
|
||||
{Name: "dave", Model: gorm.Model{ID: 4}},
|
||||
}
|
||||
nodeDaveCli := types.Node{
|
||||
Hostname: "dave-cli-device",
|
||||
IPv4: createAddr("100.64.0.5"),
|
||||
UserID: new(cliOnlyUsers[0].ID),
|
||||
User: new(cliOnlyUsers[0]),
|
||||
}
|
||||
|
||||
cliNodes := types.Nodes{&nodeTaggedServer, &nodeDaveCli}
|
||||
|
||||
policy := &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{up("dave@")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{agp("autogroup:member")},
|
||||
Destinations: SSHDstAliases{tp("tag:server")},
|
||||
Users: []SSHUser{SSHUser("localpart:*@example.com")},
|
||||
},
|
||||
},
|
||||
}
|
||||
require.NoError(t, policy.validate())
|
||||
|
||||
sshPolicy, err := policy.compileSSHPolicy(cliOnlyUsers, nodeTaggedServer.View(), cliNodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy)
|
||||
assert.Empty(t, sshPolicy.Rules, "CLI users without email should not match localpart rules")
|
||||
})
|
||||
|
||||
t.Run("localpart with multiple domains", func(t *testing.T) {
|
||||
policy := &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{up("alice@example.com")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{agp("autogroup:member")},
|
||||
Destinations: SSHDstAliases{tp("tag:server")},
|
||||
Users: []SSHUser{
|
||||
SSHUser("localpart:*@example.com"),
|
||||
SSHUser("localpart:*@other.com"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
require.NoError(t, policy.validate())
|
||||
|
||||
sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sshPolicy)
|
||||
|
||||
// alice@example.com, bob@example.com, charlie@other.com should all match
|
||||
require.Len(t, sshPolicy.Rules, 3, "Should have rules for alice, bob, and charlie")
|
||||
|
||||
localparts := make(map[string]bool)
|
||||
|
||||
for _, rule := range sshPolicy.Rules {
|
||||
for k := range rule.SSHUsers {
|
||||
localparts[k] = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, localparts["alice"], "Should have alice's localpart rule")
|
||||
assert.True(t, localparts["bob"], "Should have bob's localpart rule")
|
||||
assert.True(t, localparts["charlie"], "Should have charlie's localpart rule")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCompileSSHPolicy_CheckAction(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Name: "user1", Model: gorm.Model{ID: 1}},
|
||||
|
||||
@ -43,6 +43,7 @@ var (
|
||||
ErrSSHAutogroupSelfRequiresUserSource = errors.New("autogroup:self destination requires source to contain only users or groups, not tags or autogroup:tagged")
|
||||
ErrSSHTagSourceToAutogroupMember = errors.New("tags in SSH source cannot access autogroup:member (user-owned devices)")
|
||||
ErrSSHWildcardDestination = errors.New("wildcard (*) is not supported as SSH destination")
|
||||
ErrInvalidLocalpart = errors.New("invalid localpart format, must be localpart:*@<domain>")
|
||||
)
|
||||
|
||||
// ACL validation errors.
|
||||
@ -1953,6 +1954,14 @@ func (p *Policy) validate() error {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if user.IsLocalpart() {
|
||||
_, err := user.ParseLocalpart()
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, src := range ssh.Sources {
|
||||
@ -2255,6 +2264,11 @@ type SSHDstAliases []Alias
|
||||
|
||||
type SSHUsers []SSHUser
|
||||
|
||||
// SSHUserLocalpartPrefix is the prefix for localpart SSH user entries.
|
||||
// Format: localpart:*@<domain>
|
||||
// See: https://tailscale.com/docs/features/tailscale-ssh#users
|
||||
const SSHUserLocalpartPrefix = "localpart:"
|
||||
|
||||
func (u SSHUsers) ContainsRoot() bool {
|
||||
return slices.Contains(u, "root")
|
||||
}
|
||||
@ -2263,9 +2277,25 @@ func (u SSHUsers) ContainsNonRoot() bool {
|
||||
return slices.Contains(u, SSHUser(AutoGroupNonRoot))
|
||||
}
|
||||
|
||||
// ContainsLocalpart returns true if any entry has the localpart: prefix.
|
||||
func (u SSHUsers) ContainsLocalpart() bool {
|
||||
return slices.ContainsFunc(u, func(user SSHUser) bool {
|
||||
return user.IsLocalpart()
|
||||
})
|
||||
}
|
||||
|
||||
// NormalUsers returns all SSH users that are not root, autogroup:nonroot,
|
||||
// or localpart: entries.
|
||||
func (u SSHUsers) NormalUsers() []SSHUser {
|
||||
return slicesx.Filter(nil, u, func(user SSHUser) bool {
|
||||
return user != "root" && user != SSHUser(AutoGroupNonRoot)
|
||||
return user != "root" && user != SSHUser(AutoGroupNonRoot) && !user.IsLocalpart()
|
||||
})
|
||||
}
|
||||
|
||||
// LocalpartEntries returns only the localpart: prefixed entries.
|
||||
func (u SSHUsers) LocalpartEntries() []SSHUser {
|
||||
return slicesx.Filter(nil, u, func(user SSHUser) bool {
|
||||
return user.IsLocalpart()
|
||||
})
|
||||
}
|
||||
|
||||
@ -2275,6 +2305,41 @@ func (u SSHUser) String() string {
|
||||
return string(u)
|
||||
}
|
||||
|
||||
// IsLocalpart returns true if the SSHUser has the localpart: prefix.
|
||||
func (u SSHUser) IsLocalpart() bool {
|
||||
return strings.HasPrefix(string(u), SSHUserLocalpartPrefix)
|
||||
}
|
||||
|
||||
// ParseLocalpart validates and extracts the domain from a localpart: entry.
|
||||
// The expected format is localpart:*@<domain>.
|
||||
// Returns the domain part or an error if the format is invalid.
|
||||
func (u SSHUser) ParseLocalpart() (string, error) {
|
||||
if !u.IsLocalpart() {
|
||||
return "", fmt.Errorf("%w: missing prefix %q in %q", ErrInvalidLocalpart, SSHUserLocalpartPrefix, u)
|
||||
}
|
||||
|
||||
pattern := strings.TrimPrefix(string(u), SSHUserLocalpartPrefix)
|
||||
|
||||
// Must be *@<domain>
|
||||
atIdx := strings.LastIndex(pattern, "@")
|
||||
if atIdx < 0 {
|
||||
return "", fmt.Errorf("%w: missing @ in %q", ErrInvalidLocalpart, u)
|
||||
}
|
||||
|
||||
localPart := pattern[:atIdx]
|
||||
domain := pattern[atIdx+1:]
|
||||
|
||||
if localPart != "*" {
|
||||
return "", fmt.Errorf("%w: local part must be *, got %q in %q", ErrInvalidLocalpart, localPart, u)
|
||||
}
|
||||
|
||||
if domain == "" {
|
||||
return "", fmt.Errorf("%w: empty domain in %q", ErrInvalidLocalpart, u)
|
||||
}
|
||||
|
||||
return domain, nil
|
||||
}
|
||||
|
||||
// MarshalJSON marshals the SSHUser to JSON.
|
||||
func (u SSHUser) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(string(u))
|
||||
|
||||
@ -1766,6 +1766,105 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ssh-localpart-valid",
|
||||
input: `
|
||||
{
|
||||
"tagOwners": {"tag:prod": ["admin@"]},
|
||||
"ssh": [{
|
||||
"action": "accept",
|
||||
"src": ["autogroup:member"],
|
||||
"dst": ["tag:prod"],
|
||||
"users": ["localpart:*@example.com"]
|
||||
}]
|
||||
}
|
||||
`,
|
||||
want: &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:prod"): Owners{up("admin@")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{agp("autogroup:member")},
|
||||
Destinations: SSHDstAliases{tp("tag:prod")},
|
||||
Users: []SSHUser{SSHUser("localpart:*@example.com")},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ssh-localpart-with-other-users",
|
||||
input: `
|
||||
{
|
||||
"tagOwners": {"tag:prod": ["admin@"]},
|
||||
"ssh": [{
|
||||
"action": "accept",
|
||||
"src": ["autogroup:member"],
|
||||
"dst": ["tag:prod"],
|
||||
"users": ["localpart:*@example.com", "root", "autogroup:nonroot"]
|
||||
}]
|
||||
}
|
||||
`,
|
||||
want: &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:prod"): Owners{up("admin@")},
|
||||
},
|
||||
SSHs: []SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: SSHSrcAliases{agp("autogroup:member")},
|
||||
Destinations: SSHDstAliases{tp("tag:prod")},
|
||||
Users: []SSHUser{SSHUser("localpart:*@example.com"), "root", SSHUser(AutoGroupNonRoot)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ssh-localpart-invalid-no-at-sign",
|
||||
input: `
|
||||
{
|
||||
"tagOwners": {"tag:prod": ["admin@"]},
|
||||
"ssh": [{
|
||||
"action": "accept",
|
||||
"src": ["autogroup:member"],
|
||||
"dst": ["tag:prod"],
|
||||
"users": ["localpart:foo"]
|
||||
}]
|
||||
}
|
||||
`,
|
||||
wantErr: "invalid localpart format",
|
||||
},
|
||||
{
|
||||
name: "ssh-localpart-invalid-non-wildcard",
|
||||
input: `
|
||||
{
|
||||
"tagOwners": {"tag:prod": ["admin@"]},
|
||||
"ssh": [{
|
||||
"action": "accept",
|
||||
"src": ["autogroup:member"],
|
||||
"dst": ["tag:prod"],
|
||||
"users": ["localpart:alice@example.com"]
|
||||
}]
|
||||
}
|
||||
`,
|
||||
wantErr: "invalid localpart format",
|
||||
},
|
||||
{
|
||||
name: "ssh-localpart-invalid-empty-domain",
|
||||
input: `
|
||||
{
|
||||
"tagOwners": {"tag:prod": ["admin@"]},
|
||||
"ssh": [{
|
||||
"action": "accept",
|
||||
"src": ["autogroup:member"],
|
||||
"dst": ["tag:prod"],
|
||||
"users": ["localpart:*@"]
|
||||
}]
|
||||
}
|
||||
`,
|
||||
wantErr: "invalid localpart format",
|
||||
},
|
||||
}
|
||||
|
||||
cmps := append(util.Comparers,
|
||||
@ -2576,6 +2675,154 @@ func TestSSHUsers_ContainsNonRoot(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHUsers_ContainsLocalpart(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
users SSHUsers
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty users",
|
||||
users: SSHUsers{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "contains localpart",
|
||||
users: SSHUsers{SSHUser("localpart:*@example.com")},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "does not contain localpart",
|
||||
users: SSHUsers{"ubuntu", "admin", "root"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "contains localpart among others",
|
||||
users: SSHUsers{"ubuntu", SSHUser("localpart:*@example.com"), "admin"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "multiple localpart entries",
|
||||
users: SSHUsers{SSHUser("localpart:*@a.com"), SSHUser("localpart:*@b.com")},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.users.ContainsLocalpart()
|
||||
assert.Equal(t, tt.expected, result, "ContainsLocalpart() should return expected result")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHUsers_LocalpartEntries(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
users SSHUsers
|
||||
expected []SSHUser
|
||||
}{
|
||||
{
|
||||
name: "empty users",
|
||||
users: SSHUsers{},
|
||||
expected: []SSHUser{},
|
||||
},
|
||||
{
|
||||
name: "no localpart entries",
|
||||
users: SSHUsers{"root", "ubuntu", SSHUser(AutoGroupNonRoot)},
|
||||
expected: []SSHUser{},
|
||||
},
|
||||
{
|
||||
name: "single localpart entry",
|
||||
users: SSHUsers{"root", SSHUser("localpart:*@example.com"), "ubuntu"},
|
||||
expected: []SSHUser{SSHUser("localpart:*@example.com")},
|
||||
},
|
||||
{
|
||||
name: "multiple localpart entries",
|
||||
users: SSHUsers{SSHUser("localpart:*@a.com"), "root", SSHUser("localpart:*@b.com")},
|
||||
expected: []SSHUser{SSHUser("localpart:*@a.com"), SSHUser("localpart:*@b.com")},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.users.LocalpartEntries()
|
||||
assert.ElementsMatch(t, tt.expected, result, "LocalpartEntries() should return expected entries")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHUsers_NormalUsers_ExcludesLocalpart(t *testing.T) {
|
||||
users := SSHUsers{
|
||||
"ubuntu",
|
||||
"root",
|
||||
SSHUser(AutoGroupNonRoot),
|
||||
SSHUser("localpart:*@example.com"),
|
||||
"admin",
|
||||
}
|
||||
|
||||
result := users.NormalUsers()
|
||||
assert.ElementsMatch(t, []SSHUser{"ubuntu", "admin"}, result,
|
||||
"NormalUsers() should exclude root, autogroup:nonroot, and localpart entries")
|
||||
}
|
||||
|
||||
func TestSSHUser_ParseLocalpart(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user SSHUser
|
||||
expectedDomain string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid localpart",
|
||||
user: SSHUser("localpart:*@example.com"),
|
||||
expectedDomain: "example.com",
|
||||
},
|
||||
{
|
||||
name: "valid localpart with subdomain",
|
||||
user: SSHUser("localpart:*@corp.example.com"),
|
||||
expectedDomain: "corp.example.com",
|
||||
},
|
||||
{
|
||||
name: "missing prefix",
|
||||
user: SSHUser("ubuntu"),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing @ sign",
|
||||
user: SSHUser("localpart:foo"),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-wildcard local part",
|
||||
user: SSHUser("localpart:alice@example.com"),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty domain",
|
||||
user: SSHUser("localpart:*@"),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "just prefix",
|
||||
user: SSHUser("localpart:"),
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
domain, err := tt.user.ParseLocalpart()
|
||||
if tt.expectErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedDomain, domain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustIPSet(prefixes ...string) *netipx.IPSet {
|
||||
var builder netipx.IPSetBuilder
|
||||
for _, p := range prefixes {
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/oauth2-proxy/mockoidc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
@ -20,7 +21,8 @@ func isSSHNoAccessStdError(stderr string) bool {
|
||||
// Since https://github.com/tailscale/tailscale/pull/14853
|
||||
strings.Contains(stderr, "failed to evaluate SSH policy") ||
|
||||
// Since https://github.com/tailscale/tailscale/pull/16127
|
||||
strings.Contains(stderr, "tailnet policy does not permit you to SSH to this node")
|
||||
// Covers both "to this node" and "as user <name>" variants.
|
||||
strings.Contains(stderr, "tailnet policy does not permit you to SSH")
|
||||
}
|
||||
|
||||
func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Scenario {
|
||||
@ -420,15 +422,27 @@ func doSSHWithoutRetry(t *testing.T, client TailscaleClient, peer TailscaleClien
|
||||
func doSSHWithRetry(t *testing.T, client TailscaleClient, peer TailscaleClient, retry bool) (string, string, error) {
|
||||
t.Helper()
|
||||
|
||||
return doSSHWithRetryAsUser(t, client, peer, "ssh-it-user", retry)
|
||||
}
|
||||
|
||||
func doSSHWithRetryAsUser(
|
||||
t *testing.T,
|
||||
client TailscaleClient,
|
||||
peer TailscaleClient,
|
||||
sshUser string,
|
||||
retry bool,
|
||||
) (string, string, error) {
|
||||
t.Helper()
|
||||
|
||||
peerFQDN, _ := peer.FQDN()
|
||||
|
||||
command := []string{
|
||||
"/usr/bin/ssh", "-o StrictHostKeyChecking=no", "-o ConnectTimeout=1",
|
||||
fmt.Sprintf("%s@%s", "ssh-it-user", peerFQDN),
|
||||
fmt.Sprintf("%s@%s", sshUser, peerFQDN),
|
||||
"'hostname'",
|
||||
}
|
||||
|
||||
log.Printf("Running from %s to %s", client.Hostname(), peer.Hostname())
|
||||
log.Printf("Running from %s to %s as %s", client.Hostname(), peer.Hostname(), sshUser)
|
||||
log.Printf("Command: %s", strings.Join(command, " "))
|
||||
|
||||
var (
|
||||
@ -499,6 +513,31 @@ func assertSSHNoAccessStdError(t *testing.T, err error, stderr string) {
|
||||
}
|
||||
}
|
||||
|
||||
func doSSHAsUser(t *testing.T, client TailscaleClient, peer TailscaleClient, sshUser string) (string, string, error) {
|
||||
t.Helper()
|
||||
|
||||
return doSSHWithRetryAsUser(t, client, peer, sshUser, true)
|
||||
}
|
||||
|
||||
func assertSSHHostnameAsUser(t *testing.T, client TailscaleClient, peer TailscaleClient, sshUser string) {
|
||||
t.Helper()
|
||||
|
||||
result, _, err := doSSHAsUser(t, client, peer, sshUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Contains(t, peer.ContainerID(), strings.ReplaceAll(result, "\n", ""))
|
||||
}
|
||||
|
||||
func assertSSHPermissionDeniedAsUser(t *testing.T, client TailscaleClient, peer TailscaleClient, sshUser string) {
|
||||
t.Helper()
|
||||
|
||||
result, stderr, err := doSSHWithRetryAsUser(t, client, peer, sshUser, false)
|
||||
|
||||
assert.Empty(t, result)
|
||||
|
||||
assertSSHNoAccessStdError(t, err, stderr)
|
||||
}
|
||||
|
||||
// TestSSHAutogroupSelf tests that SSH with autogroup:self works correctly:
|
||||
// - Users can SSH to their own devices
|
||||
// - Users cannot SSH to other users' devices.
|
||||
@ -579,3 +618,234 @@ func TestSSHAutogroupSelf(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSHLocalpart tests that SSH with localpart:*@<domain> works correctly.
|
||||
// localpart maps the local-part of each user's OIDC email to an OS user,
|
||||
// so user1@headscale.net can SSH as local user "user1".
|
||||
// This requires OIDC login so that users have real email addresses.
|
||||
func TestSSHLocalpart(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
baseACLs := []policyv2.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Sources: []policyv2.Alias{wildcard()},
|
||||
Destinations: []policyv2.AliasWithPorts{
|
||||
aliasWithPorts(wildcard(), tailcfg.PortRangeAny),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
policy *policyv2.Policy
|
||||
testFn func(t *testing.T, scenario *Scenario)
|
||||
}{
|
||||
{
|
||||
name: "MemberAndTagged",
|
||||
policy: &policyv2.Policy{
|
||||
ACLs: baseACLs,
|
||||
SSHs: []policyv2.SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: policyv2.SSHSrcAliases{new(policyv2.AutoGroupMember)},
|
||||
Destinations: policyv2.SSHDstAliases{
|
||||
new(policyv2.AutoGroupMember),
|
||||
new(policyv2.AutoGroupTagged),
|
||||
},
|
||||
Users: []policyv2.SSHUser{"localpart:*@headscale.net"},
|
||||
},
|
||||
},
|
||||
},
|
||||
testFn: func(t *testing.T, scenario *Scenario) {
|
||||
t.Helper()
|
||||
|
||||
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
user2Clients, err := scenario.ListTailscaleClients("user2")
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
// user1 can SSH to user2's nodes as "user1" (localpart of user1@headscale.net)
|
||||
for _, client := range user1Clients {
|
||||
for _, peer := range user2Clients {
|
||||
assertSSHHostnameAsUser(t, client, peer, "user1")
|
||||
}
|
||||
}
|
||||
|
||||
// user2 can SSH to user1's nodes as "user2" (localpart of user2@headscale.net)
|
||||
for _, client := range user2Clients {
|
||||
for _, peer := range user1Clients {
|
||||
assertSSHHostnameAsUser(t, client, peer, "user2")
|
||||
}
|
||||
}
|
||||
|
||||
// user1 CANNOT SSH as "user2" — no rule maps user1's IPs to user2
|
||||
for _, client := range user1Clients {
|
||||
for _, peer := range user2Clients {
|
||||
assertSSHPermissionDeniedAsUser(t, client, peer, "user2")
|
||||
}
|
||||
}
|
||||
|
||||
// user2 CANNOT SSH as "user1" — no rule maps user2's IPs to user1
|
||||
for _, client := range user2Clients {
|
||||
for _, peer := range user1Clients {
|
||||
assertSSHPermissionDeniedAsUser(t, client, peer, "user1")
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "AutogroupSelf",
|
||||
policy: &policyv2.Policy{
|
||||
ACLs: baseACLs,
|
||||
SSHs: []policyv2.SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: policyv2.SSHSrcAliases{new(policyv2.AutoGroupMember)},
|
||||
Destinations: policyv2.SSHDstAliases{new(policyv2.AutoGroupSelf)},
|
||||
Users: []policyv2.SSHUser{"localpart:*@headscale.net"},
|
||||
},
|
||||
},
|
||||
},
|
||||
testFn: func(t *testing.T, scenario *Scenario) {
|
||||
t.Helper()
|
||||
|
||||
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
user2Clients, err := scenario.ListTailscaleClients("user2")
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
// With autogroup:self, cross-user SSH should be denied regardless of localpart.
|
||||
// user1 cannot SSH to user2's nodes as "user1"
|
||||
for _, client := range user1Clients {
|
||||
for _, peer := range user2Clients {
|
||||
assertSSHPermissionDeniedAsUser(t, client, peer, "user1")
|
||||
}
|
||||
}
|
||||
|
||||
// user2 cannot SSH to user1's nodes as "user2"
|
||||
for _, client := range user2Clients {
|
||||
for _, peer := range user1Clients {
|
||||
assertSSHPermissionDeniedAsUser(t, client, peer, "user2")
|
||||
}
|
||||
}
|
||||
|
||||
// user1 also cannot SSH to user2's nodes as "user2"
|
||||
for _, client := range user1Clients {
|
||||
for _, peer := range user2Clients {
|
||||
assertSSHPermissionDeniedAsUser(t, client, peer, "user2")
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "LocalpartPlusRoot",
|
||||
policy: &policyv2.Policy{
|
||||
ACLs: baseACLs,
|
||||
SSHs: []policyv2.SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: policyv2.SSHSrcAliases{new(policyv2.AutoGroupMember)},
|
||||
Destinations: policyv2.SSHDstAliases{
|
||||
new(policyv2.AutoGroupMember),
|
||||
new(policyv2.AutoGroupTagged),
|
||||
},
|
||||
Users: []policyv2.SSHUser{
|
||||
"localpart:*@headscale.net",
|
||||
"root",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
testFn: func(t *testing.T, scenario *Scenario) {
|
||||
t.Helper()
|
||||
|
||||
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
user2Clients, err := scenario.ListTailscaleClients("user2")
|
||||
requireNoErrListClients(t, err)
|
||||
|
||||
// localpart works: user1 can SSH to user2's nodes as "user1"
|
||||
for _, client := range user1Clients {
|
||||
for _, peer := range user2Clients {
|
||||
assertSSHHostnameAsUser(t, client, peer, "user1")
|
||||
}
|
||||
}
|
||||
|
||||
// root also works: user1 can SSH to user2's nodes as "root"
|
||||
for _, client := range user1Clients {
|
||||
for _, peer := range user2Clients {
|
||||
assertSSHHostnameAsUser(t, client, peer, "root")
|
||||
}
|
||||
}
|
||||
|
||||
// user2 can SSH as "user2" (localpart)
|
||||
for _, client := range user2Clients {
|
||||
for _, peer := range user1Clients {
|
||||
assertSSHHostnameAsUser(t, client, peer, "user2")
|
||||
}
|
||||
}
|
||||
|
||||
// user2 can SSH as "root"
|
||||
for _, client := range user2Clients {
|
||||
for _, peer := range user1Clients {
|
||||
assertSSHHostnameAsUser(t, client, peer, "root")
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: 1,
|
||||
Users: []string{"user1", "user2"},
|
||||
OIDCUsers: []mockoidc.MockUser{
|
||||
oidcMockUser("user1", true),
|
||||
oidcMockUser("user2", true),
|
||||
},
|
||||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
oidcMap := map[string]string{
|
||||
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
|
||||
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
|
||||
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnvWithLoginURL(
|
||||
[]tsic.Option{
|
||||
tsic.WithSSH(),
|
||||
tsic.WithNetfilter("off"),
|
||||
tsic.WithPackages("openssh"),
|
||||
tsic.WithExtraCommands("adduser user1", "adduser user2"),
|
||||
tsic.WithDockerWorkdir("/"),
|
||||
},
|
||||
hsic.WithTestName("sshlocalpart"),
|
||||
hsic.WithACLPolicy(tt.policy),
|
||||
hsic.WithConfigEnv(oidcMap),
|
||||
hsic.WithTLS(),
|
||||
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())),
|
||||
)
|
||||
requireNoErrHeadscaleEnv(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||
requireNoErrListFQDN(t, err)
|
||||
|
||||
tt.testFn(t, scenario)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user