diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 15867579..75c42af0 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -3,13 +3,16 @@ package integration import ( "fmt" "log" + "net/url" "strings" "testing" "time" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" + "github.com/oauth2-proxy/mockoidc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" @@ -580,40 +583,121 @@ func TestSSHAutogroupSelf(t *testing.T) { } } -func TestSSHOneUserToOneCheckMode(t *testing.T) { - IntegrationSkip(t) +type sshCheckResult struct { + stdout string + stderr string + err error +} - scenario := sshScenario(t, - &policyv2.Policy{ - Groups: policyv2.Groups{ - policyv2.Group("group:integration-test"): []policyv2.Username{policyv2.Username("user1@")}, +// doSSHCheck runs SSH in a goroutine with a longer timeout, returning a channel +// for the result. The SSH command will block while waiting for auth approval in +// check mode. +func doSSHCheck( + t *testing.T, + client TailscaleClient, + peer TailscaleClient, +) chan sshCheckResult { + t.Helper() + + peerFQDN, _ := peer.FQDN() + + command := []string{ + "/usr/bin/ssh", "-o StrictHostKeyChecking=no", "-o ConnectTimeout=30", + fmt.Sprintf("%s@%s", "ssh-it-user", peerFQDN), + "'hostname'", + } + + log.Printf( + "[SSH check] Running from %s to %s", + client.Hostname(), + peer.Hostname(), + ) + + ch := make(chan sshCheckResult, 1) + + go func() { + stdout, stderr, err := client.Execute( + command, + dockertestutil.ExecuteCommandTimeout(60*time.Second), + ) + ch <- sshCheckResult{stdout, stderr, err} + }() + + return ch +} + +// findSSHCheckAuthID polls headscale container logs for the SSH action auth-id. +// The SSH action handler logs "SSH action follow-up" with the auth_id on the +// follow-up request (where auth_id is non-empty). +func findSSHCheckAuthID(t *testing.T, headscale ControlServer) string { + t.Helper() + + var authID string + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + _, stderr, err := headscale.ReadLog() + assert.NoError(c, err) + + for line := range strings.SplitSeq(stderr, "\n") { + if !strings.Contains(line, "SSH action follow-up") { + continue + } + + if idx := strings.Index(line, "auth_id="); idx != -1 { + start := idx + len("auth_id=") + + end := strings.IndexByte(line[start:], ' ') + if end == -1 { + end = len(line[start:]) + } + + authID = line[start : start+end] + } + } + + assert.NotEmpty(c, authID, "auth-id not found in headscale logs") + }, 10*time.Second, 500*time.Millisecond, "waiting for SSH check auth-id in headscale logs") + + return authID +} + +// sshCheckPolicy returns a policy with SSH "check" mode for group:integration-test +// targeting autogroup:member and autogroup:tagged destinations. +func sshCheckPolicy() *policyv2.Policy { + return &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-test"): []policyv2.Username{ + policyv2.Username("user1@"), }, - ACLs: []policyv2.ACL{ - { - Action: "accept", - Protocol: "tcp", - Sources: []policyv2.Alias{wildcard()}, - Destinations: []policyv2.AliasWithPorts{ - aliasWithPorts(wildcard(), tailcfg.PortRangeAny), - }, - }, - }, - SSHs: []policyv2.SSH{ - { - Action: "check", - Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")}, - // Use autogroup:member and autogroup:tagged instead of wildcard - // since wildcard (*) is no longer supported for SSH destinations - Destinations: policyv2.SSHDstAliases{ - new(policyv2.AutoGroupMember), - new(policyv2.AutoGroupTagged), - }, - Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), }, }, }, - 1, - ) + SSHs: []policyv2.SSH{ + { + Action: "check", + Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")}, + Destinations: policyv2.SSHDstAliases{ + new(policyv2.AutoGroupMember), + new(policyv2.AutoGroupTagged), + }, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, + }, + }, + } +} + +func TestSSHOneUserToOneCheckModeCLI(t *testing.T) { + IntegrationSkip(t) + + scenario := sshScenario(t, sshCheckPolicy(), 1) // defer scenario.ShutdownAssertNoPanics(t) allClients, err := scenario.ListTailscaleClients() @@ -625,22 +709,167 @@ func TestSSHOneUserToOneCheckMode(t *testing.T) { user2Clients, err := scenario.ListTailscaleClients("user2") requireNoErrListClients(t, err) + headscale, err := scenario.Headscale() + require.NoError(t, err) + err = scenario.WaitForTailscaleSync() requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() requireNoErrListFQDN(t, err) + // user1 can SSH (via check) to all peers for _, client := range user1Clients { for _, peer := range allClients { if client.Hostname() == peer.Hostname() { continue } - assertSSHHostname(t, client, peer) + // Start SSH — will block waiting for check auth + sshResult := doSSHCheck(t, client, peer) + + // Find the auth-id from headscale logs + authID := findSSHCheckAuthID(t, headscale) + + // Approve via CLI + _, err := headscale.Execute( + []string{ + "headscale", "auth", "approve", + "--auth-id", authID, + }, + ) + require.NoError(t, err) + + // Wait for SSH to complete + select { + case result := <-sshResult: + require.NoError(t, result.err) + require.Contains( + t, + peer.ContainerID(), + strings.ReplaceAll(result.stdout, "\n", ""), + ) + case <-time.After(30 * time.Second): + t.Fatal("SSH did not complete after auth approval") + } } } + // user2 cannot SSH — not in the check policy group + for _, client := range user2Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + assertSSHPermissionDenied(t, client, peer) + } + } +} + +func TestSSHOneUserToOneCheckModeOIDC(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + OIDCSkipUserCreation: true, + OIDCUsers: []mockoidc.MockUser{ + // First 2: consumed during node registration + oidcMockUser("user1", true), + oidcMockUser("user2", true), + // Extra: consumed during SSH check auth flows. + // Each SSH check pops one user from the queue. + oidcMockUser("user1", true), + }, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + // defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + } + + err = scenario.CreateHeadscaleEnvWithLoginURL( + []tsic.Option{ + tsic.WithSSH(), + tsic.WithNetfilter("off"), + tsic.WithPackages("openssh"), + tsic.WithExtraCommands("adduser ssh-it-user"), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithACLPolicy(sshCheckPolicy()), + hsic.WithTestName("sshcheckoidc"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer( + "/tmp/hs_client_oidc_secret", + []byte(scenario.mockOIDC.ClientSecret()), + ), + ) + require.NoError(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + requireNoErrListClients(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + requireNoErrListClients(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + requireNoErrListFQDN(t, err) + + // user1 can SSH (via check) to all peers + for _, client := range user1Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + // Start SSH — will block waiting for check auth + sshResult := doSSHCheck(t, client, peer) + + // Find the auth-id from headscale logs + authID := findSSHCheckAuthID(t, headscale) + + // Build auth URL and visit it to trigger OIDC flow. + // The mock OIDC server auto-authenticates from the user queue. + authURL := headscale.GetEndpoint() + "/auth/" + authID + parsedURL, err := url.Parse(authURL) + require.NoError(t, err) + + _, err = doLoginURL("ssh-check-oidc", parsedURL) + require.NoError(t, err) + + // Wait for SSH to complete + select { + case result := <-sshResult: + require.NoError(t, result.err) + require.Contains( + t, + peer.ContainerID(), + strings.ReplaceAll(result.stdout, "\n", ""), + ) + case <-time.After(30 * time.Second): + t.Fatal("SSH did not complete after OIDC auth") + } + } + } + + // user2 cannot SSH — not in the check policy group for _, client := range user2Clients { for _, peer := range allClients { if client.Hostname() == peer.Hostname() {