diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 1dfd10ee..7e059045 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -253,6 +253,7 @@ jobs: - TestSSHIsBlockedInACL - TestSSHUserOnlyIsolation - TestSSHAutogroupSelf + - TestSSHOneUserToOneCheckMode - TestTagsAuthKeyWithTagRequestDifferentTag - TestTagsAuthKeyWithTagNoAdvertiseFlag - TestTagsAuthKeyWithTagCannotAddViaCLI diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 57a79b96..9a8814ca 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -116,6 +116,8 @@ func (h *Headscale) NoiseUpgradeHandler( r.Post("/register", ns.RegistrationHandler) r.Post("/map", ns.PollNetMapHandler) + r.Get("/ssh/action/from/{src_node_id}/to/{dst_node_id}/ssh_user/{ssh_user}/local_user/{local_user}", ns.SSHAction) + // Not implemented yet // // /whoami is a debug endpoint to validate that the client can communicate over the connection, @@ -147,12 +149,10 @@ func (h *Headscale) NoiseUpgradeHandler( r.Post("/update-health", ns.NotImplementedHandler) r.Route("/webclient", func(r chi.Router) {}) + + r.Post("/c2n", ns.NotImplementedHandler) }) - r.Post("/c2n", ns.NotImplementedHandler) - - r.Get("/ssh-action", ns.SSHAction) - ns.httpBaseConfig = &http.Server{ Handler: r, ReadHeaderTimeout: types.HTTPTimeout, @@ -246,7 +246,39 @@ func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *ht // SSHAction handles the /ssh-action endpoint, it returns a [tailcfg.SSHAction] // to the client with the verdict of an SSH access request. func (ns *noiseServer) SSHAction(writer http.ResponseWriter, req *http.Request) { - log.Trace().Caller().Str("path", req.URL.String()).Msg("got SSH action request") + srcNodeID := chi.URLParam(req, "src_node_id") + dstNodeID := chi.URLParam(req, "dst_node_id") + sshUser := chi.URLParam(req, "ssh_user") + localUser := chi.URLParam(req, "local_user") + log.Trace().Caller(). + Str("path", req.URL.String()). + Str("src_node_id", srcNodeID). + Str("dst_node_id", dstNodeID). + Str("ssh_user", sshUser). + Str("local_user", localUser). + Msg("got SSH action request") + + accept := tailcfg.SSHAction{ + Reject: false, + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + } + + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + + err := json.NewEncoder(writer).Encode(accept) + if err != nil { + log.Error().Caller().Err(err).Msg("failed to encode SSH action response") + return + } + + // Ensure response is flushed to client + if flusher, ok := writer.(http.Flusher); ok { + flusher.Flush() + } } // PollNetMapHandler takes care of /machine/:id/map using the Noise protocol diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 9c2c5f17..d75e1914 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -12,6 +12,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" + "github.com/spf13/viper" "go4.org/netipx" "tailscale.com/tailcfg" "tailscale.com/types/views" @@ -319,11 +320,27 @@ func (pol *Policy) compileACLWithAutogroupSelf( return rules, nil } -func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction { +var sshAccept = tailcfg.SSHAction{ + Reject: false, + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, +} + +func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction { return tailcfg.SSHAction{ - Reject: !accept, - Accept: accept, - SessionDuration: duration, + Reject: false, + Accept: false, + SessionDuration: duration, + // Replaced in the client: + // * $SRC_NODE_IP (URL escaped) + // * $SRC_NODE_ID (Node.ID as int64 string) + // * $DST_NODE_IP (URL escaped) + // * $DST_NODE_ID (Node.ID as int64 string) + // * $SSH_USER (URL escaped, ssh user requested) + // * $LOCAL_USER (URL escaped, local user mapped) + HoldAndDelegate: fmt.Sprintf("%s/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID/ssh_user/$SSH_USER/local_user/$LOCAL_USER", baseURL), AllowAgentForwarding: true, AllowLocalPortForwarding: true, AllowRemotePortForwarding: true, @@ -375,11 +392,14 @@ func (pol *Policy) compileSSHPolicy( var action tailcfg.SSHAction + // HACK HACK HACK + serverURL := viper.GetString("server_url") + switch rule.Action { case SSHActionAccept: - action = sshAction(true, 0) + action = sshAccept case SSHActionCheck: - action = sshAction(true, time.Duration(rule.CheckPeriod)) + action = sshCheck(serverURL, time.Duration(rule.CheckPeriod)) default: return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) } diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 45bc2dc7..15867579 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -579,3 +579,75 @@ func TestSSHAutogroupSelf(t *testing.T) { } } } + +func TestSSHOneUserToOneCheckMode(t *testing.T) { + IntegrationSkip(t) + + scenario := sshScenario(t, + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-test"): []policyv2.Username{policyv2.Username("user1@")}, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + SSHs: []policyv2.SSH{ + { + Action: "check", + Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")}, + // Use autogroup:member and autogroup:tagged instead of wildcard + // since wildcard (*) is no longer supported for SSH destinations + Destinations: policyv2.SSHDstAliases{ + new(policyv2.AutoGroupMember), + new(policyv2.AutoGroupTagged), + }, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, + }, + }, + }, + 1, + ) + // defer scenario.ShutdownAssertNoPanics(t) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + requireNoErrListClients(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + requireNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + requireNoErrListFQDN(t, err) + + for _, client := range user1Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + assertSSHHostname(t, client, peer) + } + } + + for _, client := range user2Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + assertSSHPermissionDenied(t, client, peer) + } + } +}