From 996330b2a8314049e7a1b913ed3cafa098d38945 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 9 Feb 2026 16:27:37 +0100 Subject: [PATCH 01/17] app: change gorilla to chi mux, add dummy endpoints Signed-off-by: Kristoffer Dalby --- flake.nix | 2 +- go.mod | 2 + go.sum | 4 ++ hscontrol/app.go | 71 +++++++++++++++++------------- hscontrol/noise.go | 106 +++++++++++++++++++++++++++++++++++---------- 5 files changed, 130 insertions(+), 55 deletions(-) diff --git a/flake.nix b/flake.nix index 210b888e..7a47787c 100644 --- a/flake.nix +++ b/flake.nix @@ -27,7 +27,7 @@ let pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system}; buildGo = pkgs.buildGo126Module; - vendorHash = "sha256-9BvphYDAxzwooyVokI3l+q1wRuRsWn/qM+NpWUgqJH0="; + vendorHash = "sha256-oUN53ELb3+xn4yA7lEfXyT2c7NxbQC6RtbkGVq6+RLU="; in { headscale = buildGo { diff --git a/go.mod b/go.mod index c99d4ddd..3adc7e48 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,8 @@ require ( github.com/docker/docker v28.5.2+incompatible github.com/fsnotify/fsnotify v1.9.0 github.com/glebarez/sqlite v1.11.0 + github.com/go-chi/chi/v5 v5.2.5 + github.com/go-chi/metrics v0.1.1 github.com/go-gormigrate/gormigrate/v2 v2.1.5 github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e github.com/gofrs/uuid/v5 v5.4.0 diff --git a/go.sum b/go.sum index e9c39e36..4c5f48ac 100644 --- a/go.sum +++ b/go.sum @@ -181,6 +181,10 @@ github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc= github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= +github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= +github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= +github.com/go-chi/metrics v0.1.1 h1:CXhbnkAVVjb0k73EBRQ6Z2YdWFnbXZgNtg1Mboguibk= +github.com/go-chi/metrics v0.1.1/go.mod h1:mcGTM1pPalP7WCtb+akNYFO/lwNwBBLCuedepqjoPn4= github.com/go-gormigrate/gormigrate/v2 v2.1.5 h1:1OyorA5LtdQw12cyJDEHuTrEV3GiXiIhS4/QTTa/SM8= github.com/go-gormigrate/gormigrate/v2 v2.1.5/go.mod h1:mj9ekk/7CPF3VjopaFvWKN2v7fN3D9d3eEOAXRhi/+M= github.com/go-jose/go-jose/v3 v3.0.4 h1:Wp5HA7bLQcKnf6YYao/4kpRpVMp/yf6+pJKV8WFSaNY= diff --git a/hscontrol/app.go b/hscontrol/app.go index abd29a45..4affb6e0 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -20,7 +20,9 @@ import ( "github.com/cenkalti/backoff/v5" "github.com/davecgh/go-spew/spew" - "github.com/gorilla/mux" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/metrics" grpcRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/juanfont/headscale" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -457,50 +459,57 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error { return os.Remove(h.cfg.UnixSocket) } -func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { - router := mux.NewRouter() - router.Use(prometheusMiddleware) +func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux { + r := chi.NewRouter() + r.Use(metrics.Collector(metrics.CollectorOpts{ + Host: false, + Proto: true, + Skip: func(r *http.Request) bool { + return r.Method != http.MethodOptions + }, + })) + r.Use(middleware.RequestID) + r.Use(middleware.RealIP) + r.Use(middleware.Logger) + r.Use(middleware.Recoverer) - router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler). - Methods(http.MethodPost, http.MethodGet) + r.Post(ts2021UpgradePath, h.NoiseUpgradeHandler) - router.HandleFunc("/robots.txt", h.RobotsHandler).Methods(http.MethodGet) - router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet) - router.HandleFunc("/version", h.VersionHandler).Methods(http.MethodGet) - router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet) - router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler). - Methods(http.MethodGet) + r.Get("/robots.txt", h.RobotsHandler) + r.Get("/health", h.HealthHandler) + r.Get("/version", h.VersionHandler) + r.Get("/key", h.KeyHandler) + r.Get("/register/{registration_id}", h.authProvider.RegisterHandler) if provider, ok := h.authProvider.(*AuthProviderOIDC); ok { - router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet) + r.Get("/oidc/callback", provider.OIDCCallbackHandler) } - router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet) - router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig). - Methods(http.MethodGet) - router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet) + r.Get("/apple", h.AppleConfigMessage) + r.Get("/apple/{platform}", h.ApplePlatformConfig) + r.Get("/windows", h.WindowsConfigMessage) // TODO(kristoffer): move swagger into a package - router.HandleFunc("/swagger", headscale.SwaggerUI).Methods(http.MethodGet) - router.HandleFunc("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1). - Methods(http.MethodGet) + r.Get("/swagger", headscale.SwaggerUI) + r.Get("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1) - router.HandleFunc("/verify", h.VerifyHandler).Methods(http.MethodPost) + r.Post("/verify", h.VerifyHandler) if h.cfg.DERP.ServerEnabled { - router.HandleFunc("/derp", h.DERPServer.DERPHandler) - router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler) - router.HandleFunc("/derp/latency-check", derpServer.DERPProbeHandler) - router.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.state.DERPMap())) + r.HandleFunc("/derp", h.DERPServer.DERPHandler) + r.HandleFunc("/derp/probe", derpServer.DERPProbeHandler) + r.HandleFunc("/derp/latency-check", derpServer.DERPProbeHandler) + r.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.state.DERPMap())) } - apiRouter := router.PathPrefix("/api").Subrouter() - apiRouter.Use(h.httpAuthenticationMiddleware) - apiRouter.PathPrefix("/v1/").HandlerFunc(grpcMux.ServeHTTP) - router.HandleFunc("/favicon.ico", FaviconHandler) - router.PathPrefix("/").HandlerFunc(BlankHandler) + r.Route("/api", func(r chi.Router) { + r.Use(h.httpAuthenticationMiddleware) + r.HandleFunc("/v1/*", grpcMux.ServeHTTP) + }) + r.Get("/favicon.ico", FaviconHandler) + r.Get("/", BlankHandler) - return router + return r } // Serve launches the HTTP and gRPC server service Headscale and the API. diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 2880f33a..57a79b96 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -8,7 +8,9 @@ import ( "io" "net/http" - "github.com/gorilla/mux" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/metrics" "github.com/juanfont/headscale/hscontrol/capver" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" @@ -69,7 +71,7 @@ func (h *Headscale) NoiseUpgradeHandler( return } - noiseServer := noiseServer{ + ns := noiseServer{ headscale: h, challenge: key.NewChallenge(), } @@ -79,42 +81,88 @@ func (h *Headscale) NoiseUpgradeHandler( writer, req, *h.noisePrivateKey, - noiseServer.earlyNoise, + ns.earlyNoise, ) if err != nil { httpError(writer, fmt.Errorf("upgrading noise connection: %w", err)) return } - noiseServer.conn = noiseConn - noiseServer.machineKey = noiseServer.conn.Peer() - noiseServer.protocolVersion = noiseServer.conn.ProtocolVersion() + ns.conn = noiseConn + ns.machineKey = ns.conn.Peer() + ns.protocolVersion = ns.conn.ProtocolVersion() // This router is served only over the Noise connection, and exposes only the new API. // // The HTTP2 server that exposes this router is created for // a single hijacked connection from /ts2021, using netutil.NewOneConnListener - router := mux.NewRouter() - router.Use(prometheusMiddleware) - router.HandleFunc("/machine/register", noiseServer.NoiseRegistrationHandler). - Methods(http.MethodPost) + r := chi.NewRouter() + r.Use(metrics.Collector(metrics.CollectorOpts{ + Host: false, + Proto: true, + Skip: func(r *http.Request) bool { + return r.Method != http.MethodOptions + }, + })) + r.Use(middleware.RequestID) + r.Use(middleware.RealIP) + r.Use(middleware.Logger) + r.Use(middleware.Recoverer) - // Endpoints outside of the register endpoint must use getAndValidateNode to - // get the node to ensure that the MachineKey matches the Node setting up the - // connection. - router.HandleFunc("/machine/map", noiseServer.NoisePollNetMapHandler) + r.Handle("/metrics", metrics.Handler()) - noiseServer.httpBaseConfig = &http.Server{ - Handler: router, + r.Route("/machine", func(r chi.Router) { + r.Post("/register", ns.RegistrationHandler) + r.Post("/map", ns.PollNetMapHandler) + + // Not implemented yet + // + // /whoami is a debug endpoint to validate that the client can communicate over the connection, + // not clear if there is a specific response, it looks like it is just logged. + // https://github.com/tailscale/tailscale/blob/dfba01ca9bd8c4df02c3c32f400d9aeb897c5fc7/cmd/tailscale/cli/debug.go#L1138 + r.Get("/whoami", ns.NotImplementedHandler) + + // client sends a [tailcfg.SetDNSRequest] to this endpoints and expect + // the server to create or update this DNS record "somewhere". + // It is typically a TXT record for an ACME challenge. + r.Post("/set-dns", ns.NotImplementedHandler) + + // A patch of [tailcfg.SetDeviceAttributesRequest] to update device attributes. + // We currently do not support device attributes. + r.Patch("/set-device-attr", ns.NotImplementedHandler) + + // A [tailcfg.AuditLogRequest] to send audit log entries to the server. + // The server is expected to store them "somewhere". + // We currently do not support device attributes. + r.Post("/audit-log", ns.NotImplementedHandler) + + // handles requests to get an OIDC ID token. Receives a [tailcfg.TokenRequest]. + r.Post("/id-token", ns.NotImplementedHandler) + + // Asks the server if a feature is available and receive information about how to enable it. + // Gets a [tailcfg.QueryFeatureRequest] and returns a [tailcfg.QueryFeatureResponse]. + r.Post("/feature/query", ns.NotImplementedHandler) + + r.Post("/update-health", ns.NotImplementedHandler) + + r.Route("/webclient", func(r chi.Router) {}) + }) + + r.Post("/c2n", ns.NotImplementedHandler) + + r.Get("/ssh-action", ns.SSHAction) + + ns.httpBaseConfig = &http.Server{ + Handler: r, ReadHeaderTimeout: types.HTTPTimeout, } - noiseServer.http2Server = &http2.Server{} + ns.http2Server = &http2.Server{} - noiseServer.http2Server.ServeConn( + ns.http2Server.ServeConn( noiseConn, &http2.ServeConnOpts{ - BaseConfig: noiseServer.httpBaseConfig, + BaseConfig: ns.httpBaseConfig, }, ) } @@ -189,7 +237,19 @@ func rejectUnsupported( return false } -// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol +func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *http.Request) { + d, _ := io.ReadAll(req.Body) + log.Trace().Caller().Str("path", req.URL.String()).Bytes("body", d).Msgf("not implemented handler hit") + http.Error(writer, "Not implemented yet", http.StatusNotImplemented) +} + +// SSHAction handles the /ssh-action endpoint, it returns a [tailcfg.SSHAction] +// to the client with the verdict of an SSH access request. +func (ns *noiseServer) SSHAction(writer http.ResponseWriter, req *http.Request) { + log.Trace().Caller().Str("path", req.URL.String()).Msg("got SSH action request") +} + +// PollNetMapHandler takes care of /machine/:id/map using the Noise protocol // // This is the busiest endpoint, as it keeps the HTTP long poll that updates // the clients when something in the network changes. @@ -198,7 +258,7 @@ func rejectUnsupported( // only after their first request (marked with the ReadOnly field). // // At this moment the updates are sent in a quite horrendous way, but they kinda work. -func (ns *noiseServer) NoisePollNetMapHandler( +func (ns *noiseServer) PollNetMapHandler( writer http.ResponseWriter, req *http.Request, ) { @@ -237,8 +297,8 @@ func regErr(err error) *tailcfg.RegisterResponse { return &tailcfg.RegisterResponse{Error: err.Error()} } -// NoiseRegistrationHandler handles the actual registration process of a node. -func (ns *noiseServer) NoiseRegistrationHandler( +// RegistrationHandler handles the actual registration process of a node. +func (ns *noiseServer) RegistrationHandler( writer http.ResponseWriter, req *http.Request, ) { From badbb7550db0ee3a8f44e736ec698609283e668b Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 10 Feb 2026 16:46:17 +0100 Subject: [PATCH 02/17] build: update golangci-lint and gopls in flake --- .pre-commit-config.yaml | 2 +- flake.nix | 11 +++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b05f2566..f0242a4e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,7 +48,7 @@ repos: # golangci-lint for Go code quality - id: golangci-lint name: golangci-lint - entry: golangci-lint run --new-from-rev=HEAD~1 --timeout=5m --fix + entry: nix develop --command -- golangci-lint run --new-from-rev=HEAD~1 --timeout=5m --fix language: system types: [go] pass_filenames: false diff --git a/flake.nix b/flake.nix index 7a47787c..ae02d0ff 100644 --- a/flake.nix +++ b/flake.nix @@ -135,11 +135,6 @@ }; }; - # The package uses buildGo125Module, not the convention. - # goreleaser = prev.goreleaser.override { - # buildGoModule = buildGo; - # }; - gotestsum = prev.gotestsum.override { buildGoModule = buildGo; }; @@ -152,9 +147,9 @@ buildGoModule = buildGo; }; - # gopls = prev.gopls.override { - # buildGoModule = buildGo; - # }; + gopls = prev.gopls.override { + buildGoLatestModule = buildGo; + }; }; } // flake-utils.lib.eachDefaultSystem From 0291fa8644f25e0530a95ff86904a2ae887758e2 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 10 Feb 2026 13:45:14 +0100 Subject: [PATCH 03/17] {policy, noise}: initial SSH check poc This is a rudimental version, it will call out to headscale to ask what to do over internal noise connection and log the request. For now we always return an accept, meaning that the test will pass ass we essentially have implemented "accept" with an extra step. Next is to actually "check something" Updates #1850 Signed-off-by: Kristoffer Dalby --- .github/workflows/test-integration.yaml | 1 + hscontrol/noise.go | 42 +++++++++++++-- hscontrol/policy/v2/filter.go | 32 ++++++++--- integration/ssh_test.go | 72 +++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 11 deletions(-) diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 1dfd10ee..7e059045 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -253,6 +253,7 @@ jobs: - TestSSHIsBlockedInACL - TestSSHUserOnlyIsolation - TestSSHAutogroupSelf + - TestSSHOneUserToOneCheckMode - TestTagsAuthKeyWithTagRequestDifferentTag - TestTagsAuthKeyWithTagNoAdvertiseFlag - TestTagsAuthKeyWithTagCannotAddViaCLI diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 57a79b96..9a8814ca 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -116,6 +116,8 @@ func (h *Headscale) NoiseUpgradeHandler( r.Post("/register", ns.RegistrationHandler) r.Post("/map", ns.PollNetMapHandler) + r.Get("/ssh/action/from/{src_node_id}/to/{dst_node_id}/ssh_user/{ssh_user}/local_user/{local_user}", ns.SSHAction) + // Not implemented yet // // /whoami is a debug endpoint to validate that the client can communicate over the connection, @@ -147,12 +149,10 @@ func (h *Headscale) NoiseUpgradeHandler( r.Post("/update-health", ns.NotImplementedHandler) r.Route("/webclient", func(r chi.Router) {}) + + r.Post("/c2n", ns.NotImplementedHandler) }) - r.Post("/c2n", ns.NotImplementedHandler) - - r.Get("/ssh-action", ns.SSHAction) - ns.httpBaseConfig = &http.Server{ Handler: r, ReadHeaderTimeout: types.HTTPTimeout, @@ -246,7 +246,39 @@ func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *ht // SSHAction handles the /ssh-action endpoint, it returns a [tailcfg.SSHAction] // to the client with the verdict of an SSH access request. func (ns *noiseServer) SSHAction(writer http.ResponseWriter, req *http.Request) { - log.Trace().Caller().Str("path", req.URL.String()).Msg("got SSH action request") + srcNodeID := chi.URLParam(req, "src_node_id") + dstNodeID := chi.URLParam(req, "dst_node_id") + sshUser := chi.URLParam(req, "ssh_user") + localUser := chi.URLParam(req, "local_user") + log.Trace().Caller(). + Str("path", req.URL.String()). + Str("src_node_id", srcNodeID). + Str("dst_node_id", dstNodeID). + Str("ssh_user", sshUser). + Str("local_user", localUser). + Msg("got SSH action request") + + accept := tailcfg.SSHAction{ + Reject: false, + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + } + + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + + err := json.NewEncoder(writer).Encode(accept) + if err != nil { + log.Error().Caller().Err(err).Msg("failed to encode SSH action response") + return + } + + // Ensure response is flushed to client + if flusher, ok := writer.(http.Flusher); ok { + flusher.Flush() + } } // PollNetMapHandler takes care of /machine/:id/map using the Noise protocol diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 9c2c5f17..d75e1914 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -12,6 +12,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" + "github.com/spf13/viper" "go4.org/netipx" "tailscale.com/tailcfg" "tailscale.com/types/views" @@ -319,11 +320,27 @@ func (pol *Policy) compileACLWithAutogroupSelf( return rules, nil } -func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction { +var sshAccept = tailcfg.SSHAction{ + Reject: false, + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, +} + +func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction { return tailcfg.SSHAction{ - Reject: !accept, - Accept: accept, - SessionDuration: duration, + Reject: false, + Accept: false, + SessionDuration: duration, + // Replaced in the client: + // * $SRC_NODE_IP (URL escaped) + // * $SRC_NODE_ID (Node.ID as int64 string) + // * $DST_NODE_IP (URL escaped) + // * $DST_NODE_ID (Node.ID as int64 string) + // * $SSH_USER (URL escaped, ssh user requested) + // * $LOCAL_USER (URL escaped, local user mapped) + HoldAndDelegate: fmt.Sprintf("%s/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID/ssh_user/$SSH_USER/local_user/$LOCAL_USER", baseURL), AllowAgentForwarding: true, AllowLocalPortForwarding: true, AllowRemotePortForwarding: true, @@ -375,11 +392,14 @@ func (pol *Policy) compileSSHPolicy( var action tailcfg.SSHAction + // HACK HACK HACK + serverURL := viper.GetString("server_url") + switch rule.Action { case SSHActionAccept: - action = sshAction(true, 0) + action = sshAccept case SSHActionCheck: - action = sshAction(true, time.Duration(rule.CheckPeriod)) + action = sshCheck(serverURL, time.Duration(rule.CheckPeriod)) default: return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) } diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 45bc2dc7..15867579 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -579,3 +579,75 @@ func TestSSHAutogroupSelf(t *testing.T) { } } } + +func TestSSHOneUserToOneCheckMode(t *testing.T) { + IntegrationSkip(t) + + scenario := sshScenario(t, + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-test"): []policyv2.Username{policyv2.Username("user1@")}, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + SSHs: []policyv2.SSH{ + { + Action: "check", + Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")}, + // Use autogroup:member and autogroup:tagged instead of wildcard + // since wildcard (*) is no longer supported for SSH destinations + Destinations: policyv2.SSHDstAliases{ + new(policyv2.AutoGroupMember), + new(policyv2.AutoGroupTagged), + }, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, + }, + }, + }, + 1, + ) + // defer scenario.ShutdownAssertNoPanics(t) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + requireNoErrListClients(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + requireNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + requireNoErrListFQDN(t, err) + + for _, client := range user1Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + assertSSHHostname(t, client, peer) + } + } + + for _, client := range user2Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + assertSSHPermissionDenied(t, client, peer) + } + } +} From d1364194ef1f3eecd7d4126deef629030c140d17 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 10 Feb 2026 13:59:12 +0100 Subject: [PATCH 04/17] policy: patch serverURL into ssh policy Signed-off-by: Kristoffer Dalby --- hscontrol/policy/pm.go | 2 +- hscontrol/policy/policy_test.go | 5 ++-- hscontrol/policy/v2/filter.go | 7 ++---- hscontrol/policy/v2/filter_test.go | 37 ++++++++++++++++-------------- hscontrol/policy/v2/policy.go | 4 ++-- hscontrol/state/state.go | 2 +- 6 files changed, 29 insertions(+), 28 deletions(-) diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 6dfacd91..2de2e8dd 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -19,7 +19,7 @@ type PolicyManager interface { MatchersForNode(node types.NodeView) ([]matcher.Match, error) // BuildPeerMap constructs peer relationship maps for the given nodes BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView - SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) + SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error) SetPolicy(pol []byte) (bool, error) SetUsers(users []types.User) (bool, error) SetNodes(nodes views.Slice[types.NodeView]) (bool, error) diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index 9c97e39c..536c86f3 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -1188,8 +1188,9 @@ func TestSSHPolicyRules(t *testing.T) { "root": "", }, Action: &tailcfg.SSHAction{ - Accept: true, + Accept: false, SessionDuration: 24 * time.Hour, + HoldAndDelegate: "unused-url/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER", AllowAgentForwarding: true, AllowLocalPortForwarding: true, AllowRemotePortForwarding: true, @@ -1476,7 +1477,7 @@ func TestSSHPolicyRules(t *testing.T) { require.NoError(t, err) - got, err := pm.SSHPolicy(tt.targetNode.View()) + got, err := pm.SSHPolicy("unused-url", tt.targetNode.View()) require.NoError(t, err) if diff := cmp.Diff(tt.wantSSH, got); diff != "" { diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index d75e1914..526a0cb1 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -12,7 +12,6 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" - "github.com/spf13/viper" "go4.org/netipx" "tailscale.com/tailcfg" "tailscale.com/types/views" @@ -349,6 +348,7 @@ func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction { //nolint:gocyclo // complex SSH policy compilation logic func (pol *Policy) compileSSHPolicy( + baseURL string, users types.Users, node types.NodeView, nodes views.Slice[types.NodeView], @@ -392,14 +392,11 @@ func (pol *Policy) compileSSHPolicy( var action tailcfg.SSHAction - // HACK HACK HACK - serverURL := viper.GetString("server_url") - switch rule.Action { case SSHActionAccept: action = sshAccept case SSHActionCheck: - action = sshCheck(serverURL, time.Duration(rule.CheckPeriod)) + action = sshCheck(baseURL, time.Duration(rule.CheckPeriod)) default: return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) } diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index cdf7c131..1c15f732 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -615,7 +615,7 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { require.NoError(t, err) // Compile SSH policy - sshPolicy, err := tt.policy.compileSSHPolicy(users, tt.targetNode.View(), nodes.ViewSlice()) + sshPolicy, err := tt.policy.compileSSHPolicy("unused-server-url", users, tt.targetNode.View(), nodes.ViewSlice()) require.NoError(t, err) if tt.wantEmpty { @@ -691,7 +691,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { err := policy.validate() require.NoError(t, err) - sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, nodeTaggedServer.View(), nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -704,8 +704,11 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { } assert.Equal(t, expectedUsers, rule.SSHUsers) - // Verify check action with session duration - assert.True(t, rule.Action.Accept) + // Verify check action: Accept is false, HoldAndDelegate is set + assert.False(t, rule.Action.Accept) + assert.False(t, rule.Action.Reject) + assert.NotEmpty(t, rule.Action.HoldAndDelegate) + assert.Contains(t, rule.Action.HoldAndDelegate, "/machine/ssh/action/") assert.Equal(t, 24*time.Hour, rule.Action.SessionDuration) } @@ -756,7 +759,7 @@ func TestSSHIntegrationReproduction(t *testing.T) { require.NoError(t, err) // Test SSH policy compilation for node2 (owned by user2, who is in the group) - sshPolicy, err := policy.compileSSHPolicy(users, node2.View(), nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node2.View(), nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -806,7 +809,7 @@ func TestSSHJSONSerialization(t *testing.T) { err := policy.validate() require.NoError(t, err) - sshPolicy, err := policy.compileSSHPolicy(users, node.View(), nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node.View(), nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) @@ -1413,7 +1416,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { // Test for user1's first node node1 := nodes[0].View() - sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -1432,7 +1435,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { // Test for user2's first node node3 := nodes[2].View() - sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) + sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy2) require.Len(t, sshPolicy2.Rules, 1) @@ -1451,7 +1454,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { // Test for tagged node (should have no SSH rules) node5 := nodes[4].View() - sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) + sshPolicy3, err := policy.compileSSHPolicy("unused-server-url", users, node5, nodes.ViewSlice()) require.NoError(t, err) if sshPolicy3 != nil { @@ -1491,7 +1494,7 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { // For user1's node: should allow SSH from user1's devices node1 := nodes[0].View() - sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -1508,7 +1511,7 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { // For user2's node: should have no rules (user1's devices can't match user2's self) node3 := nodes[2].View() - sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) + sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice()) require.NoError(t, err) if sshPolicy2 != nil { @@ -1551,7 +1554,7 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { // For user1's node: should allow SSH from user1's devices only (not user2's) node1 := nodes[0].View() - sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -1568,7 +1571,7 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { // For user3's node: should have no rules (not in group:admins) node5 := nodes[4].View() - sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) + sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node5, nodes.ViewSlice()) require.NoError(t, err) if sshPolicy2 != nil { @@ -1610,7 +1613,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { // For untagged node: should only get principals from other untagged nodes node1 := nodes[0].View() - sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -1628,7 +1631,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { // For tagged node: should get no SSH rules node3 := nodes[2].View() - sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) + sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice()) require.NoError(t, err) if sshPolicy2 != nil { @@ -1671,7 +1674,7 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { // Test 1: Compile for user1's device (should only match autogroup:self destination) node1 := nodes[0].View() - sshPolicy1, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy1, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy1) require.Len(t, sshPolicy1.Rules, 1, "user1's device should have 1 SSH rule (autogroup:self)") @@ -1690,7 +1693,7 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { // Test 2: Compile for router (should only match tag:router destination) routerNode := nodes[3].View() // user2-router - sshPolicyRouter, err := policy.compileSSHPolicy(users, routerNode, nodes.ViewSlice()) + sshPolicyRouter, err := policy.compileSSHPolicy("unused-server-url", users, routerNode, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicyRouter) require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)") diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 74b7ba6a..744f52c7 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -222,7 +222,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { return true, nil } -func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { +func (pm *PolicyManager) SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error) { pm.mu.Lock() defer pm.mu.Unlock() @@ -230,7 +230,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err return sshPol, nil } - sshPol, err := pm.pol.compileSSHPolicy(pm.users, node, pm.nodes) + sshPol, err := pm.pol.compileSSHPolicy(baseURL, pm.users, node, pm.nodes) if err != nil { return nil, fmt.Errorf("compiling SSH policy: %w", err) } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index e421d5bd..f546f7a4 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -851,7 +851,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha // SSHPolicy returns the SSH access policy for a node. func (s *State) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { - return s.polMan.SSHPolicy(node) + return s.polMan.SSHPolicy(s.cfg.ServerURL, node) } // Filter returns the current network filter rules and matches. From e45cf30867c42ed9d9e962c7f1d4e63ee8853c54 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 11 Feb 2026 15:31:06 +0100 Subject: [PATCH 05/17] auth: add /auth dummy, tighten AuthRequest, generalise This commit generalise the "Registration" pipeline to a more general auth pipeline supporting both registrations and general auth requests. This means we have renamed the RegistrationID to AuthID. Fields from AuthRequest has been unexported and made read only. Added dummy /auth endpoints to be filled. Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 3 +- hscontrol/auth.go | 93 ++++++++-------- hscontrol/auth_tags_test.go | 24 ++--- hscontrol/auth_test.go | 118 +++++++++------------ hscontrol/db/db.go | 4 +- hscontrol/db/db_test.go | 4 +- hscontrol/grpcv1.go | 29 +++-- hscontrol/handlers.go | 45 ++++++-- hscontrol/mapper/batcher_test.go | 4 +- hscontrol/oidc.go | 118 +++++++++++++-------- hscontrol/state/state.go | 135 ++++++++++++------------ hscontrol/templates/register_web.go | 2 +- hscontrol/templates_consistency_test.go | 8 +- hscontrol/types/common.go | 100 +++++++++++++----- hscontrol/util/util.go | 6 +- hscontrol/util/util_test.go | 10 +- integration/cli_test.go | 34 +++--- 17 files changed, 403 insertions(+), 334 deletions(-) diff --git a/hscontrol/app.go b/hscontrol/app.go index 4affb6e0..87b37510 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -479,7 +479,8 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux { r.Get("/health", h.HealthHandler) r.Get("/version", h.VersionHandler) r.Get("/key", h.KeyHandler) - r.Get("/register/{registration_id}", h.authProvider.RegisterHandler) + r.Get("/register/{auth_id}", h.authProvider.RegisterHandler) + r.Get("/auth/{auth_id}", h.authProvider.AuthHandler) if provider, ok := h.authProvider.(*AuthProviderOIDC); ok { r.Get("/oidc/callback", provider.OIDCCallbackHandler) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index fdc63461..fd1b231b 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -20,7 +20,9 @@ import ( type AuthProvider interface { RegisterHandler(w http.ResponseWriter, r *http.Request) - AuthURL(regID types.RegistrationID) string + AuthHandler(w http.ResponseWriter, r *http.Request) + RegisterURL(authID types.AuthID) string + AuthURL(authID types.AuthID) string } func (h *Headscale) handleRegister( @@ -261,22 +263,22 @@ func (h *Headscale) waitForFollowup( return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err) } - followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", "")) + followupReg, err := types.AuthIDFromString(strings.ReplaceAll(fu.Path, "/register/", "")) if err != nil { return nil, NewHTTPError(http.StatusUnauthorized, "invalid registration ID", err) } - if reg, ok := h.state.GetRegistrationCacheEntry(followupReg); ok { + if reg, ok := h.state.GetAuthCacheEntry(followupReg); ok { select { case <-ctx.Done(): return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err) - case node := <-reg.Registered: - if node == nil { + case node := <-reg.WaitForRegistration(): + if !node.Valid() { // registration is expired in the cache, instruct the client to try a new registration return h.reqToNewRegisterResponse(req, machineKey) } - return nodeToRegisterResponse(node.View()), nil + return nodeToRegisterResponse(node), nil } } @@ -291,14 +293,14 @@ func (h *Headscale) reqToNewRegisterResponse( req tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - newRegID, err := types.NewRegistrationID() + newAuthID, err := types.NewAuthID() if err != nil { return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err) } // Ensure we have a valid hostname hostname := util.EnsureHostname( - req.Hostinfo, + req.Hostinfo.View(), machineKey.String(), req.NodeKey.String(), ) @@ -307,25 +309,25 @@ func (h *Headscale) reqToNewRegisterResponse( hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{}) hostinfo.Hostname = hostname - nodeToRegister := types.NewRegisterNode( - types.Node{ - Hostname: hostname, - MachineKey: machineKey, - NodeKey: req.NodeKey, - Hostinfo: hostinfo, - LastSeen: new(time.Now()), - }, - ) - - if !req.Expiry.IsZero() { - nodeToRegister.Node.Expiry = &req.Expiry + nodeToRegister := types.Node{ + Hostname: hostname, + MachineKey: machineKey, + NodeKey: req.NodeKey, + Hostinfo: hostinfo, + LastSeen: new(time.Now()), } - log.Info().Msgf("new followup node registration using key: %s", newRegID) - h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister) + if !req.Expiry.IsZero() { + nodeToRegister.Expiry = &req.Expiry + } + + authRegReq := types.NewRegisterAuthRequest(nodeToRegister) + + log.Info().Msgf("new followup node registration using key: %s", newAuthID) + h.state.SetAuthCacheEntry(newAuthID, authRegReq) return &tailcfg.RegisterResponse{ - AuthURL: h.authProvider.AuthURL(newRegID), + AuthURL: h.authProvider.RegisterURL(newAuthID), }, nil } @@ -376,13 +378,6 @@ func (h *Headscale) handleRegisterWithAuthKey( // Send both changes. Empty changes are ignored by Change(). h.Change(changed, routesChange) - // TODO(kradalby): I think this is covered above, but we need to validate that. - // // If policy changed due to node registration, send a separate policy change - // if policyChanged { - // policyChange := change.PolicyChange() - // h.Change(policyChange) - // } - resp := &tailcfg.RegisterResponse{ MachineAuthorized: true, NodeKeyExpired: node.IsExpired(), @@ -404,14 +399,14 @@ func (h *Headscale) handleRegisterInteractive( req tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - registrationId, err := types.NewRegistrationID() + authID, err := types.NewAuthID() if err != nil { return nil, fmt.Errorf("generating registration ID: %w", err) } // Ensure we have a valid hostname hostname := util.EnsureHostname( - req.Hostinfo, + req.Hostinfo.View(), machineKey.String(), req.NodeKey.String(), ) @@ -434,28 +429,28 @@ func (h *Headscale) handleRegisterInteractive( hostinfo.Hostname = hostname - nodeToRegister := types.NewRegisterNode( - types.Node{ - Hostname: hostname, - MachineKey: machineKey, - NodeKey: req.NodeKey, - Hostinfo: hostinfo, - LastSeen: new(time.Now()), - }, - ) - - if !req.Expiry.IsZero() { - nodeToRegister.Node.Expiry = &req.Expiry + nodeToRegister := types.Node{ + Hostname: hostname, + MachineKey: machineKey, + NodeKey: req.NodeKey, + Hostinfo: hostinfo, + LastSeen: new(time.Now()), } - h.state.SetRegistrationCacheEntry( - registrationId, - nodeToRegister, + if !req.Expiry.IsZero() { + nodeToRegister.Expiry = &req.Expiry + } + + authRegReq := types.NewRegisterAuthRequest(nodeToRegister) + + h.state.SetAuthCacheEntry( + authID, + authRegReq, ) - log.Info().Msgf("starting node registration using key: %s", registrationId) + log.Info().Msgf("starting node registration using key: %s", authID) return &tailcfg.RegisterResponse{ - AuthURL: h.authProvider.AuthURL(registrationId), + AuthURL: h.authProvider.RegisterURL(authID), }, nil } diff --git a/hscontrol/auth_tags_test.go b/hscontrol/auth_tags_test.go index e7b74b75..7016af31 100644 --- a/hscontrol/auth_tags_test.go +++ b/hscontrol/auth_tags_test.go @@ -651,8 +651,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { // Step 1: Create user-owned node WITH expiry set clientExpiry := time.Now().Add(24 * time.Hour) - registrationID1 := types.MustRegistrationID() - regEntry1 := types.NewRegisterNode(types.Node{ + registrationID1 := types.MustAuthID() + regEntry1 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey1.Public(), Hostname: "personal-to-tagged", @@ -662,7 +662,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { }, Expiry: &clientExpiry, }) - app.state.SetRegistrationCacheEntry(registrationID1, regEntry1) + app.state.SetAuthCacheEntry(registrationID1, regEntry1) node, _, err := app.state.HandleNodeFromAuthPath( registrationID1, types.UserID(user.ID), nil, "webauth", @@ -673,8 +673,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { // Step 2: Re-auth with tags (Personal → Tagged conversion) nodeKey2 := key.NewNode() - registrationID2 := types.MustRegistrationID() - regEntry2 := types.NewRegisterNode(types.Node{ + registrationID2 := types.MustAuthID() + regEntry2 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey2.Public(), Hostname: "personal-to-tagged", @@ -684,7 +684,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { }, Expiry: &clientExpiry, // Client still sends expiry }) - app.state.SetRegistrationCacheEntry(registrationID2, regEntry2) + app.state.SetAuthCacheEntry(registrationID2, regEntry2) nodeAfter, _, err := app.state.HandleNodeFromAuthPath( registrationID2, types.UserID(user.ID), nil, "webauth", @@ -723,8 +723,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { nodeKey1 := key.NewNode() // Step 1: Create tagged node (expiry should be nil) - registrationID1 := types.MustRegistrationID() - regEntry1 := types.NewRegisterNode(types.Node{ + registrationID1 := types.MustAuthID() + regEntry1 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey1.Public(), Hostname: "tagged-to-personal", @@ -733,7 +733,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { RequestTags: []string{"tag:server"}, // Tagged node }, }) - app.state.SetRegistrationCacheEntry(registrationID1, regEntry1) + app.state.SetAuthCacheEntry(registrationID1, regEntry1) node, _, err := app.state.HandleNodeFromAuthPath( registrationID1, types.UserID(user.ID), nil, "webauth", @@ -745,8 +745,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { // Step 2: Re-auth with empty tags (Tagged → Personal conversion) nodeKey2 := key.NewNode() clientExpiry := time.Now().Add(48 * time.Hour) - registrationID2 := types.MustRegistrationID() - regEntry2 := types.NewRegisterNode(types.Node{ + registrationID2 := types.MustAuthID() + regEntry2 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey2.Public(), Hostname: "tagged-to-personal", @@ -756,7 +756,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { }, Expiry: &clientExpiry, // Client requests expiry }) - app.state.SetRegistrationCacheEntry(registrationID2, regEntry2) + app.state.SetAuthCacheEntry(registrationID2, regEntry2) nodeAfter, _, err := app.state.HandleNodeFromAuthPath( registrationID2, types.UserID(user.ID), nil, "webauth", diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index d28ed565..8215b07c 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -676,28 +676,23 @@ func TestAuthenticationFlows(t *testing.T) { { name: "followup_registration_success", setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper - regID, err := types.NewRegistrationID() + regID, err := types.NewAuthID() if err != nil { return "", err } - registered := make(chan *types.Node, 1) - nodeToRegister := types.RegisterNode{ - Node: types.Node{ - Hostname: "followup-success-node", - }, - Registered: registered, - } - app.state.SetRegistrationCacheEntry(regID, nodeToRegister) + nodeToRegister := types.NewRegisterAuthRequest(types.Node{ + Hostname: "followup-success-node", + }) + app.state.SetAuthCacheEntry(regID, nodeToRegister) - // Simulate successful registration - send to buffered channel - // The channel is buffered (size 1), so this can complete immediately - // and handleRegister will receive the value when it starts waiting + // Simulate successful registration + // handleRegister will receive the value when it starts waiting go func() { user := app.state.CreateUserForTest("followup-user") node := app.state.CreateNodeForTest(user, "followup-success-node") - registered <- node + nodeToRegister.FinishRegistration(node.View()) }() return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil @@ -723,20 +718,16 @@ func TestAuthenticationFlows(t *testing.T) { { name: "followup_registration_timeout", setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper - regID, err := types.NewRegistrationID() + regID, err := types.NewAuthID() if err != nil { return "", err } - registered := make(chan *types.Node, 1) - nodeToRegister := types.RegisterNode{ - Node: types.Node{ - Hostname: "followup-timeout-node", - }, - Registered: registered, - } - app.state.SetRegistrationCacheEntry(regID, nodeToRegister) - // Don't send anything on channel - will timeout + nodeToRegister := types.NewRegisterAuthRequest(types.Node{ + Hostname: "followup-timeout-node", + }) + app.state.SetAuthCacheEntry(regID, nodeToRegister) + // Don't call FinishRegistration - will timeout return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil }, @@ -1345,24 +1336,19 @@ func TestAuthenticationFlows(t *testing.T) { { name: "followup_registration_node_nil_response", setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper - regID, err := types.NewRegistrationID() + regID, err := types.NewAuthID() if err != nil { return "", err } - registered := make(chan *types.Node, 1) - nodeToRegister := types.RegisterNode{ - Node: types.Node{ - Hostname: "nil-response-node", - }, - Registered: registered, - } - app.state.SetRegistrationCacheEntry(regID, nodeToRegister) + nodeToRegister := types.NewRegisterAuthRequest(types.Node{ + Hostname: "nil-response-node", + }) + app.state.SetAuthCacheEntry(regID, nodeToRegister) - // Simulate registration that returns nil (cache expired during auth) - // The channel is buffered (size 1), so this can complete immediately + // Simulate registration that returns empty NodeView (cache expired during auth) go func() { - registered <- nil // Nil indicates cache expiry + nodeToRegister.FinishRegistration(types.NodeView{}) // Empty view indicates cache expiry }() return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil @@ -1815,7 +1801,7 @@ func TestAuthenticationFlows(t *testing.T) { setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper // Generate a registration ID that doesn't exist in cache // This simulates an expired/missing cache entry - regID, err := types.NewRegistrationID() + regID, err := types.NewAuthID() if err != nil { return "", err } @@ -1847,11 +1833,11 @@ func TestAuthenticationFlows(t *testing.T) { // Extract and validate the new registration ID exists in cache newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/") - newRegID, err := types.RegistrationIDFromString(newRegIDStr) + newRegID, err := types.AuthIDFromString(newRegIDStr) assert.NoError(t, err, "should be able to parse new registration ID") //nolint:testifylint // inside closure // Verify new registration entry exists in cache - _, found := app.state.GetRegistrationCacheEntry(newRegID) + _, found := app.state.GetAuthCacheEntry(newRegID) assert.True(t, found, "new registration should exist in cache") }, }, @@ -2300,7 +2286,7 @@ func TestAuthenticationFlows(t *testing.T) { require.NoError(t, err) // Verify cache entry exists - cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID) + cacheEntry, found := app.state.GetAuthCacheEntry(registrationID) assert.True(t, found, "registration cache entry should exist initially") assert.NotNil(t, cacheEntry) @@ -2315,7 +2301,7 @@ func TestAuthenticationFlows(t *testing.T) { assert.Error(t, err, "should fail with invalid user ID") //nolint:testifylint // inside closure, uses assert pattern // Cache entry should still exist after auth error (for retry scenarios) - _, stillFound := app.state.GetRegistrationCacheEntry(registrationID) + _, stillFound := app.state.GetAuthCacheEntry(registrationID) assert.True(t, stillFound, "registration cache entry should still exist after auth error for potential retry") }, }, @@ -2375,8 +2361,8 @@ func TestAuthenticationFlows(t *testing.T) { assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs") // Both cache entries should exist simultaneously - _, found1 := app.state.GetRegistrationCacheEntry(regID1) - _, found2 := app.state.GetRegistrationCacheEntry(regID2) + _, found1 := app.state.GetAuthCacheEntry(regID1) + _, found2 := app.state.GetAuthCacheEntry(regID2) assert.True(t, found1, "first registration cache entry should exist") assert.True(t, found2, "second registration cache entry should exist") @@ -2427,8 +2413,8 @@ func TestAuthenticationFlows(t *testing.T) { require.NoError(t, err) // Verify both exist - _, found1 := app.state.GetRegistrationCacheEntry(regID1) - _, found2 := app.state.GetRegistrationCacheEntry(regID2) + _, found1 := app.state.GetAuthCacheEntry(regID1) + _, found2 := app.state.GetAuthCacheEntry(regID2) assert.True(t, found1, "first cache entry should exist") assert.True(t, found2, "second cache entry should exist") @@ -2490,7 +2476,7 @@ func TestAuthenticationFlows(t *testing.T) { } // First registration should still be in cache (not completed) - _, stillFound := app.state.GetRegistrationCacheEntry(regID1) + _, stillFound := app.state.GetAuthCacheEntry(regID1) assert.True(t, stillFound, "first registration should still be pending") }, }, @@ -2601,7 +2587,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { var ( initialResp *tailcfg.RegisterResponse authURL string - registrationID types.RegistrationID + registrationID types.AuthID finalResp *tailcfg.RegisterResponse err error ) @@ -2629,10 +2615,10 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { if step.expectCacheEntry { // Verify registration cache entry was created - cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID) + cacheEntry, found := app.state.GetAuthCacheEntry(registrationID) require.True(t, found, "registration cache entry should exist") require.NotNil(t, cacheEntry, "cache entry should not be nil") - require.Equal(t, req.NodeKey, cacheEntry.Node.NodeKey, "cache entry should have correct node key") + require.Equal(t, req.NodeKey, cacheEntry.Node().NodeKey(), "cache entry should have correct node key") } case stepTypeAuthCompletion: @@ -2692,7 +2678,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { // Check cache cleanup expectation for this step if step.expectCacheEntry == false && registrationID != "" { // Verify cache entry was cleaned up - _, found := app.state.GetRegistrationCacheEntry(registrationID) + _, found := app.state.GetAuthCacheEntry(registrationID) require.False(t, found, "registration cache entry should be cleaned up after step: %s", step.stepType) } } @@ -2714,7 +2700,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { } // extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL. -func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) { +func extractRegistrationIDFromAuthURL(authURL string) (types.AuthID, error) { // AuthURL format: "http://localhost/register/abc123" const registerPrefix = "/register/" @@ -2725,7 +2711,7 @@ func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, err idStr := authURL[idx+len(registerPrefix):] - return types.RegistrationIDFromString(idStr) + return types.AuthIDFromString(idStr) } // validateCompleteRegistrationResponse performs comprehensive validation of a registration response. @@ -3583,8 +3569,8 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) { nodeKey := key.NewNode() // Simulate a registration cache entry (as would be created during web auth) - registrationID := types.MustRegistrationID() - regEntry := types.NewRegisterNode(types.Node{ + registrationID := types.MustAuthID() + regEntry := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "webauth-tags-node", @@ -3593,7 +3579,7 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) { RequestTags: []string{"tag:unauthorized"}, // This tag is not in policy }, }) - app.state.SetRegistrationCacheEntry(registrationID, regEntry) + app.state.SetAuthCacheEntry(registrationID, regEntry) // Complete the web auth - should fail because tag is unauthorized _, _, err := app.state.HandleNodeFromAuthPath( @@ -3646,8 +3632,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { nodeKey1 := key.NewNode() // Step 1: Initial registration with tags - registrationID1 := types.MustRegistrationID() - regEntry1 := types.NewRegisterNode(types.Node{ + registrationID1 := types.MustAuthID() + regEntry1 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey1.Public(), Hostname: "reauth-untag-node", @@ -3656,7 +3642,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { RequestTags: []string{"tag:valid-owned", "tag:second"}, }, }) - app.state.SetRegistrationCacheEntry(registrationID1, regEntry1) + app.state.SetAuthCacheEntry(registrationID1, regEntry1) // Complete initial registration with tags node, _, err := app.state.HandleNodeFromAuthPath( @@ -3673,8 +3659,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { // Step 2: Reauth with EMPTY tags to untag nodeKey2 := key.NewNode() // New node key for reauth - registrationID2 := types.MustRegistrationID() - regEntry2 := types.NewRegisterNode(types.Node{ + registrationID2 := types.MustAuthID() + regEntry2 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), // Same machine key NodeKey: nodeKey2.Public(), // Different node key (rotation) Hostname: "reauth-untag-node", @@ -3683,7 +3669,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { RequestTags: []string{}, // EMPTY - should untag }, }) - app.state.SetRegistrationCacheEntry(registrationID2, regEntry2) + app.state.SetAuthCacheEntry(registrationID2, regEntry2) // Complete reauth with empty tags nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath( @@ -3759,8 +3745,8 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) { // Step 2: Reauth via web auth with EMPTY tags to transition to user-owned nodeKey2 := key.NewNode() // New node key for reauth - registrationID := types.MustRegistrationID() - regEntry := types.NewRegisterNode(types.Node{ + registrationID := types.MustAuthID() + regEntry := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), // Same machine key NodeKey: nodeKey2.Public(), // Different node key (rotation) Hostname: "authkey-tagged-node", @@ -3769,7 +3755,7 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) { RequestTags: []string{}, // EMPTY - should untag }, }) - app.state.SetRegistrationCacheEntry(registrationID, regEntry) + app.state.SetAuthCacheEntry(registrationID, regEntry) // Complete reauth with empty tags nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath( @@ -3958,8 +3944,8 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) { // Step 4: Re-register the node to alice via HandleNodeFromAuthPath // This is what happens when running: headscale nodes register --user alice --key ... nodeKey2 := key.NewNode() - registrationID := types.MustRegistrationID() - regEntry := types.NewRegisterNode(types.Node{ + registrationID := types.MustAuthID() + regEntry := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), // Same machine key as the tagged node NodeKey: nodeKey2.Public(), Hostname: "tagged-orphan-node", @@ -3968,7 +3954,7 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) { RequestTags: []string{}, // Empty - transition to user-owned }, }) - app.state.SetRegistrationCacheEntry(registrationID, regEntry) + app.state.SetAuthCacheEntry(registrationID, regEntry) // This should NOT panic - before the fix, this would panic with: // panic: runtime error: invalid memory address or nil pointer dereference diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 6841f446..69f71e36 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -47,7 +47,7 @@ const ( type HSDatabase struct { DB *gorm.DB cfg *types.Config - regCache *zcache.Cache[types.RegistrationID, types.RegisterNode] + regCache *zcache.Cache[types.AuthID, types.AuthRequest] } // NewHeadscaleDatabase creates a new database connection and runs migrations. @@ -56,7 +56,7 @@ type HSDatabase struct { //nolint:gocyclo // complex database initialization with many migrations func NewHeadscaleDatabase( cfg *types.Config, - regCache *zcache.Cache[types.RegistrationID, types.RegisterNode], + regCache *zcache.Cache[types.AuthID, types.AuthRequest], ) (*HSDatabase, error) { dbConn, err := openDB(cfg.Database) if err != nil { diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 3c687b39..151d9966 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -162,8 +162,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) { } } -func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] { - return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour) +func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] { + return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour) } func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error { diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 073c6677..d7c192a6 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -247,7 +247,7 @@ func (api headscaleV1APIServer) RegisterNode( Str(zf.RegistrationKey, registrationKey). Msg("registering node") - registrationId, err := types.RegistrationIDFromString(request.GetKey()) + registrationId, err := types.AuthIDFromString(request.GetKey()) if err != nil { return nil, err } @@ -780,33 +780,32 @@ func (api headscaleV1APIServer) DebugCreateNode( Hostname: request.GetName(), } - registrationId, err := types.RegistrationIDFromString(request.GetKey()) + registrationId, err := types.AuthIDFromString(request.GetKey()) if err != nil { return nil, err } - newNode := types.NewRegisterNode( - types.Node{ - NodeKey: key.NewNode().Public(), - MachineKey: key.NewMachine().Public(), - Hostname: request.GetName(), - User: user, + newNode := types.Node{ + NodeKey: key.NewNode().Public(), + MachineKey: key.NewMachine().Public(), + Hostname: request.GetName(), + User: user, - Expiry: &time.Time{}, - LastSeen: &time.Time{}, + Expiry: &time.Time{}, + LastSeen: &time.Time{}, - Hostinfo: &hostinfo, - }, - ) + Hostinfo: &hostinfo, + } log.Debug(). Caller(). Str("registration_id", registrationId.String()). Msg("adding debug machine via CLI, appending to registration cache") - api.h.state.SetRegistrationCacheEntry(registrationId, newNode) + authRegReq := types.NewRegisterAuthRequest(newNode) + api.h.state.SetAuthCacheEntry(registrationId, authRegReq) - return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil + return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil } func (api headscaleV1APIServer) Health( diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 7c45f1ec..b7aa8460 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -11,7 +11,6 @@ import ( "strings" "time" - "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/assets" "github.com/juanfont/headscale/hscontrol/templates" "github.com/juanfont/headscale/hscontrol/types" @@ -245,11 +244,41 @@ func NewAuthProviderWeb(serverURL string) *AuthProviderWeb { } } -func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string { +func (a *AuthProviderWeb) RegisterURL(authID types.AuthID) string { return fmt.Sprintf( "%s/register/%s", strings.TrimSuffix(a.serverURL, "/"), - registrationId.String()) + authID.String()) +} + +func (a *AuthProviderWeb) AuthURL(authID types.AuthID) string { + return fmt.Sprintf( + "%s/auth/%s", + strings.TrimSuffix(a.serverURL, "/"), + authID.String()) +} + +func (a *AuthProviderWeb) AuthHandler( + writer http.ResponseWriter, + req *http.Request, +) { +} + +func authIDFromRequest(req *http.Request) (types.AuthID, error) { + registrationId, err := urlParam[types.AuthID](req, "auth_id") + if err != nil { + return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err)) + } + + // We need to make sure we dont open for XSS style injections, if the parameter that + // is passed as a key is not parsable/validated as a NodePublic key, then fail to render + // the template and log an error. + err = registrationId.Validate() + if err != nil { + return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err)) + } + + return registrationId, nil } // RegisterHandler shows a simple message in the browser to point to the CLI @@ -261,15 +290,9 @@ func (a *AuthProviderWeb) RegisterHandler( writer http.ResponseWriter, req *http.Request, ) { - vars := mux.Vars(req) - registrationIdStr := vars["registration_id"] - - // We need to make sure we dont open for XSS style injections, if the parameter that - // is passed as a key is not parsable/validated as a NodePublic key, then fail to render - // the template and log an error. - registrationId, err := types.RegistrationIDFromString(registrationIdStr) + registrationId, err := authIDFromRequest(req) if err != nil { - httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err)) + httpError(writer, err) return } diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 9e544633..6f3fbccb 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -95,8 +95,8 @@ var allBatcherFunctions = []batcherTestCase{ } // emptyCache creates an empty registration cache for testing. -func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] { - return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour) +func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] { + return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour) } // Test configuration constants. diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 9d284921..2bc62fa9 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -12,7 +12,6 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/templates" "github.com/juanfont/headscale/hscontrol/types" @@ -26,8 +25,8 @@ import ( const ( randomByteSize = 16 defaultOAuthOptionsCount = 3 - registerCacheExpiration = time.Minute * 15 - registerCacheCleanup = time.Minute * 20 + authCacheExpiration = time.Minute * 15 + authCacheCleanup = time.Minute * 20 ) var ( @@ -44,17 +43,21 @@ var ( errOIDCUnverifiedEmail = errors.New("authenticated principal has an unverified email") ) -// RegistrationInfo contains both machine key and verifier information for OIDC validation. -type RegistrationInfo struct { - RegistrationID types.RegistrationID - Verifier *string +// AuthInfo contains both auth ID and verifier information for OIDC validation. +type AuthInfo struct { + AuthID types.AuthID + Verifier *string + Registration bool } type AuthProviderOIDC struct { - h *Headscale - serverURL string - cfg *types.OIDCConfig - registrationCache *zcache.Cache[string, RegistrationInfo] + h *Headscale + serverURL string + cfg *types.OIDCConfig + + // authCache holds auth information between + // the auth and the callback steps. + authCache *zcache.Cache[string, AuthInfo] oidcProvider *oidc.Provider oauth2Config *oauth2.Config @@ -81,45 +84,63 @@ func NewAuthProviderOIDC( Scopes: cfg.Scope, } - registrationCache := zcache.New[string, RegistrationInfo]( - registerCacheExpiration, - registerCacheCleanup, + authCache := zcache.New[string, AuthInfo]( + authCacheExpiration, + authCacheCleanup, ) return &AuthProviderOIDC{ - h: h, - serverURL: serverURL, - cfg: cfg, - registrationCache: registrationCache, + h: h, + serverURL: serverURL, + cfg: cfg, + authCache: authCache, oidcProvider: oidcProvider, oauth2Config: oauth2Config, }, nil } -func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string { +func (a *AuthProviderOIDC) AuthURL(authID types.AuthID) string { + return fmt.Sprintf( + "%s/auth/%s", + strings.TrimSuffix(a.serverURL, "/"), + authID.String()) +} + +func (a *AuthProviderOIDC) AuthHandler( + writer http.ResponseWriter, + req *http.Request, +) { + a.authHandler(writer, req, false) +} + +func (a *AuthProviderOIDC) RegisterURL(authID types.AuthID) string { return fmt.Sprintf( "%s/register/%s", strings.TrimSuffix(a.serverURL, "/"), - registrationID.String()) + authID.String()) } // RegisterHandler registers the OIDC callback handler with the given router. // It puts NodeKey in cache so the callback can retrieve it using the oidc state param. -// Listens in /register/:registration_id. +// Listens in /register/:auth_id. func (a *AuthProviderOIDC) RegisterHandler( writer http.ResponseWriter, req *http.Request, ) { - vars := mux.Vars(req) - registrationIdStr := vars["registration_id"] + a.authHandler(writer, req, true) +} - // We need to make sure we dont open for XSS style injections, if the parameter that - // is passed as a key is not parsable/validated as a NodePublic key, then fail to render - // the template and log an error. - registrationId, err := types.RegistrationIDFromString(registrationIdStr) +// authHandler takes an incoming request that needs to be authenticated and +// validates and prepares it for the OIDC flow. +func (a *AuthProviderOIDC) authHandler( + writer http.ResponseWriter, + req *http.Request, + registration bool, +) { + authID, err := authIDFromRequest(req) if err != nil { - httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err)) + httpError(writer, err) return } @@ -137,9 +158,9 @@ func (a *AuthProviderOIDC) RegisterHandler( return } - // Initialize registration info with machine key - registrationInfo := RegistrationInfo{ - RegistrationID: registrationId, + registrationInfo := AuthInfo{ + AuthID: authID, + Registration: registration, } extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount) @@ -167,7 +188,7 @@ func (a *AuthProviderOIDC) RegisterHandler( extras = append(extras, oidc.Nonce(nonce)) // Cache the registration info - a.registrationCache.Set(state, registrationInfo) + a.authCache.Set(state, registrationInfo) authURL := a.oauth2Config.AuthCodeURL(state, extras...) log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL) @@ -302,16 +323,22 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // If the node exists, then the node should be reauthenticated, // if the node does not exist, and the machine key exists, then // this is a new node that should be registered. - registrationId := a.getRegistrationIDFromState(state) + authInfo := a.getAuthInfoFromState(state) + if authInfo == nil { + log.Debug().Caller().Str("state", state).Msg("state not found in cache, login session may have expired") + httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) - // Register the node if it does not exist. - if registrationId != nil { + return + } + + // If this is a registration flow, then we need to register the node. + if authInfo.Registration { verb := "Reauthenticated" - newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) + newNode, err := a.handleRegistration(user, authInfo.AuthID, nodeExpiry) if err != nil { if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { - log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed") + log.Debug().Caller().Str("registration_id", authInfo.AuthID.String()).Msg("registration session expired before authorization completed") httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err)) return @@ -339,9 +366,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - // Neither node nor machine key was found in the state cache meaning - // that we could not reauth nor register the node. - httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) + // TODO(kradalby): handle login flow (without registration) if needed. + // We need to send an update here to whatever might be waiting for this auth flow. } func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time { @@ -374,7 +400,7 @@ func (a *AuthProviderOIDC) getOauth2Token( var exchangeOpts []oauth2.AuthCodeOption if a.cfg.PKCE.Enabled { - regInfo, ok := a.registrationCache.Get(state) + regInfo, ok := a.authCache.Get(state) if !ok { return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo) } @@ -507,14 +533,14 @@ func doOIDCAuthorization( return nil } -// getRegistrationIDFromState retrieves the registration ID from the state. -func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID { - regInfo, ok := a.registrationCache.Get(state) +// getAuthInfoFromState retrieves the registration ID from the state. +func (a *AuthProviderOIDC) getAuthInfoFromState(state string) *AuthInfo { + authInfo, ok := a.authCache.Get(state) if !ok { return nil } - return ®Info.RegistrationID + return &authInfo } func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( @@ -562,7 +588,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( func (a *AuthProviderOIDC) handleRegistration( user *types.User, - registrationID types.RegistrationID, + registrationID types.AuthID, expiry time.Time, ) (bool, error) { node, nodeChange, err := a.h.state.HandleNodeFromAuthPath( diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index f546f7a4..eb927750 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -82,8 +82,10 @@ type State struct { derpMap atomic.Pointer[tailcfg.DERPMap] // polMan handles policy evaluation and management polMan policy.PolicyManager - // registrationCache caches node registration data to reduce database load - registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode] + + // authCache caches any pending authentication requests, from either auth type (Web and OIDC). + authCache *zcache.Cache[types.AuthID, types.AuthRequest] + // primaryRoutes tracks primary route assignments for nodes primaryRoutes *routes.PrimaryRoutes } @@ -101,20 +103,20 @@ func NewState(cfg *types.Config) (*State, error) { cacheCleanup = cfg.Tuning.RegisterCacheCleanup } - registrationCache := zcache.New[types.RegistrationID, types.RegisterNode]( + authCache := zcache.New[types.AuthID, types.AuthRequest]( cacheExpiration, cacheCleanup, ) - registrationCache.OnEvicted( - func(id types.RegistrationID, rn types.RegisterNode) { - rn.SendAndClose(nil) + authCache.OnEvicted( + func(id types.AuthID, rn types.AuthRequest) { + rn.FinishRegistration(types.NodeView{}) }, ) db, err := hsdb.NewHeadscaleDatabase( cfg, - registrationCache, + authCache, ) if err != nil { return nil, fmt.Errorf("initializing database: %w", err) @@ -178,12 +180,12 @@ func NewState(cfg *types.Config) (*State, error) { return &State{ cfg: cfg, - db: db, - ipAlloc: ipAlloc, - polMan: polMan, - registrationCache: registrationCache, - primaryRoutes: routes.New(), - nodeStore: nodeStore, + db: db, + ipAlloc: ipAlloc, + polMan: polMan, + authCache: authCache, + primaryRoutes: routes.New(), + nodeStore: nodeStore, }, nil } @@ -1042,9 +1044,9 @@ func (s *State) DeletePreAuthKey(id uint64) error { return s.db.DeletePreAuthKey(id) } -// GetRegistrationCacheEntry retrieves a node registration from cache. -func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.RegisterNode, bool) { - entry, found := s.registrationCache.Get(id) +// GetAuthCacheEntry retrieves a node registration from cache. +func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) { + entry, found := s.authCache.Get(id) if !found { return nil, false } @@ -1052,26 +1054,24 @@ func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.Regis return &entry, true } -// SetRegistrationCacheEntry stores a node registration in cache. -func (s *State) SetRegistrationCacheEntry(id types.RegistrationID, entry types.RegisterNode) { - s.registrationCache.Set(id, entry) +// SetAuthCacheEntry stores a node registration in cache. +func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) { + s.authCache.Set(id, entry) } // logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname. -func logHostinfoValidation(machineKey, nodeKey, username, hostname string, hostinfo *tailcfg.Hostinfo) { - if hostinfo == nil { +func logHostinfoValidation(nv types.NodeView, username, hostname string) { + if !nv.Hostinfo().Valid() { log.Warn(). Caller(). - Str(zf.MachineKey, machineKey). - Str(zf.NodeKey, nodeKey). + EmbedObject(nv). Str(zf.UserName, username). Str(zf.GeneratedHostname, hostname). Msg("Registration had nil hostinfo, generated default hostname") - } else if hostinfo.Hostname == "" { + } else if nv.Hostinfo().Hostname() == "" { log.Warn(). Caller(). - Str(zf.MachineKey, machineKey). - Str(zf.NodeKey, nodeKey). + EmbedObject(nv). Str(zf.UserName, username). Str(zf.GeneratedHostname, hostname). Msg("Registration had empty hostname, generated default") @@ -1113,7 +1113,7 @@ type authNodeUpdateParams struct { // Node to update; must be valid and in NodeStore. ExistingNode types.NodeView // Client data: keys, hostinfo, endpoints. - RegEntry *types.RegisterNode + RegEntry *types.AuthRequest // Pre-validated hostinfo; NetInfo preserved from ExistingNode. ValidHostinfo *tailcfg.Hostinfo // Hostname from hostinfo, or generated from keys if client omits it. @@ -1132,6 +1132,7 @@ type authNodeUpdateParams struct { // an existing node. It updates the node in NodeStore, processes RequestTags, and // persists changes to the database. func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) { + regNv := params.RegEntry.Node() // Log the operation type if params.IsConvertFromTag { log.Info(). @@ -1140,16 +1141,16 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView Msg("Converting tagged node to user-owned node") } else { log.Info(). - EmbedObject(params.ExistingNode). - Interface("hostinfo", params.RegEntry.Node.Hostinfo). + Object("existing", params.ExistingNode). + Object("incoming", regNv). Msg("Updating existing node registration via reauth") } // Process RequestTags during reauth (#2979) // Due to json:",omitempty", we treat empty/nil as "clear tags" var requestTags []string - if params.RegEntry.Node.Hostinfo != nil { - requestTags = params.RegEntry.Node.Hostinfo.RequestTags + if regNv.Hostinfo().Valid() { + requestTags = regNv.Hostinfo().RequestTags().AsSlice() } oldTags := params.ExistingNode.Tags().AsSlice() @@ -1167,8 +1168,8 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView // Update existing node in NodeStore - validation passed, safe to mutate updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) { - node.NodeKey = params.RegEntry.Node.NodeKey - node.DiscoKey = params.RegEntry.Node.DiscoKey + node.NodeKey = regNv.NodeKey() + node.DiscoKey = regNv.DiscoKey() node.Hostname = params.Hostname // Preserve NetInfo from existing node when re-registering @@ -1179,7 +1180,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView params.ValidHostinfo, ) - node.Endpoints = params.RegEntry.Node.Endpoints + node.Endpoints = regNv.Endpoints().AsSlice() node.IsOnline = new(false) node.LastSeen = new(time.Now()) @@ -1188,7 +1189,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView if params.IsConvertFromTag { node.RegisterMethod = params.RegisterMethod } else { - node.RegisterMethod = params.RegEntry.Node.RegisterMethod + node.RegisterMethod = regNv.RegisterMethod() } // Track tagged status BEFORE processing tags @@ -1208,7 +1209,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView if params.Expiry != nil { node.Expiry = params.Expiry } else { - node.Expiry = params.RegEntry.Node.Expiry + node.Expiry = regNv.Expiry().Clone() } case !wasTagged && isTagged: // Personal → Tagged: clear expiry (tagged nodes don't expire) @@ -1218,14 +1219,14 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView if params.Expiry != nil { node.Expiry = params.Expiry } else { - node.Expiry = params.RegEntry.Node.Expiry + node.Expiry = regNv.Expiry().Clone() } case !isTagged: // Personal → Personal: update expiry from client if params.Expiry != nil { node.Expiry = params.Expiry } else { - node.Expiry = params.RegEntry.Node.Expiry + node.Expiry = regNv.Expiry().Clone() } } // Tagged → Tagged: keep existing expiry (nil) - no action needed @@ -1511,13 +1512,13 @@ func (s *State) processReauthTags( // HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC). func (s *State) HandleNodeFromAuthPath( - registrationID types.RegistrationID, + authID types.AuthID, userID types.UserID, expiry *time.Time, registrationMethod string, ) (types.NodeView, change.Change, error) { // Get the registration entry from cache - regEntry, ok := s.GetRegistrationCacheEntry(registrationID) + regEntry, ok := s.GetAuthCacheEntry(authID) if !ok { return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache } @@ -1530,25 +1531,27 @@ func (s *State) HandleNodeFromAuthPath( // Ensure we have a valid hostname from the registration cache entry hostname := util.EnsureHostname( - regEntry.Node.Hostinfo, - regEntry.Node.MachineKey.String(), - regEntry.Node.NodeKey.String(), + regEntry.Node().Hostinfo(), + regEntry.Node().MachineKey().String(), + regEntry.Node().NodeKey().String(), ) // Ensure we have valid hostinfo - validHostinfo := cmp.Or(regEntry.Node.Hostinfo, &tailcfg.Hostinfo{}) - validHostinfo.Hostname = hostname + hostinfo := &tailcfg.Hostinfo{} + if regEntry.Node().Hostinfo().Valid() { + hostinfo = regEntry.Node().Hostinfo().AsStruct() + } + + hostinfo.Hostname = hostname logHostinfoValidation( - regEntry.Node.MachineKey.ShortString(), - regEntry.Node.NodeKey.String(), + regEntry.Node(), user.Name, hostname, - regEntry.Node.Hostinfo, ) // Lookup existing nodes - machineKey := regEntry.Node.MachineKey + machineKey := regEntry.Node().MachineKey() existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID)) existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey) @@ -1562,7 +1565,7 @@ func (s *State) HandleNodeFromAuthPath( // Create logger with common fields for all auth operations logger := log.With(). - Str(zf.RegistrationID, registrationID.String()). + Str(zf.RegistrationID, authID.String()). Str(zf.UserName, user.Name). Str(zf.MachineKey, machineKey.ShortString()). Str(zf.Method, registrationMethod). @@ -1571,7 +1574,7 @@ func (s *State) HandleNodeFromAuthPath( // Common params for update operations updateParams := authNodeUpdateParams{ RegEntry: regEntry, - ValidHostinfo: validHostinfo, + ValidHostinfo: hostinfo, Hostname: hostname, User: user, Expiry: expiry, @@ -1605,7 +1608,7 @@ func (s *State) HandleNodeFromAuthPath( Msg("Creating new node for different user (same machine key exists for another user)") finalNode, err = s.createNewNodeFromAuth( - logger, user, regEntry, hostname, validHostinfo, + logger, user, regEntry, hostname, hostinfo, expiry, registrationMethod, existingNodeAnyUser, ) if err != nil { @@ -1613,7 +1616,7 @@ func (s *State) HandleNodeFromAuthPath( } } else { finalNode, err = s.createNewNodeFromAuth( - logger, user, regEntry, hostname, validHostinfo, + logger, user, regEntry, hostname, hostinfo, expiry, registrationMethod, types.NodeView{}, ) if err != nil { @@ -1622,10 +1625,10 @@ func (s *State) HandleNodeFromAuthPath( } // Signal to waiting clients - regEntry.SendAndClose(finalNode.AsStruct()) + regEntry.FinishRegistration(finalNode) // Delete from registration cache - s.registrationCache.Delete(registrationID) + s.authCache.Delete(authID) // Update policy managers usersChange, err := s.updatePolicyManagerUsers() @@ -1654,7 +1657,7 @@ func (s *State) HandleNodeFromAuthPath( func (s *State) createNewNodeFromAuth( logger zerolog.Logger, user *types.User, - regEntry *types.RegisterNode, + regEntry *types.AuthRequest, hostname string, validHostinfo *tailcfg.Hostinfo, expiry *time.Time, @@ -1667,13 +1670,13 @@ func (s *State) createNewNodeFromAuth( return s.createAndSaveNewNode(newNodeParams{ User: *user, - MachineKey: regEntry.Node.MachineKey, - NodeKey: regEntry.Node.NodeKey, - DiscoKey: regEntry.Node.DiscoKey, + MachineKey: regEntry.Node().MachineKey(), + NodeKey: regEntry.Node().NodeKey(), + DiscoKey: regEntry.Node().DiscoKey(), Hostname: hostname, Hostinfo: validHostinfo, - Endpoints: regEntry.Node.Endpoints, - Expiry: cmp.Or(expiry, regEntry.Node.Expiry), + Endpoints: regEntry.Node().Endpoints().AsSlice(), + Expiry: cmp.Or(expiry, regEntry.Node().Expiry().Clone()), RegisterMethod: registrationMethod, ExistingNodeForNetinfo: existingNodeForNetinfo, }) @@ -1759,7 +1762,7 @@ func (s *State) HandleNodeFromPreAuthKey( // Ensure we have a valid hostname - handle nil/empty cases hostname := util.EnsureHostname( - regReq.Hostinfo, + regReq.Hostinfo.View(), machineKey.String(), regReq.NodeKey.String(), ) @@ -1768,14 +1771,6 @@ func (s *State) HandleNodeFromPreAuthKey( validHostinfo := cmp.Or(regReq.Hostinfo, &tailcfg.Hostinfo{}) validHostinfo.Hostname = hostname - logHostinfoValidation( - machineKey.ShortString(), - regReq.NodeKey.ShortString(), - pakUsername(), - hostname, - regReq.Hostinfo, - ) - log.Debug(). Caller(). Str(zf.NodeName, hostname). diff --git a/hscontrol/templates/register_web.go b/hscontrol/templates/register_web.go index 829af7fb..cdede03b 100644 --- a/hscontrol/templates/register_web.go +++ b/hscontrol/templates/register_web.go @@ -7,7 +7,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" ) -func RegisterWeb(registrationID types.RegistrationID) *elem.Element { +func RegisterWeb(registrationID types.AuthID) *elem.Element { return HtmlStructure( elem.Title(nil, elem.Text("Registration - Headscale")), mdTypesetBody( diff --git a/hscontrol/templates_consistency_test.go b/hscontrol/templates_consistency_test.go index 369639cc..0464fb88 100644 --- a/hscontrol/templates_consistency_test.go +++ b/hscontrol/templates_consistency_test.go @@ -21,7 +21,7 @@ func TestTemplateHTMLConsistency(t *testing.T) { }, { name: "Register Web", - html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(), }, { name: "Windows Config", @@ -77,7 +77,7 @@ func TestTemplateModernHTMLFeatures(t *testing.T) { }, { name: "Register Web", - html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(), }, { name: "Windows Config", @@ -125,7 +125,7 @@ func TestTemplateExternalLinkSecurity(t *testing.T) { }, { name: "Register Web", - html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(), externalURLs: []string{}, // No external links }, { @@ -190,7 +190,7 @@ func TestTemplateAccessibilityAttributes(t *testing.T) { }, { name: "Register Web", - html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(), }, { name: "Windows Config", diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index d852753e..66bbf619 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -22,8 +22,8 @@ const ( // Common errors. var ( - ErrCannotParsePrefix = errors.New("cannot parse prefix") - ErrInvalidRegistrationIDLength = errors.New("registration ID has invalid length") + ErrCannotParsePrefix = errors.New("cannot parse prefix") + ErrInvalidAuthIDLength = errors.New("registration ID has invalid length") ) type StateUpdateType int @@ -159,21 +159,21 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate { } } -const RegistrationIDLength = 24 +const AuthIDLength = 24 -type RegistrationID string +type AuthID string -func NewRegistrationID() (RegistrationID, error) { - rid, err := util.GenerateRandomStringURLSafe(RegistrationIDLength) +func NewAuthID() (AuthID, error) { + rid, err := util.GenerateRandomStringURLSafe(AuthIDLength) if err != nil { return "", err } - return RegistrationID(rid), nil + return AuthID(rid), nil } -func MustRegistrationID() RegistrationID { - rid, err := NewRegistrationID() +func MustAuthID() AuthID { + rid, err := NewAuthID() if err != nil { panic(err) } @@ -181,43 +181,87 @@ func MustRegistrationID() RegistrationID { return rid } -func RegistrationIDFromString(str string) (RegistrationID, error) { - if len(str) != RegistrationIDLength { - return "", fmt.Errorf("%w: expected %d, got %d", ErrInvalidRegistrationIDLength, RegistrationIDLength, len(str)) +func AuthIDFromString(str string) (AuthID, error) { + r := AuthID(str) + + err := r.Validate() + if err != nil { + return "", err } - return RegistrationID(str), nil + return r, nil } -func (r RegistrationID) String() string { +func (r AuthID) String() string { return string(r) } -type RegisterNode struct { - Node Node - Registered chan *Node - closed *atomic.Bool +func (r AuthID) Validate() error { + if len(r) != AuthIDLength { + return fmt.Errorf("%w: expected %d, got %d", ErrInvalidAuthIDLength, AuthIDLength, len(r)) + } + + return nil } -func NewRegisterNode(node Node) RegisterNode { - return RegisterNode{ - Node: node, - Registered: make(chan *Node), - closed: &atomic.Bool{}, +// AuthRequest represent a pending authentication request from a user or a node. +// If it is a registration request, the node field will be populate with the node that is trying to register. +// When the authentication process is finished, the node that has been authenticated will be sent through the Finished channel. +// The closed field is used to ensure that the Finished channel is only closed once, and that no more nodes are sent through it after it has been closed. +type AuthRequest struct { + node *Node + finished chan NodeView + closed *atomic.Bool +} + +func NewRegisterAuthRequest(node Node) AuthRequest { + return AuthRequest{ + node: &node, + finished: make(chan NodeView), + closed: &atomic.Bool{}, } } -func (rn *RegisterNode) SendAndClose(node *Node) { +// Node returns the node that is trying to register. +// It will panic if the AuthRequest is not a registration request. +// Can _only_ be used in the registration path. +func (rn *AuthRequest) Node() NodeView { + if rn.node == nil { + panic("Node can only be used in registration requests") + } + + return rn.node.View() +} + +func (rn *AuthRequest) FinishAuth() { + rn.FinishRegistration(NodeView{}) +} + +func (rn *AuthRequest) FinishRegistration(node NodeView) { if rn.closed.Swap(true) { return } - select { - case rn.Registered <- node: - default: + if node.Valid() { + select { + case rn.finished <- node: + default: + } } - close(rn.Registered) + close(rn.finished) +} + +// WaitForRegistration waits for the authentication process to finish +// and returns the authenticated node. +// Can _only_ be used in the registration path. +func (rn *AuthRequest) WaitForRegistration() <-chan NodeView { + return rn.finished +} + +// WaitForAuth waits until a authentication request has been finished. +func (rn *AuthRequest) WaitForAuth() { + <-rn.WaitForRegistration() } // DefaultBatcherWorkers returns the default number of batcher workers. diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index cbce663b..034779b5 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -295,8 +295,8 @@ func IsCI() bool { // 3. If normalisation fails → generate invalid- replacement // // Returns the guaranteed-valid hostname to use. -func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) string { - if hostinfo == nil || hostinfo.Hostname == "" { +func EnsureHostname(hostinfo tailcfg.HostinfoView, machineKey, nodeKey string) string { + if !hostinfo.Valid() || hostinfo.Hostname() == "" { key := cmp.Or(machineKey, nodeKey) if key == "" { return "unknown-node" @@ -310,7 +310,7 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri return "node-" + keyPrefix } - lowercased := strings.ToLower(hostinfo.Hostname) + lowercased := strings.ToLower(hostinfo.Hostname()) err := ValidateHostname(lowercased) if err == nil { diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index 5cca4990..6e7a0630 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -1070,7 +1070,7 @@ func TestEnsureHostname(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) + got := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey) // For invalid hostnames, we just check the prefix since the random part varies if strings.HasPrefix(tt.want, "invalid-") { if !strings.HasPrefix(got, "invalid-") { @@ -1255,7 +1255,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) + gotHostname := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey) // For invalid hostnames, we just check the prefix since the random part varies if strings.HasPrefix(tt.wantHostname, "invalid-") { if !strings.HasPrefix(gotHostname, "invalid-") { @@ -1284,7 +1284,7 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) { hostinfo := &tailcfg.Hostinfo{Hostname: hostname} - result := EnsureHostname(hostinfo, "mkey", "nkey") + result := EnsureHostname(hostinfo.View(), "mkey", "nkey") if len(result) > 63 { t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result)) } @@ -1300,8 +1300,8 @@ func TestEnsureHostname_Idempotent(t *testing.T) { OS: "linux", } - hostname1 := EnsureHostname(originalHostinfo, "mkey", "nkey") - hostname2 := EnsureHostname(originalHostinfo, "mkey", "nkey") + hostname1 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey") + hostname2 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey") if hostname1 != hostname2 { t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2) diff --git a/integration/cli_test.go b/integration/cli_test.go index a1174277..c46361d4 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -1065,11 +1065,11 @@ func TestNodeCommand(t *testing.T) { require.NoError(t, err) regIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), } nodes := make([]*v1.Node, len(regIDs)) @@ -1153,8 +1153,8 @@ func TestNodeCommand(t *testing.T) { assert.Equal(t, "node-5", listAll[4].GetName()) otherUserRegIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), } otherUserMachines := make([]*v1.Node, len(otherUserRegIDs)) @@ -1326,11 +1326,11 @@ func TestNodeExpireCommand(t *testing.T) { require.NoError(t, err) regIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), } nodes := make([]*v1.Node, len(regIDs)) @@ -1461,11 +1461,11 @@ func TestNodeRenameCommand(t *testing.T) { require.NoError(t, err) regIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), } nodes := make([]*v1.Node, len(regIDs)) From c4428d80b05aa1388e2c5048319d51998a7daaf2 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 18 Feb 2026 14:07:26 +0100 Subject: [PATCH 06/17] types: introduce AuthVerdict, unify auth finish API Replace the separate FinishRegistration(NodeView) and FinishAuth() methods with a single FinishAuth(AuthVerdict) that carries both an optional error and the authenticated node. WaitForRegistration is renamed to WaitForAuth returning <-chan AuthVerdict. This allows the auth flow to propagate structured outcomes (accept/reject with reason) rather than inferring meaning from whether a NodeView is valid. --- hscontrol/auth.go | 14 ++++++++------ hscontrol/auth_test.go | 4 ++-- hscontrol/state/state.go | 7 +++++-- hscontrol/types/common.go | 40 ++++++++++++++++++++------------------- 4 files changed, 36 insertions(+), 29 deletions(-) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index fd1b231b..ee301242 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -272,13 +272,15 @@ func (h *Headscale) waitForFollowup( select { case <-ctx.Done(): return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err) - case node := <-reg.WaitForRegistration(): - if !node.Valid() { - // registration is expired in the cache, instruct the client to try a new registration - return h.reqToNewRegisterResponse(req, machineKey) - } + case verdict := <-reg.WaitForAuth(): + if verdict.Accept() { + if !verdict.Node.Valid() { + // registration is expired in the cache, instruct the client to try a new registration + return h.reqToNewRegisterResponse(req, machineKey) + } - return nodeToRegisterResponse(node), nil + return nodeToRegisterResponse(verdict.Node), nil + } } } diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index 8215b07c..321b55fa 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -692,7 +692,7 @@ func TestAuthenticationFlows(t *testing.T) { user := app.state.CreateUserForTest("followup-user") node := app.state.CreateNodeForTest(user, "followup-success-node") - nodeToRegister.FinishRegistration(node.View()) + nodeToRegister.FinishAuth(types.AuthVerdict{Node: node.View()}) }() return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil @@ -1348,7 +1348,7 @@ func TestAuthenticationFlows(t *testing.T) { // Simulate registration that returns empty NodeView (cache expired during auth) go func() { - nodeToRegister.FinishRegistration(types.NodeView{}) // Empty view indicates cache expiry + nodeToRegister.FinishAuth(types.AuthVerdict{Node: types.NodeView{}}) // Empty view indicates cache expiry }() return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index eb927750..1ec3eedf 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -64,6 +64,9 @@ var ErrNodeNotInNodeStore = errors.New("node no longer exists in NodeStore") // ErrNodeNameNotUnique is returned when a node name is not unique. var ErrNodeNameNotUnique = errors.New("node name is not unique") +// ErrRegistrationExpired is returned when a registration has expired. +var ErrRegistrationExpired = errors.New("registration expired") + // State manages Headscale's core state, coordinating between database, policy management, // IP allocation, and DERP routing. All methods are thread-safe. type State struct { @@ -110,7 +113,7 @@ func NewState(cfg *types.Config) (*State, error) { authCache.OnEvicted( func(id types.AuthID, rn types.AuthRequest) { - rn.FinishRegistration(types.NodeView{}) + rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired}) }, ) @@ -1625,7 +1628,7 @@ func (s *State) HandleNodeFromAuthPath( } // Signal to waiting clients - regEntry.FinishRegistration(finalNode) + regEntry.FinishAuth(types.AuthVerdict{Node: finalNode}) // Delete from registration cache s.authCache.Delete(authID) diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 66bbf619..891969d3 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -210,14 +210,14 @@ func (r AuthID) Validate() error { // The closed field is used to ensure that the Finished channel is only closed once, and that no more nodes are sent through it after it has been closed. type AuthRequest struct { node *Node - finished chan NodeView + finished chan AuthVerdict closed *atomic.Bool } func NewRegisterAuthRequest(node Node) AuthRequest { return AuthRequest{ node: &node, - finished: make(chan NodeView), + finished: make(chan AuthVerdict), closed: &atomic.Bool{}, } } @@ -233,35 +233,37 @@ func (rn *AuthRequest) Node() NodeView { return rn.node.View() } -func (rn *AuthRequest) FinishAuth() { - rn.FinishRegistration(NodeView{}) -} - -func (rn *AuthRequest) FinishRegistration(node NodeView) { +func (rn *AuthRequest) FinishAuth(verdict AuthVerdict) { if rn.closed.Swap(true) { return } - if node.Valid() { - select { - case rn.finished <- node: - default: - } + select { + case rn.finished <- verdict: + default: } close(rn.finished) } -// WaitForRegistration waits for the authentication process to finish -// and returns the authenticated node. -// Can _only_ be used in the registration path. -func (rn *AuthRequest) WaitForRegistration() <-chan NodeView { +func (rn *AuthRequest) WaitForAuth() <-chan AuthVerdict { return rn.finished } -// WaitForAuth waits until a authentication request has been finished. -func (rn *AuthRequest) WaitForAuth() { - <-rn.WaitForRegistration() +type AuthVerdict struct { + // Err is the error that occurred during the authentication process, if any. + // If Err is nil, the authentication process has succeeded. + // If Err is not nil, the authentication process has failed and the node should not be authenticated. + Err error + + // Node is the node that has been authenticated. + // Node is only valid if the auth request was a registration request + // and the authentication process has succeeded. + Node NodeView +} + +func (v AuthVerdict) Accept() bool { + return v.Err == nil } // DefaultBatcherWorkers returns the default number of batcher workers. From 4525734d252c5720c8a2ee4463af6df7ab553492 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 18 Feb 2026 14:09:06 +0100 Subject: [PATCH 07/17] templates, oidc, handlers: generalise auth templates Replace the single-purpose OIDCCallback and RegisterWeb templates with two reusable templates: - AuthSuccess: configurable success page used for node registration, reauthentication, and SSH session authorisation. - AuthWeb: CLI command instruction page used for both node registration and auth approval flows. Move successBox and checkboxIcon into design.go as shared primitives. Also handle the non-registration OIDC callback path: look up the auth session, send an accept verdict, and render an SSH authorisation success page. --- hscontrol/handlers.go | 23 ++++- hscontrol/oidc.go | 67 +++++++++++---- hscontrol/oidc_template_test.go | 53 ++++++++---- hscontrol/templates/auth_success.go | 62 ++++++++++++++ hscontrol/templates/auth_web.go | 21 +++++ hscontrol/templates/design.go | 41 +++++++++ hscontrol/templates/oidc_callback.go | 69 --------------- hscontrol/templates/register_web.go | 21 ----- hscontrol/templates_consistency_test.go | 106 ++++++++++++++++++++---- 9 files changed, 324 insertions(+), 139 deletions(-) create mode 100644 hscontrol/templates/auth_success.go create mode 100644 hscontrol/templates/auth_web.go delete mode 100644 hscontrol/templates/oidc_callback.go delete mode 100644 hscontrol/templates/register_web.go diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index b7aa8460..9f544f8d 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -262,6 +262,23 @@ func (a *AuthProviderWeb) AuthHandler( writer http.ResponseWriter, req *http.Request, ) { + authID, err := authIDFromRequest(req) + if err != nil { + httpError(writer, err) + return + } + + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + writer.WriteHeader(http.StatusOK) + + _, err = writer.Write([]byte(templates.AuthWeb( + "Authentication check", + "Run the command below in the headscale server to approve this authentication request:", + "headscale auth approve --auth-id "+authID.String(), + ).Render())) + if err != nil { + log.Error().Err(err).Msg("failed to write auth response") + } } func authIDFromRequest(req *http.Request) (types.AuthID, error) { @@ -299,7 +316,11 @@ func (a *AuthProviderWeb) RegisterHandler( writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - _, err = writer.Write([]byte(templates.RegisterWeb(registrationId).Render())) + _, err = writer.Write([]byte(templates.AuthWeb( + "Node registration", + "Run the command below in the headscale server to add this node to your network:", + fmt.Sprintf("headscale auth register --auth-id %s --user USERNAME", registrationId.String()), + ).Render())) if err != nil { log.Error().Err(err).Msg("failed to write register response") } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 2bc62fa9..ee6dbeb9 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -333,8 +333,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // If this is a registration flow, then we need to register the node. if authInfo.Registration { - verb := "Reauthenticated" - newNode, err := a.handleRegistration(user, authInfo.AuthID, nodeExpiry) if err != nil { if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { @@ -349,12 +347,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - if newNode { - verb = "Authenticated" - } - - // TODO(kradalby): replace with go-elem - content := renderOIDCCallbackTemplate(user, verb) + content := renderRegistrationSuccessTemplate(user, newNode) writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) @@ -366,8 +359,28 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - // TODO(kradalby): handle login flow (without registration) if needed. - // We need to send an update here to whatever might be waiting for this auth flow. + // If this is not a registration callback, then its a regular authentication callback + // and we need to send a response and confirm that the access was allowed. + + authReq, ok := a.h.state.GetAuthCacheEntry(authInfo.AuthID) + if !ok { + log.Debug().Caller().Str("auth_id", authInfo.AuthID.String()).Msg("auth session expired before authorization completed") + httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) + + return + } + + // Send a finish auth verdict with no errors to let the CLI know that the authentication was successful. + authReq.FinishAuth(types.AuthVerdict{}) + + content := renderAuthSuccessTemplate(user) + + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + writer.WriteHeader(http.StatusOK) + + if _, err := writer.Write(content.Bytes()); err != nil { //nolint:noinlineerr + util.LogErr(err, "Failed to write HTTP response") + } } func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time { @@ -623,12 +636,38 @@ func (a *AuthProviderOIDC) handleRegistration( return !nodeChange.IsEmpty(), nil } -func renderOIDCCallbackTemplate( +func renderRegistrationSuccessTemplate( user *types.User, - verb string, + newNode bool, ) *bytes.Buffer { - html := templates.OIDCCallback(user.Display(), verb).Render() - return bytes.NewBufferString(html) + result := templates.AuthSuccessResult{ + Title: "Headscale - Node Reauthenticated", + Heading: "Node reauthenticated", + Verb: "Reauthenticated", + User: user.Display(), + Message: "You can now close this window.", + } + if newNode { + result.Title = "Headscale - Node Registered" + result.Heading = "Node registered" + result.Verb = "Registered" + } + + return bytes.NewBufferString(templates.AuthSuccess(result).Render()) +} + +func renderAuthSuccessTemplate( + user *types.User, +) *bytes.Buffer { + result := templates.AuthSuccessResult{ + Title: "Headscale - SSH Session Authorized", + Heading: "SSH session authorized", + Verb: "Authorized", + User: user.Display(), + Message: "You may return to your terminal.", + } + + return bytes.NewBufferString(templates.AuthSuccess(result).Render()) } // getCookieName generates a unique cookie name based on a cookie value. diff --git a/hscontrol/oidc_template_test.go b/hscontrol/oidc_template_test.go index 367451b1..24dfc0b0 100644 --- a/hscontrol/oidc_template_test.go +++ b/hscontrol/oidc_template_test.go @@ -7,35 +7,54 @@ import ( "github.com/stretchr/testify/assert" ) -func TestOIDCCallbackTemplate(t *testing.T) { +func TestAuthSuccessTemplate(t *testing.T) { tests := []struct { - name string - userName string - verb string + name string + result templates.AuthSuccessResult }{ { - name: "logged_in_user", - userName: "test@example.com", - verb: "Logged in", + name: "node_registered", + result: templates.AuthSuccessResult{ + Title: "Headscale - Node Registered", + Heading: "Node registered", + Verb: "Registered", + User: "newuser@example.com", + Message: "You can now close this window.", + }, }, { - name: "registered_user", - userName: "newuser@example.com", - verb: "Registered", + name: "node_reauthenticated", + result: templates.AuthSuccessResult{ + Title: "Headscale - Node Reauthenticated", + Heading: "Node reauthenticated", + Verb: "Reauthenticated", + User: "test@example.com", + Message: "You can now close this window.", + }, + }, + { + name: "ssh_session_authorized", + result: templates.AuthSuccessResult{ + Title: "Headscale - SSH Session Authorized", + Heading: "SSH session authorized", + Verb: "Authorized", + User: "test@example.com", + Message: "You may return to your terminal.", + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Render using the elem-go template - html := templates.OIDCCallback(tt.userName, tt.verb).Render() + html := templates.AuthSuccess(tt.result).Render() - // Verify the HTML contains expected elements + // Verify the HTML contains expected structural elements assert.Contains(t, html, "") - assert.Contains(t, html, "Headscale Authentication Succeeded") - assert.Contains(t, html, tt.verb) - assert.Contains(t, html, tt.userName) - assert.Contains(t, html, "You can now close this window") + assert.Contains(t, html, ""+tt.result.Title+"") + assert.Contains(t, html, tt.result.Heading) + assert.Contains(t, html, tt.result.Verb+" as ") + assert.Contains(t, html, tt.result.User) + assert.Contains(t, html, tt.result.Message) // Verify Material for MkDocs design system CSS is present assert.Contains(t, html, "Material for MkDocs") diff --git a/hscontrol/templates/auth_success.go b/hscontrol/templates/auth_success.go new file mode 100644 index 00000000..1a212b6e --- /dev/null +++ b/hscontrol/templates/auth_success.go @@ -0,0 +1,62 @@ +package templates + +import ( + "github.com/chasefleming/elem-go" +) + +// AuthSuccessResult contains the text content for an authentication success page. +// Each field controls a distinct piece of user-facing text so that every auth +// flow (node registration, reauthentication, SSH check, …) can clearly +// communicate what just happened. +type AuthSuccessResult struct { + // Title is the browser tab / page title, + // e.g. "Headscale - Node Registered". + Title string + + // Heading is the bold green text inside the success box, + // e.g. "Node registered". + Heading string + + // Verb is the action prefix in the body text before "as ", + // e.g. "Registered", "Reauthenticated", "Authorized". + Verb string + + // User is the display name shown in bold in the body text, + // e.g. "user@example.com". + User string + + // Message is the follow-up instruction shown after the user name, + // e.g. "You can now close this window." + Message string +} + +// AuthSuccess renders an authentication / authorisation success page. +// The caller controls every user-visible string via [AuthSuccessResult] so the +// page clearly describes what succeeded (registration, reauth, SSH check, …). +func AuthSuccess(result AuthSuccessResult) *elem.Element { + box := successBox( + result.Heading, + elem.Text(result.Verb+" as "), + elem.Strong(nil, elem.Text(result.User)), + elem.Text(". "+result.Message), + ) + + return HtmlStructure( + elem.Title(nil, elem.Text(result.Title)), + mdTypesetBody( + headscaleLogo(), + box, + H2(elem.Text("Getting started")), + P(elem.Text("Check out the documentation to learn more about headscale and Tailscale:")), + Ul( + elem.Li(nil, + externalLink("https://headscale.net/stable/", "Headscale documentation"), + ), + elem.Li(nil, + externalLink("https://tailscale.com/kb/", "Tailscale knowledge base"), + ), + ), + pageFooter(), + ), + ) +} diff --git a/hscontrol/templates/auth_web.go b/hscontrol/templates/auth_web.go new file mode 100644 index 00000000..8b6d6f97 --- /dev/null +++ b/hscontrol/templates/auth_web.go @@ -0,0 +1,21 @@ +package templates + +import ( + "github.com/chasefleming/elem-go" +) + +// AuthWeb renders a page that instructs an administrator to run a CLI command +// to complete an authentication or registration flow. +// It is used by both the registration and auth-approve web handlers. +func AuthWeb(title, description, command string) *elem.Element { + return HtmlStructure( + elem.Title(nil, elem.Text(title+" - Headscale")), + mdTypesetBody( + headscaleLogo(), + H1(elem.Text(title)), + P(elem.Text(description)), + Pre(PreCode(command)), + pageFooter(), + ), + ) +} diff --git a/hscontrol/templates/design.go b/hscontrol/templates/design.go index 615c0e41..221eaf11 100644 --- a/hscontrol/templates/design.go +++ b/hscontrol/templates/design.go @@ -365,6 +365,47 @@ func orDivider() *elem.Element { ) } +// successBox creates a green success feedback box with a checkmark icon. +// The heading is displayed as bold green text, and children are rendered below it. +// Pairs with warningBox for consistent feedback styling. +// +//nolint:unused // Used in auth_success.go template. +func successBox(heading string, children ...elem.Node) *elem.Element { + return elem.Div(attrs.Props{ + attrs.Style: styles.Props{ + styles.Display: "flex", + styles.AlignItems: "center", + styles.Gap: spaceM, + styles.Padding: spaceL, + styles.BackgroundColor: colorSuccessLight, + styles.Border: "1px solid " + colorSuccess, + styles.BorderRadius: "0.5rem", + styles.MarginBottom: spaceXL, + }.ToInline(), + }, + checkboxIcon(), + elem.Div(nil, + append([]elem.Node{ + elem.Strong(attrs.Props{ + attrs.Style: styles.Props{ + styles.Display: "block", + styles.Color: colorSuccess, + styles.FontSize: fontSizeH3, + styles.MarginBottom: spaceXS, + }.ToInline(), + }, elem.Text(heading)), + }, children...)..., + ), + ) +} + +// checkboxIcon returns the success checkbox SVG icon as raw HTML. +func checkboxIcon() elem.Node { + return elem.Raw(``) +} + // warningBox creates a warning message box with icon and content. // //nolint:unused // Used in apple.go template. diff --git a/hscontrol/templates/oidc_callback.go b/hscontrol/templates/oidc_callback.go deleted file mode 100644 index 16c08fde..00000000 --- a/hscontrol/templates/oidc_callback.go +++ /dev/null @@ -1,69 +0,0 @@ -package templates - -import ( - "github.com/chasefleming/elem-go" - "github.com/chasefleming/elem-go/attrs" - "github.com/chasefleming/elem-go/styles" -) - -// checkboxIcon returns the success checkbox SVG icon as raw HTML. -func checkboxIcon() elem.Node { - return elem.Raw(``) -} - -// OIDCCallback renders the OIDC authentication success callback page. -func OIDCCallback(user, verb string) *elem.Element { - // Success message box - successBox := elem.Div(attrs.Props{ - attrs.Style: styles.Props{ - styles.Display: "flex", - styles.AlignItems: "center", - styles.Gap: spaceM, - styles.Padding: spaceL, - styles.BackgroundColor: colorSuccessLight, - styles.Border: "1px solid " + colorSuccess, - styles.BorderRadius: "0.5rem", - styles.MarginBottom: spaceXL, - }.ToInline(), - }, - checkboxIcon(), - elem.Div(nil, - elem.Strong(attrs.Props{ - attrs.Style: styles.Props{ - styles.Display: "block", - styles.Color: colorSuccess, - styles.FontSize: fontSizeH3, - styles.MarginBottom: spaceXS, - }.ToInline(), - }, elem.Text("Signed in successfully")), - elem.P(attrs.Props{ - attrs.Style: styles.Props{ - styles.Margin: "0", - styles.Color: colorTextPrimary, - styles.FontSize: fontSizeBase, - }.ToInline(), - }, elem.Text(verb), elem.Text(" as "), elem.Strong(nil, elem.Text(user)), elem.Text(". You can now close this window.")), - ), - ) - - return HtmlStructure( - elem.Title(nil, elem.Text("Headscale Authentication Succeeded")), - mdTypesetBody( - headscaleLogo(), - successBox, - H2(elem.Text("Getting started")), - P(elem.Text("Check out the documentation to learn more about headscale and Tailscale:")), - Ul( - elem.Li(nil, - externalLink("https://headscale.net/stable/", "Headscale documentation"), - ), - elem.Li(nil, - externalLink("https://tailscale.com/kb/", "Tailscale knowledge base"), - ), - ), - pageFooter(), - ), - ) -} diff --git a/hscontrol/templates/register_web.go b/hscontrol/templates/register_web.go deleted file mode 100644 index cdede03b..00000000 --- a/hscontrol/templates/register_web.go +++ /dev/null @@ -1,21 +0,0 @@ -package templates - -import ( - "fmt" - - "github.com/chasefleming/elem-go" - "github.com/juanfont/headscale/hscontrol/types" -) - -func RegisterWeb(registrationID types.AuthID) *elem.Element { - return HtmlStructure( - elem.Title(nil, elem.Text("Registration - Headscale")), - mdTypesetBody( - headscaleLogo(), - H1(elem.Text("Machine registration")), - P(elem.Text("Run the command below in the headscale server to add this machine to your network:")), - Pre(PreCode(fmt.Sprintf("headscale nodes register --key %s --user USERNAME", registrationID.String()))), - pageFooter(), - ), - ) -} diff --git a/hscontrol/templates_consistency_test.go b/hscontrol/templates_consistency_test.go index 0464fb88..4836c1d1 100644 --- a/hscontrol/templates_consistency_test.go +++ b/hscontrol/templates_consistency_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/juanfont/headscale/hscontrol/templates" - "github.com/juanfont/headscale/hscontrol/types" "github.com/stretchr/testify/assert" ) @@ -16,12 +15,30 @@ func TestTemplateHTMLConsistency(t *testing.T) { html string }{ { - name: "OIDC Callback", - html: templates.OIDCCallback("test@example.com", "Logged in").Render(), + name: "Auth Success", + html: templates.AuthSuccess(templates.AuthSuccessResult{ + Title: "Headscale - Node Registered", + Heading: "Node registered", + Verb: "Registered", + User: "test@example.com", + Message: "You can now close this window.", + }).Render(), }, { - name: "Register Web", - html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(), + name: "Auth Web Register", + html: templates.AuthWeb( + "Machine registration", + "Run the command below in the headscale server to add this machine to your network:", + "headscale auth register --auth-id test-key-123 --user USERNAME", + ).Render(), + }, + { + name: "Auth Web Approve", + html: templates.AuthWeb( + "Authentication check", + "Run the command below in the headscale server to approve this authentication request:", + "headscale auth approve --auth-id test-key-123", + ).Render(), }, { name: "Windows Config", @@ -72,12 +89,30 @@ func TestTemplateModernHTMLFeatures(t *testing.T) { html string }{ { - name: "OIDC Callback", - html: templates.OIDCCallback("test@example.com", "Logged in").Render(), + name: "Auth Success", + html: templates.AuthSuccess(templates.AuthSuccessResult{ + Title: "Headscale - Node Registered", + Heading: "Node registered", + Verb: "Registered", + User: "test@example.com", + Message: "You can now close this window.", + }).Render(), }, { - name: "Register Web", - html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(), + name: "Auth Web Register", + html: templates.AuthWeb( + "Machine registration", + "Run the command below in the headscale server to add this machine to your network:", + "headscale auth register --auth-id test-key-123 --user USERNAME", + ).Render(), + }, + { + name: "Auth Web Approve", + html: templates.AuthWeb( + "Authentication check", + "Run the command below in the headscale server to approve this authentication request:", + "headscale auth approve --auth-id test-key-123", + ).Render(), }, { name: "Windows Config", @@ -116,16 +151,35 @@ func TestTemplateExternalLinkSecurity(t *testing.T) { externalURLs []string // URLs that should have security attributes }{ { - name: "OIDC Callback", - html: templates.OIDCCallback("test@example.com", "Logged in").Render(), + name: "Auth Success", + html: templates.AuthSuccess(templates.AuthSuccessResult{ + Title: "Headscale - Node Registered", + Heading: "Node registered", + Verb: "Registered", + User: "test@example.com", + Message: "You can now close this window.", + }).Render(), externalURLs: []string{ "https://headscale.net/stable/", "https://tailscale.com/kb/", }, }, { - name: "Register Web", - html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(), + name: "Auth Web Register", + html: templates.AuthWeb( + "Machine registration", + "Run the command below in the headscale server to add this machine to your network:", + "headscale auth register --auth-id test-key-123 --user USERNAME", + ).Render(), + externalURLs: []string{}, // No external links + }, + { + name: "Auth Web Approve", + html: templates.AuthWeb( + "Authentication check", + "Run the command below in the headscale server to approve this authentication request:", + "headscale auth approve --auth-id test-key-123", + ).Render(), externalURLs: []string{}, // No external links }, { @@ -185,12 +239,30 @@ func TestTemplateAccessibilityAttributes(t *testing.T) { html string }{ { - name: "OIDC Callback", - html: templates.OIDCCallback("test@example.com", "Logged in").Render(), + name: "Auth Success", + html: templates.AuthSuccess(templates.AuthSuccessResult{ + Title: "Headscale - Node Registered", + Heading: "Node registered", + Verb: "Registered", + User: "test@example.com", + Message: "You can now close this window.", + }).Render(), }, { - name: "Register Web", - html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(), + name: "Auth Web Register", + html: templates.AuthWeb( + "Machine registration", + "Run the command below in the headscale server to add this machine to your network:", + "headscale auth register --auth-id test-key-123 --user USERNAME", + ).Render(), + }, + { + name: "Auth Web Approve", + html: templates.AuthWeb( + "Authentication check", + "Run the command below in the headscale server to approve this authentication request:", + "headscale auth approve --auth-id test-key-123", + ).Render(), }, { name: "Windows Config", From d8c34ba7f0bb4e5b53c424faf468eef8e7fba930 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 10 Feb 2026 16:54:56 +0100 Subject: [PATCH 08/17] noise, policy: implement SSH check action handler --- .github/workflows/test-integration.yaml | 3 +- hscontrol/noise.go | 132 +++++++++++++++++++++--- hscontrol/policy/v2/filter.go | 2 +- hscontrol/types/common.go | 7 ++ integration/scenario.go | 21 +++- 5 files changed, 144 insertions(+), 21 deletions(-) diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 7e059045..e9483adf 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -253,7 +253,8 @@ jobs: - TestSSHIsBlockedInACL - TestSSHUserOnlyIsolation - TestSSHAutogroupSelf - - TestSSHOneUserToOneCheckMode + - TestSSHOneUserToOneCheckModeCLI + - TestSSHOneUserToOneCheckModeOIDC - TestTagsAuthKeyWithTagRequestDifferentTag - TestTagsAuthKeyWithTagNoAdvertiseFlag - TestTagsAuthKeyWithTagCannotAddViaCLI diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 9a8814ca..c232d5d2 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "net/url" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" @@ -24,6 +25,12 @@ import ( // ErrUnsupportedClientVersion is returned when a client connects with an unsupported protocol version. var ErrUnsupportedClientVersion = errors.New("unsupported client version") +// ErrMissingURLParameter is returned when a required URL parameter is not provided. +var ErrMissingURLParameter = errors.New("missing URL parameter") + +// ErrUnsupportedURLParameterType is returned when a URL parameter has an unsupported type. +var ErrUnsupportedURLParameterType = errors.New("unsupported URL parameter type") + const ( // ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade. ts2021UpgradePath = "/ts2021" @@ -116,7 +123,8 @@ func (h *Headscale) NoiseUpgradeHandler( r.Post("/register", ns.RegistrationHandler) r.Post("/map", ns.PollNetMapHandler) - r.Get("/ssh/action/from/{src_node_id}/to/{dst_node_id}/ssh_user/{ssh_user}/local_user/{local_user}", ns.SSHAction) + // SSH Check mode endpoint, consulted to validate if a given SSH connection should be accepted or rejected. + r.Get("/ssh/action/from/{src_node_id}/to/{dst_node_id}", ns.SSHActionHandler) // Not implemented yet // @@ -243,33 +251,125 @@ func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *ht http.Error(writer, "Not implemented yet", http.StatusNotImplemented) } -// SSHAction handles the /ssh-action endpoint, it returns a [tailcfg.SSHAction] +func urlParam[T any](req *http.Request, key string) (T, error) { + var zero T + + param := chi.URLParam(req, key) + if param == "" { + return zero, fmt.Errorf("%w: %s", ErrMissingURLParameter, key) + } + + var value T + switch any(value).(type) { + case string: + v, ok := any(param).(T) + if !ok { + return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value) + } + + value = v + default: + return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value) + } + + return value, nil +} + +// SSHActionHandler handles the /ssh-action endpoint, it returns a [tailcfg.SSHActionHandler] // to the client with the verdict of an SSH access request. -func (ns *noiseServer) SSHAction(writer http.ResponseWriter, req *http.Request) { - srcNodeID := chi.URLParam(req, "src_node_id") - dstNodeID := chi.URLParam(req, "dst_node_id") - sshUser := chi.URLParam(req, "ssh_user") - localUser := chi.URLParam(req, "local_user") +func (ns *noiseServer) SSHActionHandler(writer http.ResponseWriter, req *http.Request) { + srcNodeID, _ := urlParam[types.NodeID](req, "src_node_id") + dstNodeID, _ := urlParam[types.NodeID](req, "dst_node_id") + + sshUser := req.URL.Query().Get("ssh_user") + localUser := req.URL.Query().Get("local_user") + + // Set if this is a follow up request. + authIDStr := req.URL.Query().Get("auth_id") + log.Trace().Caller(). Str("path", req.URL.String()). - Str("src_node_id", srcNodeID). - Str("dst_node_id", dstNodeID). + Uint64("src_node_id", srcNodeID.Uint64()). + Uint64("dst_node_id", dstNodeID.Uint64()). Str("ssh_user", sshUser). Str("local_user", localUser). + Str("auth_id", authIDStr). Msg("got SSH action request") - accept := tailcfg.SSHAction{ - Reject: false, - Accept: true, - AllowAgentForwarding: true, - AllowLocalPortForwarding: true, - AllowRemotePortForwarding: true, + var action tailcfg.SSHAction + + action.AllowAgentForwarding = true + action.AllowLocalPortForwarding = true + action.AllowRemotePortForwarding = true + + if authIDStr == "" { + holdURL, err := url.Parse(ns.headscale.cfg.ServerURL + "/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER") + if err != nil { + log.Error().Caller().Err(err).Msg("failed to parse SSH action URL") + http.Error(writer, "Internal error", http.StatusInternalServerError) + + return + } + + authID, err := types.NewAuthID() + if err != nil { + log.Error().Caller().Err(err).Msg("failed to generate auth ID for SSH action") + http.Error(writer, "Internal error", http.StatusInternalServerError) + + return + } + + ns.headscale.state.SetAuthCacheEntry(authID, types.NewAuthRequest()) + + authURL := ns.headscale.authProvider.AuthURL(authID) + + q := holdURL.Query() + q.Set("auth_id", authID.String()) + holdURL.RawQuery = q.Encode() + + action.HoldAndDelegate = holdURL.String() + // TODO(kradalby): here we can also send a very tiny mapresponse + // "popping" the url and opening it for the user. + action.Message = fmt.Sprintf(`# Headscale SSH requires an additional check. +# To authenticate, visit: %s +# Authentication checked with Headscale SSH. +`, authURL) + } else { + authID, err := types.AuthIDFromString(authIDStr) + if err != nil { + log.Error().Caller().Err(err).Str("auth_id", authIDStr).Msg("invalid auth_id in SSH action request") + http.Error(writer, "Invalid auth_id", http.StatusBadRequest) + + return + } + + log.Trace().Caller().Str("auth_id", authID.String()).Msg("SSH action follow-up request with auth_id") + + auth, ok := ns.headscale.state.GetAuthCacheEntry(authID) + if !ok { + log.Error().Caller().Str("auth_id", authID.String()).Msg("no auth session found for auth_id in SSH action request") + http.Error(writer, "Invalid auth_id", http.StatusBadRequest) + + return + } + + verdict := <-auth.WaitForAuth() + + if verdict.Accept() { + action.Reject = false + action.Accept = true + } else { + action.Reject = true + action.Accept = false + + log.Trace().Caller().Str("auth_id", authID.String()).Err(verdict.Err).Msg("SSH action authentication rejected") + } } writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(http.StatusOK) - err := json.NewEncoder(writer).Encode(accept) + err := json.NewEncoder(writer).Encode(action) if err != nil { log.Error().Caller().Err(err).Msg("failed to encode SSH action response") return diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 526a0cb1..f15093aa 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -339,7 +339,7 @@ func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction { // * $DST_NODE_ID (Node.ID as int64 string) // * $SSH_USER (URL escaped, ssh user requested) // * $LOCAL_USER (URL escaped, local user mapped) - HoldAndDelegate: fmt.Sprintf("%s/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID/ssh_user/$SSH_USER/local_user/$LOCAL_USER", baseURL), + HoldAndDelegate: baseURL + "/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER", AllowAgentForwarding: true, AllowLocalPortForwarding: true, AllowRemotePortForwarding: true, diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 891969d3..01429dc9 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -214,6 +214,13 @@ type AuthRequest struct { closed *atomic.Bool } +func NewAuthRequest() AuthRequest { + return AuthRequest{ + finished: make(chan AuthVerdict), + closed: &atomic.Bool{}, + } +} + func NewRegisterAuthRequest(node Node) AuthRequest { return AuthRequest{ node: &node, diff --git a/integration/scenario.go b/integration/scenario.go index cd43b78f..e769bd73 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -141,6 +141,12 @@ type ScenarioSpec struct { // Versions is specific list of versions to use for the test. Versions []string + // OIDCSkipUserCreation, if true, skips creating users via headscale CLI + // during environment setup. Useful for OIDC tests where the SSH policy + // references users by name, since OIDC login creates users automatically + // and pre-creating them via CLI causes duplicate user records. + OIDCSkipUserCreation bool + // OIDCUsers, if populated, will start a Mock OIDC server and populate // the user login stack with the given users. // If the NodesPerUser is set, it should align with this list to ensure @@ -866,9 +872,18 @@ func (s *Scenario) createHeadscaleEnvWithTags( } for _, user := range s.spec.Users { - u, err := s.CreateUser(user) - if err != nil { - return err + var u *v1.User + + if s.spec.OIDCSkipUserCreation { + // Only register locally — OIDC login will create the headscale user. + s.mu.Lock() + s.users[user] = &User{Clients: make(map[string]TailscaleClient)} + s.mu.Unlock() + } else { + u, err = s.CreateUser(user) + if err != nil { + return err + } } var userOpts []tsic.Option From ecc82f25d96431de55da573db3716fae52a82ee7 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 18 Feb 2026 14:57:06 +0100 Subject: [PATCH 09/17] proto: add AuthRegister and AuthApprove RPCs Add auth.proto with AuthRegister{Request,Response} and AuthApprove{Request,Response} messages. Add AuthRegister and AuthApprove RPCs to the HeadscaleService in headscale.proto. These RPCs align the gRPC API with the new CLI commands (headscale auth register, headscale auth approve) that the HTML templates already reference. Updates #1850 --- proto/headscale/v1/auth.proto | 20 ++++++++++++++++++++ proto/headscale/v1/headscale.proto | 17 +++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 proto/headscale/v1/auth.proto diff --git a/proto/headscale/v1/auth.proto b/proto/headscale/v1/auth.proto new file mode 100644 index 00000000..8292400e --- /dev/null +++ b/proto/headscale/v1/auth.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; +package headscale.v1; +option go_package = "github.com/juanfont/headscale/gen/go/v1"; + +import "headscale/v1/node.proto"; + +message AuthRegisterRequest { + string user = 1; + string auth_id = 2; +} + +message AuthRegisterResponse { + Node node = 1; +} + +message AuthApproveRequest { + string auth_id = 1; +} + +message AuthApproveResponse {} diff --git a/proto/headscale/v1/headscale.proto b/proto/headscale/v1/headscale.proto index 5e556255..5a0dd288 100644 --- a/proto/headscale/v1/headscale.proto +++ b/proto/headscale/v1/headscale.proto @@ -8,6 +8,7 @@ import "headscale/v1/user.proto"; import "headscale/v1/preauthkey.proto"; import "headscale/v1/node.proto"; import "headscale/v1/apikey.proto"; +import "headscale/v1/auth.proto"; import "headscale/v1/policy.proto"; service HeadscaleService { @@ -139,6 +140,22 @@ service HeadscaleService { // --- Node end --- + // --- Auth start --- + rpc AuthRegister(AuthRegisterRequest) returns (AuthRegisterResponse) { + option (google.api.http) = { + post : "/api/v1/auth/register" + body : "*" + }; + } + + rpc AuthApprove(AuthApproveRequest) returns (AuthApproveResponse) { + option (google.api.http) = { + post : "/api/v1/auth/approve" + body : "*" + }; + } + // --- Auth end --- + // --- ApiKeys start --- rpc CreateApiKey(CreateApiKeyRequest) returns (CreateApiKeyResponse) { option (google.api.http) = { From 7204c0dfe0651dfdeb37cba65ac6f82172bbebd5 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 18 Feb 2026 14:57:18 +0100 Subject: [PATCH 10/17] gen: regenerate from auth proto changes Regenerated protobuf, gRPC gateway, and OpenAPI code from the new auth.proto and updated headscale.proto. Updates #1850 --- gen/go/headscale/v1/auth.pb.go | 266 ++++++++++++++++++ gen/go/headscale/v1/headscale.pb.go | 145 +++++----- gen/go/headscale/v1/headscale.pb.gw.go | 132 +++++++++ gen/go/headscale/v1/headscale_grpc.pb.go | 78 +++++ gen/openapiv2/headscale/v1/auth.swagger.json | 44 +++ .../headscale/v1/headscale.swagger.json | 95 +++++++ 6 files changed, 693 insertions(+), 67 deletions(-) create mode 100644 gen/go/headscale/v1/auth.pb.go create mode 100644 gen/openapiv2/headscale/v1/auth.swagger.json diff --git a/gen/go/headscale/v1/auth.pb.go b/gen/go/headscale/v1/auth.pb.go new file mode 100644 index 00000000..c4017b10 --- /dev/null +++ b/gen/go/headscale/v1/auth.pb.go @@ -0,0 +1,266 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc (unknown) +// source: headscale/v1/auth.proto + +package v1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type AuthRegisterRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + User string `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` + AuthId string `protobuf:"bytes,2,opt,name=auth_id,json=authId,proto3" json:"auth_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AuthRegisterRequest) Reset() { + *x = AuthRegisterRequest{} + mi := &file_headscale_v1_auth_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AuthRegisterRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AuthRegisterRequest) ProtoMessage() {} + +func (x *AuthRegisterRequest) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_auth_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AuthRegisterRequest.ProtoReflect.Descriptor instead. +func (*AuthRegisterRequest) Descriptor() ([]byte, []int) { + return file_headscale_v1_auth_proto_rawDescGZIP(), []int{0} +} + +func (x *AuthRegisterRequest) GetUser() string { + if x != nil { + return x.User + } + return "" +} + +func (x *AuthRegisterRequest) GetAuthId() string { + if x != nil { + return x.AuthId + } + return "" +} + +type AuthRegisterResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AuthRegisterResponse) Reset() { + *x = AuthRegisterResponse{} + mi := &file_headscale_v1_auth_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AuthRegisterResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AuthRegisterResponse) ProtoMessage() {} + +func (x *AuthRegisterResponse) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_auth_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AuthRegisterResponse.ProtoReflect.Descriptor instead. +func (*AuthRegisterResponse) Descriptor() ([]byte, []int) { + return file_headscale_v1_auth_proto_rawDescGZIP(), []int{1} +} + +func (x *AuthRegisterResponse) GetNode() *Node { + if x != nil { + return x.Node + } + return nil +} + +type AuthApproveRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + AuthId string `protobuf:"bytes,1,opt,name=auth_id,json=authId,proto3" json:"auth_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AuthApproveRequest) Reset() { + *x = AuthApproveRequest{} + mi := &file_headscale_v1_auth_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AuthApproveRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AuthApproveRequest) ProtoMessage() {} + +func (x *AuthApproveRequest) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_auth_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AuthApproveRequest.ProtoReflect.Descriptor instead. +func (*AuthApproveRequest) Descriptor() ([]byte, []int) { + return file_headscale_v1_auth_proto_rawDescGZIP(), []int{2} +} + +func (x *AuthApproveRequest) GetAuthId() string { + if x != nil { + return x.AuthId + } + return "" +} + +type AuthApproveResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AuthApproveResponse) Reset() { + *x = AuthApproveResponse{} + mi := &file_headscale_v1_auth_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AuthApproveResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AuthApproveResponse) ProtoMessage() {} + +func (x *AuthApproveResponse) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_auth_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AuthApproveResponse.ProtoReflect.Descriptor instead. +func (*AuthApproveResponse) Descriptor() ([]byte, []int) { + return file_headscale_v1_auth_proto_rawDescGZIP(), []int{3} +} + +var File_headscale_v1_auth_proto protoreflect.FileDescriptor + +const file_headscale_v1_auth_proto_rawDesc = "" + + "\n" + + "\x17headscale/v1/auth.proto\x12\fheadscale.v1\x1a\x17headscale/v1/node.proto\"B\n" + + "\x13AuthRegisterRequest\x12\x12\n" + + "\x04user\x18\x01 \x01(\tR\x04user\x12\x17\n" + + "\aauth_id\x18\x02 \x01(\tR\x06authId\">\n" + + "\x14AuthRegisterResponse\x12&\n" + + "\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\"-\n" + + "\x12AuthApproveRequest\x12\x17\n" + + "\aauth_id\x18\x01 \x01(\tR\x06authId\"\x15\n" + + "\x13AuthApproveResponseB)Z'github.com/juanfont/headscale/gen/go/v1b\x06proto3" + +var ( + file_headscale_v1_auth_proto_rawDescOnce sync.Once + file_headscale_v1_auth_proto_rawDescData []byte +) + +func file_headscale_v1_auth_proto_rawDescGZIP() []byte { + file_headscale_v1_auth_proto_rawDescOnce.Do(func() { + file_headscale_v1_auth_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_headscale_v1_auth_proto_rawDesc), len(file_headscale_v1_auth_proto_rawDesc))) + }) + return file_headscale_v1_auth_proto_rawDescData +} + +var file_headscale_v1_auth_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_headscale_v1_auth_proto_goTypes = []any{ + (*AuthRegisterRequest)(nil), // 0: headscale.v1.AuthRegisterRequest + (*AuthRegisterResponse)(nil), // 1: headscale.v1.AuthRegisterResponse + (*AuthApproveRequest)(nil), // 2: headscale.v1.AuthApproveRequest + (*AuthApproveResponse)(nil), // 3: headscale.v1.AuthApproveResponse + (*Node)(nil), // 4: headscale.v1.Node +} +var file_headscale_v1_auth_proto_depIdxs = []int32{ + 4, // 0: headscale.v1.AuthRegisterResponse.node:type_name -> headscale.v1.Node + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_headscale_v1_auth_proto_init() } +func file_headscale_v1_auth_proto_init() { + if File_headscale_v1_auth_proto != nil { + return + } + file_headscale_v1_node_proto_init() + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_headscale_v1_auth_proto_rawDesc), len(file_headscale_v1_auth_proto_rawDesc)), + NumEnums: 0, + NumMessages: 4, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_headscale_v1_auth_proto_goTypes, + DependencyIndexes: file_headscale_v1_auth_proto_depIdxs, + MessageInfos: file_headscale_v1_auth_proto_msgTypes, + }.Build() + File_headscale_v1_auth_proto = out.File + file_headscale_v1_auth_proto_goTypes = nil + file_headscale_v1_auth_proto_depIdxs = nil +} diff --git a/gen/go/headscale/v1/headscale.pb.go b/gen/go/headscale/v1/headscale.pb.go index 3d16778c..f52ca7e0 100644 --- a/gen/go/headscale/v1/headscale.pb.go +++ b/gen/go/headscale/v1/headscale.pb.go @@ -106,10 +106,10 @@ var File_headscale_v1_headscale_proto protoreflect.FileDescriptor const file_headscale_v1_headscale_proto_rawDesc = "" + "\n" + - "\x1cheadscale/v1/headscale.proto\x12\fheadscale.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17headscale/v1/user.proto\x1a\x1dheadscale/v1/preauthkey.proto\x1a\x17headscale/v1/node.proto\x1a\x19headscale/v1/apikey.proto\x1a\x19headscale/v1/policy.proto\"\x0f\n" + + "\x1cheadscale/v1/headscale.proto\x12\fheadscale.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17headscale/v1/user.proto\x1a\x1dheadscale/v1/preauthkey.proto\x1a\x17headscale/v1/node.proto\x1a\x19headscale/v1/apikey.proto\x1a\x17headscale/v1/auth.proto\x1a\x19headscale/v1/policy.proto\"\x0f\n" + "\rHealthRequest\"E\n" + "\x0eHealthResponse\x123\n" + - "\x15database_connectivity\x18\x01 \x01(\bR\x14databaseConnectivity2\x8c\x17\n" + + "\x15database_connectivity\x18\x01 \x01(\bR\x14databaseConnectivity2\xfa\x18\n" + "\x10HeadscaleService\x12h\n" + "\n" + "CreateUser\x12\x1f.headscale.v1.CreateUserRequest\x1a .headscale.v1.CreateUserResponse\"\x17\x82\xd3\xe4\x93\x02\x11:\x01*\"\f/api/v1/user\x12\x80\x01\n" + @@ -134,7 +134,9 @@ const file_headscale_v1_headscale_proto_rawDesc = "" + "\n" + "RenameNode\x12\x1f.headscale.v1.RenameNodeRequest\x1a .headscale.v1.RenameNodeResponse\"0\x82\xd3\xe4\x93\x02*\"(/api/v1/node/{node_id}/rename/{new_name}\x12b\n" + "\tListNodes\x12\x1e.headscale.v1.ListNodesRequest\x1a\x1f.headscale.v1.ListNodesResponse\"\x14\x82\xd3\xe4\x93\x02\x0e\x12\f/api/v1/node\x12\x80\x01\n" + - "\x0fBackfillNodeIPs\x12$.headscale.v1.BackfillNodeIPsRequest\x1a%.headscale.v1.BackfillNodeIPsResponse\" \x82\xd3\xe4\x93\x02\x1a\"\x18/api/v1/node/backfillips\x12p\n" + + "\x0fBackfillNodeIPs\x12$.headscale.v1.BackfillNodeIPsRequest\x1a%.headscale.v1.BackfillNodeIPsResponse\" \x82\xd3\xe4\x93\x02\x1a\"\x18/api/v1/node/backfillips\x12w\n" + + "\fAuthRegister\x12!.headscale.v1.AuthRegisterRequest\x1a\".headscale.v1.AuthRegisterResponse\" \x82\xd3\xe4\x93\x02\x1a:\x01*\"\x15/api/v1/auth/register\x12s\n" + + "\vAuthApprove\x12 .headscale.v1.AuthApproveRequest\x1a!.headscale.v1.AuthApproveResponse\"\x1f\x82\xd3\xe4\x93\x02\x19:\x01*\"\x14/api/v1/auth/approve\x12p\n" + "\fCreateApiKey\x12!.headscale.v1.CreateApiKeyRequest\x1a\".headscale.v1.CreateApiKeyResponse\"\x19\x82\xd3\xe4\x93\x02\x13:\x01*\"\x0e/api/v1/apikey\x12w\n" + "\fExpireApiKey\x12!.headscale.v1.ExpireApiKeyRequest\x1a\".headscale.v1.ExpireApiKeyResponse\" \x82\xd3\xe4\x93\x02\x1a:\x01*\"\x15/api/v1/apikey/expire\x12j\n" + "\vListApiKeys\x12 .headscale.v1.ListApiKeysRequest\x1a!.headscale.v1.ListApiKeysResponse\"\x16\x82\xd3\xe4\x93\x02\x10\x12\x0e/api/v1/apikey\x12v\n" + @@ -177,36 +179,40 @@ var file_headscale_v1_headscale_proto_goTypes = []any{ (*RenameNodeRequest)(nil), // 17: headscale.v1.RenameNodeRequest (*ListNodesRequest)(nil), // 18: headscale.v1.ListNodesRequest (*BackfillNodeIPsRequest)(nil), // 19: headscale.v1.BackfillNodeIPsRequest - (*CreateApiKeyRequest)(nil), // 20: headscale.v1.CreateApiKeyRequest - (*ExpireApiKeyRequest)(nil), // 21: headscale.v1.ExpireApiKeyRequest - (*ListApiKeysRequest)(nil), // 22: headscale.v1.ListApiKeysRequest - (*DeleteApiKeyRequest)(nil), // 23: headscale.v1.DeleteApiKeyRequest - (*GetPolicyRequest)(nil), // 24: headscale.v1.GetPolicyRequest - (*SetPolicyRequest)(nil), // 25: headscale.v1.SetPolicyRequest - (*CreateUserResponse)(nil), // 26: headscale.v1.CreateUserResponse - (*RenameUserResponse)(nil), // 27: headscale.v1.RenameUserResponse - (*DeleteUserResponse)(nil), // 28: headscale.v1.DeleteUserResponse - (*ListUsersResponse)(nil), // 29: headscale.v1.ListUsersResponse - (*CreatePreAuthKeyResponse)(nil), // 30: headscale.v1.CreatePreAuthKeyResponse - (*ExpirePreAuthKeyResponse)(nil), // 31: headscale.v1.ExpirePreAuthKeyResponse - (*DeletePreAuthKeyResponse)(nil), // 32: headscale.v1.DeletePreAuthKeyResponse - (*ListPreAuthKeysResponse)(nil), // 33: headscale.v1.ListPreAuthKeysResponse - (*DebugCreateNodeResponse)(nil), // 34: headscale.v1.DebugCreateNodeResponse - (*GetNodeResponse)(nil), // 35: headscale.v1.GetNodeResponse - (*SetTagsResponse)(nil), // 36: headscale.v1.SetTagsResponse - (*SetApprovedRoutesResponse)(nil), // 37: headscale.v1.SetApprovedRoutesResponse - (*RegisterNodeResponse)(nil), // 38: headscale.v1.RegisterNodeResponse - (*DeleteNodeResponse)(nil), // 39: headscale.v1.DeleteNodeResponse - (*ExpireNodeResponse)(nil), // 40: headscale.v1.ExpireNodeResponse - (*RenameNodeResponse)(nil), // 41: headscale.v1.RenameNodeResponse - (*ListNodesResponse)(nil), // 42: headscale.v1.ListNodesResponse - (*BackfillNodeIPsResponse)(nil), // 43: headscale.v1.BackfillNodeIPsResponse - (*CreateApiKeyResponse)(nil), // 44: headscale.v1.CreateApiKeyResponse - (*ExpireApiKeyResponse)(nil), // 45: headscale.v1.ExpireApiKeyResponse - (*ListApiKeysResponse)(nil), // 46: headscale.v1.ListApiKeysResponse - (*DeleteApiKeyResponse)(nil), // 47: headscale.v1.DeleteApiKeyResponse - (*GetPolicyResponse)(nil), // 48: headscale.v1.GetPolicyResponse - (*SetPolicyResponse)(nil), // 49: headscale.v1.SetPolicyResponse + (*AuthRegisterRequest)(nil), // 20: headscale.v1.AuthRegisterRequest + (*AuthApproveRequest)(nil), // 21: headscale.v1.AuthApproveRequest + (*CreateApiKeyRequest)(nil), // 22: headscale.v1.CreateApiKeyRequest + (*ExpireApiKeyRequest)(nil), // 23: headscale.v1.ExpireApiKeyRequest + (*ListApiKeysRequest)(nil), // 24: headscale.v1.ListApiKeysRequest + (*DeleteApiKeyRequest)(nil), // 25: headscale.v1.DeleteApiKeyRequest + (*GetPolicyRequest)(nil), // 26: headscale.v1.GetPolicyRequest + (*SetPolicyRequest)(nil), // 27: headscale.v1.SetPolicyRequest + (*CreateUserResponse)(nil), // 28: headscale.v1.CreateUserResponse + (*RenameUserResponse)(nil), // 29: headscale.v1.RenameUserResponse + (*DeleteUserResponse)(nil), // 30: headscale.v1.DeleteUserResponse + (*ListUsersResponse)(nil), // 31: headscale.v1.ListUsersResponse + (*CreatePreAuthKeyResponse)(nil), // 32: headscale.v1.CreatePreAuthKeyResponse + (*ExpirePreAuthKeyResponse)(nil), // 33: headscale.v1.ExpirePreAuthKeyResponse + (*DeletePreAuthKeyResponse)(nil), // 34: headscale.v1.DeletePreAuthKeyResponse + (*ListPreAuthKeysResponse)(nil), // 35: headscale.v1.ListPreAuthKeysResponse + (*DebugCreateNodeResponse)(nil), // 36: headscale.v1.DebugCreateNodeResponse + (*GetNodeResponse)(nil), // 37: headscale.v1.GetNodeResponse + (*SetTagsResponse)(nil), // 38: headscale.v1.SetTagsResponse + (*SetApprovedRoutesResponse)(nil), // 39: headscale.v1.SetApprovedRoutesResponse + (*RegisterNodeResponse)(nil), // 40: headscale.v1.RegisterNodeResponse + (*DeleteNodeResponse)(nil), // 41: headscale.v1.DeleteNodeResponse + (*ExpireNodeResponse)(nil), // 42: headscale.v1.ExpireNodeResponse + (*RenameNodeResponse)(nil), // 43: headscale.v1.RenameNodeResponse + (*ListNodesResponse)(nil), // 44: headscale.v1.ListNodesResponse + (*BackfillNodeIPsResponse)(nil), // 45: headscale.v1.BackfillNodeIPsResponse + (*AuthRegisterResponse)(nil), // 46: headscale.v1.AuthRegisterResponse + (*AuthApproveResponse)(nil), // 47: headscale.v1.AuthApproveResponse + (*CreateApiKeyResponse)(nil), // 48: headscale.v1.CreateApiKeyResponse + (*ExpireApiKeyResponse)(nil), // 49: headscale.v1.ExpireApiKeyResponse + (*ListApiKeysResponse)(nil), // 50: headscale.v1.ListApiKeysResponse + (*DeleteApiKeyResponse)(nil), // 51: headscale.v1.DeleteApiKeyResponse + (*GetPolicyResponse)(nil), // 52: headscale.v1.GetPolicyResponse + (*SetPolicyResponse)(nil), // 53: headscale.v1.SetPolicyResponse } var file_headscale_v1_headscale_proto_depIdxs = []int32{ 2, // 0: headscale.v1.HeadscaleService.CreateUser:input_type -> headscale.v1.CreateUserRequest @@ -227,40 +233,44 @@ var file_headscale_v1_headscale_proto_depIdxs = []int32{ 17, // 15: headscale.v1.HeadscaleService.RenameNode:input_type -> headscale.v1.RenameNodeRequest 18, // 16: headscale.v1.HeadscaleService.ListNodes:input_type -> headscale.v1.ListNodesRequest 19, // 17: headscale.v1.HeadscaleService.BackfillNodeIPs:input_type -> headscale.v1.BackfillNodeIPsRequest - 20, // 18: headscale.v1.HeadscaleService.CreateApiKey:input_type -> headscale.v1.CreateApiKeyRequest - 21, // 19: headscale.v1.HeadscaleService.ExpireApiKey:input_type -> headscale.v1.ExpireApiKeyRequest - 22, // 20: headscale.v1.HeadscaleService.ListApiKeys:input_type -> headscale.v1.ListApiKeysRequest - 23, // 21: headscale.v1.HeadscaleService.DeleteApiKey:input_type -> headscale.v1.DeleteApiKeyRequest - 24, // 22: headscale.v1.HeadscaleService.GetPolicy:input_type -> headscale.v1.GetPolicyRequest - 25, // 23: headscale.v1.HeadscaleService.SetPolicy:input_type -> headscale.v1.SetPolicyRequest - 0, // 24: headscale.v1.HeadscaleService.Health:input_type -> headscale.v1.HealthRequest - 26, // 25: headscale.v1.HeadscaleService.CreateUser:output_type -> headscale.v1.CreateUserResponse - 27, // 26: headscale.v1.HeadscaleService.RenameUser:output_type -> headscale.v1.RenameUserResponse - 28, // 27: headscale.v1.HeadscaleService.DeleteUser:output_type -> headscale.v1.DeleteUserResponse - 29, // 28: headscale.v1.HeadscaleService.ListUsers:output_type -> headscale.v1.ListUsersResponse - 30, // 29: headscale.v1.HeadscaleService.CreatePreAuthKey:output_type -> headscale.v1.CreatePreAuthKeyResponse - 31, // 30: headscale.v1.HeadscaleService.ExpirePreAuthKey:output_type -> headscale.v1.ExpirePreAuthKeyResponse - 32, // 31: headscale.v1.HeadscaleService.DeletePreAuthKey:output_type -> headscale.v1.DeletePreAuthKeyResponse - 33, // 32: headscale.v1.HeadscaleService.ListPreAuthKeys:output_type -> headscale.v1.ListPreAuthKeysResponse - 34, // 33: headscale.v1.HeadscaleService.DebugCreateNode:output_type -> headscale.v1.DebugCreateNodeResponse - 35, // 34: headscale.v1.HeadscaleService.GetNode:output_type -> headscale.v1.GetNodeResponse - 36, // 35: headscale.v1.HeadscaleService.SetTags:output_type -> headscale.v1.SetTagsResponse - 37, // 36: headscale.v1.HeadscaleService.SetApprovedRoutes:output_type -> headscale.v1.SetApprovedRoutesResponse - 38, // 37: headscale.v1.HeadscaleService.RegisterNode:output_type -> headscale.v1.RegisterNodeResponse - 39, // 38: headscale.v1.HeadscaleService.DeleteNode:output_type -> headscale.v1.DeleteNodeResponse - 40, // 39: headscale.v1.HeadscaleService.ExpireNode:output_type -> headscale.v1.ExpireNodeResponse - 41, // 40: headscale.v1.HeadscaleService.RenameNode:output_type -> headscale.v1.RenameNodeResponse - 42, // 41: headscale.v1.HeadscaleService.ListNodes:output_type -> headscale.v1.ListNodesResponse - 43, // 42: headscale.v1.HeadscaleService.BackfillNodeIPs:output_type -> headscale.v1.BackfillNodeIPsResponse - 44, // 43: headscale.v1.HeadscaleService.CreateApiKey:output_type -> headscale.v1.CreateApiKeyResponse - 45, // 44: headscale.v1.HeadscaleService.ExpireApiKey:output_type -> headscale.v1.ExpireApiKeyResponse - 46, // 45: headscale.v1.HeadscaleService.ListApiKeys:output_type -> headscale.v1.ListApiKeysResponse - 47, // 46: headscale.v1.HeadscaleService.DeleteApiKey:output_type -> headscale.v1.DeleteApiKeyResponse - 48, // 47: headscale.v1.HeadscaleService.GetPolicy:output_type -> headscale.v1.GetPolicyResponse - 49, // 48: headscale.v1.HeadscaleService.SetPolicy:output_type -> headscale.v1.SetPolicyResponse - 1, // 49: headscale.v1.HeadscaleService.Health:output_type -> headscale.v1.HealthResponse - 25, // [25:50] is the sub-list for method output_type - 0, // [0:25] is the sub-list for method input_type + 20, // 18: headscale.v1.HeadscaleService.AuthRegister:input_type -> headscale.v1.AuthRegisterRequest + 21, // 19: headscale.v1.HeadscaleService.AuthApprove:input_type -> headscale.v1.AuthApproveRequest + 22, // 20: headscale.v1.HeadscaleService.CreateApiKey:input_type -> headscale.v1.CreateApiKeyRequest + 23, // 21: headscale.v1.HeadscaleService.ExpireApiKey:input_type -> headscale.v1.ExpireApiKeyRequest + 24, // 22: headscale.v1.HeadscaleService.ListApiKeys:input_type -> headscale.v1.ListApiKeysRequest + 25, // 23: headscale.v1.HeadscaleService.DeleteApiKey:input_type -> headscale.v1.DeleteApiKeyRequest + 26, // 24: headscale.v1.HeadscaleService.GetPolicy:input_type -> headscale.v1.GetPolicyRequest + 27, // 25: headscale.v1.HeadscaleService.SetPolicy:input_type -> headscale.v1.SetPolicyRequest + 0, // 26: headscale.v1.HeadscaleService.Health:input_type -> headscale.v1.HealthRequest + 28, // 27: headscale.v1.HeadscaleService.CreateUser:output_type -> headscale.v1.CreateUserResponse + 29, // 28: headscale.v1.HeadscaleService.RenameUser:output_type -> headscale.v1.RenameUserResponse + 30, // 29: headscale.v1.HeadscaleService.DeleteUser:output_type -> headscale.v1.DeleteUserResponse + 31, // 30: headscale.v1.HeadscaleService.ListUsers:output_type -> headscale.v1.ListUsersResponse + 32, // 31: headscale.v1.HeadscaleService.CreatePreAuthKey:output_type -> headscale.v1.CreatePreAuthKeyResponse + 33, // 32: headscale.v1.HeadscaleService.ExpirePreAuthKey:output_type -> headscale.v1.ExpirePreAuthKeyResponse + 34, // 33: headscale.v1.HeadscaleService.DeletePreAuthKey:output_type -> headscale.v1.DeletePreAuthKeyResponse + 35, // 34: headscale.v1.HeadscaleService.ListPreAuthKeys:output_type -> headscale.v1.ListPreAuthKeysResponse + 36, // 35: headscale.v1.HeadscaleService.DebugCreateNode:output_type -> headscale.v1.DebugCreateNodeResponse + 37, // 36: headscale.v1.HeadscaleService.GetNode:output_type -> headscale.v1.GetNodeResponse + 38, // 37: headscale.v1.HeadscaleService.SetTags:output_type -> headscale.v1.SetTagsResponse + 39, // 38: headscale.v1.HeadscaleService.SetApprovedRoutes:output_type -> headscale.v1.SetApprovedRoutesResponse + 40, // 39: headscale.v1.HeadscaleService.RegisterNode:output_type -> headscale.v1.RegisterNodeResponse + 41, // 40: headscale.v1.HeadscaleService.DeleteNode:output_type -> headscale.v1.DeleteNodeResponse + 42, // 41: headscale.v1.HeadscaleService.ExpireNode:output_type -> headscale.v1.ExpireNodeResponse + 43, // 42: headscale.v1.HeadscaleService.RenameNode:output_type -> headscale.v1.RenameNodeResponse + 44, // 43: headscale.v1.HeadscaleService.ListNodes:output_type -> headscale.v1.ListNodesResponse + 45, // 44: headscale.v1.HeadscaleService.BackfillNodeIPs:output_type -> headscale.v1.BackfillNodeIPsResponse + 46, // 45: headscale.v1.HeadscaleService.AuthRegister:output_type -> headscale.v1.AuthRegisterResponse + 47, // 46: headscale.v1.HeadscaleService.AuthApprove:output_type -> headscale.v1.AuthApproveResponse + 48, // 47: headscale.v1.HeadscaleService.CreateApiKey:output_type -> headscale.v1.CreateApiKeyResponse + 49, // 48: headscale.v1.HeadscaleService.ExpireApiKey:output_type -> headscale.v1.ExpireApiKeyResponse + 50, // 49: headscale.v1.HeadscaleService.ListApiKeys:output_type -> headscale.v1.ListApiKeysResponse + 51, // 50: headscale.v1.HeadscaleService.DeleteApiKey:output_type -> headscale.v1.DeleteApiKeyResponse + 52, // 51: headscale.v1.HeadscaleService.GetPolicy:output_type -> headscale.v1.GetPolicyResponse + 53, // 52: headscale.v1.HeadscaleService.SetPolicy:output_type -> headscale.v1.SetPolicyResponse + 1, // 53: headscale.v1.HeadscaleService.Health:output_type -> headscale.v1.HealthResponse + 27, // [27:54] is the sub-list for method output_type + 0, // [0:27] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name @@ -275,6 +285,7 @@ func file_headscale_v1_headscale_proto_init() { file_headscale_v1_preauthkey_proto_init() file_headscale_v1_node_proto_init() file_headscale_v1_apikey_proto_init() + file_headscale_v1_auth_proto_init() file_headscale_v1_policy_proto_init() type x struct{} out := protoimpl.TypeBuilder{ diff --git a/gen/go/headscale/v1/headscale.pb.gw.go b/gen/go/headscale/v1/headscale.pb.gw.go index ab851614..1f769ed9 100644 --- a/gen/go/headscale/v1/headscale.pb.gw.go +++ b/gen/go/headscale/v1/headscale.pb.gw.go @@ -709,6 +709,60 @@ func local_request_HeadscaleService_BackfillNodeIPs_0(ctx context.Context, marsh return msg, metadata, err } +func request_HeadscaleService_AuthRegister_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq AuthRegisterRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + msg, err := client.AuthRegister(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_HeadscaleService_AuthRegister_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq AuthRegisterRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + msg, err := server.AuthRegister(ctx, &protoReq) + return msg, metadata, err +} + +func request_HeadscaleService_AuthApprove_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq AuthApproveRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + msg, err := client.AuthApprove(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_HeadscaleService_AuthApprove_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq AuthApproveRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + msg, err := server.AuthApprove(ctx, &protoReq) + return msg, metadata, err +} + func request_HeadscaleService_CreateApiKey_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { var ( protoReq CreateApiKeyRequest @@ -1272,6 +1326,46 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser } forward_HeadscaleService_BackfillNodeIPs_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) + mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthRegister_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthRegister", runtime.WithHTTPPathPattern("/api/v1/auth/register")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_HeadscaleService_AuthRegister_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_AuthRegister_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthApprove_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthApprove", runtime.WithHTTPPathPattern("/api/v1/auth/approve")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_HeadscaleService_AuthApprove_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_AuthApprove_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) mux.Handle(http.MethodPost, pattern_HeadscaleService_CreateApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -1758,6 +1852,40 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser } forward_HeadscaleService_BackfillNodeIPs_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) + mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthRegister_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthRegister", runtime.WithHTTPPathPattern("/api/v1/auth/register")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_HeadscaleService_AuthRegister_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_AuthRegister_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthApprove_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthApprove", runtime.WithHTTPPathPattern("/api/v1/auth/approve")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_HeadscaleService_AuthApprove_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_AuthApprove_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) mux.Handle(http.MethodPost, pattern_HeadscaleService_CreateApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -1899,6 +2027,8 @@ var ( pattern_HeadscaleService_RenameNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4, 1, 0, 4, 1, 5, 5}, []string{"api", "v1", "node", "node_id", "rename", "new_name"}, "")) pattern_HeadscaleService_ListNodes_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "node"}, "")) pattern_HeadscaleService_BackfillNodeIPs_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "node", "backfillips"}, "")) + pattern_HeadscaleService_AuthRegister_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "register"}, "")) + pattern_HeadscaleService_AuthApprove_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "approve"}, "")) pattern_HeadscaleService_CreateApiKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "apikey"}, "")) pattern_HeadscaleService_ExpireApiKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "apikey", "expire"}, "")) pattern_HeadscaleService_ListApiKeys_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "apikey"}, "")) @@ -1927,6 +2057,8 @@ var ( forward_HeadscaleService_RenameNode_0 = runtime.ForwardResponseMessage forward_HeadscaleService_ListNodes_0 = runtime.ForwardResponseMessage forward_HeadscaleService_BackfillNodeIPs_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_AuthRegister_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_AuthApprove_0 = runtime.ForwardResponseMessage forward_HeadscaleService_CreateApiKey_0 = runtime.ForwardResponseMessage forward_HeadscaleService_ExpireApiKey_0 = runtime.ForwardResponseMessage forward_HeadscaleService_ListApiKeys_0 = runtime.ForwardResponseMessage diff --git a/gen/go/headscale/v1/headscale_grpc.pb.go b/gen/go/headscale/v1/headscale_grpc.pb.go index a3963935..e763d9af 100644 --- a/gen/go/headscale/v1/headscale_grpc.pb.go +++ b/gen/go/headscale/v1/headscale_grpc.pb.go @@ -37,6 +37,8 @@ const ( HeadscaleService_RenameNode_FullMethodName = "/headscale.v1.HeadscaleService/RenameNode" HeadscaleService_ListNodes_FullMethodName = "/headscale.v1.HeadscaleService/ListNodes" HeadscaleService_BackfillNodeIPs_FullMethodName = "/headscale.v1.HeadscaleService/BackfillNodeIPs" + HeadscaleService_AuthRegister_FullMethodName = "/headscale.v1.HeadscaleService/AuthRegister" + HeadscaleService_AuthApprove_FullMethodName = "/headscale.v1.HeadscaleService/AuthApprove" HeadscaleService_CreateApiKey_FullMethodName = "/headscale.v1.HeadscaleService/CreateApiKey" HeadscaleService_ExpireApiKey_FullMethodName = "/headscale.v1.HeadscaleService/ExpireApiKey" HeadscaleService_ListApiKeys_FullMethodName = "/headscale.v1.HeadscaleService/ListApiKeys" @@ -71,6 +73,9 @@ type HeadscaleServiceClient interface { RenameNode(ctx context.Context, in *RenameNodeRequest, opts ...grpc.CallOption) (*RenameNodeResponse, error) ListNodes(ctx context.Context, in *ListNodesRequest, opts ...grpc.CallOption) (*ListNodesResponse, error) BackfillNodeIPs(ctx context.Context, in *BackfillNodeIPsRequest, opts ...grpc.CallOption) (*BackfillNodeIPsResponse, error) + // --- Auth start --- + AuthRegister(ctx context.Context, in *AuthRegisterRequest, opts ...grpc.CallOption) (*AuthRegisterResponse, error) + AuthApprove(ctx context.Context, in *AuthApproveRequest, opts ...grpc.CallOption) (*AuthApproveResponse, error) // --- ApiKeys start --- CreateApiKey(ctx context.Context, in *CreateApiKeyRequest, opts ...grpc.CallOption) (*CreateApiKeyResponse, error) ExpireApiKey(ctx context.Context, in *ExpireApiKeyRequest, opts ...grpc.CallOption) (*ExpireApiKeyResponse, error) @@ -271,6 +276,26 @@ func (c *headscaleServiceClient) BackfillNodeIPs(ctx context.Context, in *Backfi return out, nil } +func (c *headscaleServiceClient) AuthRegister(ctx context.Context, in *AuthRegisterRequest, opts ...grpc.CallOption) (*AuthRegisterResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(AuthRegisterResponse) + err := c.cc.Invoke(ctx, HeadscaleService_AuthRegister_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *headscaleServiceClient) AuthApprove(ctx context.Context, in *AuthApproveRequest, opts ...grpc.CallOption) (*AuthApproveResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(AuthApproveResponse) + err := c.cc.Invoke(ctx, HeadscaleService_AuthApprove_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *headscaleServiceClient) CreateApiKey(ctx context.Context, in *CreateApiKeyRequest, opts ...grpc.CallOption) (*CreateApiKeyResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(CreateApiKeyResponse) @@ -366,6 +391,9 @@ type HeadscaleServiceServer interface { RenameNode(context.Context, *RenameNodeRequest) (*RenameNodeResponse, error) ListNodes(context.Context, *ListNodesRequest) (*ListNodesResponse, error) BackfillNodeIPs(context.Context, *BackfillNodeIPsRequest) (*BackfillNodeIPsResponse, error) + // --- Auth start --- + AuthRegister(context.Context, *AuthRegisterRequest) (*AuthRegisterResponse, error) + AuthApprove(context.Context, *AuthApproveRequest) (*AuthApproveResponse, error) // --- ApiKeys start --- CreateApiKey(context.Context, *CreateApiKeyRequest) (*CreateApiKeyResponse, error) ExpireApiKey(context.Context, *ExpireApiKeyRequest) (*ExpireApiKeyResponse, error) @@ -440,6 +468,12 @@ func (UnimplementedHeadscaleServiceServer) ListNodes(context.Context, *ListNodes func (UnimplementedHeadscaleServiceServer) BackfillNodeIPs(context.Context, *BackfillNodeIPsRequest) (*BackfillNodeIPsResponse, error) { return nil, status.Error(codes.Unimplemented, "method BackfillNodeIPs not implemented") } +func (UnimplementedHeadscaleServiceServer) AuthRegister(context.Context, *AuthRegisterRequest) (*AuthRegisterResponse, error) { + return nil, status.Error(codes.Unimplemented, "method AuthRegister not implemented") +} +func (UnimplementedHeadscaleServiceServer) AuthApprove(context.Context, *AuthApproveRequest) (*AuthApproveResponse, error) { + return nil, status.Error(codes.Unimplemented, "method AuthApprove not implemented") +} func (UnimplementedHeadscaleServiceServer) CreateApiKey(context.Context, *CreateApiKeyRequest) (*CreateApiKeyResponse, error) { return nil, status.Error(codes.Unimplemented, "method CreateApiKey not implemented") } @@ -806,6 +840,42 @@ func _HeadscaleService_BackfillNodeIPs_Handler(srv interface{}, ctx context.Cont return interceptor(ctx, in, info, handler) } +func _HeadscaleService_AuthRegister_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(AuthRegisterRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(HeadscaleServiceServer).AuthRegister(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: HeadscaleService_AuthRegister_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(HeadscaleServiceServer).AuthRegister(ctx, req.(*AuthRegisterRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _HeadscaleService_AuthApprove_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(AuthApproveRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(HeadscaleServiceServer).AuthApprove(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: HeadscaleService_AuthApprove_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(HeadscaleServiceServer).AuthApprove(ctx, req.(*AuthApproveRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _HeadscaleService_CreateApiKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(CreateApiKeyRequest) if err := dec(in); err != nil { @@ -1011,6 +1081,14 @@ var HeadscaleService_ServiceDesc = grpc.ServiceDesc{ MethodName: "BackfillNodeIPs", Handler: _HeadscaleService_BackfillNodeIPs_Handler, }, + { + MethodName: "AuthRegister", + Handler: _HeadscaleService_AuthRegister_Handler, + }, + { + MethodName: "AuthApprove", + Handler: _HeadscaleService_AuthApprove_Handler, + }, { MethodName: "CreateApiKey", Handler: _HeadscaleService_CreateApiKey_Handler, diff --git a/gen/openapiv2/headscale/v1/auth.swagger.json b/gen/openapiv2/headscale/v1/auth.swagger.json new file mode 100644 index 00000000..2e99e1a7 --- /dev/null +++ b/gen/openapiv2/headscale/v1/auth.swagger.json @@ -0,0 +1,44 @@ +{ + "swagger": "2.0", + "info": { + "title": "headscale/v1/auth.proto", + "version": "version not set" + }, + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "paths": {}, + "definitions": { + "protobufAny": { + "type": "object", + "properties": { + "@type": { + "type": "string" + } + }, + "additionalProperties": {} + }, + "rpcStatus": { + "type": "object", + "properties": { + "code": { + "type": "integer", + "format": "int32" + }, + "message": { + "type": "string" + }, + "details": { + "type": "array", + "items": { + "type": "object", + "$ref": "#/definitions/protobufAny" + } + } + } + } + } +} diff --git a/gen/openapiv2/headscale/v1/headscale.swagger.json b/gen/openapiv2/headscale/v1/headscale.swagger.json index 1db1db94..533bd73d 100644 --- a/gen/openapiv2/headscale/v1/headscale.swagger.json +++ b/gen/openapiv2/headscale/v1/headscale.swagger.json @@ -138,6 +138,71 @@ ] } }, + "/api/v1/auth/approve": { + "post": { + "operationId": "HeadscaleService_AuthApprove", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/v1AuthApproveResponse" + } + }, + "default": { + "description": "An unexpected error response.", + "schema": { + "$ref": "#/definitions/rpcStatus" + } + } + }, + "parameters": [ + { + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/v1AuthApproveRequest" + } + } + ], + "tags": [ + "HeadscaleService" + ] + } + }, + "/api/v1/auth/register": { + "post": { + "summary": "--- Auth start ---", + "operationId": "HeadscaleService_AuthRegister", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/v1AuthRegisterResponse" + } + }, + "default": { + "description": "An unexpected error response.", + "schema": { + "$ref": "#/definitions/rpcStatus" + } + } + }, + "parameters": [ + { + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/v1AuthRegisterRequest" + } + } + ], + "tags": [ + "HeadscaleService" + ] + } + }, "/api/v1/debug/node": { "post": { "summary": "--- Node start ---", @@ -888,6 +953,36 @@ } } }, + "v1AuthApproveRequest": { + "type": "object", + "properties": { + "authId": { + "type": "string" + } + } + }, + "v1AuthApproveResponse": { + "type": "object" + }, + "v1AuthRegisterRequest": { + "type": "object", + "properties": { + "user": { + "type": "string" + }, + "authId": { + "type": "string" + } + } + }, + "v1AuthRegisterResponse": { + "type": "object", + "properties": { + "node": { + "$ref": "#/definitions/v1Node" + } + } + }, "v1BackfillNodeIPsResponse": { "type": "object", "properties": { From 353127b6e73b92250b5a7477a86377930f4b6e4c Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 18 Feb 2026 14:58:51 +0100 Subject: [PATCH 11/17] hscontrol: implement AuthRegister and AuthApprove gRPC handlers AuthRegister delegates to the existing RegisterNode logic, mapping auth_id to key. AuthApprove looks up a pending auth session and sends a success verdict, following the same pattern as the OIDC callback. Also fix authIDFromRequest to extract the URL parameter as a plain string before converting to AuthID. The urlParam generic function's type switch only matches raw string, not the named type AuthID, causing all /register and /auth endpoints to return 400. Updates #1850 --- hscontrol/grpcv1.go | 34 ++++++++++++++++++++++++++++++++++ hscontrol/handlers.go | 4 ++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index d7c192a6..4c953454 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -828,4 +828,38 @@ func (api headscaleV1APIServer) Health( return response, healthErr } +func (api headscaleV1APIServer) AuthRegister( + ctx context.Context, + request *v1.AuthRegisterRequest, +) (*v1.AuthRegisterResponse, error) { + resp, err := api.RegisterNode(ctx, &v1.RegisterNodeRequest{ + Key: request.GetAuthId(), + User: request.GetUser(), + }) + if err != nil { + return nil, err + } + + return &v1.AuthRegisterResponse{Node: resp.GetNode()}, nil +} + +func (api headscaleV1APIServer) AuthApprove( + ctx context.Context, + request *v1.AuthApproveRequest, +) (*v1.AuthApproveResponse, error) { + authID, err := types.AuthIDFromString(request.GetAuthId()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid auth_id: %v", err) + } + + authReq, ok := api.h.state.GetAuthCacheEntry(authID) + if !ok { + return nil, status.Errorf(codes.NotFound, "no pending auth session for auth_id %s", authID) + } + + authReq.FinishAuth(types.AuthVerdict{}) + + return &v1.AuthApproveResponse{}, nil +} + func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {} diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 9f544f8d..57469ce0 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -282,7 +282,7 @@ func (a *AuthProviderWeb) AuthHandler( } func authIDFromRequest(req *http.Request) (types.AuthID, error) { - registrationId, err := urlParam[types.AuthID](req, "auth_id") + raw, err := urlParam[string](req, "auth_id") if err != nil { return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err)) } @@ -290,7 +290,7 @@ func authIDFromRequest(req *http.Request) (types.AuthID, error) { // We need to make sure we dont open for XSS style injections, if the parameter that // is passed as a key is not parsable/validated as a NodePublic key, then fail to render // the template and log an error. - err = registrationId.Validate() + registrationId, err := types.AuthIDFromString(raw) if err != nil { return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err)) } From 7ef844bbc189d509ee6a30479a47003003e477d6 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 18 Feb 2026 14:59:09 +0100 Subject: [PATCH 12/17] cli: add headscale auth register/approve commands Add a new 'headscale auth' command group with two subcommands: headscale auth register --auth-id --user headscale auth approve --auth-id These replace the old 'headscale nodes register --key' workflow. The old command is marked deprecated with a pointer to the new one. Updates #1850 --- cmd/headscale/cli/auth.go | 70 ++++++++++++++++++++++++++++++++++++++ cmd/headscale/cli/debug.go | 2 +- cmd/headscale/cli/nodes.go | 5 +-- 3 files changed, 74 insertions(+), 3 deletions(-) create mode 100644 cmd/headscale/cli/auth.go diff --git a/cmd/headscale/cli/auth.go b/cmd/headscale/cli/auth.go new file mode 100644 index 00000000..cc854805 --- /dev/null +++ b/cmd/headscale/cli/auth.go @@ -0,0 +1,70 @@ +package cli + +import ( + "context" + "fmt" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/spf13/cobra" +) + +func init() { + rootCmd.AddCommand(authCmd) + + authRegisterCmd.Flags().StringP("user", "u", "", "User") + authRegisterCmd.Flags().String("auth-id", "", "Auth ID") + mustMarkRequired(authRegisterCmd, "user", "auth-id") + authCmd.AddCommand(authRegisterCmd) + + authApproveCmd.Flags().String("auth-id", "", "Auth ID") + mustMarkRequired(authApproveCmd, "auth-id") + authCmd.AddCommand(authApproveCmd) +} + +var authCmd = &cobra.Command{ + Use: "auth", + Short: "Manage node authentication and approval", +} + +var authRegisterCmd = &cobra.Command{ + Use: "register", + Short: "Register a node to your network", + RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error { + user, _ := cmd.Flags().GetString("user") + authID, _ := cmd.Flags().GetString("auth-id") + + request := &v1.AuthRegisterRequest{ + AuthId: authID, + User: user, + } + + response, err := client.AuthRegister(ctx, request) + if err != nil { + return fmt.Errorf("registering node: %w", err) + } + + return printOutput( + cmd, + response.GetNode(), + fmt.Sprintf("Node %s registered", response.GetNode().GetGivenName())) + }), +} + +var authApproveCmd = &cobra.Command{ + Use: "approve", + Short: "Approve a pending authentication request", + RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error { + authID, _ := cmd.Flags().GetString("auth-id") + + request := &v1.AuthApproveRequest{ + AuthId: authID, + } + + response, err := client.AuthApprove(ctx, request) + if err != nil { + return fmt.Errorf("approving auth request: %w", err) + } + + return printOutput(cmd, response, "Auth request approved") + }), +} diff --git a/cmd/headscale/cli/debug.go b/cmd/headscale/cli/debug.go index fac317fc..9e4a67fd 100644 --- a/cmd/headscale/cli/debug.go +++ b/cmd/headscale/cli/debug.go @@ -37,7 +37,7 @@ var createNodeCmd = &cobra.Command{ name, _ := cmd.Flags().GetString("name") registrationID, _ := cmd.Flags().GetString("key") - _, err := types.RegistrationIDFromString(registrationID) + _, err := types.AuthIDFromString(registrationID) if err != nil { return fmt.Errorf("parsing machine key: %w", err) } diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index dbc7e8bf..fa71034f 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -63,8 +63,9 @@ var nodeCmd = &cobra.Command{ } var registerNodeCmd = &cobra.Command{ - Use: "register", - Short: "Registers a node to your network", + Use: "register", + Short: "Registers a node to your network", + Deprecated: "use 'headscale auth register --auth-id --user ' instead", RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error { user, _ := cmd.Flags().GetString("user") registrationID, _ := cmd.Flags().GetString("key") From ec8b217b9ee81febf046e6d689e44a15602ff82a Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 18 Feb 2026 14:59:23 +0100 Subject: [PATCH 13/17] integration: use headscale auth register in tests Update integration test helpers and CLI test commands to use the new 'headscale auth register --auth-id' instead of the deprecated 'headscale nodes register --key'. Update test comments to reference the new command syntax. Updates #1850 --- hscontrol/auth_test.go | 6 +++--- integration/auth_web_flow_test.go | 2 +- integration/cli_test.go | 24 ++++++++++++------------ integration/scenario.go | 4 ++-- integration/tags_test.go | 2 +- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index 321b55fa..4c70cda4 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -2948,7 +2948,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Scenario: // 1. Node registers with user1 via pre-auth key // 2. Node logs out (expires) -// 3. Admin runs: headscale nodes register --user user2 --key +// 3. Admin runs: headscale auth register --auth-id --user user2 // // Expected behavior: // - User1's original node should STILL EXIST (expired) @@ -3027,7 +3027,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { require.NotEmpty(t, regID, "Should have valid registration ID") // Step 4: Admin completes authentication via CLI - // This simulates: headscale nodes register --user user2 --key + // This simulates: headscale auth register --auth-id --user user2 node, _, err := app.state.HandleNodeFromAuthPath( regID, types.UserID(user2.ID), // Register to user2, not user1! @@ -3942,7 +3942,7 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) { require.NotNil(t, alice, "Alice user should be created") // Step 4: Re-register the node to alice via HandleNodeFromAuthPath - // This is what happens when running: headscale nodes register --user alice --key ... + // This is what happens when running: headscale auth register --auth-id --user alice nodeKey2 := key.NewNode() registrationID := types.MustAuthID() regEntry := types.NewRegisterAuthRequest(types.Node{ diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index eba2ebbf..d00c5fdd 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -312,7 +312,7 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { } // Register all clients as user1 (this is where cross-user registration happens) - // This simulates: headscale nodes register --user user1 --key + // This simulates: headscale auth register --auth-id --user user1 _ = scenario.runHeadscaleRegister("user1", body) } diff --git a/integration/cli_test.go b/integration/cli_test.go index c46361d4..a7696bb4 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -1100,11 +1100,11 @@ func TestNodeCommand(t *testing.T) { headscale, []string{ "headscale", - "nodes", + "auth", + "register", "--user", "node-user", - "register", - "--key", + "--auth-id", regID, "--output", "json", @@ -1185,11 +1185,11 @@ func TestNodeCommand(t *testing.T) { headscale, []string{ "headscale", - "nodes", + "auth", + "register", "--user", "other-user", - "register", - "--key", + "--auth-id", regID, "--output", "json", @@ -1359,11 +1359,11 @@ func TestNodeExpireCommand(t *testing.T) { headscale, []string{ "headscale", - "nodes", + "auth", + "register", "--user", "node-expire-user", - "register", - "--key", + "--auth-id", regID, "--output", "json", @@ -1496,11 +1496,11 @@ func TestNodeRenameCommand(t *testing.T) { headscale, []string{ "headscale", - "nodes", + "auth", + "register", "--user", "node-rename-command", - "register", - "--key", + "--auth-id", regID, "--output", "json", diff --git a/integration/scenario.go b/integration/scenario.go index e769bd73..ba99a392 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -1184,7 +1184,7 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { return errParseAuthPage } - keySep := strings.Split(codeSep[0], "key ") + keySep := strings.Split(codeSep[0], "--auth-id ") if len(keySep) != 2 { return errParseAuthPage } @@ -1195,7 +1195,7 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { if headscale, err := s.Headscale(); err == nil { //nolint:noinlineerr _, err = headscale.Execute( - []string{"headscale", "nodes", "register", "--user", userStr, "--key", key}, + []string{"headscale", "auth", "register", "--user", userStr, "--auth-id", key}, ) if err != nil { log.Printf("registering node: %s", err) diff --git a/integration/tags_test.go b/integration/tags_test.go index b4fe678b..617f688d 100644 --- a/integration/tags_test.go +++ b/integration/tags_test.go @@ -3122,7 +3122,7 @@ func TestTagsAuthKeyWithoutUserRejectsAdvertisedTags(t *testing.T) { // TestTagsAuthKeyConvertToUserViaCLIRegister reproduces the panic from // issue #3038: register a node with a tags-only preauthkey (no user), then -// convert it to a user-owned node via "headscale nodes register --user --key ...". +// convert it to a user-owned node via "headscale auth register --auth-id --user ". // The crash happens in the mapper's generateUserProfiles when node.User is nil // after the tag→user conversion in processReauthTags. // From d4e0e92ed12febcf469743ab74548478825607e8 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 18 Feb 2026 15:10:42 +0100 Subject: [PATCH 14/17] doc: add CHANGELOG entries for SSH check action and auth commands Updates #1850 --- CHANGELOG.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e01d43e..203e7292 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,19 @@ to understand how the packet filter should be generated. We discovered a few dif overall our implementation was very close. [#3036](https://github.com/juanfont/headscale/pull/3036) +### SSH check action + +SSH rules with `"action": "check"` are now supported. When a client initiates an SSH connection to a node +with a `check` action policy, the user is prompted to authenticate via OIDC or CLI approval before access +is granted. + +A new `headscale auth` CLI command group supports the approval flow: + +- `headscale auth approve --auth-id ` approves a pending authentication request (SSH check or web auth) +- `headscale auth register --auth-id --user ` registers a node (replaces deprecated `headscale nodes register`) + +[#1850](https://github.com/juanfont/headscale/pull/1850) + ### BREAKING - **ACL Policy**: Wildcard (`*`) in ACL sources and destinations now resolves to Tailscale's CGNAT range (`100.64.0.0/10`) and ULA range (`fd7a:115c:a1e0::/48`) instead of all IPs (`0.0.0.0/0` and `::/0`) [#3036](https://github.com/juanfont/headscale/pull/3036) @@ -26,6 +39,8 @@ overall our implementation was very close. - **ACL Policy**: The `proto:icmp` protocol name now only includes ICMPv4 (protocol 1), matching Tailscale behavior [#3036](https://github.com/juanfont/headscale/pull/3036) - Previously, `proto:icmp` included both ICMPv4 and ICMPv6 - Use `proto:ipv6-icmp` or protocol number `58` explicitly for ICMPv6 +- **CLI**: `headscale nodes register` is deprecated in favour of `headscale auth register --auth-id --user ` [#1850](https://github.com/juanfont/headscale/pull/1850) + - The old command continues to work but will be removed in a future release ### Changes @@ -35,6 +50,11 @@ overall our implementation was very close. - **ACL Policy**: Merge filter rules with identical SrcIPs and IPProto matching Tailscale behavior - multiple ACL rules with the same source now produce a single FilterRule with combined DstPorts [#3036](https://github.com/juanfont/headscale/pull/3036) - Remove deprecated `--namespace` flag from `nodes list`, `nodes register`, and `debug create-node` commands (use `--user` instead) [#3093](https://github.com/juanfont/headscale/pull/3093) - Remove deprecated `namespace`/`ns` command aliases for `users` and `machine`/`machines` aliases for `nodes` [#3093](https://github.com/juanfont/headscale/pull/3093) +- Add SSH `check` action support with OIDC and CLI-based approval flows [#1850](https://github.com/juanfont/headscale/pull/1850) +- Add `headscale auth register` and `headscale auth approve` CLI commands [#1850](https://github.com/juanfont/headscale/pull/1850) +- Deprecate `headscale nodes register --key` in favour of `headscale auth register --auth-id` [#1850](https://github.com/juanfont/headscale/pull/1850) +- Generalise auth templates into reusable `AuthSuccess` and `AuthWeb` components [#1850](https://github.com/juanfont/headscale/pull/1850) +- Unify auth pipeline with `AuthVerdict` type, supporting registration, reauthentication, and SSH checks [#1850](https://github.com/juanfont/headscale/pull/1850) ## 0.28.0 (2026-02-04) From 9f1cb6fdc5784f55dcbab734d90c2d21cb90bbf1 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 18 Feb 2026 15:47:21 +0100 Subject: [PATCH 15/17] hsic: add ReadLog method for container log inspection Add ReadLog() to HeadscaleInContainer and the ControlServer interface so integration tests can inspect container logs at runtime. This is needed by the SSH check mode tests to extract the auth-id from headscale log output. Updates #1850 --- integration/control.go | 1 + integration/hsic/hsic.go | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/integration/control.go b/integration/control.go index f390d080..d9273ae6 100644 --- a/integration/control.go +++ b/integration/control.go @@ -16,6 +16,7 @@ import ( type ControlServer interface { Shutdown() (string, string, error) SaveLog(path string) (string, string, error) + ReadLog() (string, string, error) SaveProfile(path string) error Execute(command []string) (string, error) WriteFile(path string, content []byte) error diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 3ef4d5d4..cd60c20d 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -699,6 +699,18 @@ func (t *HeadscaleInContainer) WriteLogs(stdout, stderr io.Writer) error { return dockertestutil.WriteLog(t.pool, t.container, stdout, stderr) } +// ReadLog returns the current stdout and stderr logs from the headscale container. +func (t *HeadscaleInContainer) ReadLog() (string, string, error) { + var stdout, stderr bytes.Buffer + + err := dockertestutil.WriteLog(t.pool, t.container, &stdout, &stderr) + if err != nil { + return "", "", fmt.Errorf("reading container logs: %w", err) + } + + return stdout.String(), stderr.String(), nil +} + // SaveLog saves the current stdout log of the container to a path // on the host system. func (t *HeadscaleInContainer) SaveLog(path string) (string, string, error) { From e96f232ed6c2db0d873ab953daade523dc8df212 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 18 Feb 2026 15:47:42 +0100 Subject: [PATCH 16/17] integration: split SSH check mode test into CLI and OIDC variants Replace TestSSHOneUserToOneCheckMode with two variants that implement the full expected check flow: - TestSSHOneUserToOneCheckModeCLI: SSH blocks on check, test extracts auth-id from headscale logs, approves via "headscale auth approve", then verifies SSH completes. - TestSSHOneUserToOneCheckModeOIDC: SSH blocks on check, test extracts auth-id, visits /auth/{id} URL triggering mock OIDC authentication, then verifies SSH completes. Both tests will fail at whichever step of the check flow is not yet implemented, driving the implementation forward incrementally. Add helper functions: doSSHCheck (async SSH with long timeout), findSSHCheckAuthID (log polling), sshCheckPolicy (shared policy). Updates #1850 --- integration/ssh_test.go | 289 +++++++++++++++++++++++++++++++++++----- 1 file changed, 259 insertions(+), 30 deletions(-) diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 15867579..75c42af0 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -3,13 +3,16 @@ package integration import ( "fmt" "log" + "net/url" "strings" "testing" "time" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" + "github.com/oauth2-proxy/mockoidc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" @@ -580,40 +583,121 @@ func TestSSHAutogroupSelf(t *testing.T) { } } -func TestSSHOneUserToOneCheckMode(t *testing.T) { - IntegrationSkip(t) +type sshCheckResult struct { + stdout string + stderr string + err error +} - scenario := sshScenario(t, - &policyv2.Policy{ - Groups: policyv2.Groups{ - policyv2.Group("group:integration-test"): []policyv2.Username{policyv2.Username("user1@")}, +// doSSHCheck runs SSH in a goroutine with a longer timeout, returning a channel +// for the result. The SSH command will block while waiting for auth approval in +// check mode. +func doSSHCheck( + t *testing.T, + client TailscaleClient, + peer TailscaleClient, +) chan sshCheckResult { + t.Helper() + + peerFQDN, _ := peer.FQDN() + + command := []string{ + "/usr/bin/ssh", "-o StrictHostKeyChecking=no", "-o ConnectTimeout=30", + fmt.Sprintf("%s@%s", "ssh-it-user", peerFQDN), + "'hostname'", + } + + log.Printf( + "[SSH check] Running from %s to %s", + client.Hostname(), + peer.Hostname(), + ) + + ch := make(chan sshCheckResult, 1) + + go func() { + stdout, stderr, err := client.Execute( + command, + dockertestutil.ExecuteCommandTimeout(60*time.Second), + ) + ch <- sshCheckResult{stdout, stderr, err} + }() + + return ch +} + +// findSSHCheckAuthID polls headscale container logs for the SSH action auth-id. +// The SSH action handler logs "SSH action follow-up" with the auth_id on the +// follow-up request (where auth_id is non-empty). +func findSSHCheckAuthID(t *testing.T, headscale ControlServer) string { + t.Helper() + + var authID string + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + _, stderr, err := headscale.ReadLog() + assert.NoError(c, err) + + for line := range strings.SplitSeq(stderr, "\n") { + if !strings.Contains(line, "SSH action follow-up") { + continue + } + + if idx := strings.Index(line, "auth_id="); idx != -1 { + start := idx + len("auth_id=") + + end := strings.IndexByte(line[start:], ' ') + if end == -1 { + end = len(line[start:]) + } + + authID = line[start : start+end] + } + } + + assert.NotEmpty(c, authID, "auth-id not found in headscale logs") + }, 10*time.Second, 500*time.Millisecond, "waiting for SSH check auth-id in headscale logs") + + return authID +} + +// sshCheckPolicy returns a policy with SSH "check" mode for group:integration-test +// targeting autogroup:member and autogroup:tagged destinations. +func sshCheckPolicy() *policyv2.Policy { + return &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-test"): []policyv2.Username{ + policyv2.Username("user1@"), }, - ACLs: []policyv2.ACL{ - { - Action: "accept", - Protocol: "tcp", - Sources: []policyv2.Alias{wildcard()}, - Destinations: []policyv2.AliasWithPorts{ - aliasWithPorts(wildcard(), tailcfg.PortRangeAny), - }, - }, - }, - SSHs: []policyv2.SSH{ - { - Action: "check", - Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")}, - // Use autogroup:member and autogroup:tagged instead of wildcard - // since wildcard (*) is no longer supported for SSH destinations - Destinations: policyv2.SSHDstAliases{ - new(policyv2.AutoGroupMember), - new(policyv2.AutoGroupTagged), - }, - Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), }, }, }, - 1, - ) + SSHs: []policyv2.SSH{ + { + Action: "check", + Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")}, + Destinations: policyv2.SSHDstAliases{ + new(policyv2.AutoGroupMember), + new(policyv2.AutoGroupTagged), + }, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, + }, + }, + } +} + +func TestSSHOneUserToOneCheckModeCLI(t *testing.T) { + IntegrationSkip(t) + + scenario := sshScenario(t, sshCheckPolicy(), 1) // defer scenario.ShutdownAssertNoPanics(t) allClients, err := scenario.ListTailscaleClients() @@ -625,22 +709,167 @@ func TestSSHOneUserToOneCheckMode(t *testing.T) { user2Clients, err := scenario.ListTailscaleClients("user2") requireNoErrListClients(t, err) + headscale, err := scenario.Headscale() + require.NoError(t, err) + err = scenario.WaitForTailscaleSync() requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() requireNoErrListFQDN(t, err) + // user1 can SSH (via check) to all peers for _, client := range user1Clients { for _, peer := range allClients { if client.Hostname() == peer.Hostname() { continue } - assertSSHHostname(t, client, peer) + // Start SSH — will block waiting for check auth + sshResult := doSSHCheck(t, client, peer) + + // Find the auth-id from headscale logs + authID := findSSHCheckAuthID(t, headscale) + + // Approve via CLI + _, err := headscale.Execute( + []string{ + "headscale", "auth", "approve", + "--auth-id", authID, + }, + ) + require.NoError(t, err) + + // Wait for SSH to complete + select { + case result := <-sshResult: + require.NoError(t, result.err) + require.Contains( + t, + peer.ContainerID(), + strings.ReplaceAll(result.stdout, "\n", ""), + ) + case <-time.After(30 * time.Second): + t.Fatal("SSH did not complete after auth approval") + } } } + // user2 cannot SSH — not in the check policy group + for _, client := range user2Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + assertSSHPermissionDenied(t, client, peer) + } + } +} + +func TestSSHOneUserToOneCheckModeOIDC(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + OIDCSkipUserCreation: true, + OIDCUsers: []mockoidc.MockUser{ + // First 2: consumed during node registration + oidcMockUser("user1", true), + oidcMockUser("user2", true), + // Extra: consumed during SSH check auth flows. + // Each SSH check pops one user from the queue. + oidcMockUser("user1", true), + }, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + // defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + } + + err = scenario.CreateHeadscaleEnvWithLoginURL( + []tsic.Option{ + tsic.WithSSH(), + tsic.WithNetfilter("off"), + tsic.WithPackages("openssh"), + tsic.WithExtraCommands("adduser ssh-it-user"), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithACLPolicy(sshCheckPolicy()), + hsic.WithTestName("sshcheckoidc"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer( + "/tmp/hs_client_oidc_secret", + []byte(scenario.mockOIDC.ClientSecret()), + ), + ) + require.NoError(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + requireNoErrListClients(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + requireNoErrListClients(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + requireNoErrListFQDN(t, err) + + // user1 can SSH (via check) to all peers + for _, client := range user1Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + // Start SSH — will block waiting for check auth + sshResult := doSSHCheck(t, client, peer) + + // Find the auth-id from headscale logs + authID := findSSHCheckAuthID(t, headscale) + + // Build auth URL and visit it to trigger OIDC flow. + // The mock OIDC server auto-authenticates from the user queue. + authURL := headscale.GetEndpoint() + "/auth/" + authID + parsedURL, err := url.Parse(authURL) + require.NoError(t, err) + + _, err = doLoginURL("ssh-check-oidc", parsedURL) + require.NoError(t, err) + + // Wait for SSH to complete + select { + case result := <-sshResult: + require.NoError(t, result.err) + require.Contains( + t, + peer.ContainerID(), + strings.ReplaceAll(result.stdout, "\n", ""), + ) + case <-time.After(30 * time.Second): + t.Fatal("SSH did not complete after OIDC auth") + } + } + } + + // user2 cannot SSH — not in the check policy group for _, client := range user2Clients { for _, peer := range allClients { if client.Hostname() == peer.Hostname() { From 731c8f948e11529d3b82efd0a285cfeaeeaf311e Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 18 Feb 2026 21:25:09 +0100 Subject: [PATCH 17/17] integration: add negative and check-period SSH check mode tests Add two new integration tests for SSH check mode: - TestSSHCheckModeUnapprovedTimeout: verifies that SSH is rejected when the check auth request is never approved and the registration cache entry expires. Uses short cache TTL (15s) to avoid long waits. - TestSSHCheckModeCheckPeriodCLI: verifies that after approval with a 1-minute checkPeriod, the session expires and the next SSH connection requires re-authentication through a new check flow. Also adds helper functions sshCheckPolicyWithPeriod (policy with CheckPeriod) and findNewSSHCheckAuthID (finds auth-id excluding a known one for re-auth verification). Updates #1850 --- .github/workflows/test-integration.yaml | 2 + integration/ssh_test.go | 258 ++++++++++++++++++++++++ 2 files changed, 260 insertions(+) diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index e9483adf..f836734d 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -255,6 +255,8 @@ jobs: - TestSSHAutogroupSelf - TestSSHOneUserToOneCheckModeCLI - TestSSHOneUserToOneCheckModeOIDC + - TestSSHCheckModeUnapprovedTimeout + - TestSSHCheckModeCheckPeriodCLI - TestTagsAuthKeyWithTagRequestDifferentTag - TestTagsAuthKeyWithTagNoAdvertiseFlag - TestTagsAuthKeyWithTagCannotAddViaCLI diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 75c42af0..5a46f598 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -13,6 +13,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/oauth2-proxy/mockoidc" + "github.com/prometheus/common/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" @@ -694,6 +695,82 @@ func sshCheckPolicy() *policyv2.Policy { } } +// sshCheckPolicyWithPeriod returns a policy with SSH "check" mode and a +// specified checkPeriod for session duration. +func sshCheckPolicyWithPeriod(period time.Duration) *policyv2.Policy { + return &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-test"): []policyv2.Username{ + policyv2.Username("user1@"), + }, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + SSHs: []policyv2.SSH{ + { + Action: "check", + Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")}, + Destinations: policyv2.SSHDstAliases{ + new(policyv2.AutoGroupMember), + new(policyv2.AutoGroupTagged), + }, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, + CheckPeriod: model.Duration(period), + }, + }, + } +} + +// findNewSSHCheckAuthID polls headscale logs for an SSH check auth-id +// that differs from excludeID. Used to verify re-authentication after +// session expiry. +func findNewSSHCheckAuthID( + t *testing.T, + headscale ControlServer, + excludeID string, +) string { + t.Helper() + + var authID string + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + _, stderr, err := headscale.ReadLog() + assert.NoError(c, err) + + for line := range strings.SplitSeq(stderr, "\n") { + if !strings.Contains(line, "SSH action follow-up") { + continue + } + + if idx := strings.Index(line, "auth_id="); idx != -1 { + start := idx + len("auth_id=") + + end := strings.IndexByte(line[start:], ' ') + if end == -1 { + end = len(line[start:]) + } + + id := line[start : start+end] + if id != excludeID { + authID = id + } + } + } + + assert.NotEmpty(c, authID, "new auth-id not found in headscale logs") + }, 10*time.Second, 500*time.Millisecond, "waiting for new SSH check auth-id") + + return authID +} + func TestSSHOneUserToOneCheckModeCLI(t *testing.T) { IntegrationSkip(t) @@ -880,3 +957,184 @@ func TestSSHOneUserToOneCheckModeOIDC(t *testing.T) { } } } + +// TestSSHCheckModeUnapprovedTimeout verifies that SSH in check mode is rejected +// when nobody approves the auth request and the registration cache entry expires. +func TestSSHCheckModeUnapprovedTimeout(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{ + tsic.WithSSH(), + tsic.WithNetfilter("off"), + tsic.WithPackages("openssh"), + tsic.WithExtraCommands("adduser ssh-it-user"), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithACLPolicy(sshCheckPolicy()), + hsic.WithTestName("sshchecktimeout"), + hsic.WithConfigEnv(map[string]string{ + "HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "15s", + "HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "5s", + }), + ) + require.NoError(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + requireNoErrListClients(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + requireNoErrListClients(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + requireNoErrListFQDN(t, err) + + // user1 attempts SSH — enters check flow, but nobody approves + for _, client := range user1Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + sshResult := doSSHCheck(t, client, peer) + + // Confirm the check flow was entered + _ = findSSHCheckAuthID(t, headscale) + + // Do NOT approve — wait for cache expiry and SSH rejection + select { + case result := <-sshResult: + require.Error(t, result.err, "SSH should be rejected when unapproved") + assert.Empty(t, result.stdout, "no command output expected on rejection") + case <-time.After(60 * time.Second): + t.Fatal("SSH did not complete after cache expiry timeout") + } + } + } + + // user2 still gets immediate Permission Denied + for _, client := range user2Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + assertSSHPermissionDenied(t, client, peer) + } + } +} + +// TestSSHCheckModeCheckPeriodCLI verifies that after approval with a short +// checkPeriod, the session expires and the next SSH connection requires +// re-authentication via a new check flow. +func TestSSHCheckModeCheckPeriodCLI(t *testing.T) { + IntegrationSkip(t) + + // 1 minute is the documented minimum checkPeriod + scenario := sshScenario(t, sshCheckPolicyWithPeriod(time.Minute), 1) + defer scenario.ShutdownAssertNoPanics(t) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + requireNoErrListClients(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + requireNoErrListFQDN(t, err) + + // === Phase 1: First SSH check — approve, verify success === + for _, client := range user1Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + sshResult := doSSHCheck(t, client, peer) + firstAuthID := findSSHCheckAuthID(t, headscale) + + _, err := headscale.Execute( + []string{ + "headscale", "auth", "approve", + "--auth-id", firstAuthID, + }, + ) + require.NoError(t, err) + + select { + case result := <-sshResult: + require.NoError(t, result.err, "first SSH should succeed after approval") + require.Contains( + t, + peer.ContainerID(), + strings.ReplaceAll(result.stdout, "\n", ""), + ) + case <-time.After(30 * time.Second): + t.Fatal("first SSH did not complete after auth approval") + } + + // === Phase 2: Wait for checkPeriod to expire === + //nolint:forbidigo // Intentional sleep: waiting for the check period session + // to expire. This is a time-based expiry, not a pollable condition — the + // Tailscale client caches the approval for SessionDuration and only + // re-triggers the check flow after it elapses. + time.Sleep(70 * time.Second) + + // === Phase 3: Second SSH — must re-authenticate === + sshResult2 := doSSHCheck(t, client, peer) + secondAuthID := findNewSSHCheckAuthID(t, headscale, firstAuthID) + + require.NotEqual( + t, + firstAuthID, + secondAuthID, + "second SSH should trigger a new auth flow after checkPeriod expiry", + ) + + _, err = headscale.Execute( + []string{ + "headscale", "auth", "approve", + "--auth-id", secondAuthID, + }, + ) + require.NoError(t, err) + + select { + case result := <-sshResult2: + require.NoError(t, result.err, "second SSH should succeed after re-approval") + require.Contains( + t, + peer.ContainerID(), + strings.ReplaceAll(result.stdout, "\n", ""), + ) + case <-time.After(30 * time.Second): + t.Fatal("second SSH did not complete after re-auth approval") + } + } + } +}