From 08646a39cb3d8df892305994c9507fb3057fd992 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 6 Feb 2026 11:39:07 +0000 Subject: [PATCH] all: fix recvcheck issues Standardize receiver types for types that have both pointer and value receivers. Since UnmarshalJSON must use pointer receivers to modify the value, standardize all methods to use pointer receivers. Types fixed: - Protocol: String, Description, parseProtocol, validate, MarshalJSON - SSHSrcAliases: MarshalJSON, Resolve - Match: DebugString, DestsIsTheInternet - Node: IsExpired, DebugString --- hscontrol/policy/matcher/matcher.go | 4 +- hscontrol/policy/v2/filter_test.go | 3 +- hscontrol/policy/v2/types.go | 208 ++++++++++++++-------------- hscontrol/types/node.go | 4 +- 4 files changed, 110 insertions(+), 109 deletions(-) diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go index c03499ee..b52cb4dc 100644 --- a/hscontrol/policy/matcher/matcher.go +++ b/hscontrol/policy/matcher/matcher.go @@ -16,7 +16,7 @@ type Match struct { dests *netipx.IPSet } -func (m Match) DebugString() string { +func (m *Match) DebugString() string { var sb strings.Builder sb.WriteString("Match:\n") @@ -101,7 +101,7 @@ func (m *Match) DestsOverlapsPrefixes(prefixes ...netip.Prefix) bool { // cased for exit nodes. // This checks if dests is a superset of TheInternet(), which handles // merged filter rules where TheInternet is combined with other destinations. -func (m Match) DestsIsTheInternet() bool { +func (m *Match) DestsIsTheInternet() bool { if m.dests.ContainsPrefix(tsaddr.AllIPv4()) || m.dests.ContainsPrefix(tsaddr.AllIPv6()) { return true diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index 9c4ef46f..44490f29 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -1150,7 +1150,8 @@ func TestAutogroupTagged(t *testing.T) { require.NoError(t, err) // Verify autogroup:tagged includes all tagged nodes - taggedIPs, err := AutoGroupTagged.Resolve(policy, users, nodes.ViewSlice()) + ag := AutoGroupTagged + taggedIPs, err := ag.Resolve(policy, users, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, taggedIPs) diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index d161138b..c8b4f4e5 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -185,12 +185,12 @@ func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeV // Username is a string that represents a username, it must contain an @. type Username string -func (u Username) Validate() error { - if isUser(string(u)) { +func (u *Username) Validate() error { + if isUser(string(*u)) { return nil } - return fmt.Errorf("%w, got: %q", ErrInvalidUsername, u) + return fmt.Errorf("%w, got: %q", ErrInvalidUsername, *u) } func (u *Username) String() string { @@ -198,12 +198,12 @@ func (u *Username) String() string { } // MarshalJSON marshals the Username to JSON. -func (u Username) MarshalJSON() ([]byte, error) { - return json.Marshal(string(u)) +func (u *Username) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*u)) } // MarshalJSON marshals the Prefix to JSON. -func (p Prefix) MarshalJSON() ([]byte, error) { +func (p *Prefix) MarshalJSON() ([]byte, error) { return json.Marshal(p.String()) } @@ -218,11 +218,11 @@ func (u *Username) UnmarshalJSON(b []byte) error { return nil } -func (u Username) CanBeTagOwner() bool { +func (u *Username) CanBeTagOwner() bool { return true } -func (u Username) CanBeAutoApprover() bool { +func (u *Username) CanBeAutoApprover() bool { return true } @@ -231,7 +231,7 @@ func (u Username) CanBeAutoApprover() bool { // If no matching user is found, it returns an error indicating no user matching. // If multiple matching users are found, it returns an error indicating multiple users matching. // It returns the matched types.User and a nil error if exactly one match is found. -func (u Username) resolveUser(users types.Users) (types.User, error) { +func (u *Username) resolveUser(users types.Users) (types.User, error) { var potentialUsers types.Users // At parsetime, we require all usernames to contain an "@" character, if the @@ -262,7 +262,7 @@ func (u Username) resolveUser(users types.Users) (types.User, error) { return potentialUsers[0], nil } -func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (u *Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ( ips netipx.IPSetBuilder errs []error @@ -295,12 +295,12 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types. // Group is a special string which is always prefixed with `group:`. type Group string -func (g Group) Validate() error { - if isGroup(string(g)) { +func (g *Group) Validate() error { + if isGroup(string(*g)) { return nil } - return fmt.Errorf("%w, got: %q", ErrInvalidGroupFormat, g) + return fmt.Errorf("%w, got: %q", ErrInvalidGroupFormat, *g) } func (g *Group) UnmarshalJSON(b []byte) error { @@ -314,40 +314,40 @@ func (g *Group) UnmarshalJSON(b []byte) error { return nil } -func (g Group) CanBeTagOwner() bool { +func (g *Group) CanBeTagOwner() bool { return true } -func (g Group) CanBeAutoApprover() bool { +func (g *Group) CanBeAutoApprover() bool { return true } // String returns the string representation of the Group. -func (g Group) String() string { - return string(g) +func (g *Group) String() string { + return string(*g) } -func (h Host) String() string { - return string(h) +func (h *Host) String() string { + return string(*h) } // MarshalJSON marshals the Host to JSON. -func (h Host) MarshalJSON() ([]byte, error) { - return json.Marshal(string(h)) +func (h *Host) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*h)) } // MarshalJSON marshals the Group to JSON. -func (g Group) MarshalJSON() ([]byte, error) { - return json.Marshal(string(g)) +func (g *Group) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*g)) } -func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (g *Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ( ips netipx.IPSetBuilder errs []error ) - for _, user := range p.Groups[g] { + for _, user := range p.Groups[*g] { uips, err := user.Resolve(nil, users, nodes) if err != nil { errs = append(errs, err) @@ -362,12 +362,12 @@ func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.Nod // Tag is a special string which is always prefixed with `tag:`. type Tag string -func (t Tag) Validate() error { - if isTag(string(t)) { +func (t *Tag) Validate() error { + if isTag(string(*t)) { return nil } - return fmt.Errorf("%w, got: %q", ErrInvalidTagFormat, t) + return fmt.Errorf("%w, got: %q", ErrInvalidTagFormat, *t) } func (t *Tag) UnmarshalJSON(b []byte) error { @@ -381,12 +381,12 @@ func (t *Tag) UnmarshalJSON(b []byte) error { return nil } -func (t Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (t *Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder for _, node := range nodes.All() { // Check if node has this tag - if node.HasTag(string(t)) { + if node.HasTag(string(*t)) { node.AppendToIPSet(&ips) } } @@ -394,32 +394,32 @@ func (t Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeV return ips.IPSet() } -func (t Tag) CanBeAutoApprover() bool { +func (t *Tag) CanBeAutoApprover() bool { return true } -func (t Tag) CanBeTagOwner() bool { +func (t *Tag) CanBeTagOwner() bool { return true } -func (t Tag) String() string { - return string(t) +func (t *Tag) String() string { + return string(*t) } // MarshalJSON marshals the Tag to JSON. -func (t Tag) MarshalJSON() ([]byte, error) { - return json.Marshal(string(t)) +func (t *Tag) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*t)) } // Host is a string that represents a hostname. type Host string -func (h Host) Validate() error { - if isHost(string(h)) { +func (h *Host) Validate() error { + if isHost(string(*h)) { return nil } - return fmt.Errorf("%w: %q", ErrInvalidHostname, h) + return fmt.Errorf("%w: %q", ErrInvalidHostname, *h) } func (h *Host) UnmarshalJSON(b []byte) error { @@ -433,15 +433,15 @@ func (h *Host) UnmarshalJSON(b []byte) error { return nil } -func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (h *Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ( ips netipx.IPSetBuilder errs []error ) - pref, ok := p.Hosts[h] + pref, ok := p.Hosts[*h] if !ok { - return nil, fmt.Errorf("%w: %q", ErrHostResolve, h) + return nil, fmt.Errorf("%w: %q", ErrHostResolve, *h) } err := pref.Validate() @@ -473,16 +473,16 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView type Prefix netip.Prefix -func (p Prefix) Validate() error { - if netip.Prefix(p).IsValid() { +func (p *Prefix) Validate() error { + if netip.Prefix(*p).IsValid() { return nil } - return fmt.Errorf("%w: %q", ErrInvalidPrefix, p) + return fmt.Errorf("%w: %s", ErrInvalidPrefix, p.String()) } -func (p Prefix) String() string { - return netip.Prefix(p).String() +func (p *Prefix) String() string { + return netip.Prefix(*p).String() } func (p *Prefix) parseString(addr string) error { @@ -530,16 +530,16 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { // of the Prefix and the Policy, Users, and Nodes. // // See [Policy], [types.Users], and [types.Nodes] for more details. -func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (p *Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ( ips netipx.IPSetBuilder errs []error ) - ips.AddPrefix(netip.Prefix(p)) + ips.AddPrefix(netip.Prefix(*p)) // If the IP is a single host, look for a node to ensure we add all the IPs of // the node to the IPSet. - appendIfNodeHasIP(nodes, &ips, netip.Prefix(p)) + appendIfNodeHasIP(nodes, &ips, netip.Prefix(*p)) return buildIPSetMultiErr(&ips, errs) } @@ -577,12 +577,12 @@ var autogroups = []AutoGroup{ AutoGroupSelf, } -func (ag AutoGroup) Validate() error { - if slices.Contains(autogroups, ag) { +func (ag *AutoGroup) Validate() error { + if slices.Contains(autogroups, *ag) { return nil } - return fmt.Errorf("%w: got %q, must be one of %v", ErrInvalidAutogroup, ag, autogroups) + return fmt.Errorf("%w: got %q, must be one of %v", ErrInvalidAutogroup, *ag, autogroups) } func (ag *AutoGroup) UnmarshalJSON(b []byte) error { @@ -596,19 +596,19 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error { return nil } -func (ag AutoGroup) String() string { - return string(ag) +func (ag *AutoGroup) String() string { + return string(*ag) } // MarshalJSON marshals the AutoGroup to JSON. -func (ag AutoGroup) MarshalJSON() ([]byte, error) { - return json.Marshal(string(ag)) +func (ag *AutoGroup) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*ag)) } -func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (ag *AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var build netipx.IPSetBuilder - switch ag { + switch *ag { case AutoGroupInternet: return util.TheInternet(), nil @@ -646,10 +646,10 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[type case AutoGroupNonRoot: // autogroup:nonroot represents non-root users on multi-user devices. // This is not supported in headscale and requires OS-level user detection. - return nil, fmt.Errorf("%w: %q", ErrUnknownAutogroup, ag) + return nil, fmt.Errorf("%w: %q", ErrUnknownAutogroup, *ag) default: - return nil, fmt.Errorf("%w: %q", ErrUnknownAutogroup, ag) + return nil, fmt.Errorf("%w: %q", ErrUnknownAutogroup, *ag) } } @@ -814,13 +814,13 @@ func (a *Aliases) UnmarshalJSON(b []byte) error { } // MarshalJSON marshals the Aliases to JSON. -func (a Aliases) MarshalJSON() ([]byte, error) { - if a == nil { +func (a *Aliases) MarshalJSON() ([]byte, error) { + if *a == nil { return []byte("[]"), nil } - aliases := make([]string, len(a)) - for i, alias := range a { + aliases := make([]string, len(*a)) + for i, alias := range *a { switch v := alias.(type) { case *Username: aliases[i] = string(*v) @@ -844,13 +844,13 @@ func (a Aliases) MarshalJSON() ([]byte, error) { return json.Marshal(aliases) } -func (a Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (a *Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ( ips netipx.IPSetBuilder errs []error ) - for _, alias := range a { + for _, alias := range *a { aips, err := alias.Resolve(p, users, nodes) if err != nil { errs = append(errs, err) @@ -1042,12 +1042,12 @@ type Usernames []Username // Groups are a map of Group to a list of Username. type Groups map[Group]Usernames -func (g Groups) Contains(group *Group) error { +func (g *Groups) Contains(group *Group) error { if group == nil { return nil } - for defined := range map[Group]Usernames(g) { + for defined := range map[Group]Usernames(*g) { if defined == *group { return nil } @@ -1167,21 +1167,21 @@ func (h *Hosts) UnmarshalJSON(b []byte) error { } // MarshalJSON marshals the Hosts to JSON. -func (h Hosts) MarshalJSON() ([]byte, error) { - if h == nil { +func (h *Hosts) MarshalJSON() ([]byte, error) { + if *h == nil { return []byte("{}"), nil } rawHosts := make(map[string]string) - for host, prefix := range h { + for host, prefix := range *h { rawHosts[string(host)] = prefix.String() } return json.Marshal(rawHosts) } -func (h Hosts) exist(name Host) bool { - _, ok := h[name] +func (h *Hosts) exist(name Host) bool { + _, ok := (*h)[name] return ok } @@ -1344,8 +1344,8 @@ const ( ) // String returns the string representation of the Action. -func (a Action) String() string { - return string(a) +func (a *Action) String() string { + return string(*a) } // UnmarshalJSON implements JSON unmarshaling for Action. @@ -1362,13 +1362,13 @@ func (a *Action) UnmarshalJSON(b []byte) error { } // MarshalJSON implements JSON marshaling for Action. -func (a Action) MarshalJSON() ([]byte, error) { - return json.Marshal(string(a)) +func (a *Action) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*a)) } // String returns the string representation of the SSHAction. -func (a SSHAction) String() string { - return string(a) +func (a *SSHAction) String() string { + return string(*a) } // UnmarshalJSON implements JSON unmarshaling for SSHAction. @@ -1387,8 +1387,8 @@ func (a *SSHAction) UnmarshalJSON(b []byte) error { } // MarshalJSON implements JSON marshaling for SSHAction. -func (a SSHAction) MarshalJSON() ([]byte, error) { - return json.Marshal(string(a)) +func (a *SSHAction) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*a)) } // Protocol represents a network protocol with its IANA number and descriptions. @@ -1413,13 +1413,13 @@ const ( ) // String returns the string representation of the Protocol. -func (p Protocol) String() string { - return string(p) +func (p *Protocol) String() string { + return string(*p) } // Description returns the human-readable description of the Protocol. -func (p Protocol) Description() string { - switch p { +func (p *Protocol) Description() string { + switch *p { case ProtocolNameICMP: return "Internet Control Message Protocol" case ProtocolNameIGMP: @@ -1457,8 +1457,8 @@ func (p Protocol) Description() string { // parseProtocol converts a Protocol to its IANA protocol numbers. // Since validation happens during UnmarshalJSON, this method should not fail for valid Protocol values. -func (p Protocol) parseProtocol() []int { - switch p { +func (p *Protocol) parseProtocol() []int { + switch *p { case "": // Empty protocol applies to TCP, UDP, ICMP, and ICMPv6 traffic // This matches Tailscale's behavior for protocol defaults @@ -1496,7 +1496,7 @@ func (p Protocol) parseProtocol() []int { default: // Try to parse as a numeric protocol number // This should not fail since validation happened during unmarshaling - protocolNumber, _ := strconv.Atoi(string(p)) + protocolNumber, _ := strconv.Atoi(string(*p)) return []int{protocolNumber} } } @@ -1518,8 +1518,8 @@ func (p *Protocol) UnmarshalJSON(b []byte) error { } // validate checks if the Protocol is valid. -func (p Protocol) validate() error { - switch p { +func (p *Protocol) validate() error { + switch *p { case "", ProtocolNameICMP, ProtocolNameIGMP, ProtocolNameIPv4, ProtocolNameIPInIP, ProtocolNameTCP, ProtocolNameEGP, ProtocolNameIGP, ProtocolNameUDP, ProtocolNameGRE, ProtocolNameESP, ProtocolNameAH, ProtocolNameSCTP, ProtocolNameIPv6ICMP, ProtocolNameFC: @@ -1529,7 +1529,7 @@ func (p Protocol) validate() error { return errUnknownProtocolWildcard default: // Try to parse as a numeric protocol number - str := string(p) + str := string(*p) // Check for leading zeros (not allowed by Tailscale) if str == "0" || (len(str) > 1 && str[0] == '0') { @@ -1538,7 +1538,7 @@ func (p Protocol) validate() error { protocolNumber, err := strconv.Atoi(str) if err != nil { - return fmt.Errorf("%w: %q must be a known protocol name or valid protocol number 0-255", ErrInvalidProtocolNumber, p) + return fmt.Errorf("%w: %q must be a known protocol name or valid protocol number 0-255", ErrInvalidProtocolNumber, *p) } if protocolNumber < 0 || protocolNumber > 255 { @@ -1550,8 +1550,8 @@ func (p Protocol) validate() error { } // MarshalJSON implements JSON marshaling for Protocol. -func (p Protocol) MarshalJSON() ([]byte, error) { - return json.Marshal(string(p)) +func (p *Protocol) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*p)) } // Protocol constants matching the IANA numbers. @@ -2110,13 +2110,13 @@ type SSH struct { type SSHSrcAliases []Alias // MarshalJSON marshals the Groups to JSON. -func (g Groups) MarshalJSON() ([]byte, error) { - if g == nil { +func (g *Groups) MarshalJSON() ([]byte, error) { + if *g == nil { return []byte("{}"), nil } raw := make(map[string][]string) - for group, usernames := range g { + for group, usernames := range *g { users := make([]string, len(usernames)) for i, username := range usernames { users[i] = string(username) @@ -2204,13 +2204,13 @@ func (a SSHDstAliases) MarshalJSON() ([]byte, error) { } // MarshalJSON marshals the SSHSrcAliases to JSON. -func (a SSHSrcAliases) MarshalJSON() ([]byte, error) { - if a == nil { +func (a *SSHSrcAliases) MarshalJSON() ([]byte, error) { + if a == nil || *a == nil { return []byte("[]"), nil } - aliases := make([]string, len(a)) - for i, alias := range a { + aliases := make([]string, len(*a)) + for i, alias := range *a { switch v := alias.(type) { case *Username: aliases[i] = string(*v) @@ -2230,13 +2230,13 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) { return json.Marshal(aliases) } -func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (a *SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ( ips netipx.IPSetBuilder errs []error ) - for _, alias := range a { + for _, alias := range *a { aips, err := alias.Resolve(p, users, nodes) if err != nil { errs = append(errs, err) diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index ad43b961..fc3fa6aa 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -167,7 +167,7 @@ func (node *Node) GivenNameHasBeenChanged() bool { } // IsExpired returns whether the node registration has expired. -func (node Node) IsExpired() bool { +func (node *Node) IsExpired() bool { // If Expiry is not set, the client has not indicated that // it wants an expiry time, it is therefore considered // to mean "not expired" @@ -738,7 +738,7 @@ func (nodes Nodes) DebugString() string { return sb.String() } -func (node Node) DebugString() string { +func (node *Node) DebugString() string { var sb strings.Builder fmt.Fprintf(&sb, "%s(%s):\n", node.Hostname, node.ID)