1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-01-04 00:09:34 +01:00
juanfont.headscale/cmd/headscale/cli/utils.go

566 lines
16 KiB
Go
Raw Normal View History

package cli
import (
"context"
"crypto/tls"
"encoding/json"
2021-06-05 11:13:28 +02:00
"errors"
"fmt"
"io/fs"
2021-10-22 18:55:14 +02:00
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
2021-08-05 19:26:49 +02:00
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"google.golang.org/grpc"
2022-02-12 20:08:33 +01:00
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"gopkg.in/yaml.v2"
"inet.af/netaddr"
"tailscale.com/tailcfg"
2021-09-14 23:46:16 +02:00
"tailscale.com/types/dnstype"
)
const (
2022-01-25 23:11:15 +01:00
PermissionFallback = 0o700
HeadscaleDateTimeFormat = "2006-01-02 15:04:05"
)
2021-06-05 11:13:28 +02:00
func LoadConfig(path string) error {
viper.SetConfigName("config")
if path == "" {
viper.AddConfigPath("/etc/headscale/")
viper.AddConfigPath("$HOME/.headscale")
viper.AddConfigPath(".")
} else {
// For testing
viper.AddConfigPath(path)
}
viper.SetEnvPrefix("headscale")
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
2021-06-05 11:13:28 +02:00
viper.AutomaticEnv()
viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01")
2022-01-31 13:22:17 +01:00
viper.SetDefault("tls_client_auth_mode", "relaxed")
2021-06-05 11:13:28 +02:00
2021-08-20 18:15:07 +02:00
viper.SetDefault("log_level", "info")
2021-08-05 20:19:25 +02:00
2021-08-24 08:09:47 +02:00
viper.SetDefault("dns_config", nil)
viper.SetDefault("derp.server.enabled", false)
viper.SetDefault("derp.server.stun.enabled", true)
viper.SetDefault("unix_socket", "/var/run/headscale.sock")
viper.SetDefault("unix_socket_permission", "0o770")
2022-02-12 17:14:33 +01:00
viper.SetDefault("grpc_listen_addr", ":50443")
2022-02-13 10:08:46 +01:00
viper.SetDefault("grpc_allow_insecure", false)
2022-02-12 17:14:33 +01:00
viper.SetDefault("cli.timeout", "5s")
viper.SetDefault("cli.insecure", false)
viper.SetDefault("oidc.strip_email_domain", true)
2021-11-14 18:09:22 +01:00
if err := viper.ReadInConfig(); err != nil {
return fmt.Errorf("fatal error reading config file: %w", err)
2021-06-05 11:13:28 +02:00
}
// Collect any validation errors and return them all at once
var errorText string
2021-10-22 18:55:14 +02:00
if (viper.GetString("tls_letsencrypt_hostname") != "") &&
((viper.GetString("tls_cert_path") != "") || (viper.GetString("tls_key_path") != "")) {
2021-06-05 11:13:28 +02:00
errorText += "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both\n"
}
2021-10-22 18:55:14 +02:00
if (viper.GetString("tls_letsencrypt_hostname") != "") &&
(viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") &&
(!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) {
// this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule)
2021-08-05 19:26:49 +02:00
log.Warn().
Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443")
2021-06-05 11:13:28 +02:00
}
2021-10-22 18:55:14 +02:00
if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") &&
(viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") {
2021-06-05 11:13:28 +02:00
errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n"
}
2021-10-22 18:55:14 +02:00
if !strings.HasPrefix(viper.GetString("server_url"), "http://") &&
!strings.HasPrefix(viper.GetString("server_url"), "https://") {
2021-06-05 11:13:28 +02:00
errorText += "Fatal config error: server_url must start with https:// or http://\n"
}
2022-02-24 12:10:40 +01:00
_, authModeValid := headscale.LookupTLSClientAuthMode(
viper.GetString("tls_client_auth_mode"),
)
2022-02-20 15:06:14 +01:00
if !authModeValid {
2022-02-21 16:09:23 +01:00
errorText += fmt.Sprintf(
"Invalid tls_client_auth_mode supplied: %s. Accepted values: %s, %s, %s.",
viper.GetString("tls_client_auth_mode"),
headscale.DisabledClientAuth,
headscale.RelaxedClientAuth,
headscale.EnforcedClientAuth)
2022-01-29 20:15:33 +01:00
}
2021-06-05 11:13:28 +02:00
if errorText != "" {
2021-11-15 20:18:14 +01:00
//nolint
2021-06-05 11:13:28 +02:00
return errors.New(strings.TrimSuffix(errorText, "\n"))
} else {
return nil
}
2021-10-22 18:55:14 +02:00
}
func GetDERPConfig() headscale.DERPConfig {
2022-03-06 17:25:21 +01:00
serverEnabled := viper.GetBool("derp.server.enabled")
serverRegionID := viper.GetInt("derp.server.region_id")
serverRegionCode := viper.GetString("derp.server.region_code")
serverRegionName := viper.GetString("derp.server.region_name")
stunAddr := viper.GetString("derp.server.stun_listen_addr")
if serverEnabled && stunAddr == "" {
log.Fatal().Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true")
}
2021-10-22 18:55:14 +02:00
urlStrs := viper.GetStringSlice("derp.urls")
urls := make([]url.URL, len(urlStrs))
for index, urlStr := range urlStrs {
urlAddr, err := url.Parse(urlStr)
if err != nil {
log.Error().
Str("url", urlStr).
Err(err).
Msg("Failed to parse url, ignoring...")
}
urls[index] = *urlAddr
}
2021-08-24 08:09:47 +02:00
2021-10-22 18:55:14 +02:00
paths := viper.GetStringSlice("derp.paths")
autoUpdate := viper.GetBool("derp.auto_update_enabled")
updateFrequency := viper.GetDuration("derp.update_frequency")
return headscale.DERPConfig{
2022-03-06 17:25:21 +01:00
ServerEnabled: serverEnabled,
ServerRegionID: serverRegionID,
ServerRegionCode: serverRegionCode,
ServerRegionName: serverRegionName,
STUNAddr: stunAddr,
URLs: urls,
Paths: paths,
AutoUpdate: autoUpdate,
UpdateFrequency: updateFrequency,
2021-10-22 18:55:14 +02:00
}
2021-08-24 08:09:47 +02:00
}
func GetDNSConfig() (*tailcfg.DNSConfig, string) {
2021-08-24 08:09:47 +02:00
if viper.IsSet("dns_config") {
dnsConfig := &tailcfg.DNSConfig{}
if viper.IsSet("dns_config.nameservers") {
nameserversStr := viper.GetStringSlice("dns_config.nameservers")
nameservers := make([]netaddr.IP, len(nameserversStr))
2021-09-14 23:46:16 +02:00
resolvers := make([]dnstype.Resolver, len(nameserversStr))
2021-08-24 08:09:47 +02:00
for index, nameserverStr := range nameserversStr {
nameserver, err := netaddr.ParseIP(nameserverStr)
if err != nil {
log.Error().
Str("func", "getDNSConfig").
Err(err).
Msgf("Could not parse nameserver IP: %s", nameserverStr)
}
nameservers[index] = nameserver
2021-09-14 23:46:16 +02:00
resolvers[index] = dnstype.Resolver{
2021-08-25 19:43:13 +02:00
Addr: nameserver.String(),
2021-08-25 08:04:48 +02:00
}
2021-08-24 08:09:47 +02:00
}
dnsConfig.Nameservers = nameservers
2021-08-25 08:04:48 +02:00
dnsConfig.Resolvers = resolvers
2021-08-24 08:09:47 +02:00
}
if viper.IsSet("dns_config.restricted_nameservers") {
if len(dnsConfig.Nameservers) > 0 {
dnsConfig.Routes = make(map[string][]dnstype.Resolver)
2021-11-13 09:36:45 +01:00
restrictedDNS := viper.GetStringMapStringSlice(
"dns_config.restricted_nameservers",
)
for domain, restrictedNameservers := range restrictedDNS {
2021-11-13 09:36:45 +01:00
restrictedResolvers := make(
[]dnstype.Resolver,
len(restrictedNameservers),
)
for index, nameserverStr := range restrictedNameservers {
nameserver, err := netaddr.ParseIP(nameserverStr)
if err != nil {
log.Error().
Str("func", "getDNSConfig").
Err(err).
Msgf("Could not parse restricted nameserver IP: %s", nameserverStr)
}
restrictedResolvers[index] = dnstype.Resolver{
Addr: nameserver.String(),
}
}
dnsConfig.Routes[domain] = restrictedResolvers
}
} else {
log.Warn().
Msg("Warning: dns_config.restricted_nameservers is set, but no nameservers are configured. Ignoring restricted_nameservers.")
}
}
2021-08-24 08:09:47 +02:00
if viper.IsSet("dns_config.domains") {
dnsConfig.Domains = viper.GetStringSlice("dns_config.domains")
}
if viper.IsSet("dns_config.magic_dns") {
magicDNS := viper.GetBool("dns_config.magic_dns")
if len(dnsConfig.Nameservers) > 0 {
dnsConfig.Proxied = magicDNS
} else if magicDNS {
log.Warn().
Msg("Warning: dns_config.magic_dns is set, but no nameservers are configured. Ignoring magic_dns.")
}
2021-09-28 00:22:29 +02:00
}
var baseDomain string
if viper.IsSet("dns_config.base_domain") {
baseDomain = viper.GetString("dns_config.base_domain")
} else {
baseDomain = "headscale.net" // does not really matter when MagicDNS is not enabled
}
2021-10-02 13:03:08 +02:00
return dnsConfig, baseDomain
2021-08-24 08:09:47 +02:00
}
return nil, ""
2021-06-05 11:13:28 +02:00
}
func absPath(path string) string {
// If a relative path is provided, prefix it with the the directory where
// the config file was found.
if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) {
dir, _ := filepath.Split(viper.ConfigFileUsed())
if dir != "" {
path = filepath.Join(dir, path)
}
}
2021-11-14 16:46:09 +01:00
return path
}
func getHeadscaleConfig() headscale.Config {
dnsConfig, baseDomain := GetDNSConfig()
2021-10-22 18:55:14 +02:00
derpConfig := GetDERPConfig()
2022-01-16 14:16:59 +01:00
configuredPrefixes := viper.GetStringSlice("ip_prefixes")
parsedPrefixes := make([]netaddr.IPPrefix, 0, len(configuredPrefixes)+1)
legacyPrefixField := viper.GetString("ip_prefix")
if len(legacyPrefixField) > 0 {
log.
Warn().
Msgf(
"%s, %s",
"use of 'ip_prefix' for configuration is deprecated",
"please see 'ip_prefixes' in the shipped example.",
)
legacyPrefix, err := netaddr.ParseIPPrefix(legacyPrefixField)
if err != nil {
panic(fmt.Errorf("failed to parse ip_prefix: %w", err))
}
parsedPrefixes = append(parsedPrefixes, legacyPrefix)
}
2022-01-16 14:16:59 +01:00
for i, prefixInConfig := range configuredPrefixes {
prefix, err := netaddr.ParseIPPrefix(prefixInConfig)
if err != nil {
panic(fmt.Errorf("failed to parse ip_prefixes[%d]: %w", i, err))
}
parsedPrefixes = append(parsedPrefixes, prefix)
}
prefixes := make([]netaddr.IPPrefix, 0, len(parsedPrefixes))
{
// dedup
normalizedPrefixes := make(map[string]int, len(parsedPrefixes))
for i, p := range parsedPrefixes {
normalized, _ := p.Range().Prefix()
normalizedPrefixes[normalized.String()] = i
}
// convert back to list
for _, i := range normalizedPrefixes {
prefixes = append(prefixes, parsedPrefixes[i])
}
}
if len(prefixes) < 1 {
prefixes = append(prefixes, netaddr.MustParseIPPrefix("100.64.0.0/10"))
2022-01-25 23:11:15 +01:00
log.Warn().
Msgf("'ip_prefixes' not configured, falling back to default: %v", prefixes)
2022-01-16 14:16:59 +01:00
}
2022-02-21 16:09:23 +01:00
tlsClientAuthMode, _ := headscale.LookupTLSClientAuthMode(
viper.GetString("tls_client_auth_mode"),
)
2022-02-20 15:06:14 +01:00
return headscale.Config{
2022-02-13 10:08:46 +01:00
ServerURL: viper.GetString("server_url"),
Addr: viper.GetString("listen_addr"),
MetricsAddr: viper.GetString("metrics_listen_addr"),
2022-02-13 10:08:46 +01:00
GRPCAddr: viper.GetString("grpc_listen_addr"),
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
2022-01-16 14:16:59 +01:00
IPPrefixes: prefixes,
PrivateKeyPath: absPath(viper.GetString("private_key_path")),
BaseDomain: baseDomain,
2021-10-22 18:55:14 +02:00
DERP: derpConfig,
2021-11-13 09:36:45 +01:00
EphemeralNodeInactivityTimeout: viper.GetDuration(
"ephemeral_node_inactivity_timeout",
),
DBtype: viper.GetString("db_type"),
2021-05-19 01:28:47 +02:00
DBpath: absPath(viper.GetString("db_path")),
DBhost: viper.GetString("db_host"),
DBport: viper.GetInt("db_port"),
DBname: viper.GetString("db_name"),
DBuser: viper.GetString("db_user"),
DBpass: viper.GetString("db_pass"),
2021-11-13 09:36:45 +01:00
TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"),
TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"),
TLSLetsEncryptCacheDir: absPath(
viper.GetString("tls_letsencrypt_cache_dir"),
),
TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
2022-01-29 20:15:33 +01:00
TLSCertPath: absPath(viper.GetString("tls_cert_path")),
TLSKeyPath: absPath(viper.GetString("tls_key_path")),
2022-02-20 15:06:14 +01:00
TLSClientAuthMode: tlsClientAuthMode,
2021-08-24 08:09:47 +02:00
DNSConfig: dnsConfig,
2021-10-03 22:02:44 +02:00
ACMEEmail: viper.GetString("acme_email"),
ACMEURL: viper.GetString("acme_url"),
UnixSocket: viper.GetString("unix_socket"),
UnixSocketPermission: GetFileMode("unix_socket_permission"),
2021-10-31 10:40:43 +01:00
2021-10-18 21:27:52 +02:00
OIDC: headscale.OIDCConfig{
Issuer: viper.GetString("oidc.issuer"),
ClientID: viper.GetString("oidc.client_id"),
ClientSecret: viper.GetString("oidc.client_secret"),
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
2021-10-18 21:27:52 +02:00
},
2021-10-08 11:43:52 +02:00
CLI: headscale.CLIConfig{
Address: viper.GetString("cli.address"),
APIKey: viper.GetString("cli.api_key"),
Timeout: viper.GetDuration("cli.timeout"),
Insecure: viper.GetBool("cli.insecure"),
},
}
}
func getHeadscaleApp() (*headscale.Headscale, error) {
// Minimum inactivity time out is keepalive timeout (60s) plus a few seconds
// to avoid races
minInactivityTimeout, _ := time.ParseDuration("65s")
if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout {
2021-11-15 20:18:14 +01:00
// TODO: Find a better way to return this text
//nolint
err := fmt.Errorf(
"ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s",
viper.GetString("ephemeral_node_inactivity_timeout"),
minInactivityTimeout,
)
2021-11-14 16:46:09 +01:00
return nil, err
}
cfg := getHeadscaleConfig()
app, err := headscale.NewHeadscale(cfg)
if err != nil {
return nil, err
}
2021-07-04 13:24:05 +02:00
// We are doing this here, as in the future could be cool to have it also hot-reload
2021-07-11 15:10:11 +02:00
if viper.GetString("acl_policy_path") != "" {
2021-08-05 19:26:49 +02:00
aclPath := absPath(viper.GetString("acl_policy_path"))
err = app.LoadACLPolicy(aclPath)
2021-07-11 15:10:11 +02:00
if err != nil {
log.Fatal().
2021-08-05 21:57:47 +02:00
Str("path", aclPath).
2021-08-05 19:26:49 +02:00
Err(err).
Msg("Could not load the ACL policy")
2021-07-11 15:10:11 +02:00
}
2021-07-04 13:24:05 +02:00
}
return app, nil
}
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
cfg := getHeadscaleConfig()
log.Debug().
Dur("timeout", cfg.CLI.Timeout).
Msgf("Setting timeout")
ctx, cancel := context.WithTimeout(context.Background(), cfg.CLI.Timeout)
grpcOptions := []grpc.DialOption{
grpc.WithBlock(),
}
address := cfg.CLI.Address
// If the address is not set, we assume that we are on the server hosting headscale.
if address == "" {
log.Debug().
Str("socket", cfg.UnixSocket).
Msgf("HEADSCALE_CLI_ADDRESS environment is not set, connecting to unix socket.")
address = cfg.UnixSocket
grpcOptions = append(
grpcOptions,
2022-02-12 20:08:33 +01:00
grpc.WithTransportCredentials(insecure.NewCredentials()),
2021-10-30 16:29:03 +02:00
grpc.WithContextDialer(headscale.GrpcSocketDialer),
)
} else {
// If we are not connecting to a local server, require an API key for authentication
apiKey := cfg.CLI.APIKey
if apiKey == "" {
2022-01-25 23:11:15 +01:00
log.Fatal().Caller().Msgf("HEADSCALE_CLI_API_KEY environment variable needs to be set.")
}
grpcOptions = append(grpcOptions,
grpc.WithPerRPCCredentials(tokenAuth{
token: apiKey,
}),
)
if cfg.CLI.Insecure {
tlsConfig := &tls.Config{
2022-02-13 09:46:35 +01:00
// turn of gosec as we are intentionally setting
// insecure.
//nolint:gosec
InsecureSkipVerify: true,
}
grpcOptions = append(grpcOptions,
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
)
} else {
grpcOptions = append(grpcOptions,
grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")),
)
}
}
log.Trace().Caller().Str("address", address).Msg("Connecting via gRPC")
2021-10-29 19:36:11 +02:00
conn, err := grpc.DialContext(ctx, address, grpcOptions...)
if err != nil {
2022-01-25 23:11:15 +01:00
log.Fatal().Caller().Err(err).Msgf("Could not connect: %v", err)
}
client := v1.NewHeadscaleServiceClient(conn)
return ctx, client, conn, cancel
}
func SuccessOutput(result interface{}, override string, outputFormat string) {
2022-02-12 20:08:41 +01:00
var jsonBytes []byte
var err error
switch outputFormat {
case "json":
2022-02-12 20:08:41 +01:00
jsonBytes, err = json.MarshalIndent(result, "", "\t")
if err != nil {
log.Fatal().Err(err)
}
case "json-line":
2022-02-12 20:08:41 +01:00
jsonBytes, err = json.Marshal(result)
if err != nil {
log.Fatal().Err(err)
}
case "yaml":
2022-02-12 20:08:41 +01:00
jsonBytes, err = yaml.Marshal(result)
if err != nil {
log.Fatal().Err(err)
}
default:
2021-11-15 19:36:02 +01:00
//nolint
fmt.Println(override)
2021-11-14 16:46:09 +01:00
return
}
2021-11-15 19:36:02 +01:00
//nolint
2022-02-12 20:08:41 +01:00
fmt.Println(string(jsonBytes))
}
func ErrorOutput(errResult error, override string, outputFormat string) {
type errOutput struct {
Error string `json:"error"`
}
SuccessOutput(errOutput{errResult.Error()}, override, outputFormat)
}
func HasMachineOutputFlag() bool {
for _, arg := range os.Args {
if arg == "json" || arg == "json-line" || arg == "yaml" {
return true
}
}
2021-11-14 16:46:09 +01:00
return false
}
2021-10-29 19:08:21 +02:00
type tokenAuth struct {
token string
}
// Return value is mapped to request headers.
2021-11-13 09:36:45 +01:00
func (t tokenAuth) GetRequestMetadata(
ctx context.Context,
in ...string,
) (map[string]string, error) {
2021-10-29 19:08:21 +02:00
return map[string]string{
"authorization": "Bearer " + t.token,
}, nil
}
func (tokenAuth) RequireTransportSecurity() bool {
return true
2021-10-29 19:08:21 +02:00
}
2021-10-31 10:40:43 +01:00
func GetFileMode(key string) fs.FileMode {
modeStr := viper.GetString(key)
mode, err := strconv.ParseUint(modeStr, headscale.Base8, headscale.BitSize64)
if err != nil {
return PermissionFallback
}
return fs.FileMode(mode)
}