mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			241 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			241 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package cli
 | |
| 
 | |
| import (
 | |
| 	"encoding/json"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"os"
 | |
| 	"path/filepath"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/juanfont/headscale"
 | |
| 	"github.com/rs/zerolog/log"
 | |
| 	"github.com/spf13/viper"
 | |
| 	"gopkg.in/yaml.v2"
 | |
| 	"inet.af/netaddr"
 | |
| 	"tailscale.com/tailcfg"
 | |
| )
 | |
| 
 | |
| type ErrorOutput struct {
 | |
| 	Error string
 | |
| }
 | |
| 
 | |
| func LoadConfig(path string) error {
 | |
| 	viper.SetConfigName("config")
 | |
| 	if path == "" {
 | |
| 		viper.AddConfigPath("/etc/headscale/")
 | |
| 		viper.AddConfigPath("$HOME/.headscale")
 | |
| 		viper.AddConfigPath(".")
 | |
| 	} else {
 | |
| 		// For testing
 | |
| 		viper.AddConfigPath(path)
 | |
| 	}
 | |
| 	viper.AutomaticEnv()
 | |
| 
 | |
| 	viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
 | |
| 	viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01")
 | |
| 
 | |
| 	viper.SetDefault("ip_prefix", "100.64.0.0/10")
 | |
| 
 | |
| 	viper.SetDefault("log_level", "info")
 | |
| 
 | |
| 	viper.SetDefault("dns_config", nil)
 | |
| 
 | |
| 	err := viper.ReadInConfig()
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("Fatal error reading config file: %s \n", err)
 | |
| 	}
 | |
| 
 | |
| 	// Collect any validation errors and return them all at once
 | |
| 	var errorText string
 | |
| 	if (viper.GetString("tls_letsencrypt_hostname") != "") && ((viper.GetString("tls_cert_path") != "") || (viper.GetString("tls_key_path") != "")) {
 | |
| 		errorText += "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both\n"
 | |
| 	}
 | |
| 
 | |
| 	if (viper.GetString("tls_letsencrypt_hostname") != "") && (viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") && (!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) {
 | |
| 		// this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule)
 | |
| 		log.Warn().
 | |
| 			Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443")
 | |
| 	}
 | |
| 
 | |
| 	if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") && (viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") {
 | |
| 		errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n"
 | |
| 	}
 | |
| 
 | |
| 	if !strings.HasPrefix(viper.GetString("server_url"), "http://") && !strings.HasPrefix(viper.GetString("server_url"), "https://") {
 | |
| 		errorText += "Fatal config error: server_url must start with https:// or http://\n"
 | |
| 	}
 | |
| 	if errorText != "" {
 | |
| 		return errors.New(strings.TrimSuffix(errorText, "\n"))
 | |
| 	} else {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| }
 | |
| 
 | |
