diff --git a/hscontrol/db/ip_test.go b/hscontrol/db/ip_test.go index 0e5b6ad4..bc128999 100644 --- a/hscontrol/db/ip_test.go +++ b/hscontrol/db/ip_test.go @@ -449,7 +449,6 @@ func TestBackfillIPAddresses(t *testing.T) { "UserID", "Endpoints", "Hostinfo", - "Routes", "CreatedAt", "UpdatedAt", )) diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 4437b30b..705596cd 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -101,15 +101,22 @@ func generateUserProfiles( node *types.Node, peers types.Nodes, ) []tailcfg.UserProfile { - userMap := make(map[uint]types.User) - userMap[node.User.ID] = node.User + userMap := make(map[uint]*types.User) + ids := make([]uint, 0, len(userMap)) + userMap[node.User.ID] = &node.User + ids = append(ids, node.User.ID) for _, peer := range peers { - userMap[peer.User.ID] = peer.User // not worth checking if already is there + userMap[peer.User.ID] = &peer.User + ids = append(ids, peer.User.ID) } + slices.Sort(ids) + slices.Compact(ids) var profiles []tailcfg.UserProfile - for _, user := range userMap { - profiles = append(profiles, user.TailscaleUserProfile()) + for _, id := range ids { + if userMap[id] != nil { + profiles = append(profiles, userMap[id].TailscaleUserProfile()) + } } return profiles diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 07ed79a2..bc8d011d 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -11,7 +11,6 @@ import ( "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/types" - "gopkg.in/check.v1" "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" @@ -24,51 +23,6 @@ var iap = func(ipStr string) *netip.Addr { return &ip } -func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { - mach := func(hostname, username string, userid uint) *types.Node { - return &types.Node{ - Hostname: hostname, - UserID: userid, - User: types.User{ - Model: gorm.Model{ - ID: userid, - }, - Name: username, - }, - } - } - - nodeInShared1 := mach("test_get_shared_nodes_1", "user1", 1) - nodeInShared2 := mach("test_get_shared_nodes_2", "user2", 2) - nodeInShared3 := mach("test_get_shared_nodes_3", "user3", 3) - node2InShared1 := mach("test_get_shared_nodes_4", "user1", 1) - - userProfiles := generateUserProfiles( - nodeInShared1, - types.Nodes{ - nodeInShared2, nodeInShared3, node2InShared1, - }, - ) - - c.Assert(len(userProfiles), check.Equals, 3) - - users := []string{ - "user1", "user2", "user3", - } - - for _, user := range users { - found := false - for _, userProfile := range userProfiles { - if userProfile.DisplayName == user { - found = true - - break - } - } - c.Assert(found, check.Equals, true) - } -} - func TestDNSConfigMapResponse(t *testing.T) { tests := []struct { magicDNS bool diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index 56704c24..a7b12b1d 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -2165,6 +2165,9 @@ func TestReduceFilterRules(t *testing.T) { netip.MustParsePrefix("10.33.0.0/16"), }, }, + ApprovedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.33.0.0/16"), + }, }, peers: types.Nodes{ &types.Node{ @@ -2292,6 +2295,7 @@ func TestReduceFilterRules(t *testing.T) { Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, + ApprovedRoutes: tsaddr.ExitRoutes(), }, peers: types.Nodes{ &types.Node{ @@ -2398,6 +2402,7 @@ func TestReduceFilterRules(t *testing.T) { Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, + ApprovedRoutes: tsaddr.ExitRoutes(), }, peers: types.Nodes{ &types.Node{ @@ -2513,6 +2518,10 @@ func TestReduceFilterRules(t *testing.T) { netip.MustParsePrefix("16.0.0.0/16"), }, }, + ApprovedRoutes: []netip.Prefix{ + netip.MustParsePrefix("8.0.0.0/16"), + netip.MustParsePrefix("16.0.0.0/16"), + }, }, peers: types.Nodes{ &types.Node{ @@ -2603,6 +2612,10 @@ func TestReduceFilterRules(t *testing.T) { netip.MustParsePrefix("16.0.0.0/8"), }, }, + ApprovedRoutes: []netip.Prefix{ + netip.MustParsePrefix("8.0.0.0/8"), + netip.MustParsePrefix("16.0.0.0/8"), + }, }, peers: types.Nodes{ &types.Node{ @@ -2683,7 +2696,8 @@ func TestReduceFilterRules(t *testing.T) { Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, }, - ForcedTags: []string{"tag:access-servers"}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, + ForcedTags: []string{"tag:access-servers"}, }, peers: types.Nodes{ &types.Node{ diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go index 1905dad2..2b86416e 100644 --- a/hscontrol/policy/matcher/matcher.go +++ b/hscontrol/policy/matcher/matcher.go @@ -9,8 +9,8 @@ import ( ) type Match struct { - Srcs *netipx.IPSet - Dests *netipx.IPSet + srcs *netipx.IPSet + dests *netipx.IPSet } func MatchFromFilterRule(rule tailcfg.FilterRule) Match { @@ -42,16 +42,16 @@ func MatchFromStrings(sources, destinations []string) Match { destsSet, _ := dests.IPSet() match := Match{ - Srcs: srcsSet, - Dests: destsSet, + srcs: srcsSet, + dests: destsSet, } return match } -func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool { +func (m *Match) SrcsContainsIPs(ips ...netip.Addr) bool { for _, ip := range ips { - if m.Srcs.Contains(ip) { + if m.srcs.Contains(ip) { return true } } @@ -59,9 +59,29 @@ func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool { return false } -func (m *Match) DestsContainsIP(ips []netip.Addr) bool { +func (m *Match) DestsContainsIP(ips ...netip.Addr) bool { for _, ip := range ips { - if m.Dests.Contains(ip) { + if m.dests.Contains(ip) { + return true + } + } + + return false +} + +func (m *Match) SrcsOverlapsPrefixes(prefixes ...netip.Prefix) bool { + for _, prefix := range prefixes { + if m.srcs.ContainsPrefix(prefix) { + return true + } + } + + return false +} + +func (m *Match) DestsOverlapsPrefixes(prefixes ...netip.Prefix) bool { + for _, prefix := range prefixes { + if m.dests.ContainsPrefix(prefix) { return true } } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 581c4eb6..6654cefd 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -202,11 +202,15 @@ func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool { } for _, matcher := range matchers { - if !matcher.SrcsContainsIPs(src) { + if !matcher.SrcsContainsIPs(src...) { continue } - if matcher.DestsContainsIP(allowedIPs) { + if matcher.DestsContainsIP(allowedIPs...) { + return true + } + + if matcher.DestsOverlapsPrefixes(node2.SubnetRoutes()...) { return true } } diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index cd6a4780..2eba5f0f 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -55,6 +55,13 @@ type User struct { ProfilePicURL string } +func (u *User) StringID() string { + if u == nil { + return "" + } + return strconv.FormatUint(uint64(u.ID), 10) +} + // Username is the main way to get the username of a user, // it will return the email if it exists, the name if it exists, // the OIDCIdentifier if it exists, and the ID if nothing else exists. @@ -63,7 +70,11 @@ type User struct { // should be used throughout headscale, in information returned to the // user and the Policy engine. func (u *User) Username() string { - return cmp.Or(u.Email, u.Name, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10)) + return cmp.Or( + u.Email, + u.Name, + u.ProviderIdentifier.String, + u.StringID()) } // DisplayNameOrUsername returns the DisplayName if it exists, otherwise