diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index a27b105c..034ad5ae 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -1,26 +1,17 @@ package integration import ( - "errors" - "fmt" - "log" "net/netip" - "net/url" - "strings" "testing" + "slices" + "github.com/juanfont/headscale/integration/hsic" "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -var errParseAuthPage = errors.New("failed to parse auth page") - -type AuthWebFlowScenario struct { - *Scenario -} - func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { IntegrationSkip(t) @@ -29,17 +20,14 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { Users: []string{"user1", "user2"}, } - baseScenario, err := NewScenario(spec) + scenario, err := NewScenario(spec) if err != nil { t.Fatalf("failed to create scenario: %s", err) } - - scenario := AuthWebFlowScenario{ - Scenario: baseScenario, - } defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( + nil, hsic.WithTestName("webauthping"), hsic.WithEmbeddedDERPServerOnly(), hsic.WithTLS(), @@ -74,15 +62,12 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { Users: []string{"user1", "user2"}, } - baseScenario, err := NewScenario(spec) + scenario, err := NewScenario(spec) assertNoErr(t, err) - - scenario := AuthWebFlowScenario{ - Scenario: baseScenario, - } defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( + nil, hsic.WithTestName("weblogout"), hsic.WithTLS(), ) @@ -136,7 +121,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { t.Logf("all clients logged out") for _, userName := range spec.Users { - err = scenario.runTailscaleUp(userName, headscale.GetEndpoint()) + err = scenario.RunTailscaleUpWithURL(userName, headscale.GetEndpoint()) if err != nil { t.Fatalf("failed to run tailscale up (%q): %s", headscale.GetEndpoint(), err) } @@ -170,14 +155,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { } for _, ip := range ips { - found := false - for _, oldIP := range clientIPs[client] { - if ip == oldIP { - found = true - - break - } - } + found := slices.Contains(clientIPs[client], ip) if !found { t.Fatalf( @@ -192,121 +170,3 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { t.Logf("all clients IPs are the same") } - -func (s *AuthWebFlowScenario) CreateHeadscaleEnv( - opts ...hsic.Option, -) error { - headscale, err := s.Headscale(opts...) - if err != nil { - return err - } - - err = headscale.WaitForRunning() - if err != nil { - return err - } - - for _, userName := range s.spec.Users { - log.Printf("creating user %s with %d clients", userName, s.spec.NodesPerUser) - err = s.CreateUser(userName) - if err != nil { - return err - } - - err = s.CreateTailscaleNodesInUser(userName, "all", s.spec.NodesPerUser) - if err != nil { - return err - } - - err = s.runTailscaleUp(userName, headscale.GetEndpoint()) - if err != nil { - return err - } - } - - return nil -} - -func (s *AuthWebFlowScenario) runTailscaleUp( - userStr, loginServer string, -) error { - log.Printf("running tailscale up for user %q", userStr) - if user, ok := s.users[userStr]; ok { - for _, client := range user.Clients { - c := client - user.joinWaitGroup.Go(func() error { - log.Printf("logging %q into %q", c.Hostname(), loginServer) - loginURL, err := c.LoginWithURL(loginServer) - if err != nil { - log.Printf("failed to run tailscale up (%s): %s", c.Hostname(), err) - - return err - } - - err = s.runHeadscaleRegister(userStr, loginURL) - if err != nil { - log.Printf("failed to register client (%s): %s", c.Hostname(), err) - - return err - } - - return nil - }) - - err := client.WaitForRunning() - if err != nil { - log.Printf("error waiting for client %s to be ready: %s", client.Hostname(), err) - } - } - - if err := user.joinWaitGroup.Wait(); err != nil { - return err - } - - for _, client := range user.Clients { - err := client.WaitForRunning() - if err != nil { - return fmt.Errorf("%s failed to up tailscale node: %w", client.Hostname(), err) - } - } - - return nil - } - - return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) -} - -func (s *AuthWebFlowScenario) runHeadscaleRegister(userStr string, loginURL *url.URL) error { - body, err := doLoginURL("web-auth-not-set", loginURL) - if err != nil { - return err - } - - // see api.go HTML template - codeSep := strings.Split(string(body), "") - if len(codeSep) != 2 { - return errParseAuthPage - } - - keySep := strings.Split(codeSep[0], "key ") - if len(keySep) != 2 { - return errParseAuthPage - } - key := keySep[1] - log.Printf("registering node %s", key) - - if headscale, err := s.Headscale(); err == nil { - _, err = headscale.Execute( - []string{"headscale", "nodes", "register", "--user", userStr, "--key", key}, - ) - if err != nil { - log.Printf("failed to register node: %s", err) - - return err - } - - return nil - } - - return fmt.Errorf("failed to find headscale: %w", errNoHeadscaleAvailable) -} diff --git a/integration/scenario.go b/integration/scenario.go index c65f92af..9825f0dd 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -16,6 +16,7 @@ import ( "os" "sort" "strconv" + "strings" "sync" "testing" "time" @@ -666,11 +667,17 @@ func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error { log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err) } - _, err = doLoginURL(tsc.Hostname(), loginURL) + body, err := doLoginURL(tsc.Hostname(), loginURL) if err != nil { return err } + // If the URL is not a OIDC URL, then we need to + // run the register command to fully log in the client. + if !strings.Contains(loginURL.String(), "/oidc/") { + s.runHeadscaleRegister(userStr, body) + } + return nil }) @@ -741,6 +748,38 @@ func doLoginURL(hostname string, loginURL *url.URL) (string, error) { return string(body), nil } +var errParseAuthPage = errors.New("failed to parse auth page") + +func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { + // see api.go HTML template + codeSep := strings.Split(string(body), "") + if len(codeSep) != 2 { + return errParseAuthPage + } + + keySep := strings.Split(codeSep[0], "key ") + if len(keySep) != 2 { + return errParseAuthPage + } + key := keySep[1] + log.Printf("registering node %s", key) + + if headscale, err := s.Headscale(); err == nil { + _, err = headscale.Execute( + []string{"headscale", "nodes", "register", "--user", userStr, "--key", key}, + ) + if err != nil { + log.Printf("failed to register node: %s", err) + + return err + } + + return nil + } + + return fmt.Errorf("failed to find headscale: %w", errNoHeadscaleAvailable) +} + type LoggingRoundTripper struct{} func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {