From 72049e94b1c0d10da322ca7b1064f34c0ce48c55 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 26 Feb 2025 19:31:08 +0100 Subject: [PATCH] introduce policy v2 package policy v2 is built from the ground up to be stricter and follow the same pattern for all types of resolvers. TODO introduce aliass resolver Signed-off-by: Kristoffer Dalby --- hscontrol/policy/v2/filter.go | 169 ++++ hscontrol/policy/v2/filter_test.go | 378 +++++++++ hscontrol/policy/v2/policy.go | 283 +++++++ hscontrol/policy/v2/policy_test.go | 58 ++ hscontrol/policy/v2/types.go | 1005 ++++++++++++++++++++++++ hscontrol/policy/v2/types_test.go | 1162 ++++++++++++++++++++++++++++ hscontrol/policy/v2/utils.go | 164 ++++ hscontrol/policy/v2/utils_test.go | 102 +++ 8 files changed, 3321 insertions(+) create mode 100644 hscontrol/policy/v2/filter.go create mode 100644 hscontrol/policy/v2/filter_test.go create mode 100644 hscontrol/policy/v2/policy.go create mode 100644 hscontrol/policy/v2/policy_test.go create mode 100644 hscontrol/policy/v2/types.go create mode 100644 hscontrol/policy/v2/types_test.go create mode 100644 hscontrol/policy/v2/utils.go create mode 100644 hscontrol/policy/v2/utils_test.go diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go new file mode 100644 index 00000000..2d6c3f12 --- /dev/null +++ b/hscontrol/policy/v2/filter.go @@ -0,0 +1,169 @@ +package v2 + +import ( + "errors" + "fmt" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" + "go4.org/netipx" + "tailscale.com/tailcfg" +) + +var ( + ErrInvalidAction = errors.New("invalid action") +) + +// compileFilterRules takes a set of nodes and an ACLPolicy and generates a +// set of Tailscale compatible FilterRules used to allow traffic on clients. +func (pol *Policy) compileFilterRules( + users types.Users, + nodes types.Nodes, +) ([]tailcfg.FilterRule, error) { + if pol == nil { + return tailcfg.FilterAllowAll, nil + } + + var rules []tailcfg.FilterRule + + for _, acl := range pol.ACLs { + if acl.Action != "accept" { + return nil, ErrInvalidAction + } + + srcIPs, err := acl.Sources.Resolve(pol, users, nodes) + if err != nil { + log.Trace().Err(err).Msgf("resolving source ips") + } + + if len(srcIPs.Prefixes()) == 0 { + continue + } + + // TODO(kradalby): integrate type into schema + // TODO(kradalby): figure out the _ is wildcard stuff + protocols, _, err := parseProtocol(acl.Protocol) + if err != nil { + return nil, fmt.Errorf("parsing policy, protocol err: %w ", err) + } + + var destPorts []tailcfg.NetPortRange + for _, dest := range acl.Destinations { + ips, err := dest.Alias.Resolve(pol, users, nodes) + if err != nil { + log.Trace().Err(err).Msgf("resolving destination ips") + } + + for _, pref := range ips.Prefixes() { + for _, port := range dest.Ports { + pr := tailcfg.NetPortRange{ + IP: pref.String(), + Ports: port, + } + destPorts = append(destPorts, pr) + } + } + } + + if len(destPorts) == 0 { + continue + } + + rules = append(rules, tailcfg.FilterRule{ + SrcIPs: ipSetToPrefixStringList(srcIPs), + DstPorts: destPorts, + IPProto: protocols, + }) + } + + return rules, nil +} + +func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction { + return tailcfg.SSHAction{ + Reject: !accept, + Accept: accept, + SessionDuration: duration, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + } +} + +func (pol *Policy) compileSSHPolicy( + users types.Users, + node *types.Node, + nodes types.Nodes, +) (*tailcfg.SSHPolicy, error) { + if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 { + return nil, nil + } + + var rules []*tailcfg.SSHRule + + for index, rule := range pol.SSHs { + var dest netipx.IPSetBuilder + for _, src := range rule.Destinations { + ips, err := src.Resolve(pol, users, nodes) + if err != nil { + log.Trace().Err(err).Msgf("resolving destination ips") + } + dest.AddSet(ips) + } + + destSet, err := dest.IPSet() + if err != nil { + return nil, err + } + + if !node.InIPSet(destSet) { + continue + } + + var action tailcfg.SSHAction + switch rule.Action { + case "accept": + action = sshAction(true, 0) + case "check": + action = sshAction(true, rule.CheckPeriod) + default: + return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) + } + + var principals []*tailcfg.SSHPrincipal + srcIPs, err := rule.Sources.Resolve(pol, users, nodes) + if err != nil { + log.Trace().Err(err).Msgf("resolving source ips") + } + + for addr := range util.IPSetAddrIter(srcIPs) { + principals = append(principals, &tailcfg.SSHPrincipal{ + NodeIP: addr.String(), + }) + } + + userMap := make(map[string]string, len(rule.Users)) + for _, user := range rule.Users { + userMap[user.String()] = "=" + } + rules = append(rules, &tailcfg.SSHRule{ + Principals: principals, + SSHUsers: userMap, + Action: &action, + }) + } + + return &tailcfg.SSHPolicy{ + Rules: rules, + }, nil +} + +func ipSetToPrefixStringList(ips *netipx.IPSet) []string { + var out []string + + for _, pref := range ips.Prefixes() { + out = append(out, pref.String()) + } + return out +} diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go new file mode 100644 index 00000000..e0b12520 --- /dev/null +++ b/hscontrol/policy/v2/filter_test.go @@ -0,0 +1,378 @@ +package v2 + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "gorm.io/gorm" + "tailscale.com/tailcfg" +) + +func TestParsing(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "testuser"}, + } + tests := []struct { + name string + format string + acl string + want []tailcfg.FilterRule + wantErr bool + }{ + { + name: "invalid-hujson", + format: "hujson", + acl: ` +{ + `, + want: []tailcfg.FilterRule{}, + wantErr: true, + }, + // The new parser will ignore all that is irrelevant + // { + // name: "valid-hujson-invalid-content", + // format: "hujson", + // acl: ` + // { + // "valid_json": true, + // "but_a_policy_though": false + // } + // `, + // want: []tailcfg.FilterRule{}, + // wantErr: true, + // }, + // { + // name: "invalid-cidr", + // format: "hujson", + // acl: ` + // {"example-host-1": "100.100.100.100/42"} + // `, + // want: []tailcfg.FilterRule{}, + // wantErr: true, + // }, + { + name: "basic-rule", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "subnet-1", + "192.168.1.0/24" + ], + "dst": [ + "*:22,3389", + "host-1:*", + ], + }, + ], +} + `, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.100.101.0/24", "192.168.1.0/24"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}}, + {IP: "::/0", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + {IP: "::/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}}, + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + { + name: "parse-protocol", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "tcp", + "dst": [ + "host-1:*", + ], + }, + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "udp", + "dst": [ + "host-1:53", + ], + }, + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "icmp", + "dst": [ + "host-1:*", + ], + }, + ], +}`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + IPProto: []int{protocolTCP}, + }, + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRange{First: 53, Last: 53}}, + }, + IPProto: []int{protocolUDP}, + }, + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + IPProto: []int{protocolICMP, protocolIPv6ICMP}, + }, + }, + wantErr: false, + }, + { + name: "port-wildcard", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "Action": "accept", + "src": [ + "*", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + { + name: "port-range", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "subnet-1", + ], + "dst": [ + "host-1:5400-5500", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.100.101.0/24"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.100.100.100/32", + Ports: tailcfg.PortRange{First: 5400, Last: 5500}, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "port-group", + format: "hujson", + acl: ` +{ + "groups": { + "group:example": [ + "testuser@", + ], + }, + + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "group:example", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"200.200.200.200/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + { + name: "port-user", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "testuser@", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"200.200.200.200/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + { + name: "ipv6", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100/32", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "*", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pol, err := policyFromBytes([]byte(tt.acl)) + if tt.wantErr && err == nil { + t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) + + return + } else if !tt.wantErr && err != nil { + t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if err != nil { + return + } + + rules, err := pol.compileFilterRules( + users, + types.Nodes{ + &types.Node{ + IPv4: ap("100.100.100.100"), + }, + &types.Node{ + IPv4: ap("200.200.200.200"), + User: users[0], + Hostinfo: &tailcfg.Hostinfo{}, + }, + }) + + if (err != nil) != tt.wantErr { + t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if diff := cmp.Diff(tt.want, rules); diff != "" { + t.Errorf("parsing() unexpected result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go new file mode 100644 index 00000000..41f51487 --- /dev/null +++ b/hscontrol/policy/v2/policy.go @@ -0,0 +1,283 @@ +package v2 + +import ( + "encoding/json" + "fmt" + "net/netip" + "strings" + "sync" + + "github.com/juanfont/headscale/hscontrol/types" + "go4.org/netipx" + "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" + "tailscale.com/util/deephash" +) + +type PolicyManager struct { + mu sync.Mutex + pol *Policy + users []types.User + nodes types.Nodes + + filterHash deephash.Sum + filter []tailcfg.FilterRule + + tagOwnerMapHash deephash.Sum + tagOwnerMap map[Tag]*netipx.IPSet + + autoApproveMapHash deephash.Sum + autoApproveMap map[netip.Prefix]*netipx.IPSet + + // Lazy map of SSH policies + sshPolicyMap map[types.NodeID]*tailcfg.SSHPolicy +} + +// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes. +// It returns an error if the policy file is invalid. +// The policy manager will update the filter rules based on the users and nodes. +func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) { + policy, err := policyFromBytes(b) + if err != nil { + return nil, fmt.Errorf("parsing policy: %w", err) + } + + pm := PolicyManager{ + pol: policy, + users: users, + nodes: nodes, + sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, len(nodes)), + } + + _, err = pm.updateLocked() + if err != nil { + return nil, err + } + + return &pm, nil +} + +// updateLocked updates the filter rules based on the current policy and nodes. +// It must be called with the lock held. +func (pm *PolicyManager) updateLocked() (bool, error) { + filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes) + if err != nil { + return false, fmt.Errorf("compiling filter rules: %w", err) + } + + filterHash := deephash.Hash(&filter) + filterChanged := filterHash == pm.filterHash + pm.filter = filter + pm.filterHash = filterHash + + // Order matters, tags might be used in autoapprovers, so we need to ensure + // that the map for tag owners is resolved before resolving autoapprovers. + // TODO(kradalby): Order might not matter after #2417 + tagMap, err := resolveTagOwners(pm.pol, pm.users, pm.nodes) + if err != nil { + return false, fmt.Errorf("resolving tag owners map: %w", err) + } + + tagOwnerMapHash := deephash.Hash(&tagMap) + tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash + pm.tagOwnerMap = tagMap + pm.tagOwnerMapHash = tagOwnerMapHash + + autoMap, err := resolveAutoApprovers(pm.pol, pm.users, pm.nodes) + if err != nil { + return false, fmt.Errorf("resolving auto approvers map: %w", err) + } + + autoApproveMapHash := deephash.Hash(&autoMap) + autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash + pm.autoApproveMap = autoMap + pm.autoApproveMapHash = autoApproveMapHash + + // If neither of the calculated values changed, no need to update nodes + if !filterChanged && !tagOwnerChanged && !autoApproveChanged { + return false, nil + } + + // Clear the SSH policy map to ensure it's recalculated with the new policy. + // TODO(kradalby): This could potentially be optimized by only clearing the + // policies for nodes that have changed. Particularly if the only difference is + // that nodes has been added or removed. + clear(pm.sshPolicyMap) + + return true, nil +} + +func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + + if sshPol, ok := pm.sshPolicyMap[node.ID]; ok { + return sshPol, nil + } + + sshPol, err := pm.pol.compileSSHPolicy(pm.users, node, pm.nodes) + if err != nil { + return nil, fmt.Errorf("compiling SSH policy: %w", err) + } + pm.sshPolicyMap[node.ID] = sshPol + + return sshPol, nil +} + +func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) { + if len(polB) == 0 { + return false, nil + } + + pol, err := policyFromBytes(polB) + if err != nil { + return false, fmt.Errorf("parsing policy: %w", err) + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + pm.pol = pol + + return pm.updateLocked() +} + +// Filter returns the current filter rules for the entire tailnet. +func (pm *PolicyManager) Filter() []tailcfg.FilterRule { + pm.mu.Lock() + defer pm.mu.Unlock() + return pm.filter +} + +// SetUsers updates the users in the policy manager and updates the filter rules. +func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + pm.users = users + return pm.updateLocked() +} + +// SetNodes updates the nodes in the policy manager and updates the filter rules. +func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + pm.nodes = nodes + return pm.updateLocked() +} + +func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool { + if pm == nil { + return false + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + if ips, ok := pm.tagOwnerMap[Tag(tag)]; ok { + for _, nodeAddr := range node.IPs() { + if ips.Contains(nodeAddr) { + return true + } + } + } + + return false +} + +func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool { + if pm == nil { + return false + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + // The fast path is that a node requests to approve a prefix + // where there is an exact entry, e.g. 10.0.0.0/8, then + // check and return quickly + if _, ok := pm.autoApproveMap[route]; ok { + for _, nodeAddr := range node.IPs() { + if pm.autoApproveMap[route].Contains(nodeAddr) { + return true + } + } + } + + // The slow path is that the node tries to approve + // 10.0.10.0/24, which is a part of 10.0.0.0/8, then we + // cannot just lookup in the prefix map and have to check + // if there is a "parent" prefix available. + for prefix, approveAddrs := range pm.autoApproveMap { + // We do not want the exit node entry to approve all + // sorts of routes. The logic here is that it would be + // unexpected behaviour to have specific routes approved + // just because the node is allowed to designate itself as + // an exit. + if tsaddr.IsExitRoute(prefix) { + continue + } + + // Check if prefix is larger (so containing) and then overlaps + // the route to see if the node can approve a subset of an autoapprover + if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) { + for _, nodeAddr := range node.IPs() { + if approveAddrs.Contains(nodeAddr) { + return true + } + } + } + } + + return false +} + +func (pm *PolicyManager) Version() int { + return 2 +} + +func (pm *PolicyManager) DebugString() string { + var sb strings.Builder + + fmt.Fprintf(&sb, "PolicyManager (v%d):\n\n", pm.Version()) + + sb.WriteString("\n\n") + + if pm.pol != nil { + pol, err := json.MarshalIndent(pm.pol, "", " ") + if err == nil { + sb.WriteString("Policy:\n") + sb.Write(pol) + sb.WriteString("\n\n") + } + } + + fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap)) + for prefix, approveAddrs := range pm.autoApproveMap { + fmt.Fprintf(&sb, "\t%s:\n", prefix) + for _, iprange := range approveAddrs.Ranges() { + fmt.Fprintf(&sb, "\t\t%s\n", iprange) + } + } + + sb.WriteString("\n\n") + + fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap)) + for prefix, tagOwners := range pm.tagOwnerMap { + fmt.Fprintf(&sb, "\t%s:\n", prefix) + for _, iprange := range tagOwners.Ranges() { + fmt.Fprintf(&sb, "\t\t%s\n", iprange) + } + } + + sb.WriteString("\n\n") + if pm.filter != nil { + filter, err := json.MarshalIndent(pm.filter, "", " ") + if err == nil { + sb.WriteString("Compiled filter:\n") + sb.Write(filter) + sb.WriteString("\n\n") + } + } + + return sb.String() +} diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go new file mode 100644 index 00000000..ee26c596 --- /dev/null +++ b/hscontrol/policy/v2/policy_test.go @@ -0,0 +1,58 @@ +package v2 + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/tailcfg" +) + +func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node { + return &types.Node{ + ID: 0, + Hostname: name, + IPv4: ap(ipv4), + IPv6: ap(ipv6), + User: user, + UserID: user.ID, + Hostinfo: hostinfo, + } +} + +func TestPolicyManager(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"}, + {Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"}, + } + + tests := []struct { + name string + pol string + nodes types.Nodes + wantFilter []tailcfg.FilterRule + }{ + { + name: "empty-policy", + pol: "{}", + nodes: types.Nodes{}, + wantFilter: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes) + require.NoError(t, err) + + filter := pm.Filter() + if diff := cmp.Diff(filter, tt.wantFilter); diff != "" { + t.Errorf("Filter() mismatch (-want +got):\n%s", diff) + } + + // TODO(kradalby): Test SSH Policy + }) + } +} diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go new file mode 100644 index 00000000..6e644539 --- /dev/null +++ b/hscontrol/policy/v2/types.go @@ -0,0 +1,1005 @@ +package v2 + +import ( + "bytes" + "encoding/json" + "fmt" + "net/netip" + "strings" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/tailscale/hujson" + "go4.org/netipx" + "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" + "tailscale.com/util/multierr" +) + +const Wildcard = Asterix(0) + +type Asterix int + +func (a Asterix) Validate() error { + return nil +} + +func (a Asterix) String() string { + return "*" +} + +func (a Asterix) UnmarshalJSON(b []byte) error { + return nil +} + +func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + // TODO(kradalby): + // Should this actually only be the CGNAT spaces? I do not think so, because + // we also want to include subnet routers right? + ips.AddPrefix(tsaddr.AllIPv4()) + ips.AddPrefix(tsaddr.AllIPv6()) + + return ips.IPSet() +} + +// Username is a string that represents a username, it must contain an @. +type Username string + +func (u Username) Validate() error { + if isUser(string(u)) { + return nil + } + return fmt.Errorf("Username has to contain @, got: %q", u) +} + +func (u *Username) String() string { + return string(*u) +} + +func (u *Username) UnmarshalJSON(b []byte) error { + *u = Username(strings.Trim(string(b), `"`)) + if err := u.Validate(); err != nil { + return err + } + return nil +} + +func (u Username) CanBeTagOwner() bool { + return true +} + +func (u Username) CanBeAutoApprover() bool { + return true +} + +// resolveUser attempts to find a user in the provided [types.Users] slice that matches the Username. +// It prioritizes matching the ProviderIdentifier, and if not found, it falls back to matching the Email or Name. +// If no matching user is found, it returns an error indicating no user matching. +// If multiple matching users are found, it returns an error indicating multiple users matching. +// It returns the matched types.User and a nil error if exactly one match is found. +func (u Username) resolveUser(users types.Users) (types.User, error) { + var potentialUsers types.Users + + // At parsetime, we require all usernames to contain an "@" character, if the + // username token does not naturally do so (like email), the user have to + // add it to the end of the username. We strip it here as we do not expect the + // usernames to be stored with the "@". + uTrimmed := strings.TrimSuffix(u.String(), "@") + + for _, user := range users { + if user.ProviderIdentifier.Valid && user.ProviderIdentifier.String == uTrimmed { + // Prioritize ProviderIdentifier match and exit early + return user, nil + } + + if user.Email == uTrimmed || user.Name == uTrimmed { + potentialUsers = append(potentialUsers, user) + } + } + + if len(potentialUsers) == 0 { + return types.User{}, fmt.Errorf("user with token %q not found", u.String()) + } + + if len(potentialUsers) > 1 { + return types.User{}, fmt.Errorf("multiple users with token %q found: %s", u.String(), potentialUsers.String()) + } + + return potentialUsers[0], nil +} + +func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + var errs []error + + user, err := u.resolveUser(users) + if err != nil { + errs = append(errs, err) + } + + for _, node := range nodes { + if node.IsTagged() { + continue + } + + if node.User.ID == user.ID { + node.AppendToIPSet(&ips) + } + } + + return buildIPSetMultiErr(&ips, errs) +} + +// Group is a special string which is always prefixed with `group:` +type Group string + +func (g Group) Validate() error { + if isGroup(string(g)) { + return nil + } + return fmt.Errorf(`Group has to start with "group:", got: %q`, g) +} + +func (g *Group) UnmarshalJSON(b []byte) error { + *g = Group(strings.Trim(string(b), `"`)) + if err := g.Validate(); err != nil { + return err + } + return nil +} + +func (g Group) CanBeTagOwner() bool { + return true +} + +func (g Group) CanBeAutoApprover() bool { + return true +} + +func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + var errs []error + + for _, user := range p.Groups[g] { + uips, err := user.Resolve(nil, users, nodes) + if err != nil { + errs = append(errs, err) + } + + ips.AddSet(uips) + } + + return buildIPSetMultiErr(&ips, errs) +} + +// Tag is a special string which is always prefixed with `tag:` +type Tag string + +func (t Tag) Validate() error { + if isTag(string(t)) { + return nil + } + return fmt.Errorf(`tag has to start with "tag:", got: %q`, t) +} + +func (t *Tag) UnmarshalJSON(b []byte) error { + *t = Tag(strings.Trim(string(b), `"`)) + if err := t.Validate(); err != nil { + return err + } + return nil +} + +func (t Tag) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + // TODO(kradalby): This is currently resolved twice, and should be resolved once. + // It is added temporary until we sort out the story on how and when we resolve tags + // from the three places they can be "approved": + // - As part of a PreAuthKey (handled in HasTag) + // - As part of ForcedTags (set via CLI) (handled in HasTag) + // - As part of HostInfo.RequestTags and approved by policy (this is happening here) + // Part of #2417 + tagMap, err := resolveTagOwners(p, users, nodes) + if err != nil { + return nil, err + } + + for _, node := range nodes { + if node.HasTag(string(t)) { + node.AppendToIPSet(&ips) + } + + // TODO(kradalby): remove as part of #2417, see comment above + if tagMap != nil { + if tagips, ok := tagMap[t]; ok && node.InIPSet(tagips) && node.Hostinfo != nil { + for _, tag := range node.Hostinfo.RequestTags { + if tag == string(t) { + node.AppendToIPSet(&ips) + } + } + } + } + } + + return ips.IPSet() +} + +func (t Tag) CanBeAutoApprover() bool { + return true +} + +// Host is a string that represents a hostname. +type Host string + +func (h Host) Validate() error { + if isHost(string(h)) { + fmt.Errorf("Hostname %q is invalid", h) + } + return nil +} + +func (h *Host) UnmarshalJSON(b []byte) error { + *h = Host(strings.Trim(string(b), `"`)) + if err := h.Validate(); err != nil { + return err + } + return nil +} + +func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + var errs []error + + pref, ok := p.Hosts[h] + if !ok { + return nil, fmt.Errorf("unable to resolve host: %q", h) + } + err := pref.Validate() + if err != nil { + errs = append(errs, err) + } + + ips.AddPrefix(netip.Prefix(pref)) + + // If the IP is a single host, look for a node to ensure we add all the IPs of + // the node to the IPSet. + // appendIfNodeHasIP(nodes, &ips, pref) + + // TODO(kradalby): I am a bit unsure what is the correct way to do this, + // should a host with a non single IP be able to resolve the full host (inc all IPs). + ipsTemp, err := ips.IPSet() + if err != nil { + errs = append(errs, err) + } + for _, node := range nodes { + if node.InIPSet(ipsTemp) { + node.AppendToIPSet(&ips) + } + } + + return buildIPSetMultiErr(&ips, errs) +} + +type Prefix netip.Prefix + +func (p Prefix) Validate() error { + if !netip.Prefix(p).IsValid() { + return fmt.Errorf("Prefix %q is invalid", p) + } + + return nil +} + +func (p Prefix) String() string { + return netip.Prefix(p).String() +} + +func (p *Prefix) parseString(addr string) error { + if !strings.Contains(addr, "/") { + addr, err := netip.ParseAddr(addr) + if err != nil { + return err + } + addrPref, err := addr.Prefix(addr.BitLen()) + if err != nil { + return err + } + + *p = Prefix(addrPref) + return nil + } + + pref, err := netip.ParsePrefix(addr) + if err != nil { + return err + } + *p = Prefix(pref) + return nil +} + +func (p *Prefix) UnmarshalJSON(b []byte) error { + err := p.parseString(strings.Trim(string(b), `"`)) + if err != nil { + return err + } + if err := p.Validate(); err != nil { + return err + } + return nil +} + +// Resolve resolves the Prefix to an IPSet. The IPSet will contain all the IP +// addresses that the Prefix represents within Headscale. It is the product +// of the Prefix and the Policy, Users, and Nodes. +// +// See [Policy], [types.Users], and [types.Nodes] for more details. +func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + var errs []error + + ips.AddPrefix(netip.Prefix(p)) + // If the IP is a single host, look for a node to ensure we add all the IPs of + // the node to the IPSet. + // appendIfNodeHasIP(nodes, &ips, pref) + + // TODO(kradalby): I am a bit unsure what is the correct way to do this, + // should a host with a non single IP be able to resolve the full host (inc all IPs). + // Currently this is done because the old implementation did this, we might want to + // drop it before releasing. + // For example: + // If a src or dst includes "64.0.0.0/2:*", it will include 100.64/16 range, which + // means that it will need to fetch the IPv6 addrs of the node to include the full range. + // Clearly, if a user sets the dst to be "64.0.0.0/2:*", it is likely more of a exit node + // and this would be strange behaviour. + ipsTemp, err := ips.IPSet() + if err != nil { + errs = append(errs, err) + } + for _, node := range nodes { + if node.InIPSet(ipsTemp) { + node.AppendToIPSet(&ips) + } + } + + return buildIPSetMultiErr(&ips, errs) +} + +// AutoGroup is a special string which is always prefixed with `autogroup:` +type AutoGroup string + +const ( + AutoGroupInternet = "autogroup:internet" +) + +var autogroups = []string{AutoGroupInternet} + +func (ag AutoGroup) Validate() error { + for _, valid := range autogroups { + if valid == string(ag) { + return nil + } + } + + return fmt.Errorf("AutoGroup is invalid, got: %q, must be one of %v", ag, autogroups) +} + +func (ag *AutoGroup) UnmarshalJSON(b []byte) error { + *ag = AutoGroup(strings.Trim(string(b), `"`)) + if err := ag.Validate(); err != nil { + return err + } + return nil +} + +func (ag AutoGroup) Resolve(_ *Policy, _ types.Users, _ types.Nodes) (*netipx.IPSet, error) { + switch ag { + case AutoGroupInternet: + return util.TheInternet(), nil + } + + return nil, nil +} + +type Alias interface { + Validate() error + UnmarshalJSON([]byte) error + + // Resolve resolves the Alias to an IPSet. The IPSet will contain all the IP + // addresses that the Alias represents within Headscale. It is the product + // of the Alias and the Policy, Users and Nodes. + // This is an interface definition and the implementation is independent of + // the Alias type. + Resolve(*Policy, types.Users, types.Nodes) (*netipx.IPSet, error) +} + +type AliasWithPorts struct { + Alias + Ports []tailcfg.PortRange +} + +func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { + // TODO(kradalby): use encoding/json/v2 (go-json-experiment) + dec := json.NewDecoder(bytes.NewReader(b)) + var v any + if err := dec.Decode(&v); err != nil { + return err + } + + switch vs := v.(type) { + case string: + var portsPart string + var err error + + if strings.Contains(vs, ":") { + vs, portsPart, err = splitDestinationAndPort(vs) + if err != nil { + return err + } + + ports, err := parsePortRange(portsPart) + if err != nil { + return err + } + ve.Ports = ports + } + + ve.Alias, err = parseAlias(vs) + if err != nil { + return err + } + if err := ve.Alias.Validate(); err != nil { + return err + } + + default: + return fmt.Errorf("type %T not supported", vs) + } + return nil +} + +func isWildcard(str string) bool { + return str == "*" +} + +func isUser(str string) bool { + return strings.Contains(str, "@") +} + +func isGroup(str string) bool { + return strings.HasPrefix(str, "group:") +} + +func isTag(str string) bool { + return strings.HasPrefix(str, "tag:") +} + +func isAutoGroup(str string) bool { + return strings.HasPrefix(str, "autogroup:") +} + +func isHost(str string) bool { + return !isUser(str) && !strings.Contains(str, ":") +} + +func parseAlias(vs string) (Alias, error) { + var pref Prefix + err := pref.parseString(vs) + if err == nil { + return &pref, nil + } + + switch { + case isWildcard(vs): + return Wildcard, nil + case isUser(vs): + return ptr.To(Username(vs)), nil + case isGroup(vs): + return ptr.To(Group(vs)), nil + case isTag(vs): + return ptr.To(Tag(vs)), nil + case isAutoGroup(vs): + return ptr.To(AutoGroup(vs)), nil + } + + if isHost(vs) { + return ptr.To(Host(vs)), nil + } + + return nil, fmt.Errorf(`Invalid alias %q. An alias must be one of the following types: +- wildcard (*) +- user (containing an "@") +- group (starting with "group:") +- tag (starting with "tag:") +- autogroup (starting with "autogroup:") +- host + +Please check the format and try again.`, vs) +} + +// AliasEnc is used to deserialize a Alias. +type AliasEnc struct{ Alias } + +func (ve *AliasEnc) UnmarshalJSON(b []byte) error { + ptr, err := unmarshalPointer[Alias]( + b, + parseAlias, + ) + if err != nil { + return err + } + ve.Alias = ptr + return nil +} + +type Aliases []Alias + +func (a *Aliases) UnmarshalJSON(b []byte) error { + var aliases []AliasEnc + err := json.Unmarshal(b, &aliases) + if err != nil { + return err + } + + *a = make([]Alias, len(aliases)) + for i, alias := range aliases { + (*a)[i] = alias.Alias + } + return nil +} + +func (a Aliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + var errs []error + + for _, alias := range a { + aips, err := alias.Resolve(p, users, nodes) + if err != nil { + errs = append(errs, err) + } + + ips.AddSet(aips) + } + + return buildIPSetMultiErr(&ips, errs) +} + +func buildIPSetMultiErr(ipBuilder *netipx.IPSetBuilder, errs []error) (*netipx.IPSet, error) { + ips, err := ipBuilder.IPSet() + return ips, multierr.New(append(errs, err)...) +} + +// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer +func unmarshalPointer[T any]( + b []byte, + parseFunc func(string) (T, error), +) (T, error) { + var s string + err := json.Unmarshal(b, &s) + if err != nil { + var t T + return t, err + } + + return parseFunc(s) +} + +type AutoApprover interface { + CanBeAutoApprover() bool + UnmarshalJSON([]byte) error +} + +type AutoApprovers []AutoApprover + +func (aa *AutoApprovers) UnmarshalJSON(b []byte) error { + var autoApprovers []AutoApproverEnc + err := json.Unmarshal(b, &autoApprovers) + if err != nil { + return err + } + + *aa = make([]AutoApprover, len(autoApprovers)) + for i, autoApprover := range autoApprovers { + (*aa)[i] = autoApprover.AutoApprover + } + return nil +} + +func parseAutoApprover(s string) (AutoApprover, error) { + switch { + case isUser(s): + return ptr.To(Username(s)), nil + case isGroup(s): + return ptr.To(Group(s)), nil + case isTag(s): + return ptr.To(Tag(s)), nil + } + + return nil, fmt.Errorf(`Invalid AutoApprover %q. An alias must be one of the following types: +- user (containing an "@") +- group (starting with "group:") +- tag (starting with "tag:") + +Please check the format and try again.`, s) +} + +// AutoApproverEnc is used to deserialize a AutoApprover. +type AutoApproverEnc struct{ AutoApprover } + +func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error { + ptr, err := unmarshalPointer[AutoApprover]( + b, + parseAutoApprover, + ) + if err != nil { + return err + } + ve.AutoApprover = ptr + return nil +} + +type Owner interface { + CanBeTagOwner() bool + UnmarshalJSON([]byte) error +} + +// OwnerEnc is used to deserialize a Owner. +type OwnerEnc struct{ Owner } + +func (ve *OwnerEnc) UnmarshalJSON(b []byte) error { + ptr, err := unmarshalPointer[Owner]( + b, + parseOwner, + ) + if err != nil { + return err + } + ve.Owner = ptr + return nil +} + +type Owners []Owner + +func (o *Owners) UnmarshalJSON(b []byte) error { + var owners []OwnerEnc + err := json.Unmarshal(b, &owners) + if err != nil { + return err + } + + *o = make([]Owner, len(owners)) + for i, owner := range owners { + (*o)[i] = owner.Owner + } + return nil +} + +func parseOwner(s string) (Owner, error) { + switch { + case isUser(s): + return ptr.To(Username(s)), nil + case isGroup(s): + return ptr.To(Group(s)), nil + } + return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types: +- user (containing an "@") +- group (starting with "group:") +- tag (starting with "tag:") + +Please check the format and try again.`, s) +} + +type Usernames []Username + +// Groups are a map of Group to a list of Username. +type Groups map[Group]Usernames + +// UnmarshalJSON overrides the default JSON unmarshalling for Groups to ensure +// that each group name is validated using the isGroup function. This ensures +// that all group names conform to the expected format, which is always prefixed +// with "group:". If any group name is invalid, an error is returned. +func (g *Groups) UnmarshalJSON(b []byte) error { + var rawGroups map[string][]string + if err := json.Unmarshal(b, &rawGroups); err != nil { + return err + } + + *g = make(Groups) + for key, value := range rawGroups { + group := Group(key) + if err := group.Validate(); err != nil { + return err + } + + var usernames Usernames + + for _, u := range value { + username := Username(u) + if err := username.Validate(); err != nil { + if isGroup(u) { + return fmt.Errorf("Nested groups are not allowed, found %q inside %q", u, group) + } + + return err + } + usernames = append(usernames, username) + } + + (*g)[group] = usernames + } + return nil +} + +// Hosts are alias for IP addresses or subnets. +type Hosts map[Host]Prefix + +func (h *Hosts) UnmarshalJSON(b []byte) error { + var rawHosts map[string]string + if err := json.Unmarshal(b, &rawHosts); err != nil { + return err + } + + *h = make(Hosts) + for key, value := range rawHosts { + host := Host(key) + if err := host.Validate(); err != nil { + return err + } + + var pref Prefix + err := pref.parseString(value) + if err != nil { + return fmt.Errorf("Hostname %q contains an invalid IP address: %q", key, value) + } + + (*h)[host] = pref + } + return nil +} + +// TagOwners are a map of Tag to a list of the UserEntities that own the tag. +type TagOwners map[Tag]Owners + +// resolveTagOwners resolves the TagOwners to a map of Tag to netipx.IPSet. +// The resulting map can be used to quickly look up the IPSet for a given Tag. +// It is intended for internal use in a PolicyManager. +func resolveTagOwners(p *Policy, users types.Users, nodes types.Nodes) (map[Tag]*netipx.IPSet, error) { + if p == nil { + return nil, nil + } + + ret := make(map[Tag]*netipx.IPSet) + + for tag, owners := range p.TagOwners { + var ips netipx.IPSetBuilder + + for _, owner := range owners { + o, ok := owner.(Alias) + if !ok { + // Should never happen + return nil, fmt.Errorf("owner %v is not an Alias", owner) + } + // If it does not resolve, that means the tag is not associated with any IP addresses. + resolved, _ := o.Resolve(p, users, nodes) + ips.AddSet(resolved) + } + + ipSet, err := ips.IPSet() + if err != nil { + return nil, err + } + + ret[tag] = ipSet + } + + return ret, nil +} + +type AutoApproverPolicy struct { + Routes map[netip.Prefix]AutoApprovers `json:"routes"` + ExitNode AutoApprovers `json:"exitNode"` +} + +// resolveAutoApprovers resolves the AutoApprovers to a map of netip.Prefix to netipx.IPSet. +// The resulting map can be used to quickly look up if a node can self-approve a route. +// It is intended for internal use in a PolicyManager. +func resolveAutoApprovers(p *Policy, users types.Users, nodes types.Nodes) (map[netip.Prefix]*netipx.IPSet, error) { + if p == nil { + return nil, nil + } + + routes := make(map[netip.Prefix]*netipx.IPSetBuilder) + + for prefix, autoApprovers := range p.AutoApprovers.Routes { + if _, ok := routes[prefix]; !ok { + routes[prefix] = new(netipx.IPSetBuilder) + } + for _, autoApprover := range autoApprovers { + aa, ok := autoApprover.(Alias) + if !ok { + // Should never happen + return nil, fmt.Errorf("autoApprover %v is not an Alias", autoApprover) + } + // If it does not resolve, that means the autoApprover is not associated with any IP addresses. + ips, _ := aa.Resolve(p, users, nodes) + routes[prefix].AddSet(ips) + } + } + + var exitNodeSetBuilder netipx.IPSetBuilder + if len(p.AutoApprovers.ExitNode) > 0 { + for _, autoApprover := range p.AutoApprovers.ExitNode { + aa, ok := autoApprover.(Alias) + if !ok { + // Should never happen + return nil, fmt.Errorf("autoApprover %v is not an Alias", autoApprover) + } + // If it does not resolve, that means the autoApprover is not associated with any IP addresses. + ips, _ := aa.Resolve(p, users, nodes) + exitNodeSetBuilder.AddSet(ips) + } + } + + ret := make(map[netip.Prefix]*netipx.IPSet) + for prefix, builder := range routes { + ipSet, err := builder.IPSet() + if err != nil { + return nil, err + } + ret[prefix] = ipSet + } + + if len(p.AutoApprovers.ExitNode) > 0 { + exitNodeSet, err := exitNodeSetBuilder.IPSet() + if err != nil { + return nil, err + } + + ret[tsaddr.AllIPv4()] = exitNodeSet + ret[tsaddr.AllIPv6()] = exitNodeSet + } + + return ret, nil +} + +type ACL struct { + Action string `json:"action"` // TODO(kradalby): add strict type + Protocol string `json:"proto"` // TODO(kradalby): add strict type + Sources Aliases `json:"src"` + Destinations []AliasWithPorts `json:"dst"` +} + +// Policy represents a Tailscale Network Policy. +// TODO(kradalby): +// Add validation method checking: +// All users exists +// All groups and users are valid tag TagOwners +// Everything referred to in ACLs exists in other +// entities. +type Policy struct { + // validated is set if the policy has been validated. + // It is not safe to use before it is validated, and + // callers using it should panic if not + validated bool `json:"-"` + + Groups Groups `json:"groups"` + Hosts Hosts `json:"hosts"` + TagOwners TagOwners `json:"tagOwners"` + ACLs []ACL `json:"acls"` + AutoApprovers AutoApproverPolicy `json:"autoApprovers"` + SSHs []SSH `json:"ssh"` +} + +// SSH controls who can ssh into which machines. +type SSH struct { + Action string `json:"action"` // TODO(kradalby): add strict type + Sources SSHSrcAliases `json:"src"` + Destinations SSHDstAliases `json:"dst"` + Users []SSHUser `json:"users"` + CheckPeriod time.Duration `json:"checkPeriod,omitempty"` +} + +// SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule. +// It can be a list of usernames, groups, tags or autogroups. +type SSHSrcAliases []Alias + +func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { + var aliases []AliasEnc + err := json.Unmarshal(b, &aliases) + if err != nil { + return err + } + + *a = make([]Alias, len(aliases)) + for i, alias := range aliases { + switch alias.Alias.(type) { + case *Username, *Group, *Tag, *AutoGroup: + (*a)[i] = alias.Alias + default: + return fmt.Errorf("type %T not supported", alias.Alias) + } + } + return nil +} + +func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + var errs []error + + for _, alias := range a { + aips, err := alias.Resolve(p, users, nodes) + if err != nil { + errs = append(errs, err) + } + + ips.AddSet(aips) + } + + return buildIPSetMultiErr(&ips, errs) +} + +// SSHDstAliases is a list of aliases that can be used as destinations in an SSH rule. +// It can be a list of usernames, tags or autogroups. +type SSHDstAliases []Alias + +func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { + var aliases []AliasEnc + err := json.Unmarshal(b, &aliases) + if err != nil { + return err + } + + *a = make([]Alias, len(aliases)) + for i, alias := range aliases { + switch alias.Alias.(type) { + case *Username, *Tag, *AutoGroup, + // Asterix and Group is actually not supposed to be supported, + // however we do not support autogroups at the moment + // so we will leave it in as there is no other option + // to dynamically give all access + // https://tailscale.com/kb/1193/tailscale-ssh#dst + Asterix, + *Group: + (*a)[i] = alias.Alias + default: + return fmt.Errorf("type %T not supported", alias.Alias) + } + } + return nil +} + +type SSHUser string + +func (u SSHUser) String() string { + return string(u) +} + +func policyFromBytes(b []byte) (*Policy, error) { + if b == nil || len(b) == 0 { + return nil, nil + } + + var policy Policy + ast, err := hujson.Parse(b) + if err != nil { + return nil, fmt.Errorf("parsing HuJSON: %w", err) + } + + ast.Standardize() + acl := ast.Pack() + + err = json.Unmarshal(acl, &policy) + if err != nil { + return nil, fmt.Errorf("parsing policy from bytes: %w", err) + } + + return &policy, nil +} + +const ( + expectedTokenItems = 2 +) diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go new file mode 100644 index 00000000..2218685e --- /dev/null +++ b/hscontrol/policy/v2/types_test.go @@ -0,0 +1,1162 @@ +package v2 + +import ( + "encoding/json" + "net/netip" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/require" + "go4.org/netipx" + xmaps "golang.org/x/exp/maps" + "gorm.io/gorm" + "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" +) + +func TestUnmarshalPolicy(t *testing.T) { + tests := []struct { + name string + input string + want *Policy + wantErr string + }{ + { + name: "empty", + input: "{}", + want: &Policy{}, + }, + { + name: "groups", + input: ` +{ + "groups": { + "group:example": [ + "derp@headscale.net", + ], + }, +} +`, + want: &Policy{ + Groups: Groups{ + Group("group:example"): []Username{Username("derp@headscale.net")}, + }, + }, + }, + { + name: "basic-types", + input: ` +{ + "groups": { + "group:example": [ + "testuser@headscale.net", + ], + "group:other": [ + "otheruser@headscale.net", + ], + "group:noat": [ + "noat@", + ], + }, + + "tagOwners": { + "tag:user": ["testuser@headscale.net"], + "tag:group": ["group:other"], + "tag:userandgroup": ["testuser@headscale.net", "group:other"], + }, + + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + "outside": "192.168.0.0/16", + }, + + "acls": [ + // All + { + "action": "accept", + "proto": "tcp", + "src": ["*"], + "dst": ["*:*"], + }, + // Users + { + "action": "accept", + "proto": "tcp", + "src": ["testuser@headscale.net"], + "dst": ["otheruser@headscale.net:80"], + }, + // Groups + { + "action": "accept", + "proto": "tcp", + "src": ["group:example"], + "dst": ["group:other:80"], + }, + // Tailscale IP + { + "action": "accept", + "proto": "tcp", + "src": ["100.101.102.103"], + "dst": ["100.101.102.104:80"], + }, + // Subnet + { + "action": "accept", + "proto": "udp", + "src": ["10.0.0.0/8"], + "dst": ["172.16.0.0/16:80"], + }, + // Hosts + { + "action": "accept", + "proto": "tcp", + "src": ["subnet-1"], + "dst": ["host-1:80-88"], + }, + // Tags + { + "action": "accept", + "proto": "tcp", + "src": ["tag:group"], + "dst": ["tag:user:80,443"], + }, + // Autogroup + { + "action": "accept", + "proto": "tcp", + "src": ["tag:group"], + "dst": ["autogroup:internet:80"], + }, + ], +} +`, + want: &Policy{ + Groups: Groups{ + Group("group:example"): []Username{Username("testuser@headscale.net")}, + Group("group:other"): []Username{Username("otheruser@headscale.net")}, + Group("group:noat"): []Username{Username("noat@")}, + }, + TagOwners: TagOwners{ + Tag("tag:user"): Owners{up("testuser@headscale.net")}, + Tag("tag:group"): Owners{gp("group:other")}, + Tag("tag:userandgroup"): Owners{up("testuser@headscale.net"), gp("group:other")}, + }, + Hosts: Hosts{ + "host-1": Prefix(mp("100.100.100.100/32")), + "subnet-1": Prefix(mp("100.100.101.100/24")), + "outside": Prefix(mp("192.168.0.0/16")), + }, + ACLs: []ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + // TODO(kradalby): Should this be host? + // It is: + // Includes any destination (no restrictions). + Alias: Wildcard, + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + ptr.To(Username("testuser@headscale.net")), + }, + Destinations: []AliasWithPorts{ + { + Alias: ptr.To(Username("otheruser@headscale.net")), + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + gp("group:example"), + }, + Destinations: []AliasWithPorts{ + { + Alias: gp("group:other"), + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + pp("100.101.102.103/32"), + }, + Destinations: []AliasWithPorts{ + { + Alias: pp("100.101.102.104/32"), + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + { + Action: "accept", + Protocol: "udp", + Sources: Aliases{ + pp("10.0.0.0/8"), + }, + Destinations: []AliasWithPorts{ + { + Alias: pp("172.16.0.0/16"), + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + hp("subnet-1"), + }, + Destinations: []AliasWithPorts{ + { + Alias: hp("host-1"), + Ports: []tailcfg.PortRange{{First: 80, Last: 88}}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + tp("tag:group"), + }, + Destinations: []AliasWithPorts{ + { + Alias: tp("tag:user"), + Ports: []tailcfg.PortRange{ + {First: 80, Last: 80}, + {First: 443, Last: 443}, + }, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + tp("tag:group"), + }, + Destinations: []AliasWithPorts{ + { + Alias: agp("autogroup:internet"), + Ports: []tailcfg.PortRange{ + {First: 80, Last: 80}, + }, + }, + }, + }, + }, + }, + }, + { + name: "invalid-username", + input: ` +{ + "groups": { + "group:example": [ + "valid@", + "invalid", + ], + }, +} +`, + wantErr: `Username has to contain @, got: "invalid"`, + }, + { + name: "invalid-group", + input: ` +{ + "groups": { + "grou:example": [ + "valid@", + ], + }, +} +`, + wantErr: `Group has to start with "group:", got: "grou:example"`, + }, + { + name: "group-in-group", + input: ` +{ + "groups": { + "group:inner": [], + "group:example": [ + "group:inner", + ], + }, +} +`, + // wantErr: `Username has to contain @, got: "group:inner"`, + wantErr: `Nested groups are not allowed, found "group:inner" inside "group:example"`, + }, + { + name: "invalid-addr", + input: ` +{ + "hosts": { + "derp": "10.0", + }, +} +`, + wantErr: `Hostname "derp" contains an invalid IP address: "10.0"`, + }, + { + name: "invalid-prefix", + input: ` +{ + "hosts": { + "derp": "10.0/42", + }, +} +`, + wantErr: `Hostname "derp" contains an invalid IP address: "10.0/42"`, + }, + // TODO(kradalby): Figure out why this doesnt work. + // { + // name: "invalid-hostname", + // input: ` + // { + // "hosts": { + // "derp:merp": "10.0.0.0/31", + // }, + // } + // `, + // wantErr: `Hostname "derp:merp" is invalid`, + // }, + { + name: "invalid-auto-group", + input: ` +{ + "acls": [ + // Autogroup + { + "action": "accept", + "proto": "tcp", + "src": ["tag:group"], + "dst": ["autogroup:invalid:80"], + }, + ], +} +`, + wantErr: `AutoGroup is invalid, got: "autogroup:invalid", must be one of [autogroup:internet]`, + }, + } + + cmps := append(util.Comparers, cmp.Comparer(func(x, y Prefix) bool { + return x == y + })) + cmps = append(cmps, cmpopts.IgnoreUnexported(Policy{})) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + policy, err := policyFromBytes([]byte(tt.input)) + if tt.wantErr == "" { + if err != nil { + t.Fatalf("got %v; want no error", err) + } + } else { + if err == nil { + t.Fatalf("got nil; want error %q", tt.wantErr) + } else if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("got err %v; want error %q", err, tt.wantErr) + } + } + + if diff := cmp.Diff(tt.want, policy, cmps...); diff != "" { + t.Fatalf("unexpected policy (-want +got):\n%s", diff) + } + }) + } +} + +func gp(s string) *Group { return ptr.To(Group(s)) } +func up(s string) *Username { return ptr.To(Username(s)) } +func hp(s string) *Host { return ptr.To(Host(s)) } +func tp(s string) *Tag { return ptr.To(Tag(s)) } +func agp(s string) *AutoGroup { return ptr.To(AutoGroup(s)) } +func mp(pref string) netip.Prefix { return netip.MustParsePrefix(pref) } +func ap(addr string) *netip.Addr { return ptr.To(netip.MustParseAddr(addr)) } +func pp(pref string) *Prefix { return ptr.To(Prefix(mp(pref))) } +func p(pref string) Prefix { return Prefix(mp(pref)) } + +func TestResolvePolicy(t *testing.T) { + users := map[string]types.User{ + "testuser": {Model: gorm.Model{ID: 1}, Name: "testuser"}, + "groupuser": {Model: gorm.Model{ID: 2}, Name: "groupuser"}, + "groupuser1": {Model: gorm.Model{ID: 3}, Name: "groupuser1"}, + "groupuser2": {Model: gorm.Model{ID: 4}, Name: "groupuser2"}, + "notme": {Model: gorm.Model{ID: 5}, Name: "notme"}, + } + tests := []struct { + name string + nodes types.Nodes + pol *Policy + toResolve Alias + want []netip.Prefix + wantErr string + }{ + { + name: "prefix", + toResolve: pp("100.100.101.101/32"), + want: []netip.Prefix{mp("100.100.101.101/32")}, + }, + { + name: "host", + pol: &Policy{ + Hosts: Hosts{ + "testhost": p("100.100.101.102/32"), + }, + }, + toResolve: hp("testhost"), + want: []netip.Prefix{mp("100.100.101.102/32")}, + }, + { + name: "username", + toResolve: ptr.To(Username("testuser@")), + nodes: types.Nodes{ + // Not matching other user + { + User: users["notme"], + IPv4: ap("100.100.101.1"), + }, + // Not matching forced tags + { + User: users["testuser"], + ForcedTags: []string{"tag:anything"}, + IPv4: ap("100.100.101.2"), + }, + // not matchin pak tag + { + User: users["testuser"], + AuthKey: &types.PreAuthKey{ + Tags: []string{"alsotagged"}, + }, + IPv4: ap("100.100.101.3"), + }, + { + User: users["testuser"], + IPv4: ap("100.100.101.103"), + }, + { + User: users["testuser"], + IPv4: ap("100.100.101.104"), + }, + }, + want: []netip.Prefix{mp("100.100.101.103/32"), mp("100.100.101.104/32")}, + }, + { + name: "group", + toResolve: ptr.To(Group("group:testgroup")), + nodes: types.Nodes{ + // Not matching other user + { + User: users["notme"], + IPv4: ap("100.100.101.4"), + }, + // Not matching forced tags + { + User: users["groupuser"], + ForcedTags: []string{"tag:anything"}, + IPv4: ap("100.100.101.5"), + }, + // not matchin pak tag + { + User: users["groupuser"], + AuthKey: &types.PreAuthKey{ + Tags: []string{"tag:alsotagged"}, + }, + IPv4: ap("100.100.101.6"), + }, + { + User: users["groupuser"], + IPv4: ap("100.100.101.203"), + }, + { + User: users["groupuser"], + IPv4: ap("100.100.101.204"), + }, + }, + pol: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"groupuser"}, + "group:othergroup": Usernames{"notmetoo"}, + }, + }, + want: []netip.Prefix{mp("100.100.101.203/32"), mp("100.100.101.204/32")}, + }, + { + name: "tag", + toResolve: tp("tag:test"), + nodes: types.Nodes{ + // Not matching other user + { + User: users["notme"], + IPv4: ap("100.100.101.9"), + }, + // Not matching forced tags + { + ForcedTags: []string{"tag:anything"}, + IPv4: ap("100.100.101.10"), + }, + // not matchin pak tag + { + AuthKey: &types.PreAuthKey{ + Tags: []string{"tag:alsotagged"}, + }, + IPv4: ap("100.100.101.11"), + }, + // Not matching forced tags + { + ForcedTags: []string{"tag:test"}, + IPv4: ap("100.100.101.234"), + }, + // not matchin pak tag + { + AuthKey: &types.PreAuthKey{ + Tags: []string{"tag:test"}, + }, + IPv4: ap("100.100.101.239"), + }, + }, + // TODO(kradalby): tests handling TagOwners + hostinfo + pol: &Policy{}, + want: []netip.Prefix{mp("100.100.101.234/32"), mp("100.100.101.239/32")}, + }, + { + name: "empty-policy", + toResolve: pp("100.100.101.101/32"), + pol: &Policy{}, + want: []netip.Prefix{mp("100.100.101.101/32")}, + }, + { + name: "invalid-host", + toResolve: hp("invalidhost"), + pol: &Policy{ + Hosts: Hosts{ + "testhost": p("100.100.101.102/32"), + }, + }, + wantErr: `unable to resolve host: "invalidhost"`, + }, + { + name: "multiple-groups", + toResolve: ptr.To(Group("group:testgroup")), + nodes: types.Nodes{ + { + User: users["groupuser1"], + IPv4: ap("100.100.101.203"), + }, + { + User: users["groupuser2"], + IPv4: ap("100.100.101.204"), + }, + }, + pol: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"groupuser1@", "groupuser2@"}, + }, + }, + want: []netip.Prefix{mp("100.100.101.203/32"), mp("100.100.101.204/32")}, + }, + { + name: "autogroup-internet", + toResolve: agp("autogroup:internet"), + want: util.TheInternet().Prefixes(), + }, + { + name: "invalid-username", + toResolve: ptr.To(Username("invaliduser@")), + nodes: types.Nodes{ + { + User: users["testuser"], + IPv4: ap("100.100.101.103"), + }, + }, + wantErr: `user with token "invaliduser@" not found`, + }, + { + name: "invalid-tag", + toResolve: tp("tag:invalid"), + nodes: types.Nodes{ + { + ForcedTags: []string{"tag:test"}, + IPv4: ap("100.100.101.234"), + }, + }, + }, + { + name: "ipv6-address", + toResolve: pp("fd7a:115c:a1e0::1/128"), + want: []netip.Prefix{mp("fd7a:115c:a1e0::1/128")}, + }, + { + name: "wildcard-alias", + toResolve: Wildcard, + want: []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ips, err := tt.toResolve.Resolve(tt.pol, + xmaps.Values(users), + tt.nodes) + if tt.wantErr == "" { + if err != nil { + t.Fatalf("got %v; want no error", err) + } + } else { + if err == nil { + t.Fatalf("got nil; want error %q", tt.wantErr) + } else if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("got err %v; want error %q", err, tt.wantErr) + } + } + + var prefs []netip.Prefix + if ips != nil { + if p := ips.Prefixes(); len(p) > 0 { + prefs = p + } + } + + if diff := cmp.Diff(tt.want, prefs, util.Comparers...); diff != "" { + t.Fatalf("unexpected prefs (-want +got):\n%s", diff) + } + }) + } +} + +func TestResolveAutoApprovers(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + {Model: gorm.Model{ID: 3}, Name: "user3"}, + } + + nodes := types.Nodes{ + { + IPv4: ap("100.64.0.1"), + User: users[0], + }, + { + IPv4: ap("100.64.0.2"), + User: users[1], + }, + { + IPv4: ap("100.64.0.3"), + User: users[2], + }, + { + IPv4: ap("100.64.0.4"), + ForcedTags: []string{"tag:testtag"}, + }, + { + IPv4: ap("100.64.0.5"), + ForcedTags: []string{"tag:exittest"}, + }, + } + + tests := []struct { + name string + policy *Policy + want map[netip.Prefix]*netipx.IPSet + wantErr bool + }{ + { + name: "single-route", + policy: &Policy{ + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Username("user1@"))}, + }, + }, + }, + want: map[netip.Prefix]*netipx.IPSet{ + mp("10.0.0.0/24"): mustIPSet("100.64.0.1/32"), + }, + wantErr: false, + }, + { + name: "multiple-routes", + policy: &Policy{ + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Username("user1@"))}, + mp("10.0.1.0/24"): {ptr.To(Username("user2@"))}, + }, + }, + }, + want: map[netip.Prefix]*netipx.IPSet{ + mp("10.0.0.0/24"): mustIPSet("100.64.0.1/32"), + mp("10.0.1.0/24"): mustIPSet("100.64.0.2/32"), + }, + wantErr: false, + }, + { + name: "exit-node", + policy: &Policy{ + AutoApprovers: AutoApproverPolicy{ + ExitNode: AutoApprovers{ptr.To(Username("user1@"))}, + }, + }, + want: map[netip.Prefix]*netipx.IPSet{ + tsaddr.AllIPv4(): mustIPSet("100.64.0.1/32"), + tsaddr.AllIPv6(): mustIPSet("100.64.0.1/32"), + }, + wantErr: false, + }, + { + name: "group-route", + policy: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"user1@", "user2@"}, + }, + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Group("group:testgroup"))}, + }, + }, + }, + want: map[netip.Prefix]*netipx.IPSet{ + mp("10.0.0.0/24"): mustIPSet("100.64.0.1/32", "100.64.0.2/32"), + }, + wantErr: false, + }, + { + name: "tag-route-and-exit", + policy: &Policy{ + TagOwners: TagOwners{ + "tag:testtag": Owners{ + ptr.To(Username("user1@")), + ptr.To(Username("user2@")), + }, + "tag:exittest": Owners{ + ptr.To(Group("group:exitgroup")), + }, + }, + Groups: Groups{ + "group:exitgroup": Usernames{"user2@"}, + }, + AutoApprovers: AutoApproverPolicy{ + ExitNode: AutoApprovers{ptr.To(Tag("tag:exittest"))}, + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.1.0/24"): {ptr.To(Tag("tag:testtag"))}, + }, + }, + }, + want: map[netip.Prefix]*netipx.IPSet{ + mp("10.0.1.0/24"): mustIPSet("100.64.0.4/32"), + tsaddr.AllIPv4(): mustIPSet("100.64.0.5/32"), + tsaddr.AllIPv6(): mustIPSet("100.64.0.5/32"), + }, + wantErr: false, + }, + { + name: "mixed-routes-and-exit-nodes", + policy: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"user1", "user2"}, + }, + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Group("group:testgroup"))}, + mp("10.0.1.0/24"): {ptr.To(Username("user3@"))}, + }, + ExitNode: AutoApprovers{ptr.To(Username("user1@"))}, + }, + }, + want: map[netip.Prefix]*netipx.IPSet{ + mp("10.0.0.0/24"): mustIPSet("100.64.0.1/32", "100.64.0.2/32"), + mp("10.0.1.0/24"): mustIPSet("100.64.0.3/32"), + tsaddr.AllIPv4(): mustIPSet("100.64.0.1/32"), + tsaddr.AllIPv6(): mustIPSet("100.64.0.1/32"), + }, + wantErr: false, + }, + } + + cmps := append(util.Comparers, cmp.Comparer(ipSetComparer)) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := resolveAutoApprovers(tt.policy, users, nodes) + if (err != nil) != tt.wantErr { + t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr) + return + } + if diff := cmp.Diff(tt.want, got, cmps...); diff != "" { + t.Errorf("resolveAutoApprovers() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func mustIPSet(prefixes ...string) *netipx.IPSet { + var builder netipx.IPSetBuilder + for _, p := range prefixes { + builder.AddPrefix(mp(p)) + } + ipSet, _ := builder.IPSet() + return ipSet +} + +func ipSetComparer(x, y *netipx.IPSet) bool { + if x == nil || y == nil { + return x == y + } + return cmp.Equal(x.Prefixes(), y.Prefixes(), util.Comparers...) +} + +func TestNodeCanApproveRoute(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + {Model: gorm.Model{ID: 3}, Name: "user3"}, + } + + nodes := types.Nodes{ + { + IPv4: ap("100.64.0.1"), + User: users[0], + }, + { + IPv4: ap("100.64.0.2"), + User: users[1], + }, + { + IPv4: ap("100.64.0.3"), + User: users[2], + }, + } + + tests := []struct { + name string + policy *Policy + node *types.Node + route netip.Prefix + want bool + wantErr bool + }{ + { + name: "single-route-approval", + policy: &Policy{ + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Username("user1@"))}, + }, + }, + }, + node: nodes[0], + route: mp("10.0.0.0/24"), + want: true, + }, + { + name: "multiple-routes-approval", + policy: &Policy{ + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Username("user1@"))}, + mp("10.0.1.0/24"): {ptr.To(Username("user2@"))}, + }, + }, + }, + node: nodes[1], + route: mp("10.0.1.0/24"), + want: true, + }, + { + name: "exit-node-approval", + policy: &Policy{ + AutoApprovers: AutoApproverPolicy{ + ExitNode: AutoApprovers{ptr.To(Username("user1@"))}, + }, + }, + node: nodes[0], + route: tsaddr.AllIPv4(), + want: true, + }, + { + name: "group-route-approval", + policy: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"user1@", "user2@"}, + }, + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Group("group:testgroup"))}, + }, + }, + }, + node: nodes[1], + route: mp("10.0.0.0/24"), + want: true, + }, + { + name: "mixed-routes-and-exit-nodes-approval", + policy: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"user1@", "user2@"}, + }, + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Group("group:testgroup"))}, + mp("10.0.1.0/24"): {ptr.To(Username("user3@"))}, + }, + ExitNode: AutoApprovers{ptr.To(Username("user1@"))}, + }, + }, + node: nodes[0], + route: tsaddr.AllIPv4(), + want: true, + }, + { + name: "no-approval", + policy: &Policy{ + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Username("user2@"))}, + }, + }, + }, + node: nodes[0], + route: mp("10.0.0.0/24"), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b, err := json.Marshal(tt.policy) + require.NoError(t, err) + + pm, err := NewPolicyManager(b, users, nodes) + require.NoErrorf(t, err, "NewPolicyManager() error = %v", err) + + got := pm.NodeCanApproveRoute(tt.node, tt.route) + if got != tt.want { + t.Errorf("NodeCanApproveRoute() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestResolveTagOwners(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + {Model: gorm.Model{ID: 3}, Name: "user3"}, + } + + nodes := types.Nodes{ + { + IPv4: ap("100.64.0.1"), + User: users[0], + }, + { + IPv4: ap("100.64.0.2"), + User: users[1], + }, + { + IPv4: ap("100.64.0.3"), + User: users[2], + }, + } + + tests := []struct { + name string + policy *Policy + want map[Tag]*netipx.IPSet + wantErr bool + }{ + { + name: "single-tag-owner", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Username("user1@"))}, + }, + }, + want: map[Tag]*netipx.IPSet{ + Tag("tag:test"): mustIPSet("100.64.0.1/32"), + }, + wantErr: false, + }, + { + name: "multiple-tag-owners", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Username("user1@")), ptr.To(Username("user2@"))}, + }, + }, + want: map[Tag]*netipx.IPSet{ + Tag("tag:test"): mustIPSet("100.64.0.1/32", "100.64.0.2/32"), + }, + wantErr: false, + }, + { + name: "group-tag-owner", + policy: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"user1@", "user2@"}, + }, + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Group("group:testgroup"))}, + }, + }, + want: map[Tag]*netipx.IPSet{ + Tag("tag:test"): mustIPSet("100.64.0.1/32", "100.64.0.2/32"), + }, + wantErr: false, + }, + } + + cmps := append(util.Comparers, cmp.Comparer(ipSetComparer)) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := resolveTagOwners(tt.policy, users, nodes) + if (err != nil) != tt.wantErr { + t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr) + return + } + if diff := cmp.Diff(tt.want, got, cmps...); diff != "" { + t.Errorf("resolveTagOwners() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestNodeCanHaveTag(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + {Model: gorm.Model{ID: 3}, Name: "user3"}, + } + + nodes := types.Nodes{ + { + IPv4: ap("100.64.0.1"), + User: users[0], + }, + { + IPv4: ap("100.64.0.2"), + User: users[1], + }, + { + IPv4: ap("100.64.0.3"), + User: users[2], + }, + } + + tests := []struct { + name string + policy *Policy + node *types.Node + tag string + want bool + wantErr string + }{ + { + name: "single-tag-owner", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Username("user1@"))}, + }, + }, + node: nodes[0], + tag: "tag:test", + want: true, + }, + { + name: "multiple-tag-owners", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Username("user1@")), ptr.To(Username("user2@"))}, + }, + }, + node: nodes[1], + tag: "tag:test", + want: true, + }, + { + name: "group-tag-owner", + policy: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"user1@", "user2@"}, + }, + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Group("group:testgroup"))}, + }, + }, + node: nodes[1], + tag: "tag:test", + want: true, + }, + { + name: "invalid-group", + policy: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"invalid"}, + }, + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Group("group:testgroup"))}, + }, + }, + node: nodes[0], + tag: "tag:test", + want: false, + wantErr: "Username has to contain @", + }, + { + name: "node-cannot-have-tag", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Username("user2@"))}, + }, + }, + node: nodes[0], + tag: "tag:test", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b, err := json.Marshal(tt.policy) + require.NoError(t, err) + + pm, err := NewPolicyManager(b, users, nodes) + if tt.wantErr != "" { + require.ErrorContains(t, err, tt.wantErr) + return + } + require.NoError(t, err) + + got := pm.NodeCanHaveTag(tt.node, tt.tag) + if got != tt.want { + t.Errorf("NodeCanHaveTag() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/hscontrol/policy/v2/utils.go b/hscontrol/policy/v2/utils.go new file mode 100644 index 00000000..9c962af8 --- /dev/null +++ b/hscontrol/policy/v2/utils.go @@ -0,0 +1,164 @@ +package v2 + +import ( + "errors" + "fmt" + "slices" + "strconv" + "strings" + + "tailscale.com/tailcfg" +) + +// splitDestinationAndPort takes an input string and returns the destination and port as a tuple, or an error if the input is invalid. +func splitDestinationAndPort(input string) (string, string, error) { + // Find the last occurrence of the colon character + lastColonIndex := strings.LastIndex(input, ":") + + // Check if the colon character is present and not at the beginning or end of the string + if lastColonIndex == -1 { + return "", "", errors.New("input must contain a colon character separating destination and port") + } + if lastColonIndex == 0 { + return "", "", errors.New("input cannot start with a colon character") + } + if lastColonIndex == len(input)-1 { + return "", "", errors.New("input cannot end with a colon character") + } + + // Split the string into destination and port based on the last colon + destination := input[:lastColonIndex] + port := input[lastColonIndex+1:] + + return destination, port, nil +} + +// parsePortRange parses a port definition string and returns a slice of PortRange structs. +func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { + if portDef == "*" { + return []tailcfg.PortRange{tailcfg.PortRangeAny}, nil + } + + var portRanges []tailcfg.PortRange + parts := strings.Split(portDef, ",") + + for _, part := range parts { + if strings.Contains(part, "-") { + rangeParts := strings.Split(part, "-") + rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool { + return e == "" + }) + if len(rangeParts) != 2 { + return nil, errors.New("invalid port range format") + } + + first, err := parsePort(rangeParts[0]) + if err != nil { + return nil, err + } + + last, err := parsePort(rangeParts[1]) + if err != nil { + return nil, err + } + + if first > last { + return nil, errors.New("invalid port range: first port is greater than last port") + } + + portRanges = append(portRanges, tailcfg.PortRange{First: first, Last: last}) + } else { + port, err := parsePort(part) + if err != nil { + return nil, err + } + + portRanges = append(portRanges, tailcfg.PortRange{First: port, Last: port}) + } + } + + return portRanges, nil +} + +// parsePort parses a single port number from a string. +func parsePort(portStr string) (uint16, error) { + port, err := strconv.Atoi(portStr) + if err != nil { + return 0, errors.New("invalid port number") + } + + if port < 0 || port > 65535 { + return 0, errors.New("port number out of range") + } + + return uint16(port), nil +} + +// For some reason golang.org/x/net/internal/iana is an internal package. +const ( + protocolICMP = 1 // Internet Control Message + protocolIGMP = 2 // Internet Group Management + protocolIPv4 = 4 // IPv4 encapsulation + protocolTCP = 6 // Transmission Control + protocolEGP = 8 // Exterior Gateway Protocol + protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP) + protocolUDP = 17 // User Datagram + protocolGRE = 47 // Generic Routing Encapsulation + protocolESP = 50 // Encap Security Payload + protocolAH = 51 // Authentication Header + protocolIPv6ICMP = 58 // ICMP for IPv6 + protocolSCTP = 132 // Stream Control Transmission Protocol + ProtocolFC = 133 // Fibre Channel +) + +// parseProtocol reads the proto field of the ACL and generates a list of +// protocols that will be allowed, following the IANA IP protocol number +// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml +// +// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP, +// as per Tailscale behaviour (see tailcfg.FilterRule). +// +// Also returns a boolean indicating if the protocol +// requires all the destinations to use wildcard as port number (only TCP, +// UDP and SCTP support specifying ports). +func parseProtocol(protocol string) ([]int, bool, error) { + switch protocol { + case "": + return nil, false, nil + case "igmp": + return []int{protocolIGMP}, true, nil + case "ipv4", "ip-in-ip": + return []int{protocolIPv4}, true, nil + case "tcp": + return []int{protocolTCP}, false, nil + case "egp": + return []int{protocolEGP}, true, nil + case "igp": + return []int{protocolIGP}, true, nil + case "udp": + return []int{protocolUDP}, false, nil + case "gre": + return []int{protocolGRE}, true, nil + case "esp": + return []int{protocolESP}, true, nil + case "ah": + return []int{protocolAH}, true, nil + case "sctp": + return []int{protocolSCTP}, false, nil + case "icmp": + return []int{protocolICMP, protocolIPv6ICMP}, true, nil + + default: + protocolNumber, err := strconv.Atoi(protocol) + if err != nil { + return nil, false, fmt.Errorf("parsing protocol number: %w", err) + } + + // TODO(kradalby): What is this? + needsWildcard := protocolNumber != protocolTCP && + protocolNumber != protocolUDP && + protocolNumber != protocolSCTP + + return []int{protocolNumber}, needsWildcard, nil + } +} diff --git a/hscontrol/policy/v2/utils_test.go b/hscontrol/policy/v2/utils_test.go new file mode 100644 index 00000000..d1645071 --- /dev/null +++ b/hscontrol/policy/v2/utils_test.go @@ -0,0 +1,102 @@ +package v2 + +import ( + "errors" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/tailcfg" +) + +// TestParseDestinationAndPort tests the parseDestinationAndPort function using table-driven tests. +func TestParseDestinationAndPort(t *testing.T) { + testCases := []struct { + input string + expectedDst string + expectedPort string + expectedErr error + }{ + {"git-server:*", "git-server", "*", nil}, + {"192.168.1.0/24:22", "192.168.1.0/24", "22", nil}, + {"fd7a:115c:a1e0::2:22", "fd7a:115c:a1e0::2", "22", nil}, + {"fd7a:115c:a1e0::2/128:22", "fd7a:115c:a1e0::2/128", "22", nil}, + {"tag:montreal-webserver:80,443", "tag:montreal-webserver", "80,443", nil}, + {"tag:api-server:443", "tag:api-server", "443", nil}, + {"example-host-1:*", "example-host-1", "*", nil}, + {"hostname:80-90", "hostname", "80-90", nil}, + {"invalidinput", "", "", errors.New("input must contain a colon character separating destination and port")}, + {":invalid", "", "", errors.New("input cannot start with a colon character")}, + {"invalid:", "", "", errors.New("input cannot end with a colon character")}, + } + + for _, testCase := range testCases { + dst, port, err := splitDestinationAndPort(testCase.input) + if dst != testCase.expectedDst || port != testCase.expectedPort || (err != nil && err.Error() != testCase.expectedErr.Error()) { + t.Errorf("parseDestinationAndPort(%q) = (%q, %q, %v), want (%q, %q, %v)", + testCase.input, dst, port, err, testCase.expectedDst, testCase.expectedPort, testCase.expectedErr) + } + } +} + +func TestParsePort(t *testing.T) { + tests := []struct { + input string + expected uint16 + err string + }{ + {"80", 80, ""}, + {"0", 0, ""}, + {"65535", 65535, ""}, + {"-1", 0, "port number out of range"}, + {"65536", 0, "port number out of range"}, + {"abc", 0, "invalid port number"}, + {"", 0, "invalid port number"}, + } + + for _, test := range tests { + result, err := parsePort(test.input) + if err != nil && err.Error() != test.err { + t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err) + } + if err == nil && test.err != "" { + t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err) + } + if result != test.expected { + t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected) + } + } +} + +func TestParsePortRange(t *testing.T) { + tests := []struct { + input string + expected []tailcfg.PortRange + err string + }{ + {"80", []tailcfg.PortRange{{80, 80}}, ""}, + {"80-90", []tailcfg.PortRange{{80, 90}}, ""}, + {"80,90", []tailcfg.PortRange{{80, 80}, {90, 90}}, ""}, + {"80-91,92,93-95", []tailcfg.PortRange{{80, 91}, {92, 92}, {93, 95}}, ""}, + {"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""}, + {"80-", nil, "invalid port range format"}, + {"-90", nil, "invalid port range format"}, + {"80-90,", nil, "invalid port number"}, + {"80,90-", nil, "invalid port range format"}, + {"80-90,abc", nil, "invalid port number"}, + {"80-90,65536", nil, "port number out of range"}, + {"80-90,90-80", nil, "invalid port range: first port is greater than last port"}, + } + + for _, test := range tests { + result, err := parsePortRange(test.input) + if err != nil && err.Error() != test.err { + t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err) + } + if err == nil && test.err != "" { + t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err) + } + if diff := cmp.Diff(result, test.expected); diff != "" { + t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff) + } + } +}