mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			470 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			470 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package cli
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"encoding/json"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"io/fs"
 | |
| 	"net/url"
 | |
| 	"os"
 | |
| 	"path/filepath"
 | |
| 	"regexp"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/juanfont/headscale"
 | |
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
 | |
| 	"github.com/rs/zerolog/log"
 | |
| 	"github.com/spf13/viper"
 | |
| 	"google.golang.org/grpc"
 | |
| 	"gopkg.in/yaml.v2"
 | |
| 	"inet.af/netaddr"
 | |
| 	"tailscale.com/tailcfg"
 | |
| 	"tailscale.com/types/dnstype"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	PermissionFallback = 0o700
 | |
| )
 | |
| 
 | |
| func LoadConfig(path string) error {
 | |
| 	viper.SetConfigName("config")
 | |
| 	if path == "" {
 | |
| 		viper.AddConfigPath("/etc/headscale/")
 | |
| 		viper.AddConfigPath("$HOME/.headscale")
 | |
| 		viper.AddConfigPath(".")
 | |
| 	} else {
 | |
| 		// For testing
 | |
| 		viper.AddConfigPath(path)
 | |
| 	}
 | |
| 
 | |
| 	viper.SetEnvPrefix("headscale")
 | |
| 	viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
 | |
| 	viper.AutomaticEnv()
 | |
| 
 | |
| 	viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
 | |
| 	viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01")
 | |
| 
 | |
| 	viper.SetDefault("ip_prefix", "100.64.0.0/10")
 | |
| 
 | |
| 	viper.SetDefault("log_level", "info")
 | |
| 
 | |
| 	viper.SetDefault("dns_config", nil)
 | |
| 
 | |
| 	viper.SetDefault("unix_socket", "/var/run/headscale.sock")
 | |
| 	viper.SetDefault("unix_socket_permission", "0o770")
 | |
| 
 | |
| 	viper.SetDefault("cli.insecure", false)
 | |
| 	viper.SetDefault("cli.timeout", "5s")
 | |
| 
 | |
| 	if err := viper.ReadInConfig(); err != nil {
 | |
| 		return fmt.Errorf("fatal error reading config file: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	// Collect any validation errors and return them all at once
 | |
| 	var errorText string
 | |
| 	if (viper.GetString("tls_letsencrypt_hostname") != "") &&
 | |
| 		((viper.GetString("tls_cert_path") != "") || (viper.GetString("tls_key_path") != "")) {
 | |
| 		errorText += "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both\n"
 | |
| 	}
 | |
| 
 | |
| 	if (viper.GetString("tls_letsencrypt_hostname") != "") &&
 | |
| 		(viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") &&
 | |
| 		(!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) {
 | |
| 		// this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule)
 | |
| 		log.Warn().
 | |
| 			Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443")
 | |
| 	}
 | |
| 
 | |
| 	if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") &&
 | |
| 		(viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") {
 | |
| 		errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n"
 | |
| 	}
 | |
| 
 | |
| 	if !strings.HasPrefix(viper.GetString("server_url"), "http://") &&
 | |
| 		!strings.HasPrefix(viper.GetString("server_url"), "https://") {
 | |
| 		errorText += "Fatal config error: server_url must start with https:// or http://\n"
 | |
| 	}
 | |
| 	if errorText != "" {
 | |
| 		//nolint
 | |
| 		return errors.New(strings.TrimSuffix(errorText, "\n"))
 | |
| 	} else {
 | |
| 		return nil
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func GetDERPConfig() headscale.DERPConfig {
 | |
| 	urlStrs := viper.GetStringSlice("derp.urls")
 | |
| 
 | |
| 	urls := make([]url.URL, len(urlStrs))
 | |
| 	for index, urlStr := range urlStrs {
 | |
| 		urlAddr, err := url.Parse(urlStr)
 | |
| 		if err != nil {
 | |
| 			log.Error().
 | |
| 				Str("url", urlStr).
 | |
| 				Err(err).
 | |
| 				Msg("Failed to parse url, ignoring...")
 | |
| 		}
 | |
| 
 | |
| 		urls[index] = *urlAddr
 | |
| 	}
 | |
| 
 | |
| 	paths := viper.GetStringSlice("derp.paths")
 | |
| 
 | |
| 	autoUpdate := viper.GetBool("derp.auto_update_enabled")
 | |
| 	updateFrequency := viper.GetDuration("derp.update_frequency")
 | |
| 
 | |
| 	return headscale.DERPConfig{
 | |
| 		URLs:            urls,
 | |
| 		Paths:           paths,
 | |
| 		AutoUpdate:      autoUpdate,
 | |
| 		UpdateFrequency: updateFrequency,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func GetDNSConfig() (*tailcfg.DNSConfig, string) {
 | |
| 	if viper.IsSet("dns_config") {
 | |
| 		dnsConfig := &tailcfg.DNSConfig{}
 | |
| 
 | |
| 		if viper.IsSet("dns_config.nameservers") {
 | |
| 			nameserversStr := viper.GetStringSlice("dns_config.nameservers")
 | |
| 
 | |
| 			nameservers := make([]netaddr.IP, len(nameserversStr))
 | |
| 			resolvers := make([]dnstype.Resolver, len(nameserversStr))
 | |
| 
 | |
| 			for index, nameserverStr := range nameserversStr {
 | |
| 				nameserver, err := netaddr.ParseIP(nameserverStr)
 | |
| 				if err != nil {
 | |
| 					log.Error().
 | |
| 						Str("func", "getDNSConfig").
 | |
| 						Err(err).
 | |
| 						Msgf("Could not parse nameserver IP: %s", nameserverStr)
 | |
| 				}
 | |
| 
 | |
| 				nameservers[index] = nameserver
 | |
| 				resolvers[index] = dnstype.Resolver{
 | |
| 					Addr: nameserver.String(),
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			dnsConfig.Nameservers = nameservers
 | |
| 			dnsConfig.Resolvers = resolvers
 | |
| 		}
 | |
| 
 | |
| 		if viper.IsSet("dns_config.restricted_nameservers") {
 | |
| 			if len(dnsConfig.Nameservers) > 0 {
 | |
| 				dnsConfig.Routes = make(map[string][]dnstype.Resolver)
 | |
| 				restrictedDNS := viper.GetStringMapStringSlice(
 | |
| 					"dns_config.restricted_nameservers",
 | |
| 				)
 | |
| 				for domain, restrictedNameservers := range restrictedDNS {
 | |
| 					restrictedResolvers := make(
 | |
| 						[]dnstype.Resolver,
 | |
| 						len(restrictedNameservers),
 | |
| 					)
 | |
| 					for index, nameserverStr := range restrictedNameservers {
 | |
| 						nameserver, err := netaddr.ParseIP(nameserverStr)
 | |
| 						if err != nil {
 | |
| 							log.Error().
 | |
| 								Str("func", "getDNSConfig").
 | |
| 								Err(err).
 | |
| 								Msgf("Could not parse restricted nameserver IP: %s", nameserverStr)
 | |
| 						}
 | |
| 						restrictedResolvers[index] = dnstype.Resolver{
 | |
| 							Addr: nameserver.String(),
 | |
| 						}
 | |
| 					}
 | |
| 					dnsConfig.Routes[domain] = restrictedResolvers
 | |
| 				}
 | |
| 			} else {
 | |
| 				log.Warn().
 | |
| 					Msg("Warning: dns_config.restricted_nameservers is set, but no nameservers are configured. Ignoring restricted_nameservers.")
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if viper.IsSet("dns_config.domains") {
 | |
| 			dnsConfig.Domains = viper.GetStringSlice("dns_config.domains")
 | |
| 		}
 | |
| 
 | |
| 		if viper.IsSet("dns_config.magic_dns") {
 | |
| 			magicDNS := viper.GetBool("dns_config.magic_dns")
 | |
| 			if len(dnsConfig.Nameservers) > 0 {
 | |
| 				dnsConfig.Proxied = magicDNS
 | |
| 			} else if magicDNS {
 | |
| 				log.Warn().
 | |
| 					Msg("Warning: dns_config.magic_dns is set, but no nameservers are configured. Ignoring magic_dns.")
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		var baseDomain string
 | |
| 		if viper.IsSet("dns_config.base_domain") {
 | |
| 			baseDomain = viper.GetString("dns_config.base_domain")
 | |
| 		} else {
 | |
| 			baseDomain = "headscale.net" // does not really matter when MagicDNS is not enabled
 | |
| 		}
 | |
| 
 | |
| 		return dnsConfig, baseDomain
 | |
| 	}
 | |
| 
 | |
| 	return nil, ""
 | |
| }
 | |
| 
 | |
| func absPath(path string) string {
 | |
| 	// If a relative path is provided, prefix it with the the directory where
 | |
| 	// the config file was found.
 | |
| 	if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) {
 | |
| 		dir, _ := filepath.Split(viper.ConfigFileUsed())
 | |
| 		if dir != "" {
 | |
| 			path = filepath.Join(dir, path)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return path
 | |
| }
 | |
| 
 | |
| func getHeadscaleConfig() headscale.Config {
 | |
| 	dnsConfig, baseDomain := GetDNSConfig()
 | |
| 	derpConfig := GetDERPConfig()
 | |
| 
 | |
| 	return headscale.Config{
 | |
| 		ServerURL:      viper.GetString("server_url"),
 | |
| 		Addr:           viper.GetString("listen_addr"),
 | |
| 		IPPrefix:       netaddr.MustParseIPPrefix(viper.GetString("ip_prefix")),
 | |
| 		PrivateKeyPath: absPath(viper.GetString("private_key_path")),
 | |
| 		BaseDomain:     baseDomain,
 | |
| 
 | |
| 		DERP: derpConfig,
 | |
| 
 | |
| 		EphemeralNodeInactivityTimeout: viper.GetDuration(
 | |
| 			"ephemeral_node_inactivity_timeout",
 | |
| 		),
 | |
| 
 | |
| 		DBtype: viper.GetString("db_type"),
 | |
| 		DBpath: absPath(viper.GetString("db_path")),
 | |
| 		DBhost: viper.GetString("db_host"),
 | |
| 		DBport: viper.GetInt("db_port"),
 | |
| 		DBname: viper.GetString("db_name"),
 | |
| 		DBuser: viper.GetString("db_user"),
 | |
| 		DBpass: viper.GetString("db_pass"),
 | |
| 
 | |
| 		TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"),
 | |
| 		TLSLetsEncryptListen:   viper.GetString("tls_letsencrypt_listen"),
 | |
| 		TLSLetsEncryptCacheDir: absPath(
 | |
| 			viper.GetString("tls_letsencrypt_cache_dir"),
 | |
| 		),
 | |
| 		TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
 | |
| 
 | |
| 		TLSCertPath: absPath(viper.GetString("tls_cert_path")),
 | |
| 		TLSKeyPath:  absPath(viper.GetString("tls_key_path")),
 | |
| 
 | |
| 		DNSConfig: dnsConfig,
 | |
| 
 | |
| 		ACMEEmail: viper.GetString("acme_email"),
 | |
| 		ACMEURL:   viper.GetString("acme_url"),
 | |
| 
 | |
| 		UnixSocket:           viper.GetString("unix_socket"),
 | |
| 		UnixSocketPermission: GetFileMode("unix_socket_permission"),
 | |
| 
 | |
| 		OIDC: headscale.OIDCConfig{
 | |
| 			Issuer:       viper.GetString("oidc.issuer"),
 | |
| 			ClientID:     viper.GetString("oidc.client_id"),
 | |
| 			ClientSecret: viper.GetString("oidc.client_secret"),
 | |
| 		},
 | |
| 
 | |
| 		CLI: headscale.CLIConfig{
 | |
| 			Address:  viper.GetString("cli.address"),
 | |
| 			APIKey:   viper.GetString("cli.api_key"),
 | |
| 			Insecure: viper.GetBool("cli.insecure"),
 | |
| 			Timeout:  viper.GetDuration("cli.timeout"),
 | |
| 		},
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func getHeadscaleApp() (*headscale.Headscale, error) {
 | |
| 	// Minimum inactivity time out is keepalive timeout (60s) plus a few seconds
 | |
| 	// to avoid races
 | |
| 	minInactivityTimeout, _ := time.ParseDuration("65s")
 | |
| 	if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout {
 | |
| 		// TODO: Find a better way to return this text
 | |
| 		//nolint
 | |
| 		err := fmt.Errorf(
 | |
| 			"ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s",
 | |
| 			viper.GetString("ephemeral_node_inactivity_timeout"),
 | |
| 			minInactivityTimeout,
 | |
| 		)
 | |
| 
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	cfg := getHeadscaleConfig()
 | |
| 
 | |
| 	cfg.OIDC.MatchMap = loadOIDCMatchMap()
 | |
| 
 | |
| 	app, err := headscale.NewHeadscale(cfg)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// We are doing this here, as in the future could be cool to have it also hot-reload
 | |
| 
 | |
| 	if viper.GetString("acl_policy_path") != "" {
 | |
| 		aclPath := absPath(viper.GetString("acl_policy_path"))
 | |
| 		err = app.LoadACLPolicy(aclPath)
 | |
| 		if err != nil {
 | |
| 			log.Error().
 | |
| 				Str("path", aclPath).
 | |
| 				Err(err).
 | |
| 				Msg("Could not load the ACL policy")
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return app, nil
 | |
| }
 | |
| 
 | |
| func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
 | |
| 	cfg := getHeadscaleConfig()
 | |
| 
 | |
| 	log.Debug().
 | |
| 		Dur("timeout", cfg.CLI.Timeout).
 | |
| 		Msgf("Setting timeout")
 | |
| 
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), cfg.CLI.Timeout)
 | |
| 
 | |
| 	grpcOptions := []grpc.DialOption{
 | |
| 		grpc.WithBlock(),
 | |
| 	}
 | |
| 
 | |
| 	address := cfg.CLI.Address
 | |
| 
 | |
| 	// If the address is not set, we assume that we are on the server hosting headscale.
 | |
| 	if address == "" {
 | |
| 		log.Debug().
 | |
| 			Str("socket", cfg.UnixSocket).
 | |
| 			Msgf("HEADSCALE_CLI_ADDRESS environment is not set, connecting to unix socket.")
 | |
| 
 | |
| 		address = cfg.UnixSocket
 | |
| 
 | |
| 		grpcOptions = append(
 | |
| 			grpcOptions,
 | |
| 			grpc.WithInsecure(),
 | |
| 			grpc.WithContextDialer(headscale.GrpcSocketDialer),
 | |
| 		)
 | |
| 	} else {
 | |
| 		// If we are not connecting to a local server, require an API key for authentication
 | |
| 		apiKey := cfg.CLI.APIKey
 | |
| 		if apiKey == "" {
 | |
| 			log.Fatal().Msgf("HEADSCALE_CLI_API_KEY environment variable needs to be set.")
 | |
| 		}
 | |
| 		grpcOptions = append(grpcOptions,
 | |
| 			grpc.WithPerRPCCredentials(tokenAuth{
 | |
| 				token: apiKey,
 | |
| 			}),
 | |
| 		)
 | |
| 
 | |
| 		if cfg.CLI.Insecure {
 | |
| 			grpcOptions = append(grpcOptions, grpc.WithInsecure())
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	log.Trace().Caller().Str("address", address).Msg("Connecting via gRPC")
 | |
| 	conn, err := grpc.DialContext(ctx, address, grpcOptions...)
 | |
| 	if err != nil {
 | |
| 		log.Fatal().Err(err).Msgf("Could not connect: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	client := v1.NewHeadscaleServiceClient(conn)
 | |
| 
 | |
| 	return ctx, client, conn, cancel
 | |
| }
 | |
| 
 | |
| func SuccessOutput(result interface{}, override string, outputFormat string) {
 | |
| 	var j []byte
 | |
| 	var err error
 | |
| 	switch outputFormat {
 | |
| 	case "json":
 | |
| 		j, err = json.MarshalIndent(result, "", "\t")
 | |
| 		if err != nil {
 | |
| 			log.Fatal().Err(err)
 | |
| 		}
 | |
| 	case "json-line":
 | |
| 		j, err = json.Marshal(result)
 | |
| 		if err != nil {
 | |
| 			log.Fatal().Err(err)
 | |
| 		}
 | |
| 	case "yaml":
 | |
| 		j, err = yaml.Marshal(result)
 | |
| 		if err != nil {
 | |
| 			log.Fatal().Err(err)
 | |
| 		}
 | |
| 	default:
 | |
| 		//nolint
 | |
| 		fmt.Println(override)
 | |
| 
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	//nolint
 | |
| 	fmt.Println(string(j))
 | |
| }
 | |
| 
 | |
| func ErrorOutput(errResult error, override string, outputFormat string) {
 | |
| 	type errOutput struct {
 | |
| 		Error string `json:"error"`
 | |
| 	}
 | |
| 
 | |
| 	SuccessOutput(errOutput{errResult.Error()}, override, outputFormat)
 | |
| }
 | |
| 
 | |
| func HasMachineOutputFlag() bool {
 | |
| 	for _, arg := range os.Args {
 | |
| 		if arg == "json" || arg == "json-line" || arg == "yaml" {
 | |
| 			return true
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| type tokenAuth struct {
 | |
| 	token string
 | |
| }
 | |
| 
 | |
| // Return value is mapped to request headers.
 | |
| func (t tokenAuth) GetRequestMetadata(
 | |
| 	ctx context.Context,
 | |
| 	in ...string,
 | |
| ) (map[string]string, error) {
 | |
| 	return map[string]string{
 | |
| 		"authorization": "Bearer " + t.token,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (tokenAuth) RequireTransportSecurity() bool {
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| // loadOIDCMatchMap is a wrapper around viper to verifies that the keys in
 | |
| // the match map is valid regex strings.
 | |
| func loadOIDCMatchMap() map[string]string {
 | |
| 	strMap := viper.GetStringMapString("oidc.domain_map")
 | |
| 
 | |
| 	for oidcMatcher := range strMap {
 | |
| 		_ = regexp.MustCompile(oidcMatcher)
 | |
| 	}
 | |
| 
 | |
| 	return strMap
 | |
| }
 | |
| 
 | |
| func GetFileMode(key string) fs.FileMode {
 | |
| 	modeStr := viper.GetString(key)
 | |
| 
 | |
| 	mode, err := strconv.ParseUint(modeStr, headscale.Base8, headscale.BitSize64)
 | |
| 	if err != nil {
 | |
| 		return PermissionFallback
 | |
| 	}
 | |
| 
 | |
| 	return fs.FileMode(mode)
 | |
| }
 |