1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-01-22 00:11:47 +01:00

move Config definitions into types

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-06-06 10:23:39 +02:00 committed by Kristoffer Dalby
parent c72401a99b
commit 2289a2acbf
8 changed files with 34 additions and 31 deletions

View File

@ -5,7 +5,7 @@ import (
"os" "os"
"runtime" "runtime"
"github.com/juanfont/headscale/hscontrol" "github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -38,18 +38,18 @@ func initConfig() {
cfgFile = os.Getenv("HEADSCALE_CONFIG") cfgFile = os.Getenv("HEADSCALE_CONFIG")
} }
if cfgFile != "" { if cfgFile != "" {
err := hscontrol.LoadConfig(cfgFile, true) err := types.LoadConfig(cfgFile, true)
if err != nil { if err != nil {
log.Fatal().Caller().Err(err).Msgf("Error loading config file %s", cfgFile) log.Fatal().Caller().Err(err).Msgf("Error loading config file %s", cfgFile)
} }
} else { } else {
err := hscontrol.LoadConfig("", false) err := types.LoadConfig("", false)
if err != nil { if err != nil {
log.Fatal().Caller().Err(err).Msgf("Error loading config") log.Fatal().Caller().Err(err).Msgf("Error loading config")
} }
} }
cfg, err := hscontrol.GetHeadscaleConfig() cfg, err := types.GetHeadscaleConfig()
if err != nil { if err != nil {
log.Fatal().Caller().Err(err) log.Fatal().Caller().Err(err)
} }
@ -64,7 +64,7 @@ func initConfig() {
zerolog.SetGlobalLevel(zerolog.Disabled) zerolog.SetGlobalLevel(zerolog.Disabled)
} }
if cfg.Log.Format == hscontrol.JSONLogFormat { if cfg.Log.Format == types.JSONLogFormat {
log.Logger = log.Output(os.Stdout) log.Logger = log.Output(os.Stdout)
} }

View File

@ -11,6 +11,7 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol" "github.com/juanfont/headscale/hscontrol"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -25,7 +26,7 @@ const (
) )
func getHeadscaleApp() (*hscontrol.Headscale, error) { func getHeadscaleApp() (*hscontrol.Headscale, error) {
cfg, err := hscontrol.GetHeadscaleConfig() cfg, err := types.GetHeadscaleConfig()
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"failed to load configuration while creating headscale instance: %w", "failed to load configuration while creating headscale instance: %w",
@ -57,7 +58,7 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) {
} }
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) { func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
cfg, err := hscontrol.GetHeadscaleConfig() cfg, err := types.GetHeadscaleConfig()
if err != nil { if err != nil {
log.Fatal(). log.Fatal().
Err(err). Err(err).

View File

@ -7,7 +7,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/juanfont/headscale/hscontrol" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/spf13/viper" "github.com/spf13/viper"
"gopkg.in/check.v1" "gopkg.in/check.v1"
@ -51,7 +51,7 @@ func (*Suite) TestConfigFileLoading(c *check.C) {
} }
// Load example config, it should load without validation errors // Load example config, it should load without validation errors
err = hscontrol.LoadConfig(cfgFile, true) err = types.LoadConfig(cfgFile, true)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
// Test that config file was interpreted correctly // Test that config file was interpreted correctly
@ -94,7 +94,7 @@ func (*Suite) TestConfigLoading(c *check.C) {
} }
// Load example config, it should load without validation errors // Load example config, it should load without validation errors
err = hscontrol.LoadConfig(tmpDir, false) err = types.LoadConfig(tmpDir, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
// Test that config file was interpreted correctly // Test that config file was interpreted correctly
@ -138,10 +138,10 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
} }
// Load example config, it should load without validation errors // Load example config, it should load without validation errors
err = hscontrol.LoadConfig(tmpDir, false) err = types.LoadConfig(tmpDir, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
dnsConfig, baseDomain := hscontrol.GetDNSConfig() dnsConfig, baseDomain := types.GetDNSConfig()
c.Assert(dnsConfig.Nameservers[0].String(), check.Equals, "1.1.1.1") c.Assert(dnsConfig.Nameservers[0].String(), check.Equals, "1.1.1.1")
c.Assert(dnsConfig.Resolvers[0].Addr, check.Equals, "1.1.1.1") c.Assert(dnsConfig.Resolvers[0].Addr, check.Equals, "1.1.1.1")
@ -173,7 +173,7 @@ noise:
writeConfig(c, tmpDir, configYaml) writeConfig(c, tmpDir, configYaml)
// Check configuration validation errors (1) // Check configuration validation errors (1)
err = hscontrol.LoadConfig(tmpDir, false) err = types.LoadConfig(tmpDir, false)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
// check.Matches can not handle multiline strings // check.Matches can not handle multiline strings
tmp := strings.ReplaceAll(err.Error(), "\n", "***") tmp := strings.ReplaceAll(err.Error(), "\n", "***")
@ -202,6 +202,6 @@ tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: TLS-ALPN-01 tls_letsencrypt_challenge_type: TLS-ALPN-01
`) `)
writeConfig(c, tmpDir, configYaml) writeConfig(c, tmpDir, configYaml)
err = hscontrol.LoadConfig(tmpDir, false) err = types.LoadConfig(tmpDir, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
} }

