1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-09-25 17:51:11 +02:00

integration: replace time.Sleep with Eventually

sleeping in tests is a big no no, its time to stop.

Sleeping only works well on the same machine under the same conditions
we rather wait for something as things take time on different machines
This commit is contained in:
Kristoffer Dalby 2025-07-09 11:15:48 +00:00
parent b904276f2b
commit 9b47f71f37
73 changed files with 675 additions and 612 deletions

View File

@ -48,5 +48,4 @@ jobs:
- name: Deploy stable docs from tag
if: startsWith(github.ref, 'refs/tags/v')
# This assumes that only newer tags are pushed
run:
mike deploy --push --update-aliases ${GITHUB_REF_NAME#v} stable latest
run: mike deploy --push --update-aliases ${GITHUB_REF_NAME#v} stable latest

View File

@ -75,7 +75,7 @@ jobs:
# Some of the jobs might still require manual restart as they are really
# slow and this will cause them to eventually be killed by Github actions.
attempt_delay: 300000 # 5 min
attempt_limit: 3
attempt_limit: 2
command: |
nix develop --command -- hi run "^${{ inputs.test }}$" \
--timeout=120m \

View File

@ -36,8 +36,7 @@ jobs:
- name: golangci-lint
if: steps.changed-files.outputs.files == 'true'
run:
nix develop --command -- golangci-lint run
run: nix develop --command -- golangci-lint run
--new-from-rev=${{github.event.pull_request.base.sha}}
--format=colored-line-number
@ -75,8 +74,7 @@ jobs:
- name: Prettify code
if: steps.changed-files.outputs.files == 'true'
run:
nix develop --command -- prettier --no-error-on-unmatched-pattern
run: nix develop --command -- prettier --no-error-on-unmatched-pattern
--ignore-unknown --check **/*.{ts,js,md,yaml,yml,sass,css,scss,html}
proto-lint:

View File

@ -117,7 +117,7 @@ var createNodeCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot create node: %s", status.Convert(err).Message()),
"Cannot create node: "+status.Convert(err).Message(),
output,
)
}

View File

@ -2,6 +2,7 @@ package cli
import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
@ -68,7 +69,7 @@ func mockOIDC() error {
userStr := os.Getenv("MOCKOIDC_USERS")
if userStr == "" {
return fmt.Errorf("MOCKOIDC_USERS not defined")
return errors.New("MOCKOIDC_USERS not defined")
}
var users []mockoidc.MockUser

View File

@ -184,7 +184,7 @@ var listNodesCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
"Cannot get nodes: "+status.Convert(err).Message(),
output,
)
}
@ -398,10 +398,7 @@ var deleteNodeCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Error getting node node: %s",
status.Convert(err).Message(),
),
"Error getting node node: "+status.Convert(err).Message(),
output,
)
@ -437,10 +434,7 @@ var deleteNodeCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Error deleting node: %s",
status.Convert(err).Message(),
),
"Error deleting node: "+status.Convert(err).Message(),
output,
)
@ -498,10 +492,7 @@ var moveNodeCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Error getting node: %s",
status.Convert(err).Message(),
),
"Error getting node: "+status.Convert(err).Message(),
output,
)
@ -517,10 +508,7 @@ var moveNodeCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Error moving node: %s",
status.Convert(err).Message(),
),
"Error moving node: "+status.Convert(err).Message(),
output,
)
@ -567,10 +555,7 @@ be assigned to nodes.`,
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Error backfilling IPs: %s",
status.Convert(err).Message(),
),
"Error backfilling IPs: "+status.Convert(err).Message(),
output,
)

View File

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"net/url"
"strconv"
survey "github.com/AlecAivazis/survey/v2"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -27,10 +28,7 @@ func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) {
err := errors.New("--name or --identifier flag is required")
ErrorOutput(
err,
fmt.Sprintf(
"Cannot rename user: %s",
status.Convert(err).Message(),
),
"Cannot rename user: "+status.Convert(err).Message(),
"",
)
}
@ -114,10 +112,7 @@ var createUserCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Cannot create user: %s",
status.Convert(err).Message(),
),
"Cannot create user: "+status.Convert(err).Message(),
output,
)
}
@ -147,16 +142,16 @@ var destroyUserCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
"Error: "+status.Convert(err).Message(),
output,
)
}
if len(users.GetUsers()) != 1 {
err := fmt.Errorf("Unable to determine user to delete, query returned multiple users, use ID")
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
ErrorOutput(
err,
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
"Error: "+status.Convert(err).Message(),
output,
)
}
@ -185,10 +180,7 @@ var destroyUserCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Cannot destroy user: %s",
status.Convert(err).Message(),
),
"Cannot destroy user: "+status.Convert(err).Message(),
output,
)
}
@ -233,7 +225,7 @@ var listUsersCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot get users: %s", status.Convert(err).Message()),
"Cannot get users: "+status.Convert(err).Message(),
output,
)
}
@ -247,7 +239,7 @@ var listUsersCmd = &cobra.Command{
tableData = append(
tableData,
[]string{
fmt.Sprintf("%d", user.GetId()),
strconv.FormatUint(user.GetId(), 10),
user.GetDisplayName(),
user.GetName(),
user.GetEmail(),
@ -287,16 +279,16 @@ var renameUserCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
"Error: "+status.Convert(err).Message(),
output,
)
}
if len(users.GetUsers()) != 1 {
err := fmt.Errorf("Unable to determine user to delete, query returned multiple users, use ID")
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
ErrorOutput(
err,
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
"Error: "+status.Convert(err).Message(),
output,
)
}
@ -312,10 +304,7 @@ var renameUserCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Cannot rename user: %s",
status.Convert(err).Message(),
),
"Cannot rename user: "+status.Convert(err).Message(),
output,
)
}

View File

@ -66,7 +66,7 @@ func killTestContainers(ctx context.Context) error {
if cont.State == "running" {
_ = cli.ContainerKill(ctx, cont.ID, "KILL")
}
// Then remove the container with retry logic
if removeContainerWithRetry(ctx, cli, cont.ID) {
removed++
@ -87,25 +87,25 @@ func killTestContainers(ctx context.Context) error {
func removeContainerWithRetry(ctx context.Context, cli *client.Client, containerID string) bool {
maxRetries := 3
baseDelay := 100 * time.Millisecond
for attempt := 0; attempt < maxRetries; attempt++ {
for attempt := range maxRetries {
err := cli.ContainerRemove(ctx, containerID, container.RemoveOptions{
Force: true,
})
if err == nil {
return true
}
// If this is the last attempt, don't wait
if attempt == maxRetries-1 {
break
}
// Wait with exponential backoff
delay := baseDelay * time.Duration(1<<attempt)
time.Sleep(delay)
}
return false
}

View File

@ -156,10 +156,10 @@ func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunC
projectRoot := findProjectRoot(pwd)
runID := dockertestutil.ExtractRunIDFromContainerName(containerName)
env := []string{
fmt.Sprintf("HEADSCALE_INTEGRATION_POSTGRES=%d", boolToInt(config.UsePostgres)),
fmt.Sprintf("HEADSCALE_INTEGRATION_RUN_ID=%s", runID),
"HEADSCALE_INTEGRATION_RUN_ID=" + runID,
}
containerConfig := &container.Config{
Image: "golang:" + config.GoVersion,
@ -175,7 +175,7 @@ func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunC
// Get the correct Docker socket path from the current context
dockerSocketPath := getDockerSocketPath()
if config.Verbose {
log.Printf("Using Docker socket: %s", dockerSocketPath)
}
@ -184,7 +184,7 @@ func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunC
AutoRemove: false, // We'll remove manually for better control
Binds: []string{
fmt.Sprintf("%s:%s", projectRoot, projectRoot),
fmt.Sprintf("%s:/var/run/docker.sock", dockerSocketPath),
dockerSocketPath + ":/var/run/docker.sock",
logsDir + ":/tmp/control",
},
Mounts: []mount.Mount{
@ -237,7 +237,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
}
testContainers := getCurrentTestContainers(containers, testContainerID, verbose)
// Wait for all test containers to reach a final state
maxWaitTime := 10 * time.Second
checkInterval := 500 * time.Millisecond
@ -254,7 +254,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
return nil
case <-ticker.C:
allFinalized := true
for _, testCont := range testContainers {
inspect, err := cli.ContainerInspect(ctx, testCont.ID)
if err != nil {
@ -263,17 +263,18 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
}
continue
}
// Check if container is in a final state
if !isContainerFinalized(inspect.State) {
allFinalized = false
if verbose {
log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status)
}
break
}
}
if allFinalized {
if verbose {
log.Printf("All test containers finalized, ready for artifact extraction")
@ -290,7 +291,6 @@ func isContainerFinalized(state *container.State) bool {
return !state.Running && state.FinishedAt != ""
}
// findProjectRoot locates the project root by finding the directory containing go.mod.
func findProjectRoot(startPath string) string {
current := startPath
@ -427,7 +427,7 @@ func listControlFiles(logsDir string) {
}
if entry.IsDir() {
// Include directories (pprof, mapresponses)
// Include directories (pprof, mapresponses)
if strings.Contains(name, "-pprof") || strings.Contains(name, "-mapresponses") {
dataDirs = append(dataDirs, name)
}
@ -510,7 +510,7 @@ type testContainer struct {
// getCurrentTestContainers filters containers to only include those from the current test run.
func getCurrentTestContainers(containers []container.Summary, testContainerID string, verbose bool) []testContainer {
var testRunContainers []testContainer
// Find the test container to get its run ID label
var runID string
for _, cont := range containers {
@ -521,16 +521,16 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
break
}
}
if runID == "" {
log.Printf("Error: test container %s missing required hi.run-id label", testContainerID[:12])
return testRunContainers
}
if verbose {
log.Printf("Looking for containers with run ID: %s", runID)
}
// Find all containers with the same run ID
for _, cont := range containers {
for _, name := range cont.Names {
@ -546,18 +546,19 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
log.Printf("Including container %s (run ID: %s)", containerName, runID)
}
}
break
}
}
}
return testRunContainers
}
// extractContainerArtifacts saves logs and tar files from a container.
func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error {
// Ensure the logs directory exists
if err := os.MkdirAll(logsDir, 0755); err != nil {
if err := os.MkdirAll(logsDir, 0o755); err != nil {
return fmt.Errorf("failed to create logs directory: %w", err)
}
@ -608,12 +609,12 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
}
// Write stdout logs
if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0644); err != nil {
if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil {
return fmt.Errorf("failed to write stdout log: %w", err)
}
// Write stderr logs
if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0644); err != nil {
if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil {
return fmt.Errorf("failed to write stderr log: %w", err)
}
@ -626,7 +627,7 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
// extractContainerFiles extracts database file and directories from headscale containers.
// Note: The actual file extraction is now handled by the integration tests themselves
// via SaveProfile, SaveMapResponses, and SaveDatabase functions in hsic.go
// via SaveProfile, SaveMapResponses, and SaveDatabase functions in hsic.go.
func extractContainerFiles(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error {
// Files are now extracted directly by the integration tests
// This function is kept for potential future use or other file types
@ -677,7 +678,7 @@ func extractDirectory(ctx context.Context, cli *client.Client, containerID, sour
// Create target directory
targetDir := filepath.Join(logsDir, dirName)
if err := os.MkdirAll(targetDir, 0755); err != nil {
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", targetDir, err)
}

View File

@ -10,10 +10,8 @@ import (
"strings"
)
var (
// ErrFileNotFoundInTar indicates a file was not found in the tar archive.
ErrFileNotFoundInTar = errors.New("file not found in tar")
)
// ErrFileNotFoundInTar indicates a file was not found in the tar archive.
var ErrFileNotFoundInTar = errors.New("file not found in tar")
// extractFileFromTar extracts a single file from a tar reader.
func extractFileFromTar(tarReader io.Reader, fileName, outputPath string) error {
@ -42,6 +40,7 @@ func extractFileFromTar(tarReader io.Reader, fileName, outputPath string) error
if _, err := io.Copy(outFile, tr); err != nil {
return fmt.Errorf("failed to copy file contents: %w", err)
}
return nil
}
}
@ -98,4 +97,4 @@ func extractDirectoryFromTar(tarReader io.Reader, targetDir string) error {
}
return nil
}
}

View File

@ -143,6 +143,7 @@
yq-go
ripgrep
postgresql
traceroute
# 'dot' is needed for pprof graphs
# go tool pprof -http=: <source>

View File

@ -98,7 +98,6 @@ func (h *Headscale) handleExistingNode(
return nil, nil
}
}
n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
@ -169,7 +168,6 @@ func (h *Headscale) handleRegisterWithAuthKey(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
node, changed, err := h.state.HandleNodeFromPreAuthKey(
regReq,
machineKey,
@ -178,9 +176,11 @@ func (h *Headscale) handleRegisterWithAuthKey(
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
}
if perr, ok := err.(types.PAKError); ok {
var perr types.PAKError
if errors.As(err, &perr) {
return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil)
}
return nil, err
}

View File

@ -1,11 +1,10 @@
package capver
import (
"slices"
"sort"
"strings"
"slices"
xmaps "golang.org/x/exp/maps"
"tailscale.com/tailcfg"
"tailscale.com/util/set"

View File

@ -1,6 +1,6 @@
package capver
//Generated DO NOT EDIT
// Generated DO NOT EDIT
import "tailscale.com/tailcfg"
@ -38,17 +38,16 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
"v1.82.5": 115,
}
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
87: "v1.60.0",
88: "v1.62.0",
90: "v1.64.0",
95: "v1.66.0",
97: "v1.68.0",
102: "v1.70.0",
104: "v1.72.0",
106: "v1.74.0",
109: "v1.78.0",
113: "v1.80.0",
115: "v1.82.0",
87: "v1.60.0",
88: "v1.62.0",
90: "v1.64.0",
95: "v1.66.0",
97: "v1.68.0",
102: "v1.70.0",
104: "v1.72.0",
106: "v1.74.0",
109: "v1.78.0",
113: "v1.80.0",
115: "v1.82.0",
}

View File

@ -764,13 +764,13 @@ AND auth_key_id NOT IN (
// Drop all indexes first to avoid conflicts
indexesToDrop := []string{
"idx_users_deleted_at",
"idx_provider_identifier",
"idx_provider_identifier",
"idx_name_provider_identifier",
"idx_name_no_provider_identifier",
"idx_api_keys_prefix",
"idx_policies_deleted_at",
}
for _, index := range indexesToDrop {
_ = tx.Exec("DROP INDEX IF EXISTS " + index).Error
}
@ -927,6 +927,7 @@ AND auth_key_id NOT IN (
}
log.Info().Msg("Schema recreation completed successfully")
return nil
},
Rollback: func(db *gorm.DB) error { return nil },

View File

@ -93,7 +93,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
Avoid: false,
Nodes: []*tailcfg.DERPNode{
{
Name: fmt.Sprintf("%d", d.cfg.ServerRegionID),
Name: strconv.Itoa(d.cfg.ServerRegionID),
RegionID: d.cfg.ServerRegionID,
HostName: host,
DERPPort: port,

View File

@ -103,7 +103,6 @@ func (e *ExtraRecordsMan) Run() {
return struct{}{}, nil
}, backoff.WithBackOff(backoff.NewExponentialBackOff()))
if err != nil {
log.Error().Caller().Err(err).Msgf("extra records filewatcher retrying to find file after delete")
continue

View File

@ -475,7 +475,10 @@ func (api headscaleV1APIServer) RenameNode(
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
ctx = types.NotifyCtx(ctx, "cli-renamenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByNodeID(ctx, types.UpdateSelf(node.ID), node.ID)
ctx = types.NotifyCtx(ctx, "cli-renamenode-peers", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
log.Trace().

View File

@ -32,7 +32,7 @@ const (
reservedResponseHeaderSize = 4
)
// httpError logs an error and sends an HTTP error response with the given
// httpError logs an error and sends an HTTP error response with the given.
func httpError(w http.ResponseWriter, err error) {
var herr HTTPError
if errors.As(err, &herr) {
@ -102,6 +102,7 @@ func (h *Headscale) handleVerifyRequest(
resp := &tailcfg.DERPAdmitClientResponse{
Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic),
}
return json.NewEncoder(writer).Encode(resp)
}

View File

@ -500,7 +500,7 @@ func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.
}
// ListNodes queries the database for either all nodes if no parameters are given
// or for the given nodes if at least one node ID is given as parameter
// or for the given nodes if at least one node ID is given as parameter.
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
nodes, err := m.state.ListNodes(nodeIDs...)
if err != nil {

View File

@ -80,7 +80,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
}
}
// mockState is a mock implementation that provides the required methods
// mockState is a mock implementation that provides the required methods.
type mockState struct {
polMan policy.PolicyManager
derpMap *tailcfg.DERPMap
@ -133,6 +133,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
}
}
}
return filtered, nil
}
// Return all peers except the node itself
@ -142,6 +143,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
filtered = append(filtered, peer)
}
}
return filtered, nil
}
@ -157,8 +159,10 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
}
}
}
return filtered, nil
}
return m.nodes, nil
}

View File

@ -11,7 +11,7 @@ import (
"tailscale.com/types/views"
)
// NodeCanHaveTagChecker is an interface for checking if a node can have a tag
// NodeCanHaveTagChecker is an interface for checking if a node can have a tag.
type NodeCanHaveTagChecker interface {
NodeCanHaveTag(node types.NodeView, tag string) bool
}

View File

@ -111,5 +111,6 @@ func (r *respWriterProm) Write(b []byte) (int, error) {
}
n, err := r.ResponseWriter.Write(b)
r.written += int64(n)
return n, err
}

View File

@ -50,6 +50,7 @@ func NewNotifier(cfg *types.Config) *Notifier {
n.b = b
go b.doWork()
return n
}
@ -72,7 +73,7 @@ func (n *Notifier) Close() {
n.nodes = make(map[types.NodeID]chan<- types.StateUpdate)
}
// safeCloseChannel closes a channel and panic recovers if already closed
// safeCloseChannel closes a channel and panic recovers if already closed.
func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) {
defer func() {
if r := recover(); r != nil {
@ -170,6 +171,7 @@ func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
if val, ok := n.connected.Load(nodeID); ok {
return val
}
return false
}
@ -182,7 +184,7 @@ func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
return false
}
// LikelyConnectedMap returns a thread safe map of connected nodes
// LikelyConnectedMap returns a thread safe map of connected nodes.
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
return n.connected
}

View File

@ -1,17 +1,15 @@
package notifier
import (
"context"
"fmt"
"math/rand"
"net/netip"
"slices"
"sort"
"sync"
"testing"
"time"
"slices"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
@ -241,7 +239,7 @@ func TestBatcher(t *testing.T) {
defer n.RemoveNode(1, ch)
for _, u := range tt.updates {
n.NotifyAll(context.Background(), u)
n.NotifyAll(t.Context(), u)
}
n.b.flush()
@ -270,7 +268,7 @@ func TestBatcher(t *testing.T) {
// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
// close a channel that was already closed, which can happen when a node changes
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting.
func TestIsLikelyConnectedRaceCondition(t *testing.T) {
// mock config for the notifier
cfg := &types.Config{
@ -308,16 +306,17 @@ func TestIsLikelyConnectedRaceCondition(t *testing.T) {
for range iterations {
// Simulate race by having some goroutines check IsLikelyConnected
// while others add/remove the node
if routineID%3 == 0 {
switch routineID % 3 {
case 0:
// This goroutine checks connection status
isConnected := notifier.IsLikelyConnected(nodeID)
if isConnected != true && isConnected != false {
errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
}
} else if routineID%3 == 1 {
case 1:
// This goroutine removes the node
notifier.RemoveNode(nodeID, updateChan)
} else {
default:
// This goroutine adds the node back
notifier.AddNode(nodeID, updateChan)
}

View File

@ -84,11 +84,8 @@ func NewAuthProviderOIDC(
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
Endpoint: oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf(
"%s/oidc/callback",
strings.TrimSuffix(serverURL, "/"),
),
Scopes: cfg.Scope,
RedirectURL: strings.TrimSuffix(serverURL, "/") + "/oidc/callback",
Scopes: cfg.Scope,
}
registrationCache := zcache.New[string, RegistrationInfo](
@ -131,7 +128,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
req *http.Request,
) {
vars := mux.Vars(req)
registrationIdStr, _ := vars["registration_id"]
registrationIdStr := vars["registration_id"]
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
@ -232,7 +229,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
}
oauth2Token, err := a.getOauth2Token(req.Context(), code, state)
if err != nil {
httpError(writer, err)
return
@ -364,6 +360,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node.
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
return
}
@ -402,6 +399,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
if err != nil {
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err))
}
return oauth2Token, err
}

View File

@ -2,9 +2,8 @@ package matcher
import (
"net/netip"
"strings"
"slices"
"strings"
"github.com/juanfont/headscale/hscontrol/util"
"go4.org/netipx"
@ -28,6 +27,7 @@ func (m Match) DebugString() string {
for _, prefix := range m.dests.Prefixes() {
sb.WriteString(" " + prefix.String() + "\n")
}
return sb.String()
}
@ -36,6 +36,7 @@ func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match {
for _, rule := range rules {
matches = append(matches, MatchFromFilterRule(rule))
}
return matches
}

View File

@ -4,7 +4,6 @@ import (
"net/netip"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"

View File

@ -5,7 +5,6 @@ import (
"slices"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/samber/lo"
@ -131,7 +130,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
// AutoApproveRoutes approves any route that can be autoapproved from
// the nodes perspective according to the given policy.
// It reports true if any routes were approved.
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes.
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
if pm == nil {
return false

View File

@ -7,9 +7,8 @@ import (
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
@ -1974,6 +1973,7 @@ func TestSSHPolicyRules(t *testing.T) {
}
}
}
func TestReduceRoutes(t *testing.T) {
type args struct {
node *types.Node

View File

@ -13,9 +13,7 @@ import (
"tailscale.com/types/views"
)
var (
ErrInvalidAction = errors.New("invalid action")
)
var ErrInvalidAction = errors.New("invalid action")
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients.
@ -52,7 +50,7 @@ func (pol *Policy) compileFilterRules(
var destPorts []tailcfg.NetPortRange
for _, dest := range acl.Destinations {
ips, err := dest.Alias.Resolve(pol, users, nodes)
ips, err := dest.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Err(err).Msgf("resolving destination ips")
}
@ -174,5 +172,6 @@ func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
for _, pref := range ips.Prefixes() {
out = append(out, pref.String())
}
return out
}

View File

@ -4,19 +4,17 @@ import (
"encoding/json"
"fmt"
"net/netip"
"slices"
"strings"
"sync"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"slices"
"github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/util/deephash"
"tailscale.com/types/views"
"tailscale.com/util/deephash"
)
type PolicyManager struct {
@ -166,6 +164,7 @@ func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.filter, pm.matchers
}
@ -178,6 +177,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.users = users
return pm.updateLocked()
}
@ -190,6 +190,7 @@ func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, erro
pm.mu.Lock()
defer pm.mu.Unlock()
pm.nodes = nodes
return pm.updateLocked()
}
@ -249,7 +250,6 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
// cannot just lookup in the prefix map and have to check
// if there is a "parent" prefix available.
for prefix, approveAddrs := range pm.autoApproveMap {
// Check if prefix is larger (so containing) and then overlaps
// the route to see if the node can approve a subset of an autoapprover
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {

View File

@ -1,10 +1,10 @@
package v2
import (
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require"
"gorm.io/gorm"

View File

@ -6,9 +6,9 @@ import (
"errors"
"fmt"
"net/netip"
"strings"
"slices"
"strconv"
"strings"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
@ -72,14 +72,14 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) {
// Check if it's the wildcard port range
if len(a.Ports) == 1 && a.Ports[0].First == 0 && a.Ports[0].Last == 65535 {
return json.Marshal(fmt.Sprintf("%s:*", alias))
return json.Marshal(alias + ":*")
}
// Otherwise, format as "alias:ports"
var ports []string
for _, port := range a.Ports {
if port.First == port.Last {
ports = append(ports, fmt.Sprintf("%d", port.First))
ports = append(ports, strconv.FormatUint(uint64(port.First), 10))
} else {
ports = append(ports, fmt.Sprintf("%d-%d", port.First, port.Last))
}
@ -133,6 +133,7 @@ func (u *Username) UnmarshalJSON(b []byte) error {
if err := u.Validate(); err != nil {
return err
}
return nil
}
@ -203,7 +204,7 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.
return buildIPSetMultiErr(&ips, errs)
}
// Group is a special string which is always prefixed with `group:`
// Group is a special string which is always prefixed with `group:`.
type Group string
func (g Group) Validate() error {
@ -218,6 +219,7 @@ func (g *Group) UnmarshalJSON(b []byte) error {
if err := g.Validate(); err != nil {
return err
}
return nil
}
@ -264,7 +266,7 @@ func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.Nod
return buildIPSetMultiErr(&ips, errs)
}
// Tag is a special string which is always prefixed with `tag:`
// Tag is a special string which is always prefixed with `tag:`.
type Tag string
func (t Tag) Validate() error {
@ -279,6 +281,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error {
if err := t.Validate(); err != nil {
return err
}
return nil
}
@ -347,6 +350,7 @@ func (h *Host) UnmarshalJSON(b []byte) error {
if err := h.Validate(); err != nil {
return err
}
return nil
}
@ -409,6 +413,7 @@ func (p *Prefix) parseString(addr string) error {
}
*p = Prefix(addrPref)
return nil
}
@ -417,6 +422,7 @@ func (p *Prefix) parseString(addr string) error {
return err
}
*p = Prefix(pref)
return nil
}
@ -428,6 +434,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
if err := p.Validate(); err != nil {
return err
}
return nil
}
@ -462,7 +469,7 @@ func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuild
}
}
// AutoGroup is a special string which is always prefixed with `autogroup:`
// AutoGroup is a special string which is always prefixed with `autogroup:`.
type AutoGroup string
const (
@ -495,6 +502,7 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error {
if err := ag.Validate(); err != nil {
return err
}
return nil
}
@ -632,13 +640,14 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
if err != nil {
return err
}
if err := ve.Alias.Validate(); err != nil {
if err := ve.Validate(); err != nil {
return err
}
default:
return fmt.Errorf("type %T not supported", vs)
}
return nil
}
@ -713,6 +722,7 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
return err
}
ve.Alias = ptr
return nil
}
@ -729,6 +739,7 @@ func (a *Aliases) UnmarshalJSON(b []byte) error {
for i, alias := range aliases {
(*a)[i] = alias.Alias
}
return nil
}
@ -784,7 +795,7 @@ func buildIPSetMultiErr(ipBuilder *netipx.IPSetBuilder, errs []error) (*netipx.I
return ips, multierr.New(append(errs, err)...)
}
// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer
// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer.
func unmarshalPointer[T any](
b []byte,
parseFunc func(string) (T, error),
@ -818,6 +829,7 @@ func (aa *AutoApprovers) UnmarshalJSON(b []byte) error {
for i, autoApprover := range autoApprovers {
(*aa)[i] = autoApprover.AutoApprover
}
return nil
}
@ -874,6 +886,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error {
return err
}
ve.AutoApprover = ptr
return nil
}
@ -894,6 +907,7 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error {
return err
}
ve.Owner = ptr
return nil
}
@ -910,6 +924,7 @@ func (o *Owners) UnmarshalJSON(b []byte) error {
for i, owner := range owners {
(*o)[i] = owner.Owner
}
return nil
}
@ -941,6 +956,7 @@ func parseOwner(s string) (Owner, error) {
case isGroup(s):
return ptr.To(Group(s)), nil
}
return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types:
- user (containing an "@")
- group (starting with "group:")
@ -1001,6 +1017,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
(*g)[group] = usernames
}
return nil
}
@ -1252,7 +1269,7 @@ type Policy struct {
// We use the default JSON marshalling behavior provided by the Go runtime.
var (
// TODO(kradalby): Add these checks for tagOwners and autoApprovers
// TODO(kradalby): Add these checks for tagOwners and autoApprovers.
autogroupForSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged}
autogroupForDst = []AutoGroup{AutoGroupInternet, AutoGroupMember, AutoGroupTagged}
autogroupForSSHSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged}
@ -1279,7 +1296,7 @@ func validateAutogroupForSrc(src *AutoGroup) error {
}
if src.Is(AutoGroupInternet) {
return fmt.Errorf(`"autogroup:internet" used in source, it can only be used in ACL destinations`)
return errors.New(`"autogroup:internet" used in source, it can only be used in ACL destinations`)
}
if !slices.Contains(autogroupForSrc, *src) {
@ -1307,7 +1324,7 @@ func validateAutogroupForSSHSrc(src *AutoGroup) error {
}
if src.Is(AutoGroupInternet) {
return fmt.Errorf(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`)
return errors.New(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`)
}
if !slices.Contains(autogroupForSSHSrc, *src) {
@ -1323,7 +1340,7 @@ func validateAutogroupForSSHDst(dst *AutoGroup) error {
}
if dst.Is(AutoGroupInternet) {
return fmt.Errorf(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`)
return errors.New(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`)
}
if !slices.Contains(autogroupForSSHDst, *dst) {
@ -1360,14 +1377,14 @@ func (p *Policy) validate() error {
for _, acl := range p.ACLs {
for _, src := range acl.Sources {
switch src.(type) {
switch src := src.(type) {
case *Host:
h := src.(*Host)
h := src
if !p.Hosts.exist(*h) {
errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h))
}
case *AutoGroup:
ag := src.(*AutoGroup)
ag := src
if err := validateAutogroupSupported(ag); err != nil {
errs = append(errs, err)
@ -1379,12 +1396,12 @@ func (p *Policy) validate() error {
continue
}
case *Group:
g := src.(*Group)
g := src
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
tagOwner := src.(*Tag)
tagOwner := src
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
@ -1440,9 +1457,9 @@ func (p *Policy) validate() error {
}
for _, src := range ssh.Sources {
switch src.(type) {
switch src := src.(type) {
case *AutoGroup:
ag := src.(*AutoGroup)
ag := src
if err := validateAutogroupSupported(ag); err != nil {
errs = append(errs, err)
@ -1454,21 +1471,21 @@ func (p *Policy) validate() error {
continue
}
case *Group:
g := src.(*Group)
g := src
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
tagOwner := src.(*Tag)
tagOwner := src
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
}
}
for _, dst := range ssh.Destinations {
switch dst.(type) {
switch dst := dst.(type) {
case *AutoGroup:
ag := dst.(*AutoGroup)
ag := dst
if err := validateAutogroupSupported(ag); err != nil {
errs = append(errs, err)
continue
@ -1479,7 +1496,7 @@ func (p *Policy) validate() error {
continue
}
case *Tag:
tagOwner := dst.(*Tag)
tagOwner := dst
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
@ -1489,9 +1506,9 @@ func (p *Policy) validate() error {
for _, tagOwners := range p.TagOwners {
for _, tagOwner := range tagOwners {
switch tagOwner.(type) {
switch tagOwner := tagOwner.(type) {
case *Group:
g := tagOwner.(*Group)
g := tagOwner
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
@ -1501,14 +1518,14 @@ func (p *Policy) validate() error {
for _, approvers := range p.AutoApprovers.Routes {
for _, approver := range approvers {
switch approver.(type) {
switch approver := approver.(type) {
case *Group:
g := approver.(*Group)
g := approver
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
tagOwner := approver.(*Tag)
tagOwner := approver
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
@ -1517,14 +1534,14 @@ func (p *Policy) validate() error {
}
for _, approver := range p.AutoApprovers.ExitNode {
switch approver.(type) {
switch approver := approver.(type) {
case *Group:
g := approver.(*Group)
g := approver
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
tagOwner := approver.(*Tag)
tagOwner := approver
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
@ -1536,6 +1553,7 @@ func (p *Policy) validate() error {
}
p.validated = true
return nil
}
@ -1589,6 +1607,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
)
}
}
return nil
}
@ -1618,6 +1637,7 @@ func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
)
}
}
return nil
}

View File

@ -5,13 +5,13 @@ import (
"net/netip"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/prometheus/common/model"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go4.org/netipx"
@ -68,7 +68,7 @@ func TestMarshalJSON(t *testing.T) {
// Marshal the policy to JSON
marshalled, err := json.MarshalIndent(policy, "", " ")
require.NoError(t, err)
// Make sure all expected fields are present in the JSON
jsonString := string(marshalled)
assert.Contains(t, jsonString, "group:example")
@ -79,21 +79,21 @@ func TestMarshalJSON(t *testing.T) {
assert.Contains(t, jsonString, "accept")
assert.Contains(t, jsonString, "tcp")
assert.Contains(t, jsonString, "80")
// Unmarshal back to verify round trip
var roundTripped Policy
err = json.Unmarshal(marshalled, &roundTripped)
require.NoError(t, err)
// Compare the original and round-tripped policies
cmps := append(util.Comparers,
cmps := append(util.Comparers,
cmp.Comparer(func(x, y Prefix) bool {
return x == y
}),
cmpopts.IgnoreUnexported(Policy{}),
cmpopts.EquateEmpty(),
)
if diff := cmp.Diff(policy, &roundTripped, cmps...); diff != "" {
t.Fatalf("round trip policy (-original +roundtripped):\n%s", diff)
}
@ -958,13 +958,13 @@ func TestUnmarshalPolicy(t *testing.T) {
},
}
cmps := append(util.Comparers,
cmps := append(util.Comparers,
cmp.Comparer(func(x, y Prefix) bool {
return x == y
}),
cmpopts.IgnoreUnexported(Policy{}),
)
// For round-trip testing, we'll normalize the policies before comparing
for _, tt := range tests {
@ -981,6 +981,7 @@ func TestUnmarshalPolicy(t *testing.T) {
} else if !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("unmarshalling: got err %v; want error %q", err, tt.wantErr)
}
return // Skip the rest of the test if we expected an error
}
@ -1001,9 +1002,9 @@ func TestUnmarshalPolicy(t *testing.T) {
if err != nil {
t.Fatalf("round-trip unmarshalling: %v", err)
}
// Add EquateEmpty to handle nil vs empty maps/slices
roundTripCmps := append(cmps,
roundTripCmps := append(cmps,
cmpopts.EquateEmpty(),
cmpopts.IgnoreUnexported(Policy{}),
)
@ -1584,6 +1585,7 @@ func mustIPSet(prefixes ...string) *netipx.IPSet {
builder.AddPrefix(mp(p))
}
ipSet, _ := builder.IPSet()
return ipSet
}

View File

@ -73,10 +73,10 @@ func TestParsePortRange(t *testing.T) {
expected []tailcfg.PortRange
err string
}{
{"80", []tailcfg.PortRange{{80, 80}}, ""},
{"80-90", []tailcfg.PortRange{{80, 90}}, ""},
{"80,90", []tailcfg.PortRange{{80, 80}, {90, 90}}, ""},
{"80-91,92,93-95", []tailcfg.PortRange{{80, 91}, {92, 92}, {93, 95}}, ""},
{"80", []tailcfg.PortRange{{First: 80, Last: 80}}, ""},
{"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, ""},
{"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, ""},
{"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, ""},
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""},
{"80-", nil, "invalid port range format"},
{"-90", nil, "invalid port range format"},

View File

@ -158,6 +158,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix {
}
tsaddr.SortPrefixes(routes)
return routes
}

View File

@ -429,6 +429,7 @@ func (s *State) GetNodeViewByID(nodeID types.NodeID) (types.NodeView, error) {
if err != nil {
return types.NodeView{}, err
}
return node.View(), nil
}
@ -443,6 +444,7 @@ func (s *State) GetNodeViewByNodeKey(nodeKey key.NodePublic) (types.NodeView, er
if err != nil {
return types.NodeView{}, err
}
return node.View(), nil
}
@ -701,7 +703,7 @@ func (s *State) HandleNodeFromPreAuthKey(
if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) {
nodeToRegister.Expiry = &regReq.Expiry
} else if !regReq.Expiry.IsZero() {
// If client is sending an expired time (e.g., after logout),
// If client is sending an expired time (e.g., after logout),
// don't set expiry so the node won't be considered expired
log.Debug().
Time("requested_expiry", regReq.Expiry).

View File

@ -2,6 +2,7 @@ package hscontrol
import (
"context"
"errors"
"fmt"
"net/http"
"os"
@ -70,7 +71,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
// When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443.
certDomains := tsNode.CertDomains()
if len(certDomains) == 0 {
return fmt.Errorf("no cert domains available for HTTPS")
return errors.New("no cert domains available for HTTPS")
}
base := "https://" + certDomains[0]
go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -95,5 +96,6 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
logf("TailSQL started")
<-ctx.Done()
logf("TailSQL shutting down...")
return tsNode.Close()
}

View File

@ -62,7 +62,7 @@ func Apple(url string) *elem.Element {
),
elem.Pre(nil,
elem.Code(nil,
elem.Text(fmt.Sprintf("tailscale login --login-server %s", url)),
elem.Text("tailscale login --login-server "+url),
),
),
headerTwo("GUI"),
@ -143,10 +143,7 @@ func Apple(url string) *elem.Element {
elem.Code(
nil,
elem.Text(
fmt.Sprintf(
`defaults write io.tailscale.ipn.macos ControlURL %s`,
url,
),
"defaults write io.tailscale.ipn.macos ControlURL "+url,
),
),
),
@ -155,10 +152,7 @@ func Apple(url string) *elem.Element {
elem.Code(
nil,
elem.Text(
fmt.Sprintf(
`defaults write io.tailscale.ipn.macsys ControlURL %s`,
url,
),
"defaults write io.tailscale.ipn.macsys ControlURL "+url,
),
),
),

View File

@ -1,8 +1,6 @@
package templates
import (
"fmt"
"github.com/chasefleming/elem-go"
"github.com/chasefleming/elem-go/attrs"
)
@ -31,7 +29,7 @@ func Windows(url string) *elem.Element {
),
elem.Pre(nil,
elem.Code(nil,
elem.Text(fmt.Sprintf(`tailscale login --login-server %s`, url)),
elem.Text("tailscale login --login-server "+url),
),
),
),

View File

@ -180,6 +180,7 @@ func MustRegistrationID() RegistrationID {
if err != nil {
panic(err)
}
return rid
}

View File

@ -339,6 +339,7 @@ func LoadConfig(path string, isFile bool) error {
log.Warn().Msg("No config file found, using defaults")
return nil
}
return fmt.Errorf("fatal error reading config file: %w", err)
}
@ -843,7 +844,7 @@ func LoadServerConfig() (*Config, error) {
}
if prefix4 == nil && prefix6 == nil {
return nil, fmt.Errorf("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
return nil, errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
}
allocStr := viper.GetString("prefixes.allocation")
@ -1020,7 +1021,7 @@ func isSafeServerURL(serverURL, baseDomain string) error {
s := len(serverDomainParts)
b := len(baseDomainParts)
for i := range len(baseDomainParts) {
for i := range baseDomainParts {
if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] {
return nil
}

View File

@ -282,6 +282,7 @@ func TestReadConfigFromEnv(t *testing.T) {
assert.Equal(t, "trace", viper.GetString("log.level"))
assert.Equal(t, "100.64.0.0/10", viper.GetString("prefixes.v4"))
assert.False(t, viper.GetBool("database.sqlite.write_ahead_log"))
return nil, nil
},
want: nil,

View File

@ -28,8 +28,10 @@ var (
ErrNodeUserHasNoName = errors.New("node user has no name")
)
type NodeID uint64
type NodeIDs []NodeID
type (
NodeID uint64
NodeIDs []NodeID
)
func (n NodeIDs) Len() int { return len(n) }
func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] }
@ -169,6 +171,7 @@ func (node *Node) HasIP(i netip.Addr) bool {
return true
}
}
return false
}
@ -176,7 +179,7 @@ func (node *Node) HasIP(i netip.Addr) bool {
// and therefore should not be treated as a
// user owned device.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys)
// via CLI ("forced tags" and preauthkeys).
func (node *Node) IsTagged() bool {
if len(node.ForcedTags) > 0 {
return true
@ -199,7 +202,7 @@ func (node *Node) IsTagged() bool {
// HasTag reports if a node has a given tag.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys)
// via CLI ("forced tags" and preauthkeys).
func (node *Node) HasTag(tag string) bool {
return slices.Contains(node.Tags(), tag)
}
@ -577,6 +580,7 @@ func (nodes Nodes) DebugString() string {
sb.WriteString(node.DebugString())
sb.WriteString("\n")
}
return sb.String()
}
@ -590,6 +594,7 @@ func (node Node) DebugString() string {
fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes())
fmt.Fprintf(&sb, "\tSubnetRoutes: %v\n", node.SubnetRoutes())
sb.WriteString("\n")
return sb.String()
}
@ -689,7 +694,7 @@ func (v NodeView) Tags() []string {
// and therefore should not be treated as a
// user owned device.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys)
// via CLI ("forced tags" and preauthkeys).
func (v NodeView) IsTagged() bool {
if !v.Valid() {
return false
@ -727,7 +732,7 @@ func (v NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerC
// GetFQDN returns the fully qualified domain name for the node.
func (v NodeView) GetFQDN(baseDomain string) (string, error) {
if !v.Valid() {
return "", fmt.Errorf("failed to create valid FQDN: node view is invalid")
return "", errors.New("failed to create valid FQDN: node view is invalid")
}
return v.ж.GetFQDN(baseDomain)
}
@ -773,4 +778,3 @@ func (v NodeView) IPsAsString() []string {
}
return v.ж.IPsAsString()
}

View File

@ -2,7 +2,6 @@ package types
import (
"fmt"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"net/netip"
"strings"
"testing"
@ -10,6 +9,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
"tailscale.com/types/key"

View File

@ -11,7 +11,7 @@ import (
type PAKError string
func (e PAKError) Error() string { return string(e) }
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %s", e) }
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %w", e) }
// PreAuthKey describes a pre-authorization key usable in a particular user.
type PreAuthKey struct {

View File

@ -1,6 +1,7 @@
package types
import (
"errors"
"testing"
"time"
@ -109,7 +110,8 @@ func TestCanUsePreAuthKey(t *testing.T) {
if err == nil {
t.Errorf("expected error but got none")
} else {
httpErr, ok := err.(PAKError)
var httpErr PAKError
ok := errors.As(err, &httpErr)
if !ok {
t.Errorf("expected HTTPError but got %T", err)
} else {

View File

@ -249,7 +249,7 @@ func (c *OIDCClaims) Identifier() string {
// - Remove empty path segments
// - For non-URL identifiers, it joins non-empty segments with a single slash
// - Returns empty string for identifiers with only slashes
// - Normalize URL schemes to lowercase
// - Normalize URL schemes to lowercase.
func CleanIdentifier(identifier string) string {
if identifier == "" {
return identifier
@ -273,7 +273,7 @@ func CleanIdentifier(identifier string) string {
cleanParts = append(cleanParts, part)
}
}
if len(cleanParts) == 0 {
u.Path = ""
} else {
@ -281,6 +281,7 @@ func CleanIdentifier(identifier string) string {
}
// Ensure scheme is lowercase
u.Scheme = strings.ToLower(u.Scheme)
return u.String()
}
@ -297,6 +298,7 @@ func CleanIdentifier(identifier string) string {
if len(cleanParts) == 0 {
return ""
}
return strings.Join(cleanParts, "/")
}

View File

@ -1,4 +1,6 @@
package types
var Version = "dev"
var GitCommitHash = "dev"
var (
Version = "dev"
GitCommitHash = "dev"
)

View File

@ -5,6 +5,7 @@ import (
"fmt"
"net/netip"
"regexp"
"strconv"
"strings"
"unicode"
@ -21,8 +22,10 @@ const (
LabelHostnameLength = 63
)
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
var (
invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
)
var ErrInvalidUserName = errors.New("invalid user name")
@ -141,7 +144,7 @@ func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
// here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.)
rdnsSlice := []string{}
for i := lastOctet - 1; i >= 0; i-- {
rdnsSlice = append(rdnsSlice, fmt.Sprintf("%d", netRange.IP[i]))
rdnsSlice = append(rdnsSlice, strconv.FormatUint(uint64(netRange.IP[i]), 10))
}
rdnsSlice = append(rdnsSlice, "in-addr.arpa.")
rdnsBase := strings.Join(rdnsSlice, ".")
@ -205,7 +208,7 @@ func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
makeDomain := func(variablePrefix ...string) (dnsname.FQDN, error) {
prefix := strings.Join(append(variablePrefix, prefixConstantParts...), ".")
return dnsname.ToFQDN(fmt.Sprintf("%s.ip6.arpa", prefix))
return dnsname.ToFQDN(prefix + ".ip6.arpa")
}
var fqdns []dnsname.FQDN

View File

@ -70,7 +70,7 @@ func (l *DBLogWrapper) Trace(ctx context.Context, begin time.Time, fc func() (sq
"rowsAffected": rowsAffected,
}
if err != nil && !(errors.Is(err, gorm.ErrRecordNotFound) && l.SkipErrRecordNotFound) {
if err != nil && (!errors.Is(err, gorm.ErrRecordNotFound) || !l.SkipErrRecordNotFound) {
l.Logger.Error().Err(err).Fields(fields).Msgf("")
return
}

View File

@ -58,5 +58,6 @@ var TheInternet = sync.OnceValue(func() *netipx.IPSet {
internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16"))
theInternetSet, _ := internetBuilder.IPSet()
return theInternetSet
})

View File

@ -53,37 +53,37 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) {
}
type TraceroutePath struct {
// Hop is the current jump in the total traceroute.
Hop int
// Hop is the current jump in the total traceroute.
Hop int
// Hostname is the resolved hostname or IP address identifying the jump
Hostname string
// Hostname is the resolved hostname or IP address identifying the jump
Hostname string
// IP is the IP address of the jump
IP netip.Addr
// IP is the IP address of the jump
IP netip.Addr
// Latencies is a list of the latencies for this jump
Latencies []time.Duration
// Latencies is a list of the latencies for this jump
Latencies []time.Duration
}
type Traceroute struct {
// Hostname is the resolved hostname or IP address identifying the target
Hostname string
// Hostname is the resolved hostname or IP address identifying the target
Hostname string
// IP is the IP address of the target
IP netip.Addr
// IP is the IP address of the target
IP netip.Addr
// Route is the path taken to reach the target if successful. The list is ordered by the path taken.
Route []TraceroutePath
// Route is the path taken to reach the target if successful. The list is ordered by the path taken.
Route []TraceroutePath
// Success indicates if the traceroute was successful.
Success bool
// Success indicates if the traceroute was successful.
Success bool
// Err contains an error if the traceroute was not successful.
Err error
// Err contains an error if the traceroute was not successful.
Err error
}
// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct
// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct.
func ParseTraceroute(output string) (Traceroute, error) {
lines := strings.Split(strings.TrimSpace(output), "\n")
if len(lines) < 1 {
@ -112,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
}
// Parse each hop line
hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`)
hopRegex := regexp.MustCompile("^\\s*(\\d+)\\s+(?:([^ ]+) \\(([^)]+)\\)|(\\*))(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?")
for i := 1; i < len(lines); i++ {
matches := hopRegex.FindStringSubmatch(lines[i])

View File

@ -1077,7 +1077,6 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: 1,
@ -1213,7 +1212,6 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
func TestACLAutogroupMember(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario := aclScenario(t,
&policyv2.Policy{
@ -1271,7 +1269,6 @@ func TestACLAutogroupMember(t *testing.T) {
func TestACLAutogroupTagged(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario := aclScenario(t,
&policyv2.Policy{

View File

@ -3,12 +3,11 @@ package integration
import (
"fmt"
"net/netip"
"slices"
"strconv"
"testing"
"time"
"slices"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
@ -19,7 +18,6 @@ import (
func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
for _, https := range []bool{true, false} {
t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) {
@ -66,7 +64,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
assertNoErrGetHeadscale(t, err)
listNodes, err := headscale.ListNodes()
assert.Equal(t, len(listNodes), len(allClients))
assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
@ -87,7 +85,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
t.Logf("all clients logged out")
listNodes, err = headscale.ListNodes()
require.Equal(t, nodeCountBeforeLogout, len(listNodes))
require.Len(t, listNodes, nodeCountBeforeLogout)
for _, node := range listNodes {
assertLastSeenSet(t, node)
@ -99,26 +97,48 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
// https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38
// https://github.com/juanfont/headscale/issues/2164
if !https {
time.Sleep(5 * time.Minute)
}
userMap, err := headscale.MapUsers()
assertNoErr(t, err)
userMap, err := headscale.MapUsers()
assertNoErr(t, err)
for _, userName := range spec.Users {
key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false)
if err != nil {
t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err)
// Create auth keys once outside the retry loop
userKeys := make(map[string]string)
for _, userName := range spec.Users {
key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false)
assertNoErr(t, err)
userKeys[userName] = key.GetKey()
}
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey())
if err != nil {
t.Fatalf("failed to run tailscale up for user %s: %s", userName, err)
// Wait for the 2-minute noise dial memory to expire
// The Tailscale commit shows clients remember noise dials for 2 minutes
t.Logf("Waiting 2.5 minutes for Tailscale noise dial memory to expire...")
time.Sleep(2*time.Minute + 30*time.Second)
// Wait for clients to be ready to reconnect over HTTP after HTTPS
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
for _, userName := range spec.Users {
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), userKeys[userName])
assert.NoError(ct, err, "Client should be able to reconnect over HTTP")
}
}, 6*time.Minute, 30*time.Second)
} else {
userMap, err := headscale.MapUsers()
assertNoErr(t, err)
for _, userName := range spec.Users {
key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false)
if err != nil {
t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err)
}
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey())
if err != nil {
t.Fatalf("failed to run tailscale up for user %s: %s", userName, err)
}
}
}
listNodes, err = headscale.ListNodes()
require.Equal(t, nodeCountBeforeLogout, len(listNodes))
require.Len(t, listNodes, nodeCountBeforeLogout)
for _, node := range listNodes {
assertLastSeenSet(t, node)
@ -155,18 +175,17 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
}
listNodes, err = headscale.ListNodes()
require.Equal(t, nodeCountBeforeLogout, len(listNodes))
require.Len(t, listNodes, nodeCountBeforeLogout)
for _, node := range listNodes {
assertLastSeenSet(t, node)
}
})
}
}
func assertLastSeenSet(t *testing.T, node *v1.Node) {
assert.NotNil(t, node)
assert.NotNil(t, node.LastSeen)
assert.NotNil(t, node.GetLastSeen())
}
// This test will first log in two sets of nodes to two sets of users, then
@ -175,7 +194,6 @@ func assertLastSeenSet(t *testing.T, node *v1.Node) {
// still has nodes, but they are not connected.
func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@ -204,7 +222,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
assertNoErrGetHeadscale(t, err)
listNodes, err := headscale.ListNodes()
assert.Equal(t, len(listNodes), len(allClients))
assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
@ -259,7 +277,6 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
for _, https := range []bool{true, false} {
t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) {
@ -303,7 +320,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
assertNoErrGetHeadscale(t, err)
listNodes, err := headscale.ListNodes()
assert.Equal(t, len(listNodes), len(allClients))
assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
@ -325,32 +342,62 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
// https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38
// https://github.com/juanfont/headscale/issues/2164
if !https {
time.Sleep(5 * time.Minute)
}
userMap, err := headscale.MapUsers()
assertNoErr(t, err)
for _, userName := range spec.Users {
key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false)
if err != nil {
t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err)
}
// Expire the key so it can't be used
_, err = headscale.Execute(
[]string{
"headscale",
"preauthkeys",
"--user",
strconv.FormatUint(userMap[userName].GetId(), 10),
"expire",
key.Key,
})
userMap, err := headscale.MapUsers()
assertNoErr(t, err)
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey())
assert.ErrorContains(t, err, "authkey expired")
// Create and expire auth keys once outside the retry loop
userExpiredKeys := make(map[string]string)
for _, userName := range spec.Users {
key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false)
assertNoErr(t, err)
// Expire the key so it can't be used
_, err = headscale.Execute(
[]string{
"headscale",
"preauthkeys",
"--user",
strconv.FormatUint(userMap[userName].GetId(), 10),
"expire",
key.GetKey(),
})
assertNoErr(t, err)
userExpiredKeys[userName] = key.GetKey()
}
// Wait for clients to be ready to reconnect over HTTP after HTTPS
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
for _, userName := range spec.Users {
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), userExpiredKeys[userName])
assert.Error(ct, err, "Should get error when using expired key")
assert.Contains(ct, err.Error(), "authkey expired")
}
}, 6*time.Minute, 30*time.Second)
} else {
userMap, err := headscale.MapUsers()
assertNoErr(t, err)
for _, userName := range spec.Users {
key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false)
if err != nil {
t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err)
}
// Expire the key so it can't be used
_, err = headscale.Execute(
[]string{
"headscale",
"preauthkeys",
"--user",
strconv.FormatUint(userMap[userName].GetId(), 10),
"expire",
key.GetKey(),
})
assertNoErr(t, err)
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey())
assert.ErrorContains(t, err, "authkey expired")
}
}
})
}

View File

@ -1,14 +1,12 @@
package integration
import (
"fmt"
"maps"
"net/netip"
"sort"
"testing"
"time"
"maps"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -21,7 +19,6 @@ import (
func TestOIDCAuthenticationPingAll(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
// Logins to MockOIDC is served by a queue with a strict order,
// if we use more than one node per user, the order of the logins
@ -119,7 +116,6 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
// This test is really flaky.
func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
shortAccessTTL := 5 * time.Minute
@ -174,9 +170,13 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
// of safety reasons) before checking if the clients have logged out.
// The Wait function can't do it itself as it has an upper bound of 1
// min.
time.Sleep(shortAccessTTL + 10*time.Second)
assertTailscaleNodesLogout(t, allClients)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
for _, client := range allClients {
status, err := client.Status()
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}
}, shortAccessTTL+10*time.Second, 5*time.Second)
}
func TestOIDC024UserCreation(t *testing.T) {
@ -295,9 +295,7 @@ func TestOIDC024UserCreation(t *testing.T) {
spec := ScenarioSpec{
NodesPerUser: 1,
}
for _, user := range tt.cliUsers {
spec.Users = append(spec.Users, user)
}
spec.Users = append(spec.Users, tt.cliUsers...)
for _, user := range tt.oidcUsers {
spec.OIDCUsers = append(spec.OIDCUsers, oidcMockUser(user, tt.emailVerified))
@ -350,7 +348,6 @@ func TestOIDC024UserCreation(t *testing.T) {
func TestOIDCAuthenticationWithPKCE(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
// Single user with one node for testing PKCE flow
spec := ScenarioSpec{
@ -402,7 +399,6 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
func TestOIDCReloginSameNodeNewUser(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
// Create no nodes and no users
scenario, err := NewScenario(ScenarioSpec{
@ -440,7 +436,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
listUsers, err := headscale.ListUsers()
assertNoErr(t, err)
assert.Len(t, listUsers, 0)
assert.Empty(t, listUsers)
ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]))
assertNoErr(t, err)
@ -482,7 +478,13 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
err = ts.Logout()
assertNoErr(t, err)
time.Sleep(5 * time.Second)
// Wait for logout to complete and then do second logout
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Check that the first logout completed
status, err := ts.Status()
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}, 5*time.Second, 1*time.Second)
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
// logs in immediately after the first logout and I cannot reproduce it
@ -530,16 +532,22 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
// Machine key is the same as the "machine" has not changed,
// but Node key is not as it is a new node
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey)
assert.Equal(t, listNodesAfterNewUserLogin[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey)
assert.NotEqual(t, listNodesAfterNewUserLogin[0].NodeKey, listNodesAfterNewUserLogin[1].NodeKey)
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
assert.Equal(t, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
assert.NotEqual(t, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey())
// Log out user2, and log into user1, no new node should be created,
// the node should now "become" node1 again
err = ts.Logout()
assertNoErr(t, err)
time.Sleep(5 * time.Second)
// Wait for logout to complete and then do second logout
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Check that the first logout completed
status, err := ts.Status()
assert.NoError(ct, err)
assert.Equal(ct, "NeedsLogin", status.BackendState)
}, 5*time.Second, 1*time.Second)
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
// logs in immediately after the first logout and I cannot reproduce it
@ -588,24 +596,24 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
// Validate that the machine we had when we logged in the first time, has the same
// machine key, but a different ID than the newly logged in version of the same
// machine.
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey)
assert.Equal(t, listNodes[0].NodeKey, listNodesAfterNewUserLogin[0].NodeKey)
assert.Equal(t, listNodes[0].Id, listNodesAfterNewUserLogin[0].Id)
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey)
assert.NotEqual(t, listNodes[0].Id, listNodesAfterNewUserLogin[1].Id)
assert.NotEqual(t, listNodes[0].User.Id, listNodesAfterNewUserLogin[1].User.Id)
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
assert.Equal(t, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey())
assert.Equal(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId())
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
assert.NotEqual(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId())
assert.NotEqual(t, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId())
// Even tho we are logging in again with the same user, the previous key has been expired
// and a new one has been generated. The node entry in the database should be the same
// as the user + machinekey still matches.
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterLoggingBackIn[0].MachineKey)
assert.NotEqual(t, listNodes[0].NodeKey, listNodesAfterLoggingBackIn[0].NodeKey)
assert.Equal(t, listNodes[0].Id, listNodesAfterLoggingBackIn[0].Id)
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey())
assert.NotEqual(t, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey())
assert.Equal(t, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId())
// The "logged back in" machine should have the same machinekey but a different nodekey
// than the version logged in with a different user.
assert.Equal(t, listNodesAfterLoggingBackIn[0].MachineKey, listNodesAfterLoggingBackIn[1].MachineKey)
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].NodeKey, listNodesAfterLoggingBackIn[1].NodeKey)
assert.Equal(t, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey())
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey())
}
func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) {
@ -623,7 +631,7 @@ func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser {
return mockoidc.MockUser{
Subject: username,
PreferredUsername: username,
Email: fmt.Sprintf("%s@headscale.net", username),
Email: username + "@headscale.net",
EmailVerified: emailVerified,
}
}

View File

@ -2,9 +2,8 @@ package integration
import (
"net/netip"
"testing"
"slices"
"testing"
"github.com/juanfont/headscale/integration/hsic"
"github.com/samber/lo"
@ -55,7 +54,6 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@ -95,7 +93,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
assertNoErrGetHeadscale(t, err)
listNodes, err := headscale.ListNodes()
assert.Equal(t, len(listNodes), len(allClients))
assert.Len(t, allClients, len(listNodes))
nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
@ -140,7 +138,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
listNodes, err = headscale.ListNodes()
require.Equal(t, nodeCountBeforeLogout, len(listNodes))
require.Len(t, listNodes, nodeCountBeforeLogout)
t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes))
for _, client := range allClients {

View File

@ -18,8 +18,8 @@ import (
"github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
"golang.org/x/exp/slices"
"tailscale.com/tailcfg"
)
func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error {
@ -30,7 +30,7 @@ func executeAndUnmarshal[T any](headscale ControlServer, command []string, resul
err = json.Unmarshal([]byte(str), result)
if err != nil {
return fmt.Errorf("failed to unmarshal: %s\n command err: %s", err, str)
return fmt.Errorf("failed to unmarshal: %w\n command err: %s", err, str)
}
return nil
@ -48,7 +48,6 @@ func sortWithID[T GRPCSortable](a, b T) int {
func TestUserCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
Users: []string{"user1", "user2"},
@ -184,7 +183,7 @@ func TestUserCommand(t *testing.T) {
"--identifier=1",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Contains(t, deleteResult, "User destroyed")
var listAfterIDDelete []*v1.User
@ -222,7 +221,7 @@ func TestUserCommand(t *testing.T) {
"--name=newname",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Contains(t, deleteResult, "User destroyed")
var listAfterNameDelete []v1.User
@ -238,12 +237,11 @@ func TestUserCommand(t *testing.T) {
)
assertNoErr(t, err)
require.Len(t, listAfterNameDelete, 0)
require.Empty(t, listAfterNameDelete)
}
func TestPreAuthKeyCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "preauthkeyspace"
count := 3
@ -347,7 +345,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
continue
}
assert.Equal(t, listedPreAuthKeys[index].GetAclTags(), []string{"tag:test1", "tag:test2"})
assert.Equal(t, []string{"tag:test1", "tag:test2"}, listedPreAuthKeys[index].GetAclTags())
}
// Test key expiry
@ -386,7 +384,6 @@ func TestPreAuthKeyCommand(t *testing.T) {
func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "pre-auth-key-without-exp-user"
spec := ScenarioSpec{
@ -448,7 +445,6 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "pre-auth-key-reus-ephm-user"
spec := ScenarioSpec{
@ -524,7 +520,6 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user1 := "user1"
user2 := "user2"
@ -575,7 +570,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
assertNoErr(t, err)
listNodes, err := headscale.ListNodes()
require.Nil(t, err)
require.NoError(t, err)
require.Len(t, listNodes, 1)
assert.Equal(t, user1, listNodes[0].GetUser().GetName())
@ -613,7 +608,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
}
listNodes, err = headscale.ListNodes()
require.Nil(t, err)
require.NoError(t, err)
require.Len(t, listNodes, 2)
assert.Equal(t, user1, listNodes[0].GetUser().GetName())
assert.Equal(t, user2, listNodes[1].GetUser().GetName())
@ -621,7 +616,6 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
func TestApiKeyCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
count := 5
@ -653,7 +647,7 @@ func TestApiKeyCommand(t *testing.T) {
"json",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.NotEmpty(t, apiResult)
keys[idx] = apiResult
@ -672,7 +666,7 @@ func TestApiKeyCommand(t *testing.T) {
},
&listedAPIKeys,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listedAPIKeys, 5)
@ -728,7 +722,7 @@ func TestApiKeyCommand(t *testing.T) {
listedAPIKeys[idx].GetPrefix(),
},
)
assert.Nil(t, err)
assert.NoError(t, err)
expiredPrefixes[listedAPIKeys[idx].GetPrefix()] = true
}
@ -744,7 +738,7 @@ func TestApiKeyCommand(t *testing.T) {
},
&listedAfterExpireAPIKeys,
)
assert.Nil(t, err)
assert.NoError(t, err)
for index := range listedAfterExpireAPIKeys {
if _, ok := expiredPrefixes[listedAfterExpireAPIKeys[index].GetPrefix()]; ok {
@ -770,7 +764,7 @@ func TestApiKeyCommand(t *testing.T) {
"--prefix",
listedAPIKeys[0].GetPrefix(),
})
assert.Nil(t, err)
assert.NoError(t, err)
var listedAPIKeysAfterDelete []v1.ApiKey
err = executeAndUnmarshal(headscale,
@ -783,14 +777,13 @@ func TestApiKeyCommand(t *testing.T) {
},
&listedAPIKeysAfterDelete,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listedAPIKeysAfterDelete, 4)
}
func TestNodeTagCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
Users: []string{"user1"},
@ -811,7 +804,7 @@ func TestNodeTagCommand(t *testing.T) {
types.MustRegistrationID().String(),
}
nodes := make([]*v1.Node, len(regIDs))
assert.Nil(t, err)
assert.NoError(t, err)
for index, regID := range regIDs {
_, err := headscale.Execute(
@ -829,7 +822,7 @@ func TestNodeTagCommand(t *testing.T) {
"json",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
var node v1.Node
err = executeAndUnmarshal(
@ -847,7 +840,7 @@ func TestNodeTagCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
nodes[index] = &node
}
@ -866,7 +859,7 @@ func TestNodeTagCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, []string{"tag:test"}, node.GetForcedTags())
@ -894,7 +887,7 @@ func TestNodeTagCommand(t *testing.T) {
},
&resultMachines,
)
assert.Nil(t, err)
assert.NoError(t, err)
found := false
for _, node := range resultMachines {
if node.GetForcedTags() != nil {
@ -905,19 +898,15 @@ func TestNodeTagCommand(t *testing.T) {
}
}
}
assert.Equal(
assert.True(
t,
true,
found,
"should find a node with the tag 'tag:test' in the list of nodes",
)
}
func TestNodeAdvertiseTagCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
tests := []struct {
name string
@ -1024,7 +1013,7 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
},
&resultMachines,
)
assert.Nil(t, err)
assert.NoError(t, err)
found := false
for _, node := range resultMachines {
if tags := node.GetValidTags(); tags != nil {
@ -1043,7 +1032,6 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
func TestNodeCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
Users: []string{"node-user", "other-user"},
@ -1067,7 +1055,7 @@ func TestNodeCommand(t *testing.T) {
types.MustRegistrationID().String(),
}
nodes := make([]*v1.Node, len(regIDs))
assert.Nil(t, err)
assert.NoError(t, err)
for index, regID := range regIDs {
_, err := headscale.Execute(
@ -1085,7 +1073,7 @@ func TestNodeCommand(t *testing.T) {
"json",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
var node v1.Node
err = executeAndUnmarshal(
@ -1103,7 +1091,7 @@ func TestNodeCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
nodes[index] = &node
}
@ -1123,7 +1111,7 @@ func TestNodeCommand(t *testing.T) {
},
&listAll,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listAll, 5)
@ -1144,7 +1132,7 @@ func TestNodeCommand(t *testing.T) {
types.MustRegistrationID().String(),
}
otherUserMachines := make([]*v1.Node, len(otherUserRegIDs))
assert.Nil(t, err)
assert.NoError(t, err)
for index, regID := range otherUserRegIDs {
_, err := headscale.Execute(
@ -1162,7 +1150,7 @@ func TestNodeCommand(t *testing.T) {
"json",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
var node v1.Node
err = executeAndUnmarshal(
@ -1180,7 +1168,7 @@ func TestNodeCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
otherUserMachines[index] = &node
}
@ -1200,7 +1188,7 @@ func TestNodeCommand(t *testing.T) {
},
&listAllWithotherUser,
)
assert.Nil(t, err)
assert.NoError(t, err)
// All nodes, nodes + otherUser
assert.Len(t, listAllWithotherUser, 7)
@ -1226,7 +1214,7 @@ func TestNodeCommand(t *testing.T) {
},
&listOnlyotherUserMachineUser,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listOnlyotherUserMachineUser, 2)
@ -1258,7 +1246,7 @@ func TestNodeCommand(t *testing.T) {
"--force",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
// Test: list main user after node is deleted
var listOnlyMachineUserAfterDelete []v1.Node
@ -1275,14 +1263,13 @@ func TestNodeCommand(t *testing.T) {
},
&listOnlyMachineUserAfterDelete,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listOnlyMachineUserAfterDelete, 4)
}
func TestNodeExpireCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
Users: []string{"node-expire-user"},
@ -1323,7 +1310,7 @@ func TestNodeExpireCommand(t *testing.T) {
"json",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
var node v1.Node
err = executeAndUnmarshal(
@ -1341,7 +1328,7 @@ func TestNodeExpireCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
nodes[index] = &node
}
@ -1360,7 +1347,7 @@ func TestNodeExpireCommand(t *testing.T) {
},
&listAll,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listAll, 5)
@ -1377,10 +1364,10 @@ func TestNodeExpireCommand(t *testing.T) {
"nodes",
"expire",
"--identifier",
fmt.Sprintf("%d", listAll[idx].GetId()),
strconv.FormatUint(listAll[idx].GetId(), 10),
},
)
assert.Nil(t, err)
assert.NoError(t, err)
}
var listAllAfterExpiry []v1.Node
@ -1395,7 +1382,7 @@ func TestNodeExpireCommand(t *testing.T) {
},
&listAllAfterExpiry,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listAllAfterExpiry, 5)
@ -1408,7 +1395,6 @@ func TestNodeExpireCommand(t *testing.T) {
func TestNodeRenameCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
Users: []string{"node-rename-command"},
@ -1432,7 +1418,7 @@ func TestNodeRenameCommand(t *testing.T) {
types.MustRegistrationID().String(),
}
nodes := make([]*v1.Node, len(regIDs))
assert.Nil(t, err)
assert.NoError(t, err)
for index, regID := range regIDs {
_, err := headscale.Execute(
@ -1487,7 +1473,7 @@ func TestNodeRenameCommand(t *testing.T) {
},
&listAll,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listAll, 5)
@ -1504,11 +1490,11 @@ func TestNodeRenameCommand(t *testing.T) {
"nodes",
"rename",
"--identifier",
fmt.Sprintf("%d", listAll[idx].GetId()),
strconv.FormatUint(listAll[idx].GetId(), 10),
fmt.Sprintf("newnode-%d", idx+1),
},
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Contains(t, res, "Node renamed")
}
@ -1525,7 +1511,7 @@ func TestNodeRenameCommand(t *testing.T) {
},
&listAllAfterRename,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listAllAfterRename, 5)
@ -1542,7 +1528,7 @@ func TestNodeRenameCommand(t *testing.T) {
"nodes",
"rename",
"--identifier",
fmt.Sprintf("%d", listAll[4].GetId()),
strconv.FormatUint(listAll[4].GetId(), 10),
strings.Repeat("t", 64),
},
)
@ -1560,7 +1546,7 @@ func TestNodeRenameCommand(t *testing.T) {
},
&listAllAfterRenameAttempt,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, listAllAfterRenameAttempt, 5)
@ -1573,7 +1559,6 @@ func TestNodeRenameCommand(t *testing.T) {
func TestNodeMoveCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
Users: []string{"old-user", "new-user"},
@ -1610,7 +1595,7 @@ func TestNodeMoveCommand(t *testing.T) {
"json",
},
)
assert.Nil(t, err)
assert.NoError(t, err)
var node v1.Node
err = executeAndUnmarshal(
@ -1628,13 +1613,13 @@ func TestNodeMoveCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, uint64(1), node.GetId())
assert.Equal(t, "nomad-node", node.GetName())
assert.Equal(t, node.GetUser().GetName(), "old-user")
assert.Equal(t, "old-user", node.GetUser().GetName())
nodeID := fmt.Sprintf("%d", node.GetId())
nodeID := strconv.FormatUint(node.GetId(), 10)
err = executeAndUnmarshal(
headscale,
@ -1651,9 +1636,9 @@ func TestNodeMoveCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, node.GetUser().GetName(), "new-user")
assert.Equal(t, "new-user", node.GetUser().GetName())
var allNodes []v1.Node
err = executeAndUnmarshal(
@ -1667,13 +1652,13 @@ func TestNodeMoveCommand(t *testing.T) {
},
&allNodes,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Len(t, allNodes, 1)
assert.Equal(t, allNodes[0].GetId(), node.GetId())
assert.Equal(t, allNodes[0].GetUser(), node.GetUser())
assert.Equal(t, allNodes[0].GetUser().GetName(), "new-user")
assert.Equal(t, "new-user", allNodes[0].GetUser().GetName())
_, err = headscale.Execute(
[]string{
@ -1693,7 +1678,7 @@ func TestNodeMoveCommand(t *testing.T) {
err,
"user not found",
)
assert.Equal(t, node.GetUser().GetName(), "new-user")
assert.Equal(t, "new-user", node.GetUser().GetName())
err = executeAndUnmarshal(
headscale,
@ -1710,9 +1695,9 @@ func TestNodeMoveCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, node.GetUser().GetName(), "old-user")
assert.Equal(t, "old-user", node.GetUser().GetName())
err = executeAndUnmarshal(
headscale,
@ -1729,14 +1714,13 @@ func TestNodeMoveCommand(t *testing.T) {
},
&node,
)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, node.GetUser().GetName(), "old-user")
assert.Equal(t, "old-user", node.GetUser().GetName())
}
func TestPolicyCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
Users: []string{"user1"},
@ -1817,7 +1801,6 @@ func TestPolicyCommand(t *testing.T) {
func TestPolicyBrokenConfigCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: 1,

View File

@ -1,7 +1,6 @@
package integration
import (
"context"
"fmt"
"net"
"strconv"
@ -104,7 +103,7 @@ func DERPVerify(
defer c.Close()
var result error
if err := c.Connect(context.Background()); err != nil {
if err := c.Connect(t.Context()); err != nil {
result = fmt.Errorf("client Connect: %w", err)
}
if m, err := c.Recv(); err != nil {

View File

@ -15,7 +15,6 @@ import (
func TestResolveMagicDNS(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@ -49,7 +48,7 @@ func TestResolveMagicDNS(t *testing.T) {
// It is safe to ignore this error as we handled it when caching it
peerFQDN, _ := peer.FQDN()
assert.Equal(t, fmt.Sprintf("%s.headscale.net.", peer.Hostname()), peerFQDN)
assert.Equal(t, peer.Hostname()+".headscale.net.", peerFQDN)
command := []string{
"tailscale",
@ -85,7 +84,6 @@ func TestResolveMagicDNS(t *testing.T) {
func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: 1,
@ -222,12 +220,14 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
_, err = hs.Execute([]string{"rm", erPath})
assertNoErr(t, err)
time.Sleep(2 * time.Second)
// The same paths should still be available as it is not cleared on delete.
for _, client := range allClients {
assertCommandOutputContains(t, client, []string{"dig", "docker.myvpn.example.com"}, "9.9.9.9")
}
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
for _, client := range allClients {
result, _, err := client.Execute([]string{"dig", "docker.myvpn.example.com"})
assert.NoError(ct, err)
assert.Contains(ct, result, "9.9.9.9")
}
}, 10*time.Second, 1*time.Second)
// Write a new file, the backoff mechanism should make the filewatcher pick it up
// again.

View File

@ -33,26 +33,27 @@ func DockerAddIntegrationLabels(opts *dockertest.RunOptions, testType string) {
}
// GenerateRunID creates a unique run identifier with timestamp and random hash.
// Format: YYYYMMDD-HHMMSS-HASH (e.g., 20250619-143052-a1b2c3)
// Format: YYYYMMDD-HHMMSS-HASH (e.g., 20250619-143052-a1b2c3).
func GenerateRunID() string {
now := time.Now()
timestamp := now.Format("20060102-150405")
// Add a short random hash to ensure uniqueness
randomHash := util.MustGenerateRandomStringDNSSafe(6)
return fmt.Sprintf("%s-%s", timestamp, randomHash)
}
// ExtractRunIDFromContainerName extracts the run ID from container name.
// Expects format: "prefix-YYYYMMDD-HHMMSS-HASH"
// Expects format: "prefix-YYYYMMDD-HHMMSS-HASH".
func ExtractRunIDFromContainerName(containerName string) string {
parts := strings.Split(containerName, "-")
if len(parts) >= 3 {
// Return the last three parts as the run ID (YYYYMMDD-HHMMSS-HASH)
return strings.Join(parts[len(parts)-3:], "-")
}
panic(fmt.Sprintf("unexpected container name format: %s", containerName))
panic("unexpected container name format: " + containerName)
}
// IsRunningInContainer checks if the current process is running inside a Docker container.
@ -62,4 +63,4 @@ func IsRunningInContainer() bool {
// This could be improved with more robust detection if needed
_, err := os.Stat("/.dockerenv")
return err == nil
}
}

View File

@ -30,7 +30,7 @@ func ExecuteCommandTimeout(timeout time.Duration) ExecuteCommandOption {
})
}
// buffer is a goroutine safe bytes.buffer
// buffer is a goroutine safe bytes.buffer.
type buffer struct {
store bytes.Buffer
mutex sync.Mutex
@ -58,8 +58,8 @@ func ExecuteCommand(
env []string,
options ...ExecuteCommandOption,
) (string, string, error) {
var stdout = buffer{}
var stderr = buffer{}
stdout := buffer{}
stderr := buffer{}
execConfig := ExecuteCommandConfig{
timeout: dockerExecuteTimeout,

View File

@ -159,7 +159,6 @@ func New(
},
}
if dsic.workdir != "" {
runOptions.WorkingDir = dsic.workdir
}
@ -192,7 +191,7 @@ func New(
}
// Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(runOptions, "derp")
container, err = pool.BuildAndRunWithBuildOptions(
buildOptions,
runOptions,

View File

@ -2,13 +2,13 @@ package integration
import (
"strings"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"testing"
"time"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
type ClientsSpec struct {
@ -71,9 +71,9 @@ func TestDERPServerWebsocketScenario(t *testing.T) {
NodesPerUser: 1,
Users: []string{"user1", "user2", "user3"},
Networks: map[string][]string{
"usernet1": []string{"user1"},
"usernet2": []string{"user2"},
"usernet3": []string{"user3"},
"usernet1": {"user1"},
"usernet2": {"user2"},
"usernet3": {"user3"},
},
}
@ -106,7 +106,6 @@ func derpServerScenario(
furtherAssertions ...func(*Scenario),
) {
IntegrationSkip(t)
// t.Parallel()
scenario, err := NewScenario(spec)
assertNoErr(t, err)

View File

@ -26,7 +26,6 @@ import (
func TestPingAllByIP(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@ -68,7 +67,6 @@ func TestPingAllByIP(t *testing.T) {
func TestPingAllByIPPublicDERP(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@ -118,7 +116,6 @@ func TestEphemeralInAlternateTimezone(t *testing.T) {
func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@ -191,7 +188,6 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
// deleted by accident if they are still online and active.
func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@ -260,18 +256,21 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
// Wait a bit and bring up the clients again before the expiry
// time of the ephemeral nodes.
// Nodes should be able to reconnect and work fine.
time.Sleep(30 * time.Second)
for _, client := range allClients {
err := client.Up()
if err != nil {
t.Fatalf("failed to take down client %s: %s", client.Hostname(), err)
}
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
success = pingAllHelper(t, allClients, allAddrs)
// Wait for clients to sync and be able to ping each other after reconnection
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
err = scenario.WaitForTailscaleSync()
assert.NoError(ct, err)
success = pingAllHelper(t, allClients, allAddrs)
assert.Greater(ct, success, 0, "Ephemeral nodes should be able to reconnect and ping")
}, 60*time.Second, 2*time.Second)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
// Take down all clients, this should start an expiry timer for each.
@ -284,7 +283,13 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
// This time wait for all of the nodes to expire and check that they are no longer
// registered.
time.Sleep(3 * time.Minute)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
for _, userName := range spec.Users {
nodes, err := headscale.ListNodes(userName)
assert.NoError(ct, err)
assert.Len(ct, nodes, 0, "Ephemeral nodes should be expired and removed for user %s", userName)
}
}, 4*time.Minute, 10*time.Second)
for _, userName := range spec.Users {
nodes, err := headscale.ListNodes(userName)
@ -305,7 +310,6 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
func TestPingAllByHostname(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@ -341,20 +345,6 @@ func TestPingAllByHostname(t *testing.T) {
// nolint:tparallel
func TestTaildrop(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
retry := func(times int, sleepInterval time.Duration, doWork func() error) error {
var err error
for range times {
err = doWork()
if err == nil {
return nil
}
time.Sleep(sleepInterval)
}
return err
}
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@ -396,40 +386,27 @@ func TestTaildrop(t *testing.T) {
"/var/run/tailscale/tailscaled.sock",
"http://local-tailscaled.sock/localapi/v0/file-targets",
}
err = retry(10, 1*time.Second, func() error {
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
result, _, err := client.Execute(curlCommand)
if err != nil {
return err
}
assert.NoError(ct, err)
var fts []apitype.FileTarget
err = json.Unmarshal([]byte(result), &fts)
if err != nil {
return err
}
assert.NoError(ct, err)
if len(fts) != len(allClients)-1 {
ftStr := fmt.Sprintf("FileTargets for %s:\n", client.Hostname())
for _, ft := range fts {
ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name)
}
return fmt.Errorf(
"client %s does not have all its peers as FileTargets, got %d, want: %d\n%s",
client.Hostname(),
assert.Failf(ct, "client %s does not have all its peers as FileTargets",
"got %d, want: %d\n%s",
len(fts),
len(allClients)-1,
ftStr,
)
}
return err
})
if err != nil {
t.Errorf(
"failed to query localapi for filetarget on %s, err: %s",
client.Hostname(),
err,
)
}
}, 10*time.Second, 1*time.Second)
}
for _, client := range allClients {
@ -454,24 +431,15 @@ func TestTaildrop(t *testing.T) {
fmt.Sprintf("%s:", peerFQDN),
}
err := retry(10, 1*time.Second, func() error {
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
t.Logf(
"Sending file from %s to %s\n",
client.Hostname(),
peer.Hostname(),
)
_, _, err := client.Execute(command)
return err
})
if err != nil {
t.Fatalf(
"failed to send taildrop file on %s with command %q, err: %s",
client.Hostname(),
strings.Join(command, " "),
err,
)
}
assert.NoError(ct, err)
}, 10*time.Second, 1*time.Second)
})
}
}
@ -520,7 +488,6 @@ func TestTaildrop(t *testing.T) {
func TestUpdateHostnameFromClient(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
hostnames := map[string]string{
"1": "user1-host",
@ -603,9 +570,47 @@ func TestUpdateHostnameFromClient(t *testing.T) {
assertNoErr(t, err)
}
time.Sleep(5 * time.Second)
// Verify that the server-side rename is reflected in DNSName while HostName remains unchanged
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Build a map of expected DNSNames by node ID
expectedDNSNames := make(map[string]string)
for _, node := range nodes {
nodeID := strconv.FormatUint(node.GetId(), 10)
expectedDNSNames[nodeID] = fmt.Sprintf("%d-givenname.headscale.net.", node.GetId())
}
// Verify from each client's perspective
for _, client := range allClients {
status, err := client.Status()
assert.NoError(ct, err)
// Check self node
selfID := string(status.Self.ID)
expectedDNS := expectedDNSNames[selfID]
assert.Equal(ct, expectedDNS, status.Self.DNSName,
"Self DNSName should be renamed for client %s (ID: %s)", client.Hostname(), selfID)
// HostName should remain as the original client-reported hostname
originalHostname := hostnames[selfID]
assert.Equal(ct, originalHostname, status.Self.HostName,
"Self HostName should remain unchanged for client %s (ID: %s)", client.Hostname(), selfID)
// Check peers
for _, peer := range status.Peer {
peerID := string(peer.ID)
if expectedDNS, ok := expectedDNSNames[peerID]; ok {
assert.Equal(ct, expectedDNS, peer.DNSName,
"Peer DNSName should be renamed for peer ID %s as seen by client %s", peerID, client.Hostname())
// HostName should remain as the original client-reported hostname
originalHostname := hostnames[peerID]
assert.Equal(ct, originalHostname, peer.HostName,
"Peer HostName should remain unchanged for peer ID %s as seen by client %s", peerID, client.Hostname())
}
}
}
}, 60*time.Second, 2*time.Second)
// Verify that the clients can see the new hostname, but no givenName
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
@ -647,7 +652,6 @@ func TestUpdateHostnameFromClient(t *testing.T) {
func TestExpireNode(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@ -707,7 +711,23 @@ func TestExpireNode(t *testing.T) {
t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String())
time.Sleep(2 * time.Minute)
// Verify that the expired node has been marked in all peers list.
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
for _, client := range allClients {
status, err := client.Status()
assert.NoError(ct, err)
if client.Hostname() != node.GetName() {
// Check if the expired node appears as expired in this client's peer list
for key, peer := range status.Peer {
if key == expiredNodeKey {
assert.True(ct, peer.Expired, "Node should be marked as expired for client %s", client.Hostname())
break
}
}
}
}
}, 3*time.Minute, 10*time.Second)
now := time.Now()
@ -774,7 +794,6 @@ func TestExpireNode(t *testing.T) {
func TestNodeOnlineStatus(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@ -890,7 +909,6 @@ func TestNodeOnlineStatus(t *testing.T) {
// five times ensuring they are able to restablish connectivity.
func TestPingAllByIPManyUpDown(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
@ -944,8 +962,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
t.Fatalf("failed to take down all nodes: %s", err)
}
time.Sleep(5 * time.Second)
for _, client := range allClients {
c := client
wg.Go(func() error {
@ -958,10 +974,14 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
t.Fatalf("failed to take down all nodes: %s", err)
}
time.Sleep(5 * time.Second)
// Wait for sync and successful pings after nodes come back up
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
err = scenario.WaitForTailscaleSync()
assert.NoError(ct, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
success := pingAllHelper(t, allClients, allAddrs)
assert.Greater(ct, success, 0, "Nodes should be able to ping after coming back up")
}, 30*time.Second, 2*time.Second)
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
@ -970,7 +990,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
func Test2118DeletingOnlineNodePanics(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: 1,
@ -1042,10 +1061,24 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
)
require.NoError(t, err)
time.Sleep(2 * time.Second)
// Ensure that the node has been deleted, this did not occur due to a panic.
var nodeListAfter []v1.Node
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"list",
"--output",
"json",
},
&nodeListAfter,
)
assert.NoError(ct, err)
assert.Len(ct, nodeListAfter, 1, "Node should be deleted from list")
}, 10*time.Second, 1*time.Second)
err = executeAndUnmarshal(
headscale,
[]string{

View File

@ -191,7 +191,7 @@ func WithPostgres() Option {
}
}
// WithPolicy sets the policy mode for headscale
// WithPolicy sets the policy mode for headscale.
func WithPolicyMode(mode types.PolicyMode) Option {
return func(hsic *HeadscaleInContainer) {
hsic.policyMode = mode
@ -279,7 +279,7 @@ func New(
return nil, err
}
hostname := fmt.Sprintf("hs-%s", hash)
hostname := "hs-" + hash
hsic := &HeadscaleInContainer{
hostname: hostname,
@ -308,14 +308,14 @@ func New(
if hsic.postgres {
hsic.env["HEADSCALE_DATABASE_TYPE"] = "postgres"
hsic.env["HEADSCALE_DATABASE_POSTGRES_HOST"] = fmt.Sprintf("postgres-%s", hash)
hsic.env["HEADSCALE_DATABASE_POSTGRES_HOST"] = "postgres-" + hash
hsic.env["HEADSCALE_DATABASE_POSTGRES_USER"] = "headscale"
hsic.env["HEADSCALE_DATABASE_POSTGRES_PASS"] = "headscale"
hsic.env["HEADSCALE_DATABASE_POSTGRES_NAME"] = "headscale"
delete(hsic.env, "HEADSCALE_DATABASE_SQLITE_PATH")
pgRunOptions := &dockertest.RunOptions{
Name: fmt.Sprintf("postgres-%s", hash),
Name: "postgres-" + hash,
Repository: "postgres",
Tag: "latest",
Networks: networks,
@ -328,7 +328,7 @@ func New(
// Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(pgRunOptions, "postgres")
pg, err := pool.RunWithOptions(pgRunOptions)
if err != nil {
return nil, fmt.Errorf("starting postgres container: %w", err)
@ -373,7 +373,6 @@ func New(
Env: env,
}
if len(hsic.hostPortBindings) > 0 {
runOptions.PortBindings = map[docker.Port][]docker.PortBinding{}
for port, hostPorts := range hsic.hostPortBindings {
@ -396,7 +395,7 @@ func New(
// Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(runOptions, "headscale")
container, err := pool.BuildAndRunWithBuildOptions(
headscaleBuildOptions,
runOptions,
@ -566,7 +565,7 @@ func (t *HeadscaleInContainer) SaveMetrics(savePath string) error {
// extractTarToDirectory extracts a tar archive to a directory.
func extractTarToDirectory(tarData []byte, targetDir string) error {
if err := os.MkdirAll(targetDir, 0755); err != nil {
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", targetDir, err)
}
@ -624,6 +623,7 @@ func (t *HeadscaleInContainer) SaveProfile(savePath string) error {
}
targetDir := path.Join(savePath, t.hostname+"-pprof")
return extractTarToDirectory(tarFile, targetDir)
}
@ -634,6 +634,7 @@ func (t *HeadscaleInContainer) SaveMapResponses(savePath string) error {
}
targetDir := path.Join(savePath, t.hostname+"-mapresponses")
return extractTarToDirectory(tarFile, targetDir)
}
@ -672,17 +673,16 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
if err != nil {
return fmt.Errorf("failed to check database schema (sqlite3 command failed): %w", err)
}
if strings.TrimSpace(schemaCheck) == "" {
return fmt.Errorf("database file exists but has no schema (empty database)")
return errors.New("database file exists but has no schema (empty database)")
}
// Show a preview of the schema (first 500 chars)
schemaPreview := schemaCheck
if len(schemaPreview) > 500 {
schemaPreview = schemaPreview[:500] + "..."
}
log.Printf("Database schema preview:\n%s", schemaPreview)
tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3")
if err != nil {
@ -727,7 +727,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
}
}
return fmt.Errorf("no regular file found in database tar archive")
return errors.New("no regular file found in database tar archive")
}
// Execute runs a command inside the Headscale container and returns the
@ -756,13 +756,13 @@ func (t *HeadscaleInContainer) Execute(
// GetPort returns the docker container port as a string.
func (t *HeadscaleInContainer) GetPort() string {
return fmt.Sprintf("%d", t.port)
return strconv.Itoa(t.port)
}
// GetHealthEndpoint returns a health endpoint for the HeadscaleInContainer
// instance.
func (t *HeadscaleInContainer) GetHealthEndpoint() string {
return fmt.Sprintf("%s/health", t.GetEndpoint())
return t.GetEndpoint() + "/health"
}
// GetEndpoint returns the Headscale endpoint for the HeadscaleInContainer.
@ -772,10 +772,10 @@ func (t *HeadscaleInContainer) GetEndpoint() string {
t.port)
if t.hasTLS() {
return fmt.Sprintf("https://%s", hostEndpoint)
return "https://" + hostEndpoint
}
return fmt.Sprintf("http://%s", hostEndpoint)
return "http://" + hostEndpoint
}
// GetCert returns the public certificate of the HeadscaleInContainer.
@ -910,6 +910,7 @@ func (t *HeadscaleInContainer) ListNodes(
}
ret = append(ret, nodes...)
return nil
}
@ -932,6 +933,7 @@ func (t *HeadscaleInContainer) ListNodes(
sort.Slice(ret, func(i, j int) bool {
return cmp.Compare(ret[i].GetId(), ret[j].GetId()) == -1
})
return ret, nil
}
@ -943,10 +945,10 @@ func (t *HeadscaleInContainer) NodesByUser() (map[string][]*v1.Node, error) {
var userMap map[string][]*v1.Node
for _, node := range nodes {
if _, ok := userMap[node.User.Name]; !ok {
mak.Set(&userMap, node.User.Name, []*v1.Node{node})
if _, ok := userMap[node.GetUser().GetName()]; !ok {
mak.Set(&userMap, node.GetUser().GetName(), []*v1.Node{node})
} else {
userMap[node.User.Name] = append(userMap[node.User.Name], node)
userMap[node.GetUser().GetName()] = append(userMap[node.GetUser().GetName()], node)
}
}
@ -999,7 +1001,7 @@ func (t *HeadscaleInContainer) MapUsers() (map[string]*v1.User, error) {
var userMap map[string]*v1.User
for _, user := range users {
mak.Set(&userMap, user.Name, user)
mak.Set(&userMap, user.GetName(), user)
}
return userMap, nil
@ -1095,7 +1097,7 @@ func (h *HeadscaleInContainer) PID() (int, error) {
case 1:
return pids[0], nil
default:
return 0, fmt.Errorf("multiple headscale processes running")
return 0, errors.New("multiple headscale processes running")
}
}
@ -1121,7 +1123,7 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) (
"headscale", "nodes", "approve-routes",
"--output", "json",
"--identifier", strconv.FormatUint(id, 10),
fmt.Sprintf("--routes=%s", strings.Join(util.PrefixesToString(routes), ",")),
"--routes=" + strings.Join(util.PrefixesToString(routes), ","),
}
result, _, err := dockertestutil.ExecuteCommand(

View File

@ -4,13 +4,12 @@ import (
"encoding/json"
"fmt"
"net/netip"
"slices"
"sort"
"strings"
"testing"
"time"
"slices"
cmpdiff "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -37,7 +36,6 @@ var allPorts = filter.PortRange{First: 0, Last: 0xffff}
// routes.
func TestEnablingRoutes(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: 3,
@ -182,11 +180,12 @@ func TestEnablingRoutes(t *testing.T) {
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
if peerStatus.ID == "1" {
switch peerStatus.ID {
case "1":
requirePeerSubnetRoutes(t, peerStatus, nil)
} else if peerStatus.ID == "2" {
case "2":
requirePeerSubnetRoutes(t, peerStatus, nil)
} else {
default:
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix("10.0.2.0/24")})
}
}
@ -195,7 +194,6 @@ func TestEnablingRoutes(t *testing.T) {
func TestHASubnetRouterFailover(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: 3,
@ -779,7 +777,6 @@ func TestHASubnetRouterFailover(t *testing.T) {
// https://github.com/juanfont/headscale/issues/1604
func TestSubnetRouteACL(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "user4"
@ -1003,7 +1000,6 @@ func TestSubnetRouteACL(t *testing.T) {
// set during login instead of set.
func TestEnablingExitRoutes(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "user2"
@ -1097,7 +1093,6 @@ func TestEnablingExitRoutes(t *testing.T) {
// subnet router is working as expected.
func TestSubnetRouterMultiNetwork(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: 1,
@ -1177,7 +1172,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
// Enable route
_, err = headscale.ApproveRoutes(
nodes[0].Id,
nodes[0].GetId(),
[]netip.Prefix{*pref},
)
require.NoError(t, err)
@ -1224,7 +1219,6 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
spec := ScenarioSpec{
NodesPerUser: 1,
@ -1300,7 +1294,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
}
// Enable route
_, err = headscale.ApproveRoutes(nodes[0].Id, []netip.Prefix{tsaddr.AllIPv4()})
_, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()})
require.NoError(t, err)
time.Sleep(5 * time.Second)
@ -1719,7 +1713,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
pak, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false)
assertNoErr(t, err)
err = routerUsernet1.Login(headscale.GetEndpoint(), pak.Key)
err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey())
assertNoErr(t, err)
}
// extra creation end.
@ -2065,7 +2059,6 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub
// that are explicitly allowed in the ACL.
func TestSubnetRouteACLFiltering(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
// Use router and node users for better clarity
routerUser := "router"
@ -2090,7 +2083,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
defer scenario.ShutdownAssertNoPanics(t)
// Set up the ACL policy that allows the node to access only one of the subnet routes (10.10.10.0/24)
aclPolicyStr := fmt.Sprintf(`{
aclPolicyStr := `{
"hosts": {
"router": "100.64.0.1/32",
"node": "100.64.0.2/32"
@ -2115,7 +2108,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
]
}
]
}`)
}`
route, err := scenario.SubnetOfNetwork("usernet1")
require.NoError(t, err)

View File

@ -123,7 +123,7 @@ type ScenarioSpec struct {
// NodesPerUser is how many nodes should be attached to each user.
NodesPerUser int
// Networks, if set, is the seperate Docker networks that should be
// Networks, if set, is the separate Docker networks that should be
// created and a list of the users that should be placed in those networks.
// If not set, a single network will be created and all users+nodes will be
// added there.
@ -1077,7 +1077,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
hostname := fmt.Sprintf("hs-oidcmock-%s", hash)
hostname := "hs-oidcmock-" + hash
usersJSON, err := json.Marshal(users)
if err != nil {
@ -1093,16 +1093,15 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
},
Networks: s.Networks(),
Env: []string{
fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname),
"MOCKOIDC_ADDR=" + hostname,
fmt.Sprintf("MOCKOIDC_PORT=%d", port),
"MOCKOIDC_CLIENT_ID=superclient",
"MOCKOIDC_CLIENT_SECRET=supersecret",
fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()),
fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)),
"MOCKOIDC_ACCESS_TTL=" + accessTTL.String(),
"MOCKOIDC_USERS=" + string(usersJSON),
},
}
headscaleBuildOptions := &dockertest.BuildOptions{
Dockerfile: hsic.IntegrationTestDockerFileName,
ContextDir: dockerContextPath,
@ -1117,7 +1116,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
// Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(mockOidcOptions, "oidc")
if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions(
headscaleBuildOptions,
mockOidcOptions,
@ -1184,7 +1183,7 @@ func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) {
hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
hostname := fmt.Sprintf("hs-webservice-%s", hash)
hostname := "hs-webservice-" + hash
network, ok := s.networks[s.prefixedNetworkName(networkName)]
if !ok {

View File

@ -28,7 +28,6 @@ func IntegrationSkip(t *testing.T) {
// nolint:tparallel
func TestHeadscale(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
var err error
@ -75,7 +74,6 @@ func TestHeadscale(t *testing.T) {
// nolint:tparallel
func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
var err error

View File

@ -22,35 +22,6 @@ func isSSHNoAccessStdError(stderr string) bool {
strings.Contains(stderr, "tailnet policy does not permit you to SSH to this node")
}
var retry = func(times int, sleepInterval time.Duration,
doWork func() (string, string, error),
) (string, string, error) {
var result string
var stderr string
var err error
for range times {
tempResult, tempStderr, err := doWork()
result += tempResult
stderr += tempStderr
if err == nil {
return result, stderr, nil
}
// If we get a permission denied error, we can fail immediately
// since that is something we won-t recover from by retrying.
if err != nil && isSSHNoAccessStdError(stderr) {
return result, stderr, err
}
time.Sleep(sleepInterval)
}
return result, stderr, err
}
func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Scenario {
t.Helper()
@ -92,7 +63,6 @@ func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Sce
func TestSSHOneUserToAll(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario := sshScenario(t,
&policyv2.Policy{
@ -160,7 +130,6 @@ func TestSSHOneUserToAll(t *testing.T) {
func TestSSHMultipleUsersAllToAll(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario := sshScenario(t,
&policyv2.Policy{
@ -216,7 +185,6 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) {
func TestSSHNoSSHConfigured(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario := sshScenario(t,
&policyv2.Policy{
@ -261,7 +229,6 @@ func TestSSHNoSSHConfigured(t *testing.T) {
func TestSSHIsBlockedInACL(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario := sshScenario(t,
&policyv2.Policy{
@ -313,7 +280,6 @@ func TestSSHIsBlockedInACL(t *testing.T) {
func TestSSHUserOnlyIsolation(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario := sshScenario(t,
&policyv2.Policy{
@ -404,6 +370,14 @@ func TestSSHUserOnlyIsolation(t *testing.T) {
}
func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
return doSSHWithRetry(t, client, peer, true)
}
func doSSHWithoutRetry(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
return doSSHWithRetry(t, client, peer, false)
}
func doSSHWithRetry(t *testing.T, client TailscaleClient, peer TailscaleClient, retry bool) (string, string, error) {
t.Helper()
peerFQDN, _ := peer.FQDN()
@ -417,9 +391,29 @@ func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string,
log.Printf("Running from %s to %s", client.Hostname(), peer.Hostname())
log.Printf("Command: %s", strings.Join(command, " "))
return retry(10, 1*time.Second, func() (string, string, error) {
return client.Execute(command)
})
var result, stderr string
var err error
if retry {
// Use assert.EventuallyWithT to retry SSH connections for success cases
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
result, stderr, err = client.Execute(command)
// If we get a permission denied error, we can fail immediately
// since that is something we won't recover from by retrying.
if err != nil && isSSHNoAccessStdError(stderr) {
return // Don't retry permission denied errors
}
// For all other errors, assert no error to trigger retry
assert.NoError(ct, err)
}, 10*time.Second, 1*time.Second)
} else {
// For failure cases, just execute once
result, stderr, err = client.Execute(command)
}
return result, stderr, err
}
func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClient) {
@ -434,7 +428,7 @@ func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClien
func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer TailscaleClient) {
t.Helper()
result, stderr, err := doSSH(t, client, peer)
result, stderr, err := doSSHWithoutRetry(t, client, peer)
assert.Empty(t, result)
@ -444,7 +438,7 @@ func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer Tailsc
func assertSSHTimeout(t *testing.T, client TailscaleClient, peer TailscaleClient) {
t.Helper()
result, stderr, _ := doSSH(t, client, peer)
result, stderr, _ := doSSHWithoutRetry(t, client, peer)
assert.Empty(t, result)

View File

@ -251,7 +251,6 @@ func New(
Env: []string{},
}
if tsic.withWebsocketDERP {
if version != VersionHead {
return tsic, errInvalidClientConfig
@ -463,7 +462,7 @@ func (t *TailscaleInContainer) buildLoginCommand(
if len(t.withTags) > 0 {
command = append(command,
fmt.Sprintf(`--advertise-tags=%s`, strings.Join(t.withTags, ",")),
"--advertise-tags="+strings.Join(t.withTags, ","),
)
}
@ -685,7 +684,7 @@ func (t *TailscaleInContainer) MustID() types.NodeID {
// Panics if version is lower then minimum.
func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
if !util.TailscaleVersionNewerOrEqual("1.56", t.version) {
panic(fmt.Sprintf("tsic.Netmap() called with unsupported version: %s", t.version))
panic("tsic.Netmap() called with unsupported version: " + t.version)
}
command := []string{
@ -1026,7 +1025,7 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) err
"tailscale", "ping",
fmt.Sprintf("--timeout=%s", args.timeout),
fmt.Sprintf("--c=%d", args.count),
fmt.Sprintf("--until-direct=%s", strconv.FormatBool(args.direct)),
"--until-direct=" + strconv.FormatBool(args.direct),
}
command = append(command, hostnameOrIP)
@ -1131,11 +1130,11 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err
command := []string{
"curl",
"--silent",
"--connect-timeout", fmt.Sprintf("%d", int(args.connectionTimeout.Seconds())),
"--max-time", fmt.Sprintf("%d", int(args.maxTime.Seconds())),
"--retry", fmt.Sprintf("%d", args.retry),
"--retry-delay", fmt.Sprintf("%d", int(args.retryDelay.Seconds())),
"--retry-max-time", fmt.Sprintf("%d", int(args.retryMaxTime.Seconds())),
"--connect-timeout", strconv.Itoa(int(args.connectionTimeout.Seconds())),
"--max-time", strconv.Itoa(int(args.maxTime.Seconds())),
"--retry", strconv.Itoa(args.retry),
"--retry-delay", strconv.Itoa(int(args.retryDelay.Seconds())),
"--retry-max-time", strconv.Itoa(int(args.retryMaxTime.Seconds())),
url,
}
@ -1230,7 +1229,7 @@ func (t *TailscaleInContainer) ReadFile(path string) ([]byte, error) {
}
if out.Len() == 0 {
return nil, fmt.Errorf("file is empty")
return nil, errors.New("file is empty")
}
return out.Bytes(), nil
@ -1259,5 +1258,6 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) {
if err = json.Unmarshal(currentProfile, &p); err != nil {
return nil, fmt.Errorf("failed to unmarshal current profile state: %w", err)
}
return &p.Persist.PrivateNodeKey, nil
}

View File

@ -3,7 +3,6 @@ package integration
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net/netip"
@ -267,7 +266,7 @@ func assertValidStatus(t *testing.T, client TailscaleClient) {
// This isn't really relevant for Self as it won't be in its own socket/wireguard.
// assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname())
// assert.Truef(t, status.Self.InEngine, "%q is not in in wireguard engine", client.Hostname())
// assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname())
for _, peer := range status.Peer {
assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname())
@ -311,7 +310,7 @@ func assertValidNetcheck(t *testing.T, client TailscaleClient) {
func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) {
t.Helper()
_, err := backoff.Retry(context.Background(), func() (struct{}, error) {
_, err := backoff.Retry(t.Context(), func() (struct{}, error) {
stdout, stderr, err := c.Execute(command)
if err != nil {
return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err)
@ -492,6 +491,7 @@ func groupApprover(name string) policyv2.AutoApprover {
func tagApprover(name string) policyv2.AutoApprover {
return ptr.To(policyv2.Tag(name))
}
//
// // findPeerByHostname takes a hostname and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus
// // if there is a peer with the given hostname. If no peer is found, nil is returned.