| func GetDNSConfig() *tailcfg.DNSConfig {
 | |
| 	if viper.IsSet("dns_config") {
 | |
| 		dnsConfig := &tailcfg.DNSConfig{}
 | |
| 
 | |
| 		if viper.IsSet("dns_config.nameservers") {
 | |
| 			nameserversStr := viper.GetStringSlice("dns_config.nameservers")
 | |
| 
 | |
| 			nameservers := make([]netaddr.IP, len(nameserversStr))
 | |
| 			resolvers := make([]tailcfg.DNSResolver, len(nameserversStr))
 | |
| 
 | |
| 			for index, nameserverStr := range nameserversStr {
 | |
| 				nameserver, err := netaddr.ParseIP(nameserverStr)
 | |
| 				if err != nil {
 | |
| 					log.Error().
 | |
| 						Str("func", "getDNSConfig").
 | |
| 						Err(err).
 | |
| 						Msgf("Could not parse nameserver IP: %s", nameserverStr)
 | |
| 				}
 | |
| 
 | |
| 				nameservers[index] = nameserver
 | |
| 				resolvers[index] = tailcfg.DNSResolver{
 | |
| 					Addr: nameserver.String(),
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			dnsConfig.Nameservers = nameservers
 | |
| 			dnsConfig.Resolvers = resolvers
 | |
| 		}
 | |
| 		if viper.IsSet("dns_config.domains") {
 | |
| 			dnsConfig.Domains = viper.GetStringSlice("dns_config.domains")
 | |
| 		}
 | |
| 
 | |
| 		return dnsConfig
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func absPath(path string) string {
 | |
| 	// If a relative path is provided, prefix it with the the directory where
 | |
| 	// the config file was found.
 | |
| 	if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) {
 | |
| 		dir, _ := filepath.Split(viper.ConfigFileUsed())
 | |
| 		if dir != "" {
 | |
| 			path = filepath.Join(dir, path)
 | |
| 		}
 | |
| 	}
 | |
| 	return path
 | |
| }
 | |
| 
 | |
| func getHeadscaleApp() (*headscale.Headscale, error) {
 | |
| 	derpPath := absPath(viper.GetString("derp_map_path"))
 | |
| 	derpMap, err := loadDerpMap(derpPath)
 | |
| 	if err != nil {
 | |
| 		log.Error().
 | |
| 			Str("path", derpPath).
 | |
| 			Err(err).
 | |
| 			Msg("Could not load DERP servers map file")
 | |
| 	}
 | |
| 
 | |
| 	// Minimum inactivity time out is keepalive timeout (60s) plus a few seconds
 | |
| 	// to avoid races
 | |
| 	minInactivityTimeout, _ := time.ParseDuration("65s")
 | |
| 	if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout {
 | |
| 		err = fmt.Errorf("ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s\n", viper.GetString("ephemeral_node_inactivity_timeout"), minInactivityTimeout)
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	cfg := headscale.Config{
 | |
| 		ServerURL:      viper.GetString("server_url"),
 | |
| 		Addr:           viper.GetString("listen_addr"),
 | |
| 		PrivateKeyPath: absPath(viper.GetString("private_key_path")),
 | |
| 		DerpMap:        derpMap,
 | |
| 		IPPrefix:       netaddr.MustParseIPPrefix(viper.GetString("ip_prefix")),
 | |
| 
 | |
| 		EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"),
 | |
| 
 | |
| 		DBtype: viper.GetString("db_type"),
 | |
| 		DBpath: absPath(viper.GetString("db_path")),
 | |
| 		DBhost: viper.GetString("db_host"),
 | |
| 		DBport: viper.GetInt("db_port"),
 | |
| 		DBname: viper.GetString("db_name"),
 | |
| 		DBuser: viper.GetString("db_user"),
 | |
| 		DBpass: viper.GetString("db_pass"),
 | |
| 
 | |
| 		TLSLetsEncryptHostname:      viper.GetString("tls_letsencrypt_hostname"),
 | |
| 		TLSLetsEncryptListen:        viper.GetString("tls_letsencrypt_listen"),
 | |
| 		TLSLetsEncryptCacheDir:      absPath(viper.GetString("tls_letsencrypt_cache_dir")),
 | |
| 		TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
 | |
| 
 | |
| 		TLSCertPath: absPath(viper.GetString("tls_cert_path")),
 | |
| 		TLSKeyPath:  absPath(viper.GetString("tls_key_path")),
 | |
| 
 | |
| 		DNSConfig: GetDNSConfig(),
 | |
| 	}
 | |
| 
 | |
| 	h, err := headscale.NewHeadscale(cfg)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// We are doing this here, as in the future could be cool to have it also hot-reload
 | |
| 
 | |
| 	if viper.GetString("acl_policy_path") != "" {
 | |
| 		aclPath := absPath(viper.GetString("acl_policy_path"))
 | |
| 		err = h.LoadACLPolicy(aclPath)
 | |
| 		if err != nil {
 | |
| 			log.Error().
 | |
| 				Str("path", aclPath).
 | |
| 				Err(err).
 | |
| 				Msg("Could not load the ACL policy")
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return h, nil
 | |
| }
 | |
| 
 | |
| func loadDerpMap(path string) (*tailcfg.DERPMap, error) {
 | |
| 	derpFile, err := os.Open(path)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	defer derpFile.Close()
 | |
| 	var derpMap tailcfg.DERPMap
 | |
| 	b, err := io.ReadAll(derpFile)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	err = yaml.Unmarshal(b, &derpMap)
 | |
| 	return &derpMap, err
 | |
| }
 | |
| 
 | |
| func JsonOutput(result interface{}, errResult error, outputFormat string) {
 | |
| 	var j []byte
 | |
| 	var err error
 | |
| 	switch outputFormat {
 | |
| 	case "json":
 | |
| 		if errResult != nil {
 | |
| 			j, err = json.MarshalIndent(ErrorOutput{errResult.Error()}, "", "\t")
 | |
| 			if err != nil {
 | |
| 				log.Fatal().Err(err)
 | |
| 			}
 | |
| 		} else {
 | |
| 			j, err = json.MarshalIndent(result, "", "\t")
 | |
| 			if err != nil {
 | |
| 				log.Fatal().Err(err)
 | |
| 			}
 | |
| 		}
 | |
| 	case "json-line":
 | |
| 		if errResult != nil {
 | |
| 			j, err = json.Marshal(ErrorOutput{errResult.Error()})
 | |
| 			if err != nil {
 | |
| 				log.Fatal().Err(err)
 | |
| 			}
 | |
| 		} else {
 | |
| 			j, err = json.Marshal(result)
 | |
| 			if err != nil {
 | |
| 				log.Fatal().Err(err)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	fmt.Println(string(j))
 | |
| }
 |