package cli import ( "context" "encoding/json" "errors" "fmt" "net/url" "os" "path/filepath" "regexp" "strconv" "strings" "time" "github.com/juanfont/headscale" apiV1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/rs/zerolog/log" "github.com/spf13/viper" "google.golang.org/grpc" "inet.af/netaddr" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" ) type ErrorOutput struct { Error string } func LoadConfig(path string) error { viper.SetConfigName("config") if path == "" { viper.AddConfigPath("/etc/headscale/") viper.AddConfigPath("$HOME/.headscale") viper.AddConfigPath(".") } else { // For testing viper.AddConfigPath(path) } viper.AutomaticEnv() viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01") viper.SetDefault("ip_prefix", "100.64.0.0/10") viper.SetDefault("log_level", "info") viper.SetDefault("dns_config", nil) viper.SetDefault("unix_socket", "/var/run/headscale.sock") err := viper.ReadInConfig() if err != nil { return fmt.Errorf("Fatal error reading config file: %s \n", err) } // Collect any validation errors and return them all at once var errorText string if (viper.GetString("tls_letsencrypt_hostname") != "") && ((viper.GetString("tls_cert_path") != "") || (viper.GetString("tls_key_path") != "")) { errorText += "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both\n" } if (viper.GetString("tls_letsencrypt_hostname") != "") && (viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") && (!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) { // this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule) log.Warn(). Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443") } if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") && (viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") { errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n" } if !strings.HasPrefix(viper.GetString("server_url"), "http://") && !strings.HasPrefix(viper.GetString("server_url"), "https://") { errorText += "Fatal config error: server_url must start with https:// or http://\n" } if errorText != "" { return errors.New(strings.TrimSuffix(errorText, "\n")) } else { return nil } } func GetDERPConfig() headscale.DERPConfig { urlStrs := viper.GetStringSlice("derp.urls") urls := make([]url.URL, len(urlStrs)) for index, urlStr := range urlStrs { urlAddr, err := url.Parse(urlStr) if err != nil { log.Error(). Str("url", urlStr). Err(err). Msg("Failed to parse url, ignoring...") } urls[index] = *urlAddr } paths := viper.GetStringSlice("derp.paths") autoUpdate := viper.GetBool("derp.auto_update_enabled") updateFrequency := viper.GetDuration("derp.update_frequency") return headscale.DERPConfig{ URLs: urls, Paths: paths, AutoUpdate: autoUpdate, UpdateFrequency: updateFrequency, } } func GetDNSConfig() (*tailcfg.DNSConfig, string) { if viper.IsSet("dns_config") { dnsConfig := &tailcfg.DNSConfig{} if viper.IsSet("dns_config.nameservers") { nameserversStr := viper.GetStringSlice("dns_config.nameservers") nameservers := make([]netaddr.IP, len(nameserversStr)) resolvers := make([]dnstype.Resolver, len(nameserversStr)) for index, nameserverStr := range nameserversStr { nameserver, err := netaddr.ParseIP(nameserverStr) if err != nil { log.Error(). Str("func", "getDNSConfig"). Err(err). Msgf("Could not parse nameserver IP: %s", nameserverStr) } nameservers[index] = nameserver resolvers[index] = dnstype.Resolver{ Addr: nameserver.String(), } } dnsConfig.Nameservers = nameservers dnsConfig.Resolvers = resolvers } if viper.IsSet("dns_config.restricted_nameservers") { if len(dnsConfig.Nameservers) > 0 { dnsConfig.Routes = make(map[string][]dnstype.Resolver) restrictedDNS := viper.GetStringMapStringSlice("dns_config.restricted_nameservers") for domain, restrictedNameservers := range restrictedDNS { restrictedResolvers := make([]dnstype.Resolver, len(restrictedNameservers)) for index, nameserverStr := range restrictedNameservers { nameserver, err := netaddr.ParseIP(nameserverStr) if err != nil { log.Error(). Str("func", "getDNSConfig"). Err(err). Msgf("Could not parse restricted nameserver IP: %s", nameserverStr) } restrictedResolvers[index] = dnstype.Resolver{ Addr: nameserver.String(), } } dnsConfig.Routes[domain] = restrictedResolvers } } else { log.Warn(). Msg("Warning: dns_config.restricted_nameservers is set, but no nameservers are configured. Ignoring restricted_nameservers.") } } if viper.IsSet("dns_config.domains") { dnsConfig.Domains = viper.GetStringSlice("dns_config.domains") } if viper.IsSet("dns_config.magic_dns") { magicDNS := viper.GetBool("dns_config.magic_dns") if len(dnsConfig.Nameservers) > 0 { dnsConfig.Proxied = magicDNS } else if magicDNS { log.Warn(). Msg("Warning: dns_config.magic_dns is set, but no nameservers are configured. Ignoring magic_dns.") } } var baseDomain string if viper.IsSet("dns_config.base_domain") { baseDomain = viper.GetString("dns_config.base_domain") } else { baseDomain = "headscale.net" // does not really matter when MagicDNS is not enabled } return dnsConfig, baseDomain } return nil, "" } func absPath(path string) string { // If a relative path is provided, prefix it with the the directory where // the config file was found. if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) { dir, _ := filepath.Split(viper.ConfigFileUsed()) if dir != "" { path = filepath.Join(dir, path) } } return path } func getHeadscaleConfig() headscale.Config { // maxMachineRegistrationDuration is the maximum time headscale will allow a client to (optionally) request for // the machine key expiry time. RegisterRequests with Expiry times that are more than // maxMachineRegistrationDuration in the future will be clamped to (now + maxMachineRegistrationDuration) maxMachineRegistrationDuration, _ := time.ParseDuration( "10h", ) // use 10h here because it is the length of a standard business day plus a small amount of leeway if viper.GetDuration("max_machine_registration_duration") >= time.Second { maxMachineRegistrationDuration = viper.GetDuration("max_machine_registration_duration") } // defaultMachineRegistrationDuration is the default time assigned to a machine registration if one is not // specified by the tailscale client. It is the default amount of time a machine registration is valid for // (ie the amount of time before the user has to re-authenticate when requesting a connection) defaultMachineRegistrationDuration, _ := time.ParseDuration( "8h", ) // use 8h here because it's the length of a standard business day if viper.GetDuration("default_machine_registration_duration") >= time.Second { defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration") } dnsConfig, baseDomain := GetDNSConfig() derpConfig := GetDERPConfig() return headscale.Config{ ServerURL: viper.GetString("server_url"), Addr: viper.GetString("listen_addr"), PrivateKeyPath: absPath(viper.GetString("private_key_path")), IPPrefix: netaddr.MustParseIPPrefix(viper.GetString("ip_prefix")), BaseDomain: baseDomain, DERP: derpConfig, EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"), DBtype: viper.GetString("db_type"), DBpath: absPath(viper.GetString("db_path")), DBhost: viper.GetString("db_host"), DBport: viper.GetInt("db_port"), DBname: viper.GetString("db_name"), DBuser: viper.GetString("db_user"), DBpass: viper.GetString("db_pass"), TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"), TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"), TLSLetsEncryptCacheDir: absPath(viper.GetString("tls_letsencrypt_cache_dir")), TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"), TLSCertPath: absPath(viper.GetString("tls_cert_path")), TLSKeyPath: absPath(viper.GetString("tls_key_path")), DNSConfig: dnsConfig, ACMEEmail: viper.GetString("acme_email"), ACMEURL: viper.GetString("acme_url"), UnixSocket: viper.GetString("unix_socket"), OIDC: headscale.OIDCConfig{ Issuer: viper.GetString("oidc.issuer"), ClientID: viper.GetString("oidc.client_id"), ClientSecret: viper.GetString("oidc.client_secret"), }, MaxMachineRegistrationDuration: maxMachineRegistrationDuration, DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, } } func getHeadscaleApp() (*headscale.Headscale, error) { // Minimum inactivity time out is keepalive timeout (60s) plus a few seconds // to avoid races minInactivityTimeout, _ := time.ParseDuration("65s") if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout { err := fmt.Errorf( "ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s\n", viper.GetString("ephemeral_node_inactivity_timeout"), minInactivityTimeout, ) return nil, err } cfg := getHeadscaleConfig() cfg.OIDC.MatchMap = loadOIDCMatchMap() h, err := headscale.NewHeadscale(cfg) if err != nil { return nil, err } // We are doing this here, as in the future could be cool to have it also hot-reload if viper.GetString("acl_policy_path") != "" { aclPath := absPath(viper.GetString("acl_policy_path")) err = h.LoadACLPolicy(aclPath) if err != nil { log.Error(). Str("path", aclPath). Err(err). Msg("Could not load the ACL policy") } } return h, nil } func getHeadscaleGRPCClient(ctx context.Context) (apiV1.HeadscaleServiceClient, *grpc.ClientConn) { grpcOptions := []grpc.DialOption{ grpc.WithBlock(), } address := os.Getenv("HEADSCALE_ADDRESS") // If the address is not set, we assume that we are on the server hosting headscale. if address == "" { cfg := getHeadscaleConfig() log.Debug(). Str("socket", cfg.UnixSocket). Msgf("HEADSCALE_ADDRESS environment is not set, connecting to unix socket.") address = cfg.UnixSocket grpcOptions = append( grpcOptions, grpc.WithInsecure(), grpc.WithContextDialer(headscale.GrpcSocketDialer), ) } else { // If we are not connecting to a local server, require an API key for authentication apiKey := os.Getenv("HEADSCALE_API_KEY") if apiKey == "" { log.Fatal().Msgf("HEADSCALE_API_KEY environment variable needs to be set.") } grpcOptions = append(grpcOptions, grpc.WithPerRPCCredentials(tokenAuth{ token: apiKey, }), ) insecureStr := os.Getenv("HEADSCALE_INSECURE") if insecureStr != "" { insecure, err := strconv.ParseBool(insecureStr) if err != nil { log.Fatal().Err(err).Msgf("Failed to parse HEADSCALE_INSECURE: %v", err) } if insecure { grpcOptions = append(grpcOptions, grpc.WithInsecure()) } } } log.Trace().Caller().Str("address", address).Msg("Connecting via gRPC") conn, err := grpc.DialContext(ctx, address, grpcOptions...) if err != nil { log.Fatal().Err(err).Msgf("Could not connect: %v", err) } client := apiV1.NewHeadscaleServiceClient(conn) return client, conn } func JsonOutput(result interface{}, errResult error, outputFormat string) { var j []byte var err error switch outputFormat { case "json": if errResult != nil { j, err = json.MarshalIndent(ErrorOutput{errResult.Error()}, "", "\t") if err != nil { log.Fatal().Err(err) } } else { j, err = json.MarshalIndent(result, "", "\t") if err != nil { log.Fatal().Err(err) } } case "json-line": if errResult != nil { j, err = json.Marshal(ErrorOutput{errResult.Error()}) if err != nil { log.Fatal().Err(err) } } else { j, err = json.Marshal(result) if err != nil { log.Fatal().Err(err) } } } fmt.Println(string(j)) } func HasJsonOutputFlag() bool { for _, arg := range os.Args { if arg == "json" || arg == "json-line" { return true } } return false } type tokenAuth struct { token string } // Return value is mapped to request headers. func (t tokenAuth) GetRequestMetadata(ctx context.Context, in ...string) (map[string]string, error) { return map[string]string{ "authorization": "Bearer " + t.token, }, nil } func (tokenAuth) RequireTransportSecurity() bool { return true } // loadOIDCMatchMap is a wrapper around viper to verifies that the keys in // the match map is valid regex strings. func loadOIDCMatchMap() map[string]string { strMap := viper.GetStringMapString("oidc.domain_map") for oidcMatcher := range strMap { _ = regexp.MustCompile(oidcMatcher) } return strMap }