diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go index c03499ee..b52cb4dc 100644 --- a/hscontrol/policy/matcher/matcher.go +++ b/hscontrol/policy/matcher/matcher.go @@ -16,7 +16,7 @@ type Match struct { dests *netipx.IPSet } -func (m Match) DebugString() string { +func (m *Match) DebugString() string { var sb strings.Builder sb.WriteString("Match:\n") @@ -101,7 +101,7 @@ func (m *Match) DestsOverlapsPrefixes(prefixes ...netip.Prefix) bool { // cased for exit nodes. // This checks if dests is a superset of TheInternet(), which handles // merged filter rules where TheInternet is combined with other destinations. -func (m Match) DestsIsTheInternet() bool { +func (m *Match) DestsIsTheInternet() bool { if m.dests.ContainsPrefix(tsaddr.AllIPv4()) || m.dests.ContainsPrefix(tsaddr.AllIPv6()) { return true diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index 9c4ef46f..44490f29 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -1150,7 +1150,8 @@ func TestAutogroupTagged(t *testing.T) { require.NoError(t, err) // Verify autogroup:tagged includes all tagged nodes - taggedIPs, err := AutoGroupTagged.Resolve(policy, users, nodes.ViewSlice()) + ag := AutoGroupTagged + taggedIPs, err := ag.Resolve(policy, users, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, taggedIPs) diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index d161138b..c8b4f4e5 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -185,12 +185,12 @@ func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeV // Username is a string that represents a username, it must contain an @. type Username string -func (u Username) Validate() error { - if isUser(string(u)) { +func (u *Username) Validate() error { + if isUser(string(*u)) { return nil } - return fmt.Errorf("%w, got: %q", ErrInvalidUsername, u) + return fmt.Errorf("%w, got: %q", ErrInvalidUsername, *u) } func (u *Username) String() string { @@ -198,12 +198,12 @@ func (u *Username) String() string { } // MarshalJSON marshals the Username to JSON. -func (u Username) MarshalJSON() ([]byte, error) { - return json.Marshal(string(u)) +func (u *Username) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*u)) } // MarshalJSON marshals the Prefix to JSON. -func (p Prefix) MarshalJSON() ([]byte, error) { +func (p *Prefix) MarshalJSON() ([]byte, error) { return json.Marshal(p.String()) } @@ -218,11 +218,11 @@ func (u *Username) UnmarshalJSON(b []byte) error { return nil } -func (u Username) CanBeTagOwner() bool { +func (u *Username) CanBeTagOwner() bool { return true } -func (u Username) CanBeAutoApprover() bool { +func (u *Username) CanBeAutoApprover() bool { return true } @@ -231,7 +231,7 @@ func (u Username) CanBeAutoApprover() bool { // If no matching user is found, it returns an error indicating no user matching. // If multiple matching users are found, it returns an error indicating multiple users matching. // It returns the matched types.User and a nil error if exactly one match is found. -func (u Username) resolveUser(users types.Users) (types.User, error) { +func (u *Username) resolveUser(users types.Users) (types.User, error) { var potentialUsers types.Users // At parsetime, we require all usernames to contain an "@" character, if the @@ -262,7 +262,7 @@ func (u Username) resolveUser(users types.Users) (types.User, error) { return potentialUsers[0], nil } -func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (u *Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ( ips netipx.IPSetBuilder errs []error @@ -295,12 +295,12 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types. // Group is a special string which is always prefixed with `group:`. type Group string -func (g Group) Validate() error { - if isGroup(string(g)) { +func (g *Group) Validate() error { + if isGroup(string(*g)) { return nil } - return fmt.Errorf("%w, got: %q", ErrInvalidGroupFormat, g) + return fmt.Errorf("%w, got: %q", ErrInvalidGroupFormat, *g) } func (g *Group) UnmarshalJSON(b []byte) error { @@ -314,40 +314,40 @@ func (g *Group) UnmarshalJSON(b []byte) error { return nil } -func (g Group) CanBeTagOwner() bool { +func (g *Group) CanBeTagOwner() bool { return true } -func (g Group) CanBeAutoApprover() bool { +func (g *Group) CanBeAutoApprover() bool { return true } // String returns the string representation of the Group. -func (g Group) String() string { - return string(g) +func (g *Group) String() string { + return string(*g) } -func (h Host) String() string { - return string(h) +func (h *Host) String() string { + return string(*h) } // MarshalJSON marshals the Host to JSON. -func (h Host) MarshalJSON() ([]byte, error) { - return json.Marshal(string(h)) +func (h *Host) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*h)) } // MarshalJSON marshals the Group to JSON. -func (g Group) MarshalJSON() ([]byte, error) { - return json.Marshal(string(g)) +func (g *Group) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*g)) } -func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (g *Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ( ips netipx.IPSetBuilder errs []error ) - for _, user := range p.Groups[g] { + for _, user := range p.Groups[*g] { uips, err := user.Resolve(nil, users, nodes) if err != nil { errs = append(errs, err) @@ -362,12 +362,12 @@ func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.Nod // Tag is a special string which is always prefixed with `tag:`. type Tag string -func (t Tag) Validate() error { - if isTag(string(t)) { +func (t *Tag) Validate() error { + if isTag(string(*t)) { return nil } - return fmt.Errorf("%w, got: %q", ErrInvalidTagFormat, t) + return fmt.Errorf("%w, got: %q", ErrInvalidTagFormat, *t) } func (t *Tag) UnmarshalJSON(b []byte) error { @@ -381,12 +381,12 @@ func (t *Tag) UnmarshalJSON(b []byte) error { return nil } -func (t Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (t *Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder for _, node := range nodes.All() { // Check if node has this tag - if node.HasTag(string(t)) { + if node.HasTag(string(*t)) { node.AppendToIPSet(&ips) } } @@ -394,32 +394,32 @@ func (t Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeV return ips.IPSet() } -func (t Tag) CanBeAutoApprover() bool { +func (t *Tag) CanBeAutoApprover() bool { return true } -func (t Tag) CanBeTagOwner() bool { +func (t *Tag) CanBeTagOwner() bool { return true } -func (t Tag) String() string { - return string(t) +func (t *Tag) String() string { + return string(*t) } // MarshalJSON marshals the Tag to JSON. -func (t Tag) MarshalJSON() ([]byte, error) { - return json.Marshal(string(t)) +func (t *Tag) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*t)) } // Host is a string that represents a hostname. type Host string -func (h Host) Validate() error { - if isHost(string(h)) { +func (h *Host) Validate() error { + if isHost(string(*h)) { return nil } - return fmt.Errorf("%w: %q", ErrInvalidHostname, h) + return fmt.Errorf("%w: %q", ErrInvalidHostname, *h) } func (h *Host) UnmarshalJSON(b []byte) error { @@ -433,15 +433,15 @@ func (h *Host) UnmarshalJSON(b []byte) error { return nil } -func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (h *Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ( ips netipx.IPSetBuilder errs []error ) - pref, ok := p.Hosts[h] + pref, ok := p.Hosts[*h] if !ok { - return nil, fmt.Errorf("%w: %q", ErrHostResolve, h) + return nil, fmt.Errorf("%w: %q", ErrHostResolve, *h) } err := pref.Validate() @@ -473,16 +473,16 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView type Prefix netip.Prefix -func (p Prefix) Validate() error { - if netip.Prefix(p).IsValid() { +func (p *Prefix) Validate() error { + if netip.Prefix(*p).IsValid() { return nil } - return fmt.Errorf("%w: %q", ErrInvalidPrefix, p) + return fmt.Errorf("%w: %s", ErrInvalidPrefix, p.String()) } -func (p Prefix) String() string { - return netip.Prefix(p).String() +func (p *Prefix) String() string { + return netip.Prefix(*p).String() } func (p *Prefix) parseString(addr string) error { @@ -530,16 +530,16 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { // of the Prefix and the Policy, Users, and Nodes. // // See [Policy], [types.Users], and [types.Nodes] for more details. -func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (p *Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ( ips netipx.IPSetBuilder errs []error ) - ips.AddPrefix(netip.Prefix(p)) + ips.AddPrefix(netip.Prefix(*p)) // If the IP is a single host, look for a node to ensure we add all the IPs of // the node to the IPSet. - appendIfNodeHasIP(nodes, &ips, netip.Prefix(p)) + appendIfNodeHasIP(nodes, &ips, netip.Prefix(*p)) return buildIPSetMultiErr(&ips, errs) } @@ -577,12 +577,12 @@ var autogroups = []AutoGroup{ AutoGroupSelf, } -func (ag AutoGroup) Validate() error { - if slices.Contains(autogroups, ag) { +func (ag *AutoGroup) Validate() error { + if slices.Contains(autogroups, *ag) { return nil } - return fmt.Errorf("%w: got %q, must be one of %v", ErrInvalidAutogroup, ag, autogroups) + return fmt.Errorf("%w: got %q, must be one of %v", ErrInvalidAutogroup, *ag, autogroups) } func (ag *AutoGroup) UnmarshalJSON(b []byte) error { @@ -596,19 +596,19 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error { return nil } -func (ag AutoGroup) String() string { - return string(ag) +func (ag *AutoGroup) String() string { + return string(*ag) } // MarshalJSON marshals the AutoGroup to JSON. -func (ag AutoGroup) MarshalJSON() ([]byte, error) { - return json.Marshal(string(ag)) +func (ag *AutoGroup) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*ag)) } -func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (ag *AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var build netipx.IPSetBuilder - switch ag { + switch *ag { case AutoGroupInternet: return util.TheInternet(), nil @@ -646,10 +646,10 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[type case AutoGroupNonRoot: // autogroup:nonroot represents non-root users on multi-user devices. // This is not supported in headscale and requires OS-level user detection. - return nil, fmt.Errorf("%w: %q", ErrUnknownAutogroup, ag) + return nil, fmt.Errorf("%w: %q", ErrUnknownAutogroup, *ag) default: - return nil, fmt.Errorf("%w: %q", ErrUnknownAutogroup, ag) + return nil, fmt.Errorf("%w: %q", ErrUnknownAutogroup, *ag) } } @@ -814,13 +814,13 @@ func (a *Aliases) UnmarshalJSON(b []byte) error { } // MarshalJSON marshals the Aliases to JSON. -func (a Aliases) MarshalJSON() ([]byte, error) { - if a == nil { +func (a *Aliases) MarshalJSON() ([]byte, error) { + if *a == nil { return []byte("[]"), nil } - aliases := make([]string, len(a)) - for i, alias := range a { + aliases := make([]string, len(*a)) + for i, alias := range *a { switch v := alias.(type) { case *Username: aliases[i] = string(*v) @@ -844,13 +844,13 @@ func (a Aliases) MarshalJSON() ([]byte, error) { return json.Marshal(aliases) } -func (a Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (a *Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ( ips netipx.IPSetBuilder errs []error ) - for _, alias := range a { + for _, alias := range *a { aips, err := alias.Resolve(p, users, nodes) if err != nil { errs = append(errs, err) @@ -1042,12 +1042,12 @@ type Usernames []Username // Groups are a map of Group to a list of Username. type Groups map[Group]Usernames -func (g Groups) Contains(group *Group) error { +func (g *Groups) Contains(group *Group) error { if group == nil { return nil } - for defined := range map[Group]Usernames(g) { + for defined := range map[Group]Usernames(*g) { if defined == *group { return nil } @@ -1167,21 +1167,21 @@ func (h *Hosts) UnmarshalJSON(b []byte) error { } // MarshalJSON marshals the Hosts to JSON. -func (h Hosts) MarshalJSON() ([]byte, error) { - if h == nil { +func (h *Hosts) MarshalJSON() ([]byte, error) { + if *h == nil { return []byte("{}"), nil } rawHosts := make(map[string]string) - for host, prefix := range h { + for host, prefix := range *h { rawHosts[string(host)] = prefix.String() } return json.Marshal(rawHosts) } -func (h Hosts) exist(name Host) bool { - _, ok := h[name] +func (h *Hosts) exist(name Host) bool { + _, ok := (*h)[name] return ok } @@ -1344,8 +1344,8 @@ const ( ) // String returns the string representation of the Action. -func (a Action) String() string { - return string(a) +func (a *Action) String() string { + return string(*a) } // UnmarshalJSON implements JSON unmarshaling for Action. @@ -1362,13 +1362,13 @@ func (a *Action) UnmarshalJSON(b []byte) error { } // MarshalJSON implements JSON marshaling for Action. -func (a Action) MarshalJSON() ([]byte, error) { - return json.Marshal(string(a)) +func (a *Action) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*a)) } // String returns the string representation of the SSHAction. -func (a SSHAction) String() string { - return string(a) +func (a *SSHAction) String() string { + return string(*a) } // UnmarshalJSON implements JSON unmarshaling for SSHAction. @@ -1387,8 +1387,8 @@ func (a *SSHAction) UnmarshalJSON(b []byte) error { } // MarshalJSON implements JSON marshaling for SSHAction. -func (a SSHAction) MarshalJSON() ([]byte, error) { - return json.Marshal(string(a)) +func (a *SSHAction) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*a)) } // Protocol represents a network protocol with its IANA number and descriptions. @@ -1413,13 +1413,13 @@ const ( ) // String returns the string representation of the Protocol. -func (p Protocol) String() string { - return string(p) +func (p *Protocol) String() string { + return string(*p) } // Description returns the human-readable description of the Protocol. -func (p Protocol) Description() string { - switch p { +func (p *Protocol) Description() string { + switch *p { case ProtocolNameICMP: return "Internet Control Message Protocol" case ProtocolNameIGMP: @@ -1457,8 +1457,8 @@ func (p Protocol) Description() string { // parseProtocol converts a Protocol to its IANA protocol numbers. // Since validation happens during UnmarshalJSON, this method should not fail for valid Protocol values. -func (p Protocol) parseProtocol() []int { - switch p { +func (p *Protocol) parseProtocol() []int { + switch *p { case "": // Empty protocol applies to TCP, UDP, ICMP, and ICMPv6 traffic // This matches Tailscale's behavior for protocol defaults @@ -1496,7 +1496,7 @@ func (p Protocol) parseProtocol() []int { default: // Try to parse as a numeric protocol number // This should not fail since validation happened during unmarshaling - protocolNumber, _ := strconv.Atoi(string(p)) + protocolNumber, _ := strconv.Atoi(string(*p)) return []int{protocolNumber} } } @@ -1518,8 +1518,8 @@ func (p *Protocol) UnmarshalJSON(b []byte) error { } // validate checks if the Protocol is valid. -func (p Protocol) validate() error { - switch p { +func (p *Protocol) validate() error { + switch *p { case "", ProtocolNameICMP, ProtocolNameIGMP, ProtocolNameIPv4, ProtocolNameIPInIP, ProtocolNameTCP, ProtocolNameEGP, ProtocolNameIGP, ProtocolNameUDP, ProtocolNameGRE, ProtocolNameESP, ProtocolNameAH, ProtocolNameSCTP, ProtocolNameIPv6ICMP, ProtocolNameFC: @@ -1529,7 +1529,7 @@ func (p Protocol) validate() error { return errUnknownProtocolWildcard default: // Try to parse as a numeric protocol number - str := string(p) + str := string(*p) // Check for leading zeros (not allowed by Tailscale) if str == "0" || (len(str) > 1 && str[0] == '0') { @@ -1538,7 +1538,7 @@ func (p Protocol) validate() error { protocolNumber, err := strconv.Atoi(str) if err != nil { - return fmt.Errorf("%w: %q must be a known protocol name or valid protocol number 0-255", ErrInvalidProtocolNumber, p) + return fmt.Errorf("%w: %q must be a known protocol name or valid protocol number 0-255", ErrInvalidProtocolNumber, *p) } if protocolNumber < 0 || protocolNumber > 255 { @@ -1550,8 +1550,8 @@ func (p Protocol) validate() error { } // MarshalJSON implements JSON marshaling for Protocol. -func (p Protocol) MarshalJSON() ([]byte, error) { - return json.Marshal(string(p)) +func (p *Protocol) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*p)) } // Protocol constants matching the IANA numbers. @@ -2110,13 +2110,13 @@ type SSH struct { type SSHSrcAliases []Alias // MarshalJSON marshals the Groups to JSON. -func (g Groups) MarshalJSON() ([]byte, error) { - if g == nil { +func (g *Groups) MarshalJSON() ([]byte, error) { + if *g == nil { return []byte("{}"), nil } raw := make(map[string][]string) - for group, usernames := range g { + for group, usernames := range *g { users := make([]string, len(usernames)) for i, username := range usernames { users[i] = string(username) @@ -2204,13 +2204,13 @@ func (a SSHDstAliases) MarshalJSON() ([]byte, error) { } // MarshalJSON marshals the SSHSrcAliases to JSON. -func (a SSHSrcAliases) MarshalJSON() ([]byte, error) { - if a == nil { +func (a *SSHSrcAliases) MarshalJSON() ([]byte, error) { + if a == nil || *a == nil { return []byte("[]"), nil } - aliases := make([]string, len(a)) - for i, alias := range a { + aliases := make([]string, len(*a)) + for i, alias := range *a { switch v := alias.(type) { case *Username: aliases[i] = string(*v) @@ -2230,13 +2230,13 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) { return json.Marshal(aliases) } -func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (a *SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ( ips netipx.IPSetBuilder errs []error ) - for _, alias := range a { + for _, alias := range *a { aips, err := alias.Resolve(p, users, nodes) if err != nil { errs = append(errs, err) diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index ad43b961..fc3fa6aa 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -167,7 +167,7 @@ func (node *Node) GivenNameHasBeenChanged() bool { } // IsExpired returns whether the node registration has expired. -func (node Node) IsExpired() bool { +func (node *Node) IsExpired() bool { // If Expiry is not set, the client has not indicated that // it wants an expiry time, it is therefore considered // to mean "not expired" @@ -738,7 +738,7 @@ func (nodes Nodes) DebugString() string { return sb.String() } -func (node Node) DebugString() string { +func (node *Node) DebugString() string { var sb strings.Builder fmt.Fprintf(&sb, "%s(%s):\n", node.Hostname, node.ID)