mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	Add tests to verify "Hosts" aliases in ACL (#1304)
This commit is contained in:
		
							parent
							
								
									681c86cc95
								
							
						
					
					
						commit
						ceeef40cdf
					
				
							
								
								
									
										60
									
								
								acls.go
									
									
									
									
									
								
							
							
						
						
									
										60
									
								
								acls.go
									
									
									
									
									
								
							| @ -14,6 +14,7 @@ import ( | |||||||
| 
 | 
 | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| 	"github.com/tailscale/hujson" | 	"github.com/tailscale/hujson" | ||||||
|  | 	"go4.org/netipx" | ||||||
| 	"gopkg.in/yaml.v3" | 	"gopkg.in/yaml.v3" | ||||||
| 	"tailscale.com/envknob" | 	"tailscale.com/envknob" | ||||||
| 	"tailscale.com/tailcfg" | 	"tailscale.com/tailcfg" | ||||||
| @ -165,16 +166,22 @@ func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string]map[string]s | |||||||
| 	aclCachePeerMap := make(map[string]map[string]struct{}) | 	aclCachePeerMap := make(map[string]map[string]struct{}) | ||||||
| 	for _, rule := range rules { | 	for _, rule := range rules { | ||||||
| 		for _, srcIP := range rule.SrcIPs { | 		for _, srcIP := range rule.SrcIPs { | ||||||
| 			if data, ok := aclCachePeerMap[srcIP]; ok { | 			for _, ip := range expandACLPeerAddr(srcIP) { | ||||||
| 				for _, dstPort := range rule.DstPorts { | 				if data, ok := aclCachePeerMap[ip]; ok { | ||||||
| 					data[dstPort.IP] = struct{}{} | 					for _, dstPort := range rule.DstPorts { | ||||||
|  | 						for _, dstIP := range expandACLPeerAddr(dstPort.IP) { | ||||||
|  | 							data[dstIP] = struct{}{} | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 				} else { | ||||||
|  | 					dstPortsMap := make(map[string]struct{}, len(rule.DstPorts)) | ||||||
|  | 					for _, dstPort := range rule.DstPorts { | ||||||
|  | 						for _, dstIP := range expandACLPeerAddr(dstPort.IP) { | ||||||
|  | 							dstPortsMap[dstIP] = struct{}{} | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 					aclCachePeerMap[ip] = dstPortsMap | ||||||
| 				} | 				} | ||||||
| 			} else { |  | ||||||
| 				dstPortsMap := make(map[string]struct{}, len(rule.DstPorts)) |  | ||||||
| 				for _, dstPort := range rule.DstPorts { |  | ||||||
| 					dstPortsMap[dstPort.IP] = struct{}{} |  | ||||||
| 				} |  | ||||||
| 				aclCachePeerMap[srcIP] = dstPortsMap |  | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| @ -184,6 +191,41 @@ func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string]map[string]s | |||||||
| 	return aclCachePeerMap | 	return aclCachePeerMap | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // expandACLPeerAddr takes a "tailcfg.FilterRule" "IP" and expands it into
 | ||||||
|  | // something our cache logic can look up, which is "*" or single IP addresses.
 | ||||||
|  | // This is probably quite inefficient, but it is a result of
 | ||||||
|  | // "make it work, then make it fast", and a lot of the ACL stuff does not
 | ||||||
|  | // work, but people have tried to make it fast.
 | ||||||
|  | func expandACLPeerAddr(srcIP string) []string { | ||||||
|  | 	if ip, err := netip.ParseAddr(srcIP); err == nil { | ||||||
|  | 		return []string{ip.String()} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if cidr, err := netip.ParsePrefix(srcIP); err == nil { | ||||||
|  | 		addrs := []string{} | ||||||
|  | 
 | ||||||
|  | 		ipRange := netipx.RangeOfPrefix(cidr) | ||||||
|  | 
 | ||||||
|  | 		from := ipRange.From() | ||||||
|  | 		too := ipRange.To() | ||||||
|  | 
 | ||||||
|  | 		if from == too { | ||||||
|  | 			return []string{from.String()} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for from != too { | ||||||
|  | 			addrs = append(addrs, from.String()) | ||||||
|  | 
 | ||||||
|  | 			from = from.Next() | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return addrs | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// probably "*" or other string based "IP"
 | ||||||
|  | 	return []string{srcIP} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func generateACLRules( | func generateACLRules( | ||||||
| 	machines []Machine, | 	machines []Machine, | ||||||
| 	aclPolicy ACLPolicy, | 	aclPolicy ACLPolicy, | ||||||
|  | |||||||
							
								
								
									
										64
									
								
								acls_test.go
									
									
									
									
									
								
							
							
						
						
									
										64
									
								
								acls_test.go
									
									
									
									
									
								
							| @ -1556,3 +1556,67 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { | |||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func Test_expandACLPeerAddr(t *testing.T) { | ||||||
|  | 	type args struct { | ||||||
|  | 		srcIP string | ||||||
|  | 	} | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name string | ||||||
|  | 		args args | ||||||
|  | 		want []string | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name: "asterix", | ||||||
|  | 			args: args{ | ||||||
|  | 				srcIP: "*", | ||||||
|  | 			}, | ||||||
|  | 			want: []string{"*"}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "ip", | ||||||
|  | 			args: args{ | ||||||
|  | 				srcIP: "10.0.0.1", | ||||||
|  | 			}, | ||||||
|  | 			want: []string{"10.0.0.1"}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "ip/32", | ||||||
|  | 			args: args{ | ||||||
|  | 				srcIP: "10.0.0.1/32", | ||||||
|  | 			}, | ||||||
|  | 			want: []string{"10.0.0.1"}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "ip/30", | ||||||
|  | 			args: args{ | ||||||
|  | 				srcIP: "10.0.0.1/30", | ||||||
|  | 			}, | ||||||
|  | 			want: []string{ | ||||||
|  | 				"10.0.0.0", | ||||||
|  | 				"10.0.0.1", | ||||||
|  | 				"10.0.0.2", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "ip/28", | ||||||
|  | 			args: args{ | ||||||
|  | 				srcIP: "192.168.0.128/28", | ||||||
|  | 			}, | ||||||
|  | 			want: []string{ | ||||||
|  | 				"192.168.0.128", "192.168.0.129", "192.168.0.130", | ||||||
|  | 				"192.168.0.131", "192.168.0.132", "192.168.0.133", | ||||||
|  | 				"192.168.0.134", "192.168.0.135", "192.168.0.136", | ||||||
|  | 				"192.168.0.137", "192.168.0.138", "192.168.0.139", | ||||||
|  | 				"192.168.0.140", "192.168.0.141", "192.168.0.142", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	for _, tt := range tests { | ||||||
|  | 		t.Run(tt.name, func(t *testing.T) { | ||||||
|  | 			if got := expandACLPeerAddr(tt.args.srcIP); !reflect.DeepEqual(got, tt.want) { | ||||||
|  | 				t.Errorf("expandACLPeerAddr() = %v, want %v", got, tt.want) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
| @ -2,6 +2,7 @@ package integration | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"net/netip" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| @ -439,3 +440,214 @@ func TestACLAllowStarDst(t *testing.T) { | |||||||
| 	err = scenario.Shutdown() | 	err = scenario.Shutdown() | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // This test aims to cover cases where individual hosts are allowed and denied
 | ||||||
|  | // access based on their assigned hostname
 | ||||||
|  | // https://github.com/juanfont/headscale/issues/941
 | ||||||
|  | 
 | ||||||
|  | //	ACL = [{
 | ||||||
|  | //			"DstPorts": [{
 | ||||||
|  | //				"Bits": null,
 | ||||||
|  | //				"IP": "100.64.0.3/32",
 | ||||||
|  | //				"Ports": {
 | ||||||
|  | //					"First": 0,
 | ||||||
|  | //					"Last": 65535
 | ||||||
|  | //				}
 | ||||||
|  | //			}],
 | ||||||
|  | //			"SrcIPs": ["*"]
 | ||||||
|  | //		}, {
 | ||||||
|  | //
 | ||||||
|  | //			"DstPorts": [{
 | ||||||
|  | //				"Bits": null,
 | ||||||
|  | //				"IP": "100.64.0.2/32",
 | ||||||
|  | //				"Ports": {
 | ||||||
|  | //					"First": 0,
 | ||||||
|  | //					"Last": 65535
 | ||||||
|  | //				}
 | ||||||
|  | //			}],
 | ||||||
|  | //			"SrcIPs": ["100.64.0.1/32"]
 | ||||||
|  | //		}]
 | ||||||
|  | //
 | ||||||
|  | //	ACL Cache Map= {
 | ||||||
|  | //		"*": {
 | ||||||
|  | //			"100.64.0.3/32": {}
 | ||||||
|  | //		},
 | ||||||
|  | //		"100.64.0.1/32": {
 | ||||||
|  | //			"100.64.0.2/32": {}
 | ||||||
|  | //		}
 | ||||||
|  | //	}
 | ||||||
|  | func TestACLNamedHostsCanReach(t *testing.T) { | ||||||
|  | 	IntegrationSkip(t) | ||||||
|  | 
 | ||||||
|  | 	scenario := aclScenario(t, | ||||||
|  | 		headscale.ACLPolicy{ | ||||||
|  | 			Hosts: headscale.Hosts{ | ||||||
|  | 				"test1": netip.MustParsePrefix("100.64.0.1/32"), | ||||||
|  | 				"test2": netip.MustParsePrefix("100.64.0.2/32"), | ||||||
|  | 				"test3": netip.MustParsePrefix("100.64.0.3/32"), | ||||||
|  | 			}, | ||||||
|  | 			ACLs: []headscale.ACL{ | ||||||
|  | 				// Everyone can curl test3
 | ||||||
|  | 				{ | ||||||
|  | 					Action:       "accept", | ||||||
|  | 					Sources:      []string{"*"}, | ||||||
|  | 					Destinations: []string{"test3:*"}, | ||||||
|  | 				}, | ||||||
|  | 				// test1 can curl test2
 | ||||||
|  | 				{ | ||||||
|  | 					Action:       "accept", | ||||||
|  | 					Sources:      []string{"test1"}, | ||||||
|  | 					Destinations: []string{"test2:*"}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	// Since user/users dont matter here, we basically expect that some clients
 | ||||||
|  | 	// will be assigned these ips and that we can pick them up for our own use.
 | ||||||
|  | 	test1ip := netip.MustParseAddr("100.64.0.1") | ||||||
|  | 	test1, err := scenario.FindTailscaleClientByIP(test1ip) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	test1fqdn, err := test1.FQDN() | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	test1ipURL := fmt.Sprintf("http://%s/etc/hostname", test1ip.String()) | ||||||
|  | 	test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn) | ||||||
|  | 
 | ||||||
|  | 	test2ip := netip.MustParseAddr("100.64.0.2") | ||||||
|  | 	test2, err := scenario.FindTailscaleClientByIP(test2ip) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	test2fqdn, err := test2.FQDN() | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	test2ipURL := fmt.Sprintf("http://%s/etc/hostname", test2ip.String()) | ||||||
|  | 	test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn) | ||||||
|  | 
 | ||||||
|  | 	test3ip := netip.MustParseAddr("100.64.0.3") | ||||||
|  | 	test3, err := scenario.FindTailscaleClientByIP(test3ip) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	test3fqdn, err := test3.FQDN() | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	test3ipURL := fmt.Sprintf("http://%s/etc/hostname", test3ip.String()) | ||||||
|  | 	test3fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test3fqdn) | ||||||
|  | 
 | ||||||
|  | 	// test1 can query test3
 | ||||||
|  | 	result, err := test1.Curl(test3ipURL) | ||||||
|  | 	assert.Len(t, result, 13) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	result, err = test1.Curl(test3fqdnURL) | ||||||
|  | 	assert.Len(t, result, 13) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	// test2 can query test3
 | ||||||
|  | 	result, err = test2.Curl(test3ipURL) | ||||||
|  | 	assert.Len(t, result, 13) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	result, err = test2.Curl(test3fqdnURL) | ||||||
|  | 	assert.Len(t, result, 13) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	// test3 cannot query test1
 | ||||||
|  | 	result, err = test3.Curl(test1ipURL) | ||||||
|  | 	assert.Empty(t, result) | ||||||
|  | 	assert.Error(t, err) | ||||||
|  | 
 | ||||||
|  | 	result, err = test3.Curl(test1fqdnURL) | ||||||
|  | 	assert.Empty(t, result) | ||||||
|  | 	assert.Error(t, err) | ||||||
|  | 
 | ||||||
|  | 	// test3 cannot query test2
 | ||||||
|  | 	result, err = test3.Curl(test2ipURL) | ||||||
|  | 	assert.Empty(t, result) | ||||||
|  | 	assert.Error(t, err) | ||||||
|  | 
 | ||||||
|  | 	result, err = test3.Curl(test2fqdnURL) | ||||||
|  | 	assert.Empty(t, result) | ||||||
|  | 	assert.Error(t, err) | ||||||
|  | 
 | ||||||
|  | 	// test1 can query test2
 | ||||||
|  | 	result, err = test1.Curl(test2ipURL) | ||||||
|  | 	assert.Len(t, result, 13) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	result, err = test1.Curl(test2fqdnURL) | ||||||
|  | 	assert.Len(t, result, 13) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	// test2 cannot query test1
 | ||||||
|  | 	result, err = test2.Curl(test1ipURL) | ||||||
|  | 	assert.Empty(t, result) | ||||||
|  | 	assert.Error(t, err) | ||||||
|  | 
 | ||||||
|  | 	result, err = test2.Curl(test1fqdnURL) | ||||||
|  | 	assert.Empty(t, result) | ||||||
|  | 	assert.Error(t, err) | ||||||
|  | 
 | ||||||
|  | 	err = scenario.Shutdown() | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // TestACLNamedHostsCanReachBySubnet is the same as
 | ||||||
|  | // TestACLNamedHostsCanReach, but it tests if we expand a
 | ||||||
|  | // full CIDR correctly. All routes should work.
 | ||||||
|  | func TestACLNamedHostsCanReachBySubnet(t *testing.T) { | ||||||
|  | 	IntegrationSkip(t) | ||||||
|  | 
 | ||||||
|  | 	scenario := aclScenario(t, | ||||||
|  | 		headscale.ACLPolicy{ | ||||||
|  | 			Hosts: headscale.Hosts{ | ||||||
|  | 				"all": netip.MustParsePrefix("100.64.0.0/24"), | ||||||
|  | 			}, | ||||||
|  | 			ACLs: []headscale.ACL{ | ||||||
|  | 				// Everyone can curl test3
 | ||||||
|  | 				{ | ||||||
|  | 					Action:       "accept", | ||||||
|  | 					Sources:      []string{"*"}, | ||||||
|  | 					Destinations: []string{"all:*"}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	user1Clients, err := scenario.ListTailscaleClients("user1") | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	user2Clients, err := scenario.ListTailscaleClients("user2") | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 	// Test that user1 can visit all user2
 | ||||||
|  | 	for _, client := range user1Clients { | ||||||
|  | 		for _, peer := range user2Clients { | ||||||
|  | 			fqdn, err := peer.FQDN() | ||||||
|  | 			assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 			url := fmt.Sprintf("http://%s/etc/hostname", fqdn) | ||||||
|  | 			t.Logf("url from %s to %s", client.Hostname(), url) | ||||||
|  | 
 | ||||||
|  | 			result, err := client.Curl(url) | ||||||
|  | 			assert.Len(t, result, 13) | ||||||
|  | 			assert.NoError(t, err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Test that user2 can visit all user1
 | ||||||
|  | 	for _, client := range user2Clients { | ||||||
|  | 		for _, peer := range user1Clients { | ||||||
|  | 			fqdn, err := peer.FQDN() | ||||||
|  | 			assert.NoError(t, err) | ||||||
|  | 
 | ||||||
|  | 			url := fmt.Sprintf("http://%s/etc/hostname", fqdn) | ||||||
|  | 			t.Logf("url from %s to %s", client.Hostname(), url) | ||||||
|  | 
 | ||||||
|  | 			result, err := client.Curl(url) | ||||||
|  | 			assert.Len(t, result, 13) | ||||||
|  | 			assert.NoError(t, err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	err = scenario.Shutdown() | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | } | ||||||
|  | |||||||
							
								
								
									
										17
									
								
								machine.go
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								machine.go
									
									
									
									
									
								
							| @ -170,13 +170,14 @@ func (h *Headscale) filterMachinesByACL(currentMachine *Machine, peers Machines) | |||||||
| // filterMachinesByACL returns the list of peers authorized to be accessed from a given machine.
 | // filterMachinesByACL returns the list of peers authorized to be accessed from a given machine.
 | ||||||
| func filterMachinesByACL( | func filterMachinesByACL( | ||||||
| 	machine *Machine, | 	machine *Machine, | ||||||
| 	machines []Machine, | 	machines Machines, | ||||||
| 	lock *sync.RWMutex, | 	lock *sync.RWMutex, | ||||||
| 	aclPeerCacheMap map[string]map[string]struct{}, | 	aclPeerCacheMap map[string]map[string]struct{}, | ||||||
| ) Machines { | ) Machines { | ||||||
| 	log.Trace(). | 	log.Trace(). | ||||||
| 		Caller(). | 		Caller(). | ||||||
| 		Str("machine", machine.Hostname). | 		Str("self", machine.Hostname). | ||||||
|  | 		Str("input", machines.String()). | ||||||
| 		Msg("Finding peers filtered by ACLs") | 		Msg("Finding peers filtered by ACLs") | ||||||
| 
 | 
 | ||||||
| 	peers := make(map[uint64]Machine) | 	peers := make(map[uint64]Machine) | ||||||
| @ -263,7 +264,7 @@ func filterMachinesByACL( | |||||||
| 
 | 
 | ||||||
| 	lock.RUnlock() | 	lock.RUnlock() | ||||||
| 
 | 
 | ||||||
| 	authorizedPeers := make([]Machine, 0, len(peers)) | 	authorizedPeers := make(Machines, 0, len(peers)) | ||||||
| 	for _, m := range peers { | 	for _, m := range peers { | ||||||
| 		authorizedPeers = append(authorizedPeers, m) | 		authorizedPeers = append(authorizedPeers, m) | ||||||
| 	} | 	} | ||||||
| @ -274,8 +275,9 @@ func filterMachinesByACL( | |||||||
| 
 | 
 | ||||||
| 	log.Trace(). | 	log.Trace(). | ||||||
| 		Caller(). | 		Caller(). | ||||||
| 		Str("machine", machine.Hostname). | 		Str("self", machine.Hostname). | ||||||
| 		Msgf("Found some machines: %v", machines) | 		Str("peers", authorizedPeers.String()). | ||||||
|  | 		Msg("Authorized peers") | ||||||
| 
 | 
 | ||||||
| 	return authorizedPeers | 	return authorizedPeers | ||||||
| } | } | ||||||
| @ -335,8 +337,9 @@ func (h *Headscale) getPeers(machine *Machine) (Machines, error) { | |||||||
| 
 | 
 | ||||||
| 	log.Trace(). | 	log.Trace(). | ||||||
| 		Caller(). | 		Caller(). | ||||||
| 		Str("machine", machine.Hostname). | 		Str("self", machine.Hostname). | ||||||
| 		Msgf("Found total peers: %s", peers.String()) | 		Str("peers", peers.String()). | ||||||
|  | 		Msg("Peers returned to caller") | ||||||
| 
 | 
 | ||||||
| 	return peers, nil | 	return peers, nil | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user