package cli import ( "context" "crypto/tls" "encoding/json" "fmt" "os" "reflect" "time" "github.com/juanfont/headscale" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/rs/zerolog/log" "github.com/spf13/viper" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "gopkg.in/yaml.v2" ) const ( HeadscaleDateTimeFormat = "2006-01-02 15:04:05" ) func getHeadscaleApp() (*headscale.Headscale, error) { // Minimum inactivity time out is keepalive timeout (60s) plus a few seconds // to avoid races minInactivityTimeout, _ := time.ParseDuration("65s") if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout { // TODO: Find a better way to return this text //nolint err := fmt.Errorf( "ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s", viper.GetString("ephemeral_node_inactivity_timeout"), minInactivityTimeout, ) return nil, err } cfg := headscale.GetHeadscaleConfig() app, err := headscale.NewHeadscale(cfg) if err != nil { return nil, err } // We are doing this here, as in the future could be cool to have it also hot-reload if cfg.ACL.PolicyPath != "" { aclPath := headscale.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath) err = app.LoadACLPolicy(aclPath) if err != nil { log.Fatal(). Str("path", aclPath). Err(err). Msg("Could not load the ACL policy") } } return app, nil } func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) { cfg := headscale.GetHeadscaleConfig() log.Debug(). Dur("timeout", cfg.CLI.Timeout). Msgf("Setting timeout") ctx, cancel := context.WithTimeout(context.Background(), cfg.CLI.Timeout) grpcOptions := []grpc.DialOption{ grpc.WithBlock(), } address := cfg.CLI.Address // If the address is not set, we assume that we are on the server hosting headscale. if address == "" { log.Debug(). Str("socket", cfg.UnixSocket). Msgf("HEADSCALE_CLI_ADDRESS environment is not set, connecting to unix socket.") address = cfg.UnixSocket grpcOptions = append( grpcOptions, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(headscale.GrpcSocketDialer), ) } else { // If we are not connecting to a local server, require an API key for authentication apiKey := cfg.CLI.APIKey if apiKey == "" { log.Fatal().Caller().Msgf("HEADSCALE_CLI_API_KEY environment variable needs to be set.") } grpcOptions = append(grpcOptions, grpc.WithPerRPCCredentials(tokenAuth{ token: apiKey, }), ) if cfg.CLI.Insecure { tlsConfig := &tls.Config{ // turn of gosec as we are intentionally setting // insecure. //nolint:gosec InsecureSkipVerify: true, } grpcOptions = append(grpcOptions, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), ) } else { grpcOptions = append(grpcOptions, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")), ) } } log.Trace().Caller().Str("address", address).Msg("Connecting via gRPC") conn, err := grpc.DialContext(ctx, address, grpcOptions...) if err != nil { log.Fatal().Caller().Err(err).Msgf("Could not connect: %v", err) } client := v1.NewHeadscaleServiceClient(conn) return ctx, client, conn, cancel } func SuccessOutput(result interface{}, override string, outputFormat string) { var jsonBytes []byte var err error switch outputFormat { case "json": jsonBytes, err = json.MarshalIndent(result, "", "\t") if err != nil { log.Fatal().Err(err) } case "json-line": jsonBytes, err = json.Marshal(result) if err != nil { log.Fatal().Err(err) } case "yaml": jsonBytes, err = yaml.Marshal(result) if err != nil { log.Fatal().Err(err) } default: //nolint fmt.Println(override) return } //nolint fmt.Println(string(jsonBytes)) } func ErrorOutput(errResult error, override string, outputFormat string) { type errOutput struct { Error string `json:"error"` } SuccessOutput(errOutput{errResult.Error()}, override, outputFormat) } func HasMachineOutputFlag() bool { for _, arg := range os.Args { if arg == "json" || arg == "json-line" || arg == "yaml" { return true } } return false } type tokenAuth struct { token string } // Return value is mapped to request headers. func (t tokenAuth) GetRequestMetadata( ctx context.Context, in ...string, ) (map[string]string, error) { return map[string]string{ "authorization": "Bearer " + t.token, }, nil } func (tokenAuth) RequireTransportSecurity() bool { return true } func contains[T string](ts []T, t T) bool { for _, v := range ts { if reflect.DeepEqual(v, t) { return true } } return false }