View File

@ -75,7 +75,7 @@ const (
// Headscale represents the base app of the service. // Headscale represents the base app of the service.
type Headscale struct { type Headscale struct {
cfg *Config cfg *types.Config
db *db.HSDatabase db *db.HSDatabase
dbString string dbString string
dbType string dbType string
@ -102,7 +102,7 @@ type Headscale struct {
cancelStateUpdateChan chan struct{} cancelStateUpdateChan chan struct{}
} }
func NewHeadscale(cfg *Config) (*Headscale, error) { func NewHeadscale(cfg *types.Config) (*Headscale, error) {
privateKey, err := readOrCreatePrivateKey(cfg.PrivateKeyPath) privateKey, err := readOrCreatePrivateKey(cfg.PrivateKeyPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read or create private key: %w", err) return nil, fmt.Errorf("failed to read or create private key: %w", err)
@ -778,13 +778,13 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
} }
switch h.cfg.TLS.LetsEncrypt.ChallengeType { switch h.cfg.TLS.LetsEncrypt.ChallengeType {
case tlsALPN01ChallengeType: case types.TlsALPN01ChallengeType:
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737) // Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
// The RFC requires that the validation is done on port 443; in other words, headscale // The RFC requires that the validation is done on port 443; in other words, headscale
// must be reachable on port 443. // must be reachable on port 443.
return certManager.TLSConfig(), nil return certManager.TLSConfig(), nil
case http01ChallengeType: case types.Http01ChallengeType:
// Configuration via autocert with HTTP-01. This requires listening on // Configuration via autocert with HTTP-01. This requires listening on
// port 80 for the certificate validation in addition to the headscale // port 80 for the certificate validation in addition to the headscale
// service, which can be configured to run on any other port. // service, which can be configured to run on any other port.

View File

@ -9,6 +9,7 @@ import (
"os" "os"
"time" "time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -80,7 +81,7 @@ func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap {
return &result return &result
} }
func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap { func GetDERPMap(cfg types.DERPConfig) *tailcfg.DERPMap {
derpMaps := make([]*tailcfg.DERPMap, 0) derpMaps := make([]*tailcfg.DERPMap, 0)
for _, path := range cfg.Paths { for _, path := range cfg.Paths {

View File

@ -5,6 +5,7 @@ import (
"os" "os"
"testing" "testing"
"github.com/juanfont/headscale/hscontrol/types"
"gopkg.in/check.v1" "gopkg.in/check.v1"
) )
@ -38,7 +39,7 @@ func (s *Suite) ResetDB(c *check.C) {
if err != nil { if err != nil {
c.Fatal(err) c.Fatal(err)
} }
cfg := Config{ cfg := types.Config{
PrivateKeyPath: tmpDir + "/private.key", PrivateKeyPath: tmpDir + "/private.key",
NoisePrivateKeyPath: tmpDir + "/noise_private.key", NoisePrivateKeyPath: tmpDir + "/noise_private.key",
DBtype: "sqlite3", DBtype: "sqlite3",
@ -46,7 +47,7 @@ func (s *Suite) ResetDB(c *check.C) {
IPPrefixes: []netip.Prefix{ IPPrefixes: []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"), netip.MustParsePrefix("10.27.0.0/23"),
}, },
OIDC: OIDCConfig{ OIDC: types.OIDCConfig{
StripEmaildomain: false, StripEmaildomain: false,
}, },
} }

View File

@ -1,4 +1,4 @@
package hscontrol package types
import ( import (
"errors" "errors"
@ -23,8 +23,8 @@ import (
) )
const ( const (
tlsALPN01ChallengeType = "TLS-ALPN-01" TlsALPN01ChallengeType = "TLS-ALPN-01"
http01ChallengeType = "HTTP-01" Http01ChallengeType = "HTTP-01"
JSONLogFormat = "json" JSONLogFormat = "json"
TextLogFormat = "text" TextLogFormat = "text"
@ -165,7 +165,7 @@ func LoadConfig(path string, isFile bool) error {
viper.AutomaticEnv() viper.AutomaticEnv()
viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
viper.SetDefault("tls_letsencrypt_challenge_type", http01ChallengeType) viper.SetDefault("tls_letsencrypt_challenge_type", Http01ChallengeType)
viper.SetDefault("log.level", "info") viper.SetDefault("log.level", "info")
viper.SetDefault("log.format", TextLogFormat) viper.SetDefault("log.format", TextLogFormat)
@ -222,15 +222,15 @@ func LoadConfig(path string, isFile bool) error {
} }
if (viper.GetString("tls_letsencrypt_hostname") != "") && if (viper.GetString("tls_letsencrypt_hostname") != "") &&
(viper.GetString("tls_letsencrypt_challenge_type") == tlsALPN01ChallengeType) && (viper.GetString("tls_letsencrypt_challenge_type") == TlsALPN01ChallengeType) &&
(!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) { (!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) {
// this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule) // this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule)
log.Warn(). log.Warn().
Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443") Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443")
} }
if (viper.GetString("tls_letsencrypt_challenge_type") != http01ChallengeType) && if (viper.GetString("tls_letsencrypt_challenge_type") != Http01ChallengeType) &&
(viper.GetString("tls_letsencrypt_challenge_type") != tlsALPN01ChallengeType) { (viper.GetString("tls_letsencrypt_challenge_type") != TlsALPN01ChallengeType) {
errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n" errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n"
} }

View File

@ -14,7 +14,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/juanfont/headscale/hscontrol" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
@ -214,7 +214,7 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv(
return nil return nil
} }
func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*hscontrol.OIDCConfig, error) { func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConfig, error) {
port, err := dockertestutil.RandomFreeHostPort() port, err := dockertestutil.RandomFreeHostPort()
if err != nil { if err != nil {
log.Fatalf("could not find an open port: %s", err) log.Fatalf("could not find an open port: %s", err)
@ -288,7 +288,7 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*hscontrol.OIDC
log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint) log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint)
return &hscontrol.OIDCConfig{ return &types.OIDCConfig{
Issuer: fmt.Sprintf( Issuer: fmt.Sprintf(
"http://%s/oidc", "http://%s/oidc",
net.JoinHostPort(s.mockOIDC.GetIPInNetwork(s.network), strconv.Itoa(port)), net.JoinHostPort(s.mockOIDC.GetIPInNetwork(s.network), strconv.Itoa(port)),