diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go index f9dcaa17..776ffb92 100644 --- a/cmd/headscale/headscale.go +++ b/cmd/headscale/headscale.go @@ -1,6 +1,7 @@ package main import ( + "errors" "fmt" "io" "log" @@ -249,7 +250,7 @@ var createPreAuthKeyCmd = &cobra.Command{ }, } -func loadConfig(path string) { +func loadConfig(path string) error { viper.SetConfigName("config") if path == "" { viper.AddConfigPath("/etc/headscale/") @@ -266,28 +267,38 @@ func loadConfig(path string) { err := viper.ReadInConfig() if err != nil { - log.Fatalf("Fatal error config file: %s \n", err) + return errors.New(fmt.Sprintf("Fatal error reading config file: %s \n", err)) } + // Collect any validation errors and return them all at once + var errorText string if (viper.GetString("tls_letsencrypt_hostname") != "") && ((viper.GetString("tls_cert_path") != "") || (viper.GetString("tls_key_path") != "")) { - log.Fatalf("Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both") + errorText += "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both\n" } if (viper.GetString("tls_letsencrypt_hostname") != "") && (viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") && (!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) { - log.Fatalf("Fatal config error: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, listen_addr must end in :443") + errorText += "Fatal config error: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, listen_addr must end in :443\n" } if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") && (viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") { - log.Fatalf("Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01") + errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n" } if !strings.HasPrefix(viper.GetString("server_url"), "http://") && !strings.HasPrefix(viper.GetString("server_url"), "https://") { - log.Fatalf("Fatal config error: server_url must start with https:// or http://") + errorText += "Fatal config error: server_url must start with https:// or http://\n" + } + if errorText != "" { + return errors.New(strings.TrimSuffix(errorText, "\n")) + } else { + return nil } } func main() { - loadConfig("") + err := loadConfig("") + if err != nil { + log.Fatalf(err.Error()) + } headscaleCmd.AddCommand(versionCmd) headscaleCmd.AddCommand(serveCmd) diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index c1fa3c07..a3894f62 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -1,9 +1,11 @@ package main import ( + "fmt" "io/ioutil" "os" "path/filepath" + "strings" "testing" "github.com/spf13/viper" @@ -43,8 +45,9 @@ func (*Suite) TestConfigLoading(c *check.C) { c.Fatal(err) } - // Load config - loadConfig(tmpDir) + // Load example config, it should load without validation errors + err = loadConfig(tmpDir) + c.Assert(err, check.IsNil) // Test that config file was interpreted correctly c.Assert(viper.GetString("server_url"), check.Equals, "http://192.168.1.12:8000") @@ -54,3 +57,42 @@ func (*Suite) TestConfigLoading(c *check.C) { c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") } + +func writeConfig(c *check.C, tmpDir string, configYaml []byte) { + // Populate a custom config file + configFile := filepath.Join(tmpDir, "config.yaml") + err := ioutil.WriteFile(configFile, configYaml, 0644) + if err != nil { + c.Fatalf("Couldn't write file %s", configFile) + } +} + +func (*Suite) TestTLSConfigValidation(c *check.C) { + tmpDir, err := ioutil.TempDir("", "headscale") + if err != nil { + c.Fatal(err) + } + //defer os.RemoveAll(tmpDir) + fmt.Println(tmpDir) + + configYaml := []byte("---\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"\"\ntls_cert_path: \"abc.pem\"") + writeConfig(c, tmpDir, configYaml) + + // Check configuration validation errors (1) + err = loadConfig(tmpDir) + c.Assert(err, check.NotNil) + // check.Matches can not handle multiline strings + tmp := strings.ReplaceAll(err.Error(), "\n", "***") + c.Assert(tmp, check.Matches, ".*Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both.*") + c.Assert(tmp, check.Matches, ".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*") + c.Assert(tmp, check.Matches, ".*Fatal config error: server_url must start with https:// or http://.*") + fmt.Println(tmp) + + // Check configuration validation errors (2) + configYaml = []byte("---\nserver_url: \"http://192.168.1.12:8000\"\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"TLS-ALPN-01\"") + fmt.Printf(string(configYaml)) + writeConfig(c, tmpDir, configYaml) + err = loadConfig(tmpDir) + c.Assert(err, check.NotNil) + c.Assert(err, check.ErrorMatches, "Fatal config error: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, listen_addr must end in :443.*") +}