From d8c34ba7f0bb4e5b53c424faf468eef8e7fba930 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 10 Feb 2026 16:54:56 +0100 Subject: [PATCH] noise, policy: implement SSH check action handler --- .github/workflows/test-integration.yaml | 3 +- hscontrol/noise.go | 132 +++++++++++++++++++++--- hscontrol/policy/v2/filter.go | 2 +- hscontrol/types/common.go | 7 ++ integration/scenario.go | 21 +++- 5 files changed, 144 insertions(+), 21 deletions(-) diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 7e059045..e9483adf 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -253,7 +253,8 @@ jobs: - TestSSHIsBlockedInACL - TestSSHUserOnlyIsolation - TestSSHAutogroupSelf - - TestSSHOneUserToOneCheckMode + - TestSSHOneUserToOneCheckModeCLI + - TestSSHOneUserToOneCheckModeOIDC - TestTagsAuthKeyWithTagRequestDifferentTag - TestTagsAuthKeyWithTagNoAdvertiseFlag - TestTagsAuthKeyWithTagCannotAddViaCLI diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 9a8814ca..c232d5d2 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "net/url" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" @@ -24,6 +25,12 @@ import ( // ErrUnsupportedClientVersion is returned when a client connects with an unsupported protocol version. var ErrUnsupportedClientVersion = errors.New("unsupported client version") +// ErrMissingURLParameter is returned when a required URL parameter is not provided. +var ErrMissingURLParameter = errors.New("missing URL parameter") + +// ErrUnsupportedURLParameterType is returned when a URL parameter has an unsupported type. +var ErrUnsupportedURLParameterType = errors.New("unsupported URL parameter type") + const ( // ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade. ts2021UpgradePath = "/ts2021" @@ -116,7 +123,8 @@ func (h *Headscale) NoiseUpgradeHandler( r.Post("/register", ns.RegistrationHandler) r.Post("/map", ns.PollNetMapHandler) - r.Get("/ssh/action/from/{src_node_id}/to/{dst_node_id}/ssh_user/{ssh_user}/local_user/{local_user}", ns.SSHAction) + // SSH Check mode endpoint, consulted to validate if a given SSH connection should be accepted or rejected. + r.Get("/ssh/action/from/{src_node_id}/to/{dst_node_id}", ns.SSHActionHandler) // Not implemented yet // @@ -243,33 +251,125 @@ func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *ht http.Error(writer, "Not implemented yet", http.StatusNotImplemented) } -// SSHAction handles the /ssh-action endpoint, it returns a [tailcfg.SSHAction] +func urlParam[T any](req *http.Request, key string) (T, error) { + var zero T + + param := chi.URLParam(req, key) + if param == "" { + return zero, fmt.Errorf("%w: %s", ErrMissingURLParameter, key) + } + + var value T + switch any(value).(type) { + case string: + v, ok := any(param).(T) + if !ok { + return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value) + } + + value = v + default: + return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value) + } + + return value, nil +} + +// SSHActionHandler handles the /ssh-action endpoint, it returns a [tailcfg.SSHActionHandler] // to the client with the verdict of an SSH access request. -func (ns *noiseServer) SSHAction(writer http.ResponseWriter, req *http.Request) { - srcNodeID := chi.URLParam(req, "src_node_id") - dstNodeID := chi.URLParam(req, "dst_node_id") - sshUser := chi.URLParam(req, "ssh_user") - localUser := chi.URLParam(req, "local_user") +func (ns *noiseServer) SSHActionHandler(writer http.ResponseWriter, req *http.Request) { + srcNodeID, _ := urlParam[types.NodeID](req, "src_node_id") + dstNodeID, _ := urlParam[types.NodeID](req, "dst_node_id") + + sshUser := req.URL.Query().Get("ssh_user") + localUser := req.URL.Query().Get("local_user") + + // Set if this is a follow up request. + authIDStr := req.URL.Query().Get("auth_id") + log.Trace().Caller(). Str("path", req.URL.String()). - Str("src_node_id", srcNodeID). - Str("dst_node_id", dstNodeID). + Uint64("src_node_id", srcNodeID.Uint64()). + Uint64("dst_node_id", dstNodeID.Uint64()). Str("ssh_user", sshUser). Str("local_user", localUser). + Str("auth_id", authIDStr). Msg("got SSH action request") - accept := tailcfg.SSHAction{ - Reject: false, - Accept: true, - AllowAgentForwarding: true, - AllowLocalPortForwarding: true, - AllowRemotePortForwarding: true, + var action tailcfg.SSHAction + + action.AllowAgentForwarding = true + action.AllowLocalPortForwarding = true + action.AllowRemotePortForwarding = true + + if authIDStr == "" { + holdURL, err := url.Parse(ns.headscale.cfg.ServerURL + "/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER") + if err != nil { + log.Error().Caller().Err(err).Msg("failed to parse SSH action URL") + http.Error(writer, "Internal error", http.StatusInternalServerError) + + return + } + + authID, err := types.NewAuthID() + if err != nil { + log.Error().Caller().Err(err).Msg("failed to generate auth ID for SSH action") + http.Error(writer, "Internal error", http.StatusInternalServerError) + + return + } + + ns.headscale.state.SetAuthCacheEntry(authID, types.NewAuthRequest()) + + authURL := ns.headscale.authProvider.AuthURL(authID) + + q := holdURL.Query() + q.Set("auth_id", authID.String()) + holdURL.RawQuery = q.Encode() + + action.HoldAndDelegate = holdURL.String() + // TODO(kradalby): here we can also send a very tiny mapresponse + // "popping" the url and opening it for the user. + action.Message = fmt.Sprintf(`# Headscale SSH requires an additional check. +# To authenticate, visit: %s +# Authentication checked with Headscale SSH. +`, authURL) + } else { + authID, err := types.AuthIDFromString(authIDStr) + if err != nil { + log.Error().Caller().Err(err).Str("auth_id", authIDStr).Msg("invalid auth_id in SSH action request") + http.Error(writer, "Invalid auth_id", http.StatusBadRequest) + + return + } + + log.Trace().Caller().Str("auth_id", authID.String()).Msg("SSH action follow-up request with auth_id") + + auth, ok := ns.headscale.state.GetAuthCacheEntry(authID) + if !ok { + log.Error().Caller().Str("auth_id", authID.String()).Msg("no auth session found for auth_id in SSH action request") + http.Error(writer, "Invalid auth_id", http.StatusBadRequest) + + return + } + + verdict := <-auth.WaitForAuth() + + if verdict.Accept() { + action.Reject = false + action.Accept = true + } else { + action.Reject = true + action.Accept = false + + log.Trace().Caller().Str("auth_id", authID.String()).Err(verdict.Err).Msg("SSH action authentication rejected") + } } writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(http.StatusOK) - err := json.NewEncoder(writer).Encode(accept) + err := json.NewEncoder(writer).Encode(action) if err != nil { log.Error().Caller().Err(err).Msg("failed to encode SSH action response") return diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 526a0cb1..f15093aa 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -339,7 +339,7 @@ func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction { // * $DST_NODE_ID (Node.ID as int64 string) // * $SSH_USER (URL escaped, ssh user requested) // * $LOCAL_USER (URL escaped, local user mapped) - HoldAndDelegate: fmt.Sprintf("%s/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID/ssh_user/$SSH_USER/local_user/$LOCAL_USER", baseURL), + HoldAndDelegate: baseURL + "/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER", AllowAgentForwarding: true, AllowLocalPortForwarding: true, AllowRemotePortForwarding: true, diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 891969d3..01429dc9 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -214,6 +214,13 @@ type AuthRequest struct { closed *atomic.Bool } +func NewAuthRequest() AuthRequest { + return AuthRequest{ + finished: make(chan AuthVerdict), + closed: &atomic.Bool{}, + } +} + func NewRegisterAuthRequest(node Node) AuthRequest { return AuthRequest{ node: &node, diff --git a/integration/scenario.go b/integration/scenario.go index cd43b78f..e769bd73 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -141,6 +141,12 @@ type ScenarioSpec struct { // Versions is specific list of versions to use for the test. Versions []string + // OIDCSkipUserCreation, if true, skips creating users via headscale CLI + // during environment setup. Useful for OIDC tests where the SSH policy + // references users by name, since OIDC login creates users automatically + // and pre-creating them via CLI causes duplicate user records. + OIDCSkipUserCreation bool + // OIDCUsers, if populated, will start a Mock OIDC server and populate // the user login stack with the given users. // If the NodesPerUser is set, it should align with this list to ensure @@ -866,9 +872,18 @@ func (s *Scenario) createHeadscaleEnvWithTags( } for _, user := range s.spec.Users { - u, err := s.CreateUser(user) - if err != nil { - return err + var u *v1.User + + if s.spec.OIDCSkipUserCreation { + // Only register locally — OIDC login will create the headscale user. + s.mu.Lock() + s.users[user] = &User{Clients: make(map[string]TailscaleClient)} + s.mu.Unlock() + } else { + u, err = s.CreateUser(user) + if err != nil { + return err + } } var userOpts []tsic.Option