mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	make parse destination string into a func
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									68f040a89c
								
							
						
					
					
						commit
						1c9c472d2c
					
				@ -375,9 +375,39 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
 | 
			
		||||
	machines types.Machines,
 | 
			
		||||
	needsWildcard bool,
 | 
			
		||||
) ([]tailcfg.NetPortRange, error) {
 | 
			
		||||
	var tokens []string
 | 
			
		||||
	alias, port, err := parseDestination(dest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Trace().Str("destination", dest).Msg("generating policy destination")
 | 
			
		||||
	expanded, err := pol.ExpandAlias(
 | 
			
		||||
		machines,
 | 
			
		||||
		alias,
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	ports, err := expandPorts(port, needsWildcard)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dests := []tailcfg.NetPortRange{}
 | 
			
		||||
	for _, dest := range expanded.Prefixes() {
 | 
			
		||||
		for _, port := range *ports {
 | 
			
		||||
			pr := tailcfg.NetPortRange{
 | 
			
		||||
				IP:    dest.String(),
 | 
			
		||||
				Ports: port,
 | 
			
		||||
			}
 | 
			
		||||
			dests = append(dests, pr)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return dests, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parseDestination(dest string) (string, string, error) {
 | 
			
		||||
	var tokens []string
 | 
			
		||||
 | 
			
		||||
	// Check if there is a IPv4/6:Port combination, IPv6 has more than
 | 
			
		||||
	// three ":".
 | 
			
		||||
@ -397,7 +427,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
 | 
			
		||||
		if maybeIPv6, err := netip.ParseAddr(filteredMaybeIPv6Str); err != nil && !maybeIPv6.Is6() {
 | 
			
		||||
			log.Trace().Err(err).Msg("trying to parse as IPv6")
 | 
			
		||||
 | 
			
		||||
			return nil, fmt.Errorf(
 | 
			
		||||
			return "", "", fmt.Errorf(
 | 
			
		||||
				"failed to parse destination, tokens %v: %w",
 | 
			
		||||
				tokens,
 | 
			
		||||
				ErrInvalidPortFormat,
 | 
			
		||||
@ -407,8 +437,6 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Trace().Strs("tokens", tokens).Msg("generating policy destination")
 | 
			
		||||
 | 
			
		||||
	var alias string
 | 
			
		||||
	// We can have here stuff like:
 | 
			
		||||
	// git-server:*
 | 
			
		||||
@ -424,30 +452,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
 | 
			
		||||
		alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	expanded, err := pol.ExpandAlias(
 | 
			
		||||
		machines,
 | 
			
		||||
		alias,
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	ports, err := expandPorts(tokens[len(tokens)-1], needsWildcard)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dests := []tailcfg.NetPortRange{}
 | 
			
		||||
	for _, dest := range expanded.Prefixes() {
 | 
			
		||||
		for _, port := range *ports {
 | 
			
		||||
			pr := tailcfg.NetPortRange{
 | 
			
		||||
				IP:    dest.String(),
 | 
			
		||||
				Ports: port,
 | 
			
		||||
			}
 | 
			
		||||
			dests = append(dests, pr)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return dests, nil
 | 
			
		||||
	return alias, tokens[len(tokens)-1], nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// parseProtocol reads the proto field of the ACL and generates a list of
 | 
			
		||||
 | 
			
		||||
@ -2557,3 +2557,66 @@ func TestSSHRules(t *testing.T) {
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestParseDestination(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		dest      string
 | 
			
		||||
		wantAlias string
 | 
			
		||||
		wantPort  string
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			dest:      "git-server:*",
 | 
			
		||||
			wantAlias: "git-server",
 | 
			
		||||
			wantPort:  "*",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			dest:      "192.168.1.0/24:22",
 | 
			
		||||
			wantAlias: "192.168.1.0/24",
 | 
			
		||||
			wantPort:  "22",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			dest:      "192.168.1.1:22",
 | 
			
		||||
			wantAlias: "192.168.1.1",
 | 
			
		||||
			wantPort:  "22",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			dest:      "fd7a:115c:a1e0::2:22",
 | 
			
		||||
			wantAlias: "fd7a:115c:a1e0::2",
 | 
			
		||||
			wantPort:  "22",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			dest:      "fd7a:115c:a1e0::2/128:22",
 | 
			
		||||
			wantAlias: "fd7a:115c:a1e0::2/128",
 | 
			
		||||
			wantPort:  "22",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			dest:      "tag:montreal-webserver:80,443",
 | 
			
		||||
			wantAlias: "tag:montreal-webserver",
 | 
			
		||||
			wantPort:  "80,443",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			dest:      "tag:api-server:443",
 | 
			
		||||
			wantAlias: "tag:api-server",
 | 
			
		||||
			wantPort:  "443",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			dest:      "example-host-1:*",
 | 
			
		||||
			wantAlias: "example-host-1",
 | 
			
		||||
			wantPort:  "*",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.dest, func(t *testing.T) {
 | 
			
		||||
			alias, port, _ := parseDestination(tt.dest)
 | 
			
		||||
 | 
			
		||||
			if alias != tt.wantAlias {
 | 
			
		||||
				t.Errorf("unexpected alias: want(%s) != got(%s)", tt.wantAlias, alias)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if port != tt.wantPort {
 | 
			
		||||
				t.Errorf("unexpected port: want(%s) != got(%s)", tt.wantPort, port)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user