1
0
mirror of https://github.com/juanfont/headscale.git synced 2024-10-17 20:05:55 +02:00
juanfont.headscale/cmd/headscale/cli/utils.go

478 lines
13 KiB
Go
Raw Normal View History

package cli
import (
"context"
"encoding/json"
2021-06-05 11:13:28 +02:00
"errors"
"fmt"
"io/fs"
2021-10-22 18:55:14 +02:00
"net/url"
"os"
"path/filepath"
2021-10-18 21:27:52 +02:00
"regexp"
"strconv"
"strings"
"time"
"github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
2021-08-05 19:26:49 +02:00
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"google.golang.org/grpc"
"gopkg.in/yaml.v2"
"inet.af/netaddr"
"tailscale.com/tailcfg"
2021-09-14 23:46:16 +02:00
"tailscale.com/types/dnstype"
)
const (
PermissionFallback = 0o700
)
2021-06-05 11:13:28 +02:00
func LoadConfig(path string) error {
viper.SetConfigName("config")
if path == "" {
viper.AddConfigPath("/etc/headscale/")
viper.AddConfigPath("$HOME/.headscale")
viper.AddConfigPath(".")
} else {
// For testing
viper.AddConfigPath(path)
}
viper.SetEnvPrefix("headscale")
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
2021-06-05 11:13:28 +02:00
viper.AutomaticEnv()
viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01")
2022-01-29 20:15:33 +01:00
viper.SetDefault("tls_client_auth_mode", "disabled")
2021-06-05 11:13:28 +02:00
viper.SetDefault("ip_prefix", "100.64.0.0/10")
2021-08-20 18:15:07 +02:00
viper.SetDefault("log_level", "info")
2021-08-05 20:19:25 +02:00
2021-08-24 08:09:47 +02:00
viper.SetDefault("dns_config", nil)
viper.SetDefault("unix_socket", "/var/run/headscale.sock")
viper.SetDefault("unix_socket_permission", "0o770")
viper.SetDefault("cli.insecure", false)
viper.SetDefault("cli.timeout", "5s")
2021-11-14 18:09:22 +01:00
if err := viper.ReadInConfig(); err != nil {
return fmt.Errorf("fatal error reading config file: %w", err)
2021-06-05 11:13:28 +02:00
}
// Collect any validation errors and return them all at once
var errorText string
2021-10-22 18:55:14 +02:00
if (viper.GetString("tls_letsencrypt_hostname") != "") &&
((viper.GetString("tls_cert_path") != "") || (viper.GetString("tls_key_path") != "")) {
2021-06-05 11:13:28 +02:00
errorText += "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both\n"
}
2021-10-22 18:55:14 +02:00
if (viper.GetString("tls_letsencrypt_hostname") != "") &&
(viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") &&
(!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) {
// this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule)
2021-08-05 19:26:49 +02:00
log.Warn().
Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443")
2021-06-05 11:13:28 +02:00
}
2021-10-22 18:55:14 +02:00
if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") &&
(viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") {
2021-06-05 11:13:28 +02:00
errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n"
}
2021-10-22 18:55:14 +02:00
if !strings.HasPrefix(viper.GetString("server_url"), "http://") &&
!strings.HasPrefix(viper.GetString("server_url"), "https://") {
2021-06-05 11:13:28 +02:00
errorText += "Fatal config error: server_url must start with https:// or http://\n"
}
2022-01-29 20:15:33 +01:00
clientAuthMode := viper.GetString("tls_client_auth_mode")
if clientAuthMode != "disabled" && clientAuthMode != "relaxed" && clientAuthMode != "enforced" {
errorText += "Invalid tls_client_auth_mode supplied. Accepted values: disabled, relaxed, enforced."
}
2021-06-05 11:13:28 +02:00
if errorText != "" {
2021-11-15 20:18:14 +01:00
//nolint
2021-06-05 11:13:28 +02:00
return errors.New(strings.TrimSuffix(errorText, "\n"))
} else {
return nil
}
2021-10-22 18:55:14 +02:00
}
func GetDERPConfig() headscale.DERPConfig {
urlStrs := viper.GetStringSlice("derp.urls")
urls := make([]url.URL, len(urlStrs))
for index, urlStr := range urlStrs {
urlAddr, err := url.Parse(urlStr)
if err != nil {
log.Error().
Str("url", urlStr).
Err(err).
Msg("Failed to parse url, ignoring...")
}
urls[index] = *urlAddr
}
2021-08-24 08:09:47 +02:00
2021-10-22 18:55:14 +02:00
paths := viper.GetStringSlice("derp.paths")
autoUpdate := viper.GetBool("derp.auto_update_enabled")
updateFrequency := viper.GetDuration("derp.update_frequency")
return headscale.DERPConfig{
URLs: urls,
Paths: paths,
AutoUpdate: autoUpdate,
UpdateFrequency: updateFrequency,
}
2021-08-24 08:09:47 +02:00
}
func GetDNSConfig() (*tailcfg.DNSConfig, string) {
2021-08-24 08:09:47 +02:00
if viper.IsSet("dns_config") {
dnsConfig := &tailcfg.DNSConfig{}
if viper.IsSet("dns_config.nameservers") {
nameserversStr := viper.GetStringSlice("dns_config.nameservers")
nameservers := make([]netaddr.IP, len(nameserversStr))
2021-09-14 23:46:16 +02:00
resolvers := make([]dnstype.Resolver, len(nameserversStr))
2021-08-24 08:09:47 +02:00
for index, nameserverStr := range nameserversStr {
nameserver, err := netaddr.ParseIP(nameserverStr)
if err != nil {
log.Error().
Str("func", "getDNSConfig").
Err(err).
Msgf("Could not parse nameserver IP: %s", nameserverStr)
}
nameservers[index] = nameserver
2021-09-14 23:46:16 +02:00
resolvers[index] = dnstype.Resolver{
2021-08-25 19:43:13 +02:00
Addr: nameserver.String(),
2021-08-25 08:04:48 +02:00
}
2021-08-24 08:09:47 +02:00
}
dnsConfig.Nameservers = nameservers
2021-08-25 08:04:48 +02:00
dnsConfig.Resolvers = resolvers
2021-08-24 08:09:47 +02:00
}
if viper.IsSet("dns_config.restricted_nameservers") {
if len(dnsConfig.Nameservers) > 0 {
dnsConfig.Routes = make(map[string][]dnstype.Resolver)
2021-11-13 09:36:45 +01:00
restrictedDNS := viper.GetStringMapStringSlice(
"dns_config.restricted_nameservers",
)
for domain, restrictedNameservers := range restrictedDNS {
2021-11-13 09:36:45 +01:00
restrictedResolvers := make(
[]dnstype.Resolver,
len(restrictedNameservers),
)
for index, nameserverStr := range restrictedNameservers {
nameserver, err := netaddr.ParseIP(nameserverStr)
if err != nil {
log.Error().
Str("func", "getDNSConfig").
Err(err).
Msgf("Could not parse restricted nameserver IP: %s", nameserverStr)
}
restrictedResolvers[index] = dnstype.Resolver{
Addr: nameserver.String(),
}
}
dnsConfig.Routes[domain] = restrictedResolvers
}
} else {
log.Warn().
Msg("Warning: dns_config.restricted_nameservers is set, but no nameservers are configured. Ignoring restricted_nameservers.")
}
}
2021-08-24 08:09:47 +02:00
if viper.IsSet("dns_config.domains") {
dnsConfig.Domains = viper.GetStringSlice("dns_config.domains")
}
if viper.IsSet("dns_config.magic_dns") {
magicDNS := viper.GetBool("dns_config.magic_dns")
if len(dnsConfig.Nameservers) > 0 {
dnsConfig.Proxied = magicDNS
} else if magicDNS {
log.Warn().
Msg("Warning: dns_config.magic_dns is set, but no nameservers are configured. Ignoring magic_dns.")
}
2021-09-28 00:22:29 +02:00
}
var baseDomain string
if viper.IsSet("dns_config.base_domain") {
baseDomain = viper.GetString("dns_config.base_domain")
} else {
baseDomain = "headscale.net" // does not really matter when MagicDNS is not enabled
}
2021-10-02 13:03:08 +02:00
return dnsConfig, baseDomain
2021-08-24 08:09:47 +02:00
}
return nil, ""
2021-06-05 11:13:28 +02:00
}
func absPath(path string) string {
// If a relative path is provided, prefix it with the the directory where
// the config file was found.
if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) {
dir, _ := filepath.Split(viper.ConfigFileUsed())
if dir != "" {
path = filepath.Join(dir, path)
}
}
2021-11-14 16:46:09 +01:00
return path
}
func getHeadscaleConfig() headscale.Config {
dnsConfig, baseDomain := GetDNSConfig()
2021-10-22 18:55:14 +02:00
derpConfig := GetDERPConfig()
return headscale.Config{
ServerURL: viper.GetString("server_url"),
Addr: viper.GetString("listen_addr"),
IPPrefix: netaddr.MustParseIPPrefix(viper.GetString("ip_prefix")),
PrivateKeyPath: absPath(viper.GetString("private_key_path")),
BaseDomain: baseDomain,
2021-10-22 18:55:14 +02:00
DERP: derpConfig,
2021-11-13 09:36:45 +01:00
EphemeralNodeInactivityTimeout: viper.GetDuration(
"ephemeral_node_inactivity_timeout",
),
DBtype: viper.GetString("db_type"),
2021-05-19 01:28:47 +02:00
DBpath: absPath(viper.GetString("db_path")),
DBhost: viper.GetString("db_host"),
DBport: viper.GetInt("db_port"),
DBname: viper.GetString("db_name"),
DBuser: viper.GetString("db_user"),
DBpass: viper.GetString("db_pass"),
2021-11-13 09:36:45 +01:00
TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"),
TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"),
TLSLetsEncryptCacheDir: absPath(
viper.GetString("tls_letsencrypt_cache_dir"),
),
TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
2022-01-29 20:15:33 +01:00
TLSCertPath: absPath(viper.GetString("tls_cert_path")),
TLSKeyPath: absPath(viper.GetString("tls_key_path")),
TLSClientAuthMode: viper.GetString("tls_client_auth_mode"),
2021-08-24 08:09:47 +02:00
DNSConfig: dnsConfig,
2021-10-03 22:02:44 +02:00
ACMEEmail: viper.GetString("acme_email"),
ACMEURL: viper.GetString("acme_url"),
UnixSocket: viper.GetString("unix_socket"),
UnixSocketPermission: GetFileMode("unix_socket_permission"),
2021-10-31 10:40:43 +01:00
2021-10-18 21:27:52 +02:00
OIDC: headscale.OIDCConfig{
Issuer: viper.GetString("oidc.issuer"),
ClientID: viper.GetString("oidc.client_id"),
ClientSecret: viper.GetString("oidc.client_secret"),
},
2021-10-08 11:43:52 +02:00
CLI: headscale.CLIConfig{
Address: viper.GetString("cli.address"),
APIKey: viper.GetString("cli.api_key"),
Insecure: viper.GetBool("cli.insecure"),
Timeout: viper.GetDuration("cli.timeout"),
},
}
}
func getHeadscaleApp() (*headscale.Headscale, error) {
// Minimum inactivity time out is keepalive timeout (60s) plus a few seconds
// to avoid races
minInactivityTimeout, _ := time.ParseDuration("65s")
if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout {
2021-11-15 20:18:14 +01:00
// TODO: Find a better way to return this text
//nolint
err := fmt.Errorf(
"ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s",
viper.GetString("ephemeral_node_inactivity_timeout"),
minInactivityTimeout,
)
2021-11-14 16:46:09 +01:00
return nil, err
}
cfg := getHeadscaleConfig()
2021-10-18 21:27:52 +02:00
cfg.OIDC.MatchMap = loadOIDCMatchMap()
app, err := headscale.NewHeadscale(cfg)
if err != nil {
return nil, err
}
2021-07-04 13:24:05 +02:00
// We are doing this here, as in the future could be cool to have it also hot-reload
2021-07-11 15:10:11 +02:00
if viper.GetString("acl_policy_path") != "" {
2021-08-05 19:26:49 +02:00
aclPath := absPath(viper.GetString("acl_policy_path"))
err = app.LoadACLPolicy(aclPath)
2021-07-11 15:10:11 +02:00
if err != nil {
2021-08-05 19:26:49 +02:00
log.Error().
2021-08-05 21:57:47 +02:00
Str("path", aclPath).
2021-08-05 19:26:49 +02:00
Err(err).
Msg("Could not load the ACL policy")
2021-07-11 15:10:11 +02:00
}
2021-07-04 13:24:05 +02:00
}
return app, nil
}
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
cfg := getHeadscaleConfig()
log.Debug().
Dur("timeout", cfg.CLI.Timeout).
Msgf("Setting timeout")
ctx, cancel := context.WithTimeout(context.Background(), cfg.CLI.Timeout)
grpcOptions := []grpc.DialOption{
grpc.WithBlock(),
}
address := cfg.CLI.Address
// If the address is not set, we assume that we are on the server hosting headscale.
if address == "" {
log.Debug().
Str("socket", cfg.UnixSocket).
Msgf("HEADSCALE_CLI_ADDRESS environment is not set, connecting to unix socket.")
address = cfg.UnixSocket
grpcOptions = append(
grpcOptions,
grpc.WithInsecure(),
2021-10-30 16:29:03 +02:00
grpc.WithContextDialer(headscale.GrpcSocketDialer),
)
} else {
// If we are not connecting to a local server, require an API key for authentication
apiKey := cfg.CLI.APIKey
if apiKey == "" {
log.Fatal().Msgf("HEADSCALE_CLI_API_KEY environment variable needs to be set.")
}
grpcOptions = append(grpcOptions,
grpc.WithPerRPCCredentials(tokenAuth{
token: apiKey,
}),
)
if cfg.CLI.Insecure {
grpcOptions = append(grpcOptions, grpc.WithInsecure())
}
}
log.Trace().Caller().Str("address", address).Msg("Connecting via gRPC")
2021-10-29 19:36:11 +02:00
conn, err := grpc.DialContext(ctx, address, grpcOptions...)
if err != nil {
log.Fatal().Err(err).Msgf("Could not connect: %v", err)
}
client := v1.NewHeadscaleServiceClient(conn)
return ctx, client, conn, cancel
}
func SuccessOutput(result interface{}, override string, outputFormat string) {
var j []byte
var err error
switch outputFormat {
case "json":
j, err = json.MarshalIndent(result, "", "\t")
if err != nil {
log.Fatal().Err(err)
}
case "json-line":
j, err = json.Marshal(result)
if err != nil {
log.Fatal().Err(err)
}
case "yaml":
j, err = yaml.Marshal(result)
if err != nil {
log.Fatal().Err(err)
}
default:
2021-11-15 19:36:02 +01:00
//nolint
fmt.Println(override)
2021-11-14 16:46:09 +01:00
return
}
2021-11-15 19:36:02 +01:00
//nolint
fmt.Println(string(j))
}
func ErrorOutput(errResult error, override string, outputFormat string) {
type errOutput struct {
Error string `json:"error"`
}
SuccessOutput(errOutput{errResult.Error()}, override, outputFormat)
}
func HasMachineOutputFlag() bool {
for _, arg := range os.Args {
if arg == "json" || arg == "json-line" || arg == "yaml" {
return true
}
}
2021-11-14 16:46:09 +01:00
return false
}
2021-10-29 19:08:21 +02:00
type tokenAuth struct {
token string
}
// Return value is mapped to request headers.
2021-11-13 09:36:45 +01:00
func (t tokenAuth) GetRequestMetadata(
ctx context.Context,
in ...string,
) (map[string]string, error) {
2021-10-29 19:08:21 +02:00
return map[string]string{
"authorization": "Bearer " + t.token,
}, nil
}
func (tokenAuth) RequireTransportSecurity() bool {
return true
}
2021-10-31 10:40:43 +01:00
2021-10-18 21:27:52 +02:00
// loadOIDCMatchMap is a wrapper around viper to verifies that the keys in
// the match map is valid regex strings.
func loadOIDCMatchMap() map[string]string {
strMap := viper.GetStringMapString("oidc.domain_map")
for oidcMatcher := range strMap {
_ = regexp.MustCompile(oidcMatcher)
}
return strMap
}
func GetFileMode(key string) fs.FileMode {
modeStr := viper.GetString(key)
mode, err := strconv.ParseUint(modeStr, headscale.Base8, headscale.BitSize64)
if err != nil {
return PermissionFallback
}
return fs.FileMode(mode)
}