mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	Add tests to verify "Hosts" aliases in ACL (#1304)
This commit is contained in:
		
							parent
							
								
									681c86cc95
								
							
						
					
					
						commit
						ceeef40cdf
					
				
							
								
								
									
										50
									
								
								acls.go
									
									
									
									
									
								
							
							
						
						
									
										50
									
								
								acls.go
									
									
									
									
									
								
							| @ -14,6 +14,7 @@ import ( | ||||
| 
 | ||||
| 	"github.com/rs/zerolog/log" | ||||
| 	"github.com/tailscale/hujson" | ||||
| 	"go4.org/netipx" | ||||
| 	"gopkg.in/yaml.v3" | ||||
| 	"tailscale.com/envknob" | ||||
| 	"tailscale.com/tailcfg" | ||||
| @ -165,16 +166,22 @@ func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string]map[string]s | ||||
| 	aclCachePeerMap := make(map[string]map[string]struct{}) | ||||
| 	for _, rule := range rules { | ||||
| 		for _, srcIP := range rule.SrcIPs { | ||||
| 			if data, ok := aclCachePeerMap[srcIP]; ok { | ||||
| 			for _, ip := range expandACLPeerAddr(srcIP) { | ||||
| 				if data, ok := aclCachePeerMap[ip]; ok { | ||||
| 					for _, dstPort := range rule.DstPorts { | ||||
| 					data[dstPort.IP] = struct{}{} | ||||
| 						for _, dstIP := range expandACLPeerAddr(dstPort.IP) { | ||||
| 							data[dstIP] = struct{}{} | ||||
| 						} | ||||
| 					} | ||||
| 				} else { | ||||
| 					dstPortsMap := make(map[string]struct{}, len(rule.DstPorts)) | ||||
| 					for _, dstPort := range rule.DstPorts { | ||||
| 					dstPortsMap[dstPort.IP] = struct{}{} | ||||
| 						for _, dstIP := range expandACLPeerAddr(dstPort.IP) { | ||||
| 							dstPortsMap[dstIP] = struct{}{} | ||||
| 						} | ||||
| 					} | ||||
| 					aclCachePeerMap[ip] = dstPortsMap | ||||
| 				} | ||||
| 				aclCachePeerMap[srcIP] = dstPortsMap | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| @ -184,6 +191,41 @@ func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string]map[string]s | ||||
| 	return aclCachePeerMap | ||||
| } | ||||
| 
 | ||||
| // expandACLPeerAddr takes a "tailcfg.FilterRule" "IP" and expands it into
 | ||||
| // something our cache logic can look up, which is "*" or single IP addresses.
 | ||||
| // This is probably quite inefficient, but it is a result of
 | ||||
| // "make it work, then make it fast", and a lot of the ACL stuff does not
 | ||||
| // work, but people have tried to make it fast.
 | ||||
| func expandACLPeerAddr(srcIP string) []string { | ||||
| 	if ip, err := netip.ParseAddr(srcIP); err == nil { | ||||
| 		return []string{ip.String()} | ||||
| 	} | ||||
| 
 | ||||
| 	if cidr, err := netip.ParsePrefix(srcIP); err == nil { | ||||
| 		addrs := []string{} | ||||
| 
 | ||||
| 		ipRange := netipx.RangeOfPrefix(cidr) | ||||
| 
 | ||||
| 		from := ipRange.From() | ||||
| 		too := ipRange.To() | ||||
| 
 | ||||
| 		if from == too { | ||||
| 			return []string{from.String()} | ||||
| 		} | ||||
| 
 | ||||
| 		for from != too { | ||||
| 			addrs = append(addrs, from.String()) | ||||
| 
 | ||||
| 			from = from.Next() | ||||
| 		} | ||||
| 
 | ||||
| 		return addrs | ||||
| 	} | ||||
| 
 | ||||
| 	// probably "*" or other string based "IP"
 | ||||
| 	return []string{srcIP} | ||||
| } | ||||
| 
 | ||||
| func generateACLRules( | ||||
| 	machines []Machine, | ||||
| 	aclPolicy ACLPolicy, | ||||
|  | ||||
							
								
								
									
										64
									
								
								acls_test.go
									
									
									
									
									
								
							
							
						
						
									
										64
									
								
								acls_test.go
									
									
									
									
									
								
							| @ -1556,3 +1556,67 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func Test_expandACLPeerAddr(t *testing.T) { | ||||
| 	type args struct { | ||||
| 		srcIP string | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name string | ||||
| 		args args | ||||
| 		want []string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "asterix", | ||||
| 			args: args{ | ||||
| 				srcIP: "*", | ||||
| 			}, | ||||
| 			want: []string{"*"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "ip", | ||||
| 			args: args{ | ||||
| 				srcIP: "10.0.0.1", | ||||
| 			}, | ||||
| 			want: []string{"10.0.0.1"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "ip/32", | ||||
| 			args: args{ | ||||
| 				srcIP: "10.0.0.1/32", | ||||
| 			}, | ||||
| 			want: []string{"10.0.0.1"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "ip/30", | ||||
| 			args: args{ | ||||
| 				srcIP: "10.0.0.1/30", | ||||
| 			}, | ||||
| 			want: []string{ | ||||
| 				"10.0.0.0", | ||||
| 				"10.0.0.1", | ||||
| 				"10.0.0.2", | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "ip/28", | ||||
| 			args: args{ | ||||
| 				srcIP: "192.168.0.128/28", | ||||
| 			}, | ||||
| 			want: []string{ | ||||
| 				"192.168.0.128", "192.168.0.129", "192.168.0.130", | ||||
| 				"192.168.0.131", "192.168.0.132", "192.168.0.133", | ||||
| 				"192.168.0.134", "192.168.0.135", "192.168.0.136", | ||||
| 				"192.168.0.137", "192.168.0.138", "192.168.0.139", | ||||
| 				"192.168.0.140", "192.168.0.141", "192.168.0.142", | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			if got := expandACLPeerAddr(tt.args.srcIP); !reflect.DeepEqual(got, tt.want) { | ||||
| 				t.Errorf("expandACLPeerAddr() = %v, want %v", got, tt.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -2,6 +2,7 @@ package integration | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/netip" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 
 | ||||
| @ -439,3 +440,214 @@ func TestACLAllowStarDst(t *testing.T) { | ||||
| 	err = scenario.Shutdown() | ||||
| 	assert.NoError(t, err) | ||||
| } | ||||
| 
 | ||||
| // This test aims to cover cases where individual hosts are allowed and denied
 | ||||
| // access based on their assigned hostname
 | ||||
| // https://github.com/juanfont/headscale/issues/941
 | ||||
| 
 | ||||
| //	ACL = [{
 | ||||
| //			"DstPorts": [{
 | ||||
| //				"Bits": null,
 | ||||
| //				"IP": "100.64.0.3/32",
 | ||||
| //				"Ports": {
 | ||||
| //					"First": 0,
 | ||||
| //					"Last": 65535
 | ||||
| //				}
 | ||||
| //			}],
 | ||||
| //			"SrcIPs": ["*"]
 | ||||
| //		}, {
 | ||||
| //
 | ||||
| //			"DstPorts": [{
 | ||||
| //				"Bits": null,
 | ||||
| //				"IP": "100.64.0.2/32",
 | ||||
| //				"Ports": {
 | ||||
| //					"First": 0,
 | ||||
| //					"Last": 65535
 | ||||
| //				}
 | ||||
| //			}],
 | ||||
| //			"SrcIPs": ["100.64.0.1/32"]
 | ||||
| //		}]
 | ||||
| //
 | ||||
| //	ACL Cache Map= {
 | ||||
| //		"*": {
 | ||||
| //			"100.64.0.3/32": {}
 | ||||
| //		},
 | ||||
| //		"100.64.0.1/32": {
 | ||||
| //			"100.64.0.2/32": {}
 | ||||
| //		}
 | ||||
| //	}
 | ||||
| func TestACLNamedHostsCanReach(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 
 | ||||
| 	scenario := aclScenario(t, | ||||
| 		headscale.ACLPolicy{ | ||||
| 			Hosts: headscale.Hosts{ | ||||
| 				"test1": netip.MustParsePrefix("100.64.0.1/32"), | ||||
| 				"test2": netip.MustParsePrefix("100.64.0.2/32"), | ||||
| 				"test3": netip.MustParsePrefix("100.64.0.3/32"), | ||||
| 			}, | ||||
| 			ACLs: []headscale.ACL{ | ||||
| 				// Everyone can curl test3
 | ||||
| 				{ | ||||
| 					Action:       "accept", | ||||
| 					Sources:      []string{"*"}, | ||||
| 					Destinations: []string{"test3:*"}, | ||||
| 				}, | ||||
| 				// test1 can curl test2
 | ||||
| 				{ | ||||
| 					Action:       "accept", | ||||
| 					Sources:      []string{"test1"}, | ||||
| 					Destinations: []string{"test2:*"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	) | ||||
| 
 | ||||
| 	// Since user/users dont matter here, we basically expect that some clients
 | ||||
| 	// will be assigned these ips and that we can pick them up for our own use.
 | ||||
| 	test1ip := netip.MustParseAddr("100.64.0.1") | ||||
| 	test1, err := scenario.FindTailscaleClientByIP(test1ip) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	test1fqdn, err := test1.FQDN() | ||||
| 	assert.NoError(t, err) | ||||
| 	test1ipURL := fmt.Sprintf("http://%s/etc/hostname", test1ip.String()) | ||||
| 	test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn) | ||||
| 
 | ||||
| 	test2ip := netip.MustParseAddr("100.64.0.2") | ||||
| 	test2, err := scenario.FindTailscaleClientByIP(test2ip) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	test2fqdn, err := test2.FQDN() | ||||
| 	assert.NoError(t, err) | ||||
| 	test2ipURL := fmt.Sprintf("http://%s/etc/hostname", test2ip.String()) | ||||
| 	test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn) | ||||
| 
 | ||||
| 	test3ip := netip.MustParseAddr("100.64.0.3") | ||||
| 	test3, err := scenario.FindTailscaleClientByIP(test3ip) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	test3fqdn, err := test3.FQDN() | ||||
| 	assert.NoError(t, err) | ||||
| 	test3ipURL := fmt.Sprintf("http://%s/etc/hostname", test3ip.String()) | ||||
| 	test3fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test3fqdn) | ||||
| 
 | ||||
| 	// test1 can query test3
 | ||||
| 	result, err := test1.Curl(test3ipURL) | ||||
| 	assert.Len(t, result, 13) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	result, err = test1.Curl(test3fqdnURL) | ||||
| 	assert.Len(t, result, 13) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	// test2 can query test3
 | ||||
| 	result, err = test2.Curl(test3ipURL) | ||||
| 	assert.Len(t, result, 13) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	result, err = test2.Curl(test3fqdnURL) | ||||
| 	assert.Len(t, result, 13) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	// test3 cannot query test1
 | ||||
| 	result, err = test3.Curl(test1ipURL) | ||||
| 	assert.Empty(t, result) | ||||
| 	assert.Error(t, err) | ||||
| 
 | ||||
| 	result, err = test3.Curl(test1fqdnURL) | ||||
| 	assert.Empty(t, result) | ||||
| 	assert.Error(t, err) | ||||
| 
 | ||||
| 	// test3 cannot query test2
 | ||||
| 	result, err = test3.Curl(test2ipURL) | ||||
| 	assert.Empty(t, result) | ||||
| 	assert.Error(t, err) | ||||
| 
 | ||||
| 	result, err = test3.Curl(test2fqdnURL) | ||||
| 	assert.Empty(t, result) | ||||
| 	assert.Error(t, err) | ||||
| 
 | ||||
| 	// test1 can query test2
 | ||||
| 	result, err = test1.Curl(test2ipURL) | ||||
| 	assert.Len(t, result, 13) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	result, err = test1.Curl(test2fqdnURL) | ||||
| 	assert.Len(t, result, 13) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	// test2 cannot query test1
 | ||||
| 	result, err = test2.Curl(test1ipURL) | ||||
| 	assert.Empty(t, result) | ||||
| 	assert.Error(t, err) | ||||
| 
 | ||||
| 	result, err = test2.Curl(test1fqdnURL) | ||||
| 	assert.Empty(t, result) | ||||
| 	assert.Error(t, err) | ||||
| 
 | ||||
| 	err = scenario.Shutdown() | ||||
| 	assert.NoError(t, err) | ||||
| } | ||||
| 
 | ||||
| // TestACLNamedHostsCanReachBySubnet is the same as
 | ||||
| // TestACLNamedHostsCanReach, but it tests if we expand a
 | ||||
| // full CIDR correctly. All routes should work.
 | ||||
| func TestACLNamedHostsCanReachBySubnet(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 
 | ||||
| 	scenario := aclScenario(t, | ||||
| 		headscale.ACLPolicy{ | ||||
| 			Hosts: headscale.Hosts{ | ||||
| 				"all": netip.MustParsePrefix("100.64.0.0/24"), | ||||
| 			}, | ||||
| 			ACLs: []headscale.ACL{ | ||||
| 				// Everyone can curl test3
 | ||||
| 				{ | ||||
| 					Action:       "accept", | ||||
| 					Sources:      []string{"*"}, | ||||
| 					Destinations: []string{"all:*"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	) | ||||
| 
 | ||||
| 	user1Clients, err := scenario.ListTailscaleClients("user1") | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	user2Clients, err := scenario.ListTailscaleClients("user2") | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	// Test that user1 can visit all user2
 | ||||
| 	for _, client := range user1Clients { | ||||
| 		for _, peer := range user2Clients { | ||||
| 			fqdn, err := peer.FQDN() | ||||
| 			assert.NoError(t, err) | ||||
| 
 | ||||
| 			url := fmt.Sprintf("http://%s/etc/hostname", fqdn) | ||||
| 			t.Logf("url from %s to %s", client.Hostname(), url) | ||||
| 
 | ||||
| 			result, err := client.Curl(url) | ||||
| 			assert.Len(t, result, 13) | ||||
| 			assert.NoError(t, err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// Test that user2 can visit all user1
 | ||||
| 	for _, client := range user2Clients { | ||||
| 		for _, peer := range user1Clients { | ||||
| 			fqdn, err := peer.FQDN() | ||||
| 			assert.NoError(t, err) | ||||
| 
 | ||||
| 			url := fmt.Sprintf("http://%s/etc/hostname", fqdn) | ||||
| 			t.Logf("url from %s to %s", client.Hostname(), url) | ||||
| 
 | ||||
| 			result, err := client.Curl(url) | ||||
| 			assert.Len(t, result, 13) | ||||
| 			assert.NoError(t, err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	err = scenario.Shutdown() | ||||
| 	assert.NoError(t, err) | ||||
| } | ||||
|  | ||||
							
								
								
									
										17
									
								
								machine.go
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								machine.go
									
									
									
									
									
								
							| @ -170,13 +170,14 @@ func (h *Headscale) filterMachinesByACL(currentMachine *Machine, peers Machines) | ||||
| // filterMachinesByACL returns the list of peers authorized to be accessed from a given machine.
 | ||||
| func filterMachinesByACL( | ||||
| 	machine *Machine, | ||||
| 	machines []Machine, | ||||
| 	machines Machines, | ||||
| 	lock *sync.RWMutex, | ||||
| 	aclPeerCacheMap map[string]map[string]struct{}, | ||||
| ) Machines { | ||||
| 	log.Trace(). | ||||
| 		Caller(). | ||||
| 		Str("machine", machine.Hostname). | ||||
| 		Str("self", machine.Hostname). | ||||
| 		Str("input", machines.String()). | ||||
| 		Msg("Finding peers filtered by ACLs") | ||||
| 
 | ||||
| 	peers := make(map[uint64]Machine) | ||||
| @ -263,7 +264,7 @@ func filterMachinesByACL( | ||||
| 
 | ||||
| 	lock.RUnlock() | ||||
| 
 | ||||
| 	authorizedPeers := make([]Machine, 0, len(peers)) | ||||
| 	authorizedPeers := make(Machines, 0, len(peers)) | ||||
| 	for _, m := range peers { | ||||
| 		authorizedPeers = append(authorizedPeers, m) | ||||
| 	} | ||||
| @ -274,8 +275,9 @@ func filterMachinesByACL( | ||||
| 
 | ||||
| 	log.Trace(). | ||||
| 		Caller(). | ||||
| 		Str("machine", machine.Hostname). | ||||
| 		Msgf("Found some machines: %v", machines) | ||||
| 		Str("self", machine.Hostname). | ||||
| 		Str("peers", authorizedPeers.String()). | ||||
| 		Msg("Authorized peers") | ||||
| 
 | ||||
| 	return authorizedPeers | ||||
| } | ||||
| @ -335,8 +337,9 @@ func (h *Headscale) getPeers(machine *Machine) (Machines, error) { | ||||
| 
 | ||||
| 	log.Trace(). | ||||
| 		Caller(). | ||||
| 		Str("machine", machine.Hostname). | ||||
| 		Msgf("Found total peers: %s", peers.String()) | ||||
| 		Str("self", machine.Hostname). | ||||
| 		Str("peers", peers.String()). | ||||
| 		Msg("Peers returned to caller") | ||||
| 
 | ||||
| 	return peers, nil | ||||
| } | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user