diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 8785bed0..b876dd69 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -43,6 +43,7 @@ var ( ErrSSHAutogroupSelfRequiresUserSource = errors.New("autogroup:self destination requires source to contain only users or groups, not tags or autogroup:tagged") ErrSSHTagSourceToAutogroupMember = errors.New("tags in SSH source cannot access autogroup:member (user-owned devices)") ErrSSHWildcardDestination = errors.New("wildcard (*) is not supported as SSH destination") + ErrInvalidLocalpart = errors.New("invalid localpart format, must be localpart:*@") ) // ACL validation errors. @@ -1953,6 +1954,14 @@ func (p *Policy) validate() error { continue } } + + if user.IsLocalpart() { + _, err := user.ParseLocalpart() + if err != nil { + errs = append(errs, err) + continue + } + } } for _, src := range ssh.Sources { @@ -2255,6 +2264,11 @@ type SSHDstAliases []Alias type SSHUsers []SSHUser +// SSHUserLocalpartPrefix is the prefix for localpart SSH user entries. +// Format: localpart:*@ +// See: https://tailscale.com/docs/features/tailscale-ssh#users +const SSHUserLocalpartPrefix = "localpart:" + func (u SSHUsers) ContainsRoot() bool { return slices.Contains(u, "root") } @@ -2263,9 +2277,25 @@ func (u SSHUsers) ContainsNonRoot() bool { return slices.Contains(u, SSHUser(AutoGroupNonRoot)) } +// ContainsLocalpart returns true if any entry has the localpart: prefix. +func (u SSHUsers) ContainsLocalpart() bool { + return slices.ContainsFunc(u, func(user SSHUser) bool { + return user.IsLocalpart() + }) +} + +// NormalUsers returns all SSH users that are not root, autogroup:nonroot, +// or localpart: entries. func (u SSHUsers) NormalUsers() []SSHUser { return slicesx.Filter(nil, u, func(user SSHUser) bool { - return user != "root" && user != SSHUser(AutoGroupNonRoot) + return user != "root" && user != SSHUser(AutoGroupNonRoot) && !user.IsLocalpart() + }) +} + +// LocalpartEntries returns only the localpart: prefixed entries. +func (u SSHUsers) LocalpartEntries() []SSHUser { + return slicesx.Filter(nil, u, func(user SSHUser) bool { + return user.IsLocalpart() }) } @@ -2275,6 +2305,41 @@ func (u SSHUser) String() string { return string(u) } +// IsLocalpart returns true if the SSHUser has the localpart: prefix. +func (u SSHUser) IsLocalpart() bool { + return strings.HasPrefix(string(u), SSHUserLocalpartPrefix) +} + +// ParseLocalpart validates and extracts the domain from a localpart: entry. +// The expected format is localpart:*@. +// Returns the domain part or an error if the format is invalid. +func (u SSHUser) ParseLocalpart() (string, error) { + if !u.IsLocalpart() { + return "", fmt.Errorf("%w: missing prefix %q in %q", ErrInvalidLocalpart, SSHUserLocalpartPrefix, u) + } + + pattern := strings.TrimPrefix(string(u), SSHUserLocalpartPrefix) + + // Must be *@ + atIdx := strings.LastIndex(pattern, "@") + if atIdx < 0 { + return "", fmt.Errorf("%w: missing @ in %q", ErrInvalidLocalpart, u) + } + + localPart := pattern[:atIdx] + domain := pattern[atIdx+1:] + + if localPart != "*" { + return "", fmt.Errorf("%w: local part must be *, got %q in %q", ErrInvalidLocalpart, localPart, u) + } + + if domain == "" { + return "", fmt.Errorf("%w: empty domain in %q", ErrInvalidLocalpart, u) + } + + return domain, nil +} + // MarshalJSON marshals the SSHUser to JSON. func (u SSHUser) MarshalJSON() ([]byte, error) { return json.Marshal(string(u)) diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index acea9c28..95899b17 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -1766,6 +1766,105 @@ func TestUnmarshalPolicy(t *testing.T) { }, }, }, + { + name: "ssh-localpart-valid", + input: ` +{ + "tagOwners": {"tag:prod": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:prod"], + "users": ["localpart:*@example.com"] + }] +} +`, + want: &Policy{ + TagOwners: TagOwners{ + Tag("tag:prod"): Owners{up("admin@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:prod")}, + Users: []SSHUser{SSHUser("localpart:*@example.com")}, + }, + }, + }, + }, + { + name: "ssh-localpart-with-other-users", + input: ` +{ + "tagOwners": {"tag:prod": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:prod"], + "users": ["localpart:*@example.com", "root", "autogroup:nonroot"] + }] +} +`, + want: &Policy{ + TagOwners: TagOwners{ + Tag("tag:prod"): Owners{up("admin@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{tp("tag:prod")}, + Users: []SSHUser{SSHUser("localpart:*@example.com"), "root", SSHUser(AutoGroupNonRoot)}, + }, + }, + }, + }, + { + name: "ssh-localpart-invalid-no-at-sign", + input: ` +{ + "tagOwners": {"tag:prod": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:prod"], + "users": ["localpart:foo"] + }] +} +`, + wantErr: "invalid localpart format", + }, + { + name: "ssh-localpart-invalid-non-wildcard", + input: ` +{ + "tagOwners": {"tag:prod": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:prod"], + "users": ["localpart:alice@example.com"] + }] +} +`, + wantErr: "invalid localpart format", + }, + { + name: "ssh-localpart-invalid-empty-domain", + input: ` +{ + "tagOwners": {"tag:prod": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:prod"], + "users": ["localpart:*@"] + }] +} +`, + wantErr: "invalid localpart format", + }, } cmps := append(util.Comparers, @@ -2576,6 +2675,154 @@ func TestSSHUsers_ContainsNonRoot(t *testing.T) { } } +func TestSSHUsers_ContainsLocalpart(t *testing.T) { + tests := []struct { + name string + users SSHUsers + expected bool + }{ + { + name: "empty users", + users: SSHUsers{}, + expected: false, + }, + { + name: "contains localpart", + users: SSHUsers{SSHUser("localpart:*@example.com")}, + expected: true, + }, + { + name: "does not contain localpart", + users: SSHUsers{"ubuntu", "admin", "root"}, + expected: false, + }, + { + name: "contains localpart among others", + users: SSHUsers{"ubuntu", SSHUser("localpart:*@example.com"), "admin"}, + expected: true, + }, + { + name: "multiple localpart entries", + users: SSHUsers{SSHUser("localpart:*@a.com"), SSHUser("localpart:*@b.com")}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.users.ContainsLocalpart() + assert.Equal(t, tt.expected, result, "ContainsLocalpart() should return expected result") + }) + } +} + +func TestSSHUsers_LocalpartEntries(t *testing.T) { + tests := []struct { + name string + users SSHUsers + expected []SSHUser + }{ + { + name: "empty users", + users: SSHUsers{}, + expected: []SSHUser{}, + }, + { + name: "no localpart entries", + users: SSHUsers{"root", "ubuntu", SSHUser(AutoGroupNonRoot)}, + expected: []SSHUser{}, + }, + { + name: "single localpart entry", + users: SSHUsers{"root", SSHUser("localpart:*@example.com"), "ubuntu"}, + expected: []SSHUser{SSHUser("localpart:*@example.com")}, + }, + { + name: "multiple localpart entries", + users: SSHUsers{SSHUser("localpart:*@a.com"), "root", SSHUser("localpart:*@b.com")}, + expected: []SSHUser{SSHUser("localpart:*@a.com"), SSHUser("localpart:*@b.com")}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.users.LocalpartEntries() + assert.ElementsMatch(t, tt.expected, result, "LocalpartEntries() should return expected entries") + }) + } +} + +func TestSSHUsers_NormalUsers_ExcludesLocalpart(t *testing.T) { + users := SSHUsers{ + "ubuntu", + "root", + SSHUser(AutoGroupNonRoot), + SSHUser("localpart:*@example.com"), + "admin", + } + + result := users.NormalUsers() + assert.ElementsMatch(t, []SSHUser{"ubuntu", "admin"}, result, + "NormalUsers() should exclude root, autogroup:nonroot, and localpart entries") +} + +func TestSSHUser_ParseLocalpart(t *testing.T) { + tests := []struct { + name string + user SSHUser + expectedDomain string + expectErr bool + }{ + { + name: "valid localpart", + user: SSHUser("localpart:*@example.com"), + expectedDomain: "example.com", + }, + { + name: "valid localpart with subdomain", + user: SSHUser("localpart:*@corp.example.com"), + expectedDomain: "corp.example.com", + }, + { + name: "missing prefix", + user: SSHUser("ubuntu"), + expectErr: true, + }, + { + name: "missing @ sign", + user: SSHUser("localpart:foo"), + expectErr: true, + }, + { + name: "non-wildcard local part", + user: SSHUser("localpart:alice@example.com"), + expectErr: true, + }, + { + name: "empty domain", + user: SSHUser("localpart:*@"), + expectErr: true, + }, + { + name: "just prefix", + user: SSHUser("localpart:"), + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + domain, err := tt.user.ParseLocalpart() + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedDomain, domain) + } + }) + } +} + func mustIPSet(prefixes ...string) *netipx.IPSet { var builder netipx.IPSetBuilder for _, p := range prefixes {