mirror of
https://github.com/juanfont/headscale.git
synced 2025-09-25 17:51:11 +02:00
integration: replace time.Sleep with Eventually
sleeping in tests is a big no no, its time to stop. Sleeping only works well on the same machine under the same conditions we rather wait for something as things take time on different machines
This commit is contained in:
parent
b904276f2b
commit
9b47f71f37
3
.github/workflows/docs-deploy.yml
vendored
3
.github/workflows/docs-deploy.yml
vendored
@ -48,5 +48,4 @@ jobs:
|
||||
- name: Deploy stable docs from tag
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
# This assumes that only newer tags are pushed
|
||||
run:
|
||||
mike deploy --push --update-aliases ${GITHUB_REF_NAME#v} stable latest
|
||||
run: mike deploy --push --update-aliases ${GITHUB_REF_NAME#v} stable latest
|
||||
|
@ -75,7 +75,7 @@ jobs:
|
||||
# Some of the jobs might still require manual restart as they are really
|
||||
# slow and this will cause them to eventually be killed by Github actions.
|
||||
attempt_delay: 300000 # 5 min
|
||||
attempt_limit: 3
|
||||
attempt_limit: 2
|
||||
command: |
|
||||
nix develop --command -- hi run "^${{ inputs.test }}$" \
|
||||
--timeout=120m \
|
||||
|
6
.github/workflows/lint.yml
vendored
6
.github/workflows/lint.yml
vendored
@ -36,8 +36,7 @@ jobs:
|
||||
|
||||
- name: golangci-lint
|
||||
if: steps.changed-files.outputs.files == 'true'
|
||||
run:
|
||||
nix develop --command -- golangci-lint run
|
||||
run: nix develop --command -- golangci-lint run
|
||||
--new-from-rev=${{github.event.pull_request.base.sha}}
|
||||
--format=colored-line-number
|
||||
|
||||
@ -75,8 +74,7 @@ jobs:
|
||||
|
||||
- name: Prettify code
|
||||
if: steps.changed-files.outputs.files == 'true'
|
||||
run:
|
||||
nix develop --command -- prettier --no-error-on-unmatched-pattern
|
||||
run: nix develop --command -- prettier --no-error-on-unmatched-pattern
|
||||
--ignore-unknown --check **/*.{ts,js,md,yaml,yml,sass,css,scss,html}
|
||||
|
||||
proto-lint:
|
||||
|
@ -117,7 +117,7 @@ var createNodeCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot create node: %s", status.Convert(err).Message()),
|
||||
"Cannot create node: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -68,7 +69,7 @@ func mockOIDC() error {
|
||||
|
||||
userStr := os.Getenv("MOCKOIDC_USERS")
|
||||
if userStr == "" {
|
||||
return fmt.Errorf("MOCKOIDC_USERS not defined")
|
||||
return errors.New("MOCKOIDC_USERS not defined")
|
||||
}
|
||||
|
||||
var users []mockoidc.MockUser
|
||||
|
@ -184,7 +184,7 @@ var listNodesCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
|
||||
"Cannot get nodes: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
@ -398,10 +398,7 @@ var deleteNodeCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Error getting node node: %s",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
"Error getting node node: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
|
||||
@ -437,10 +434,7 @@ var deleteNodeCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Error deleting node: %s",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
"Error deleting node: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
|
||||
@ -498,10 +492,7 @@ var moveNodeCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Error getting node: %s",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
"Error getting node: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
|
||||
@ -517,10 +508,7 @@ var moveNodeCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Error moving node: %s",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
"Error moving node: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
|
||||
@ -567,10 +555,7 @@ be assigned to nodes.`,
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Error backfilling IPs: %s",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
"Error backfilling IPs: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
survey "github.com/AlecAivazis/survey/v2"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
@ -27,10 +28,7 @@ func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) {
|
||||
err := errors.New("--name or --identifier flag is required")
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Cannot rename user: %s",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
"Cannot rename user: "+status.Convert(err).Message(),
|
||||
"",
|
||||
)
|
||||
}
|
||||
@ -114,10 +112,7 @@ var createUserCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Cannot create user: %s",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
"Cannot create user: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
@ -147,16 +142,16 @@ var destroyUserCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
if len(users.GetUsers()) != 1 {
|
||||
err := fmt.Errorf("Unable to determine user to delete, query returned multiple users, use ID")
|
||||
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
@ -185,10 +180,7 @@ var destroyUserCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Cannot destroy user: %s",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
"Cannot destroy user: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
@ -233,7 +225,7 @@ var listUsersCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Cannot get users: %s", status.Convert(err).Message()),
|
||||
"Cannot get users: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
@ -247,7 +239,7 @@ var listUsersCmd = &cobra.Command{
|
||||
tableData = append(
|
||||
tableData,
|
||||
[]string{
|
||||
fmt.Sprintf("%d", user.GetId()),
|
||||
strconv.FormatUint(user.GetId(), 10),
|
||||
user.GetDisplayName(),
|
||||
user.GetName(),
|
||||
user.GetEmail(),
|
||||
@ -287,16 +279,16 @@ var renameUserCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
||||
if len(users.GetUsers()) != 1 {
|
||||
err := fmt.Errorf("Unable to determine user to delete, query returned multiple users, use ID")
|
||||
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf("Error: %s", status.Convert(err).Message()),
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
@ -312,10 +304,7 @@ var renameUserCmd = &cobra.Command{
|
||||
if err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
fmt.Sprintf(
|
||||
"Cannot rename user: %s",
|
||||
status.Convert(err).Message(),
|
||||
),
|
||||
"Cannot rename user: "+status.Convert(err).Message(),
|
||||
output,
|
||||
)
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ func killTestContainers(ctx context.Context) error {
|
||||
if cont.State == "running" {
|
||||
_ = cli.ContainerKill(ctx, cont.ID, "KILL")
|
||||
}
|
||||
|
||||
|
||||
// Then remove the container with retry logic
|
||||
if removeContainerWithRetry(ctx, cli, cont.ID) {
|
||||
removed++
|
||||
@ -87,25 +87,25 @@ func killTestContainers(ctx context.Context) error {
|
||||
func removeContainerWithRetry(ctx context.Context, cli *client.Client, containerID string) bool {
|
||||
maxRetries := 3
|
||||
baseDelay := 100 * time.Millisecond
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
|
||||
for attempt := range maxRetries {
|
||||
err := cli.ContainerRemove(ctx, containerID, container.RemoveOptions{
|
||||
Force: true,
|
||||
})
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
// If this is the last attempt, don't wait
|
||||
if attempt == maxRetries-1 {
|
||||
break
|
||||
}
|
||||
|
||||
|
||||
// Wait with exponential backoff
|
||||
delay := baseDelay * time.Duration(1<<attempt)
|
||||
time.Sleep(delay)
|
||||
}
|
||||
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -156,10 +156,10 @@ func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunC
|
||||
projectRoot := findProjectRoot(pwd)
|
||||
|
||||
runID := dockertestutil.ExtractRunIDFromContainerName(containerName)
|
||||
|
||||
|
||||
env := []string{
|
||||
fmt.Sprintf("HEADSCALE_INTEGRATION_POSTGRES=%d", boolToInt(config.UsePostgres)),
|
||||
fmt.Sprintf("HEADSCALE_INTEGRATION_RUN_ID=%s", runID),
|
||||
"HEADSCALE_INTEGRATION_RUN_ID=" + runID,
|
||||
}
|
||||
containerConfig := &container.Config{
|
||||
Image: "golang:" + config.GoVersion,
|
||||
@ -175,7 +175,7 @@ func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunC
|
||||
|
||||
// Get the correct Docker socket path from the current context
|
||||
dockerSocketPath := getDockerSocketPath()
|
||||
|
||||
|
||||
if config.Verbose {
|
||||
log.Printf("Using Docker socket: %s", dockerSocketPath)
|
||||
}
|
||||
@ -184,7 +184,7 @@ func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunC
|
||||
AutoRemove: false, // We'll remove manually for better control
|
||||
Binds: []string{
|
||||
fmt.Sprintf("%s:%s", projectRoot, projectRoot),
|
||||
fmt.Sprintf("%s:/var/run/docker.sock", dockerSocketPath),
|
||||
dockerSocketPath + ":/var/run/docker.sock",
|
||||
logsDir + ":/tmp/control",
|
||||
},
|
||||
Mounts: []mount.Mount{
|
||||
@ -237,7 +237,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
||||
}
|
||||
|
||||
testContainers := getCurrentTestContainers(containers, testContainerID, verbose)
|
||||
|
||||
|
||||
// Wait for all test containers to reach a final state
|
||||
maxWaitTime := 10 * time.Second
|
||||
checkInterval := 500 * time.Millisecond
|
||||
@ -254,7 +254,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
allFinalized := true
|
||||
|
||||
|
||||
for _, testCont := range testContainers {
|
||||
inspect, err := cli.ContainerInspect(ctx, testCont.ID)
|
||||
if err != nil {
|
||||
@ -263,17 +263,18 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Check if container is in a final state
|
||||
if !isContainerFinalized(inspect.State) {
|
||||
allFinalized = false
|
||||
if verbose {
|
||||
log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status)
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if allFinalized {
|
||||
if verbose {
|
||||
log.Printf("All test containers finalized, ready for artifact extraction")
|
||||
@ -290,7 +291,6 @@ func isContainerFinalized(state *container.State) bool {
|
||||
return !state.Running && state.FinishedAt != ""
|
||||
}
|
||||
|
||||
|
||||
// findProjectRoot locates the project root by finding the directory containing go.mod.
|
||||
func findProjectRoot(startPath string) string {
|
||||
current := startPath
|
||||
@ -427,7 +427,7 @@ func listControlFiles(logsDir string) {
|
||||
}
|
||||
|
||||
if entry.IsDir() {
|
||||
// Include directories (pprof, mapresponses)
|
||||
// Include directories (pprof, mapresponses)
|
||||
if strings.Contains(name, "-pprof") || strings.Contains(name, "-mapresponses") {
|
||||
dataDirs = append(dataDirs, name)
|
||||
}
|
||||
@ -510,7 +510,7 @@ type testContainer struct {
|
||||
// getCurrentTestContainers filters containers to only include those from the current test run.
|
||||
func getCurrentTestContainers(containers []container.Summary, testContainerID string, verbose bool) []testContainer {
|
||||
var testRunContainers []testContainer
|
||||
|
||||
|
||||
// Find the test container to get its run ID label
|
||||
var runID string
|
||||
for _, cont := range containers {
|
||||
@ -521,16 +521,16 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if runID == "" {
|
||||
log.Printf("Error: test container %s missing required hi.run-id label", testContainerID[:12])
|
||||
return testRunContainers
|
||||
}
|
||||
|
||||
|
||||
if verbose {
|
||||
log.Printf("Looking for containers with run ID: %s", runID)
|
||||
}
|
||||
|
||||
|
||||
// Find all containers with the same run ID
|
||||
for _, cont := range containers {
|
||||
for _, name := range cont.Names {
|
||||
@ -546,18 +546,19 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
|
||||
log.Printf("Including container %s (run ID: %s)", containerName, runID)
|
||||
}
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return testRunContainers
|
||||
}
|
||||
|
||||
// extractContainerArtifacts saves logs and tar files from a container.
|
||||
func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error {
|
||||
// Ensure the logs directory exists
|
||||
if err := os.MkdirAll(logsDir, 0755); err != nil {
|
||||
if err := os.MkdirAll(logsDir, 0o755); err != nil {
|
||||
return fmt.Errorf("failed to create logs directory: %w", err)
|
||||
}
|
||||
|
||||
@ -608,12 +609,12 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
|
||||
}
|
||||
|
||||
// Write stdout logs
|
||||
if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0644); err != nil {
|
||||
if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil {
|
||||
return fmt.Errorf("failed to write stdout log: %w", err)
|
||||
}
|
||||
|
||||
// Write stderr logs
|
||||
if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0644); err != nil {
|
||||
if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil {
|
||||
return fmt.Errorf("failed to write stderr log: %w", err)
|
||||
}
|
||||
|
||||
@ -626,7 +627,7 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
|
||||
|
||||
// extractContainerFiles extracts database file and directories from headscale containers.
|
||||
// Note: The actual file extraction is now handled by the integration tests themselves
|
||||
// via SaveProfile, SaveMapResponses, and SaveDatabase functions in hsic.go
|
||||
// via SaveProfile, SaveMapResponses, and SaveDatabase functions in hsic.go.
|
||||
func extractContainerFiles(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error {
|
||||
// Files are now extracted directly by the integration tests
|
||||
// This function is kept for potential future use or other file types
|
||||
@ -677,7 +678,7 @@ func extractDirectory(ctx context.Context, cli *client.Client, containerID, sour
|
||||
|
||||
// Create target directory
|
||||
targetDir := filepath.Join(logsDir, dirName)
|
||||
if err := os.MkdirAll(targetDir, 0755); err != nil {
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
return fmt.Errorf("failed to create directory %s: %w", targetDir, err)
|
||||
}
|
||||
|
||||
|
@ -10,10 +10,8 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrFileNotFoundInTar indicates a file was not found in the tar archive.
|
||||
ErrFileNotFoundInTar = errors.New("file not found in tar")
|
||||
)
|
||||
// ErrFileNotFoundInTar indicates a file was not found in the tar archive.
|
||||
var ErrFileNotFoundInTar = errors.New("file not found in tar")
|
||||
|
||||
// extractFileFromTar extracts a single file from a tar reader.
|
||||
func extractFileFromTar(tarReader io.Reader, fileName, outputPath string) error {
|
||||
@ -42,6 +40,7 @@ func extractFileFromTar(tarReader io.Reader, fileName, outputPath string) error
|
||||
if _, err := io.Copy(outFile, tr); err != nil {
|
||||
return fmt.Errorf("failed to copy file contents: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@ -98,4 +97,4 @@ func extractDirectoryFromTar(tarReader io.Reader, targetDir string) error {
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -143,6 +143,7 @@
|
||||
yq-go
|
||||
ripgrep
|
||||
postgresql
|
||||
traceroute
|
||||
|
||||
# 'dot' is needed for pprof graphs
|
||||
# go tool pprof -http=: <source>
|
||||
|
@ -98,7 +98,6 @@ func (h *Headscale) handleExistingNode(
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
||||
@ -169,7 +168,6 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
|
||||
node, changed, err := h.state.HandleNodeFromPreAuthKey(
|
||||
regReq,
|
||||
machineKey,
|
||||
@ -178,9 +176,11 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
|
||||
}
|
||||
if perr, ok := err.(types.PAKError); ok {
|
||||
var perr types.PAKError
|
||||
if errors.As(err, &perr) {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -1,11 +1,10 @@
|
||||
package capver
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"slices"
|
||||
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/set"
|
||||
|
@ -1,6 +1,6 @@
|
||||
package capver
|
||||
|
||||
//Generated DO NOT EDIT
|
||||
// Generated DO NOT EDIT
|
||||
|
||||
import "tailscale.com/tailcfg"
|
||||
|
||||
@ -38,17 +38,16 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
|
||||
"v1.82.5": 115,
|
||||
}
|
||||
|
||||
|
||||
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
|
||||
87: "v1.60.0",
|
||||
88: "v1.62.0",
|
||||
90: "v1.64.0",
|
||||
95: "v1.66.0",
|
||||
97: "v1.68.0",
|
||||
102: "v1.70.0",
|
||||
104: "v1.72.0",
|
||||
106: "v1.74.0",
|
||||
109: "v1.78.0",
|
||||
113: "v1.80.0",
|
||||
115: "v1.82.0",
|
||||
87: "v1.60.0",
|
||||
88: "v1.62.0",
|
||||
90: "v1.64.0",
|
||||
95: "v1.66.0",
|
||||
97: "v1.68.0",
|
||||
102: "v1.70.0",
|
||||
104: "v1.72.0",
|
||||
106: "v1.74.0",
|
||||
109: "v1.78.0",
|
||||
113: "v1.80.0",
|
||||
115: "v1.82.0",
|
||||
}
|
||||
|
@ -764,13 +764,13 @@ AND auth_key_id NOT IN (
|
||||
// Drop all indexes first to avoid conflicts
|
||||
indexesToDrop := []string{
|
||||
"idx_users_deleted_at",
|
||||
"idx_provider_identifier",
|
||||
"idx_provider_identifier",
|
||||
"idx_name_provider_identifier",
|
||||
"idx_name_no_provider_identifier",
|
||||
"idx_api_keys_prefix",
|
||||
"idx_policies_deleted_at",
|
||||
}
|
||||
|
||||
|
||||
for _, index := range indexesToDrop {
|
||||
_ = tx.Exec("DROP INDEX IF EXISTS " + index).Error
|
||||
}
|
||||
@ -927,6 +927,7 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
log.Info().Msg("Schema recreation completed successfully")
|
||||
|
||||
return nil
|
||||
},
|
||||
Rollback: func(db *gorm.DB) error { return nil },
|
||||
|
@ -93,7 +93,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
||||
Avoid: false,
|
||||
Nodes: []*tailcfg.DERPNode{
|
||||
{
|
||||
Name: fmt.Sprintf("%d", d.cfg.ServerRegionID),
|
||||
Name: strconv.Itoa(d.cfg.ServerRegionID),
|
||||
RegionID: d.cfg.ServerRegionID,
|
||||
HostName: host,
|
||||
DERPPort: port,
|
||||
|
@ -103,7 +103,6 @@ func (e *ExtraRecordsMan) Run() {
|
||||
|
||||
return struct{}{}, nil
|
||||
}, backoff.WithBackOff(backoff.NewExponentialBackOff()))
|
||||
|
||||
if err != nil {
|
||||
log.Error().Caller().Err(err).Msgf("extra records filewatcher retrying to find file after delete")
|
||||
continue
|
||||
|
@ -475,7 +475,10 @@ func (api headscaleV1APIServer) RenameNode(
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
|
||||
ctx = types.NotifyCtx(ctx, "cli-renamenode-self", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyByNodeID(ctx, types.UpdateSelf(node.ID), node.ID)
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-renamenode-peers", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
|
||||
log.Trace().
|
||||
|
@ -32,7 +32,7 @@ const (
|
||||
reservedResponseHeaderSize = 4
|
||||
)
|
||||
|
||||
// httpError logs an error and sends an HTTP error response with the given
|
||||
// httpError logs an error and sends an HTTP error response with the given.
|
||||
func httpError(w http.ResponseWriter, err error) {
|
||||
var herr HTTPError
|
||||
if errors.As(err, &herr) {
|
||||
@ -102,6 +102,7 @@ func (h *Headscale) handleVerifyRequest(
|
||||
resp := &tailcfg.DERPAdmitClientResponse{
|
||||
Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic),
|
||||
}
|
||||
|
||||
return json.NewEncoder(writer).Encode(resp)
|
||||
}
|
||||
|
||||
|
@ -500,7 +500,7 @@ func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.
|
||||
}
|
||||
|
||||
// ListNodes queries the database for either all nodes if no parameters are given
|
||||
// or for the given nodes if at least one node ID is given as parameter
|
||||
// or for the given nodes if at least one node ID is given as parameter.
|
||||
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
nodes, err := m.state.ListNodes(nodeIDs...)
|
||||
if err != nil {
|
||||
|
@ -80,7 +80,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// mockState is a mock implementation that provides the required methods
|
||||
// mockState is a mock implementation that provides the required methods.
|
||||
type mockState struct {
|
||||
polMan policy.PolicyManager
|
||||
derpMap *tailcfg.DERPMap
|
||||
@ -133,6 +133,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
// Return all peers except the node itself
|
||||
@ -142,6 +143,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
|
||||
filtered = append(filtered, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
@ -157,8 +159,10 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
return m.nodes, nil
|
||||
}
|
||||
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
// NodeCanHaveTagChecker is an interface for checking if a node can have a tag
|
||||
// NodeCanHaveTagChecker is an interface for checking if a node can have a tag.
|
||||
type NodeCanHaveTagChecker interface {
|
||||
NodeCanHaveTag(node types.NodeView, tag string) bool
|
||||
}
|
||||
|
@ -111,5 +111,6 @@ func (r *respWriterProm) Write(b []byte) (int, error) {
|
||||
}
|
||||
n, err := r.ResponseWriter.Write(b)
|
||||
r.written += int64(n)
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
@ -50,6 +50,7 @@ func NewNotifier(cfg *types.Config) *Notifier {
|
||||
n.b = b
|
||||
|
||||
go b.doWork()
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
@ -72,7 +73,7 @@ func (n *Notifier) Close() {
|
||||
n.nodes = make(map[types.NodeID]chan<- types.StateUpdate)
|
||||
}
|
||||
|
||||
// safeCloseChannel closes a channel and panic recovers if already closed
|
||||
// safeCloseChannel closes a channel and panic recovers if already closed.
|
||||
func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@ -170,6 +171,7 @@ func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
|
||||
if val, ok := n.connected.Load(nodeID); ok {
|
||||
return val
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@ -182,7 +184,7 @@ func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// LikelyConnectedMap returns a thread safe map of connected nodes
|
||||
// LikelyConnectedMap returns a thread safe map of connected nodes.
|
||||
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
|
||||
return n.connected
|
||||
}
|
||||
|
@ -1,17 +1,15 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"slices"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
@ -241,7 +239,7 @@ func TestBatcher(t *testing.T) {
|
||||
defer n.RemoveNode(1, ch)
|
||||
|
||||
for _, u := range tt.updates {
|
||||
n.NotifyAll(context.Background(), u)
|
||||
n.NotifyAll(t.Context(), u)
|
||||
}
|
||||
|
||||
n.b.flush()
|
||||
@ -270,7 +268,7 @@ func TestBatcher(t *testing.T) {
|
||||
// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
|
||||
// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
|
||||
// close a channel that was already closed, which can happen when a node changes
|
||||
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting
|
||||
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting.
|
||||
func TestIsLikelyConnectedRaceCondition(t *testing.T) {
|
||||
// mock config for the notifier
|
||||
cfg := &types.Config{
|
||||
@ -308,16 +306,17 @@ func TestIsLikelyConnectedRaceCondition(t *testing.T) {
|
||||
for range iterations {
|
||||
// Simulate race by having some goroutines check IsLikelyConnected
|
||||
// while others add/remove the node
|
||||
if routineID%3 == 0 {
|
||||
switch routineID % 3 {
|
||||
case 0:
|
||||
// This goroutine checks connection status
|
||||
isConnected := notifier.IsLikelyConnected(nodeID)
|
||||
if isConnected != true && isConnected != false {
|
||||
errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
|
||||
}
|
||||
} else if routineID%3 == 1 {
|
||||
case 1:
|
||||
// This goroutine removes the node
|
||||
notifier.RemoveNode(nodeID, updateChan)
|
||||
} else {
|
||||
default:
|
||||
// This goroutine adds the node back
|
||||
notifier.AddNode(nodeID, updateChan)
|
||||
}
|
||||
|
@ -84,11 +84,8 @@ func NewAuthProviderOIDC(
|
||||
ClientID: cfg.ClientID,
|
||||
ClientSecret: cfg.ClientSecret,
|
||||
Endpoint: oidcProvider.Endpoint(),
|
||||
RedirectURL: fmt.Sprintf(
|
||||
"%s/oidc/callback",
|
||||
strings.TrimSuffix(serverURL, "/"),
|
||||
),
|
||||
Scopes: cfg.Scope,
|
||||
RedirectURL: strings.TrimSuffix(serverURL, "/") + "/oidc/callback",
|
||||
Scopes: cfg.Scope,
|
||||
}
|
||||
|
||||
registrationCache := zcache.New[string, RegistrationInfo](
|
||||
@ -131,7 +128,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
||||
req *http.Request,
|
||||
) {
|
||||
vars := mux.Vars(req)
|
||||
registrationIdStr, _ := vars["registration_id"]
|
||||
registrationIdStr := vars["registration_id"]
|
||||
|
||||
// We need to make sure we dont open for XSS style injections, if the parameter that
|
||||
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
|
||||
@ -232,7 +229,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
}
|
||||
|
||||
oauth2Token, err := a.getOauth2Token(req.Context(), code, state)
|
||||
|
||||
if err != nil {
|
||||
httpError(writer, err)
|
||||
return
|
||||
@ -364,6 +360,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
// Neither node nor machine key was found in the state cache meaning
|
||||
// that we could not reauth nor register the node.
|
||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -402,6 +399,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
|
||||
if err != nil {
|
||||
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err))
|
||||
}
|
||||
|
||||
return oauth2Token, err
|
||||
}
|
||||
|
||||
|
@ -2,9 +2,8 @@ package matcher
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"go4.org/netipx"
|
||||
@ -28,6 +27,7 @@ func (m Match) DebugString() string {
|
||||
for _, prefix := range m.dests.Prefixes() {
|
||||
sb.WriteString(" " + prefix.String() + "\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
@ -36,6 +36,7 @@ func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match {
|
||||
for _, rule := range rules {
|
||||
matches = append(matches, MatchFromFilterRule(rule))
|
||||
}
|
||||
|
||||
return matches
|
||||
}
|
||||
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
|
||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"slices"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/samber/lo"
|
||||
@ -131,7 +130,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
|
||||
// AutoApproveRoutes approves any route that can be autoapproved from
|
||||
// the nodes perspective according to the given policy.
|
||||
// It reports true if any routes were approved.
|
||||
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes
|
||||
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes.
|
||||
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
|
||||
if pm == nil {
|
||||
return false
|
||||
|
@ -7,9 +7,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
@ -1974,6 +1973,7 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReduceRoutes(t *testing.T) {
|
||||
type args struct {
|
||||
node *types.Node
|
||||
|
@ -13,9 +13,7 @@ import (
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidAction = errors.New("invalid action")
|
||||
)
|
||||
var ErrInvalidAction = errors.New("invalid action")
|
||||
|
||||
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||
@ -52,7 +50,7 @@ func (pol *Policy) compileFilterRules(
|
||||
|
||||
var destPorts []tailcfg.NetPortRange
|
||||
for _, dest := range acl.Destinations {
|
||||
ips, err := dest.Alias.Resolve(pol, users, nodes)
|
||||
ips, err := dest.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving destination ips")
|
||||
}
|
||||
@ -174,5 +172,6 @@ func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
|
||||
for _, pref := range ips.Prefixes() {
|
||||
out = append(out, pref.String())
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
@ -4,19 +4,17 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
|
||||
"slices"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/deephash"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/deephash"
|
||||
)
|
||||
|
||||
type PolicyManager struct {
|
||||
@ -166,6 +164,7 @@ func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
return pm.filter, pm.matchers
|
||||
}
|
||||
|
||||
@ -178,6 +177,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.users = users
|
||||
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
@ -190,6 +190,7 @@ func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, erro
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.nodes = nodes
|
||||
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
@ -249,7 +250,6 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
|
||||
// cannot just lookup in the prefix map and have to check
|
||||
// if there is a "parent" prefix available.
|
||||
for prefix, approveAddrs := range pm.autoApproveMap {
|
||||
|
||||
// Check if prefix is larger (so containing) and then overlaps
|
||||
// the route to see if the node can approve a subset of an autoapprover
|
||||
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {
|
||||
|
@ -1,10 +1,10 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
@ -6,9 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
@ -72,14 +72,14 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) {
|
||||
|
||||
// Check if it's the wildcard port range
|
||||
if len(a.Ports) == 1 && a.Ports[0].First == 0 && a.Ports[0].Last == 65535 {
|
||||
return json.Marshal(fmt.Sprintf("%s:*", alias))
|
||||
return json.Marshal(alias + ":*")
|
||||
}
|
||||
|
||||
// Otherwise, format as "alias:ports"
|
||||
var ports []string
|
||||
for _, port := range a.Ports {
|
||||
if port.First == port.Last {
|
||||
ports = append(ports, fmt.Sprintf("%d", port.First))
|
||||
ports = append(ports, strconv.FormatUint(uint64(port.First), 10))
|
||||
} else {
|
||||
ports = append(ports, fmt.Sprintf("%d-%d", port.First, port.Last))
|
||||
}
|
||||
@ -133,6 +133,7 @@ func (u *Username) UnmarshalJSON(b []byte) error {
|
||||
if err := u.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -203,7 +204,7 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.
|
||||
return buildIPSetMultiErr(&ips, errs)
|
||||
}
|
||||
|
||||
// Group is a special string which is always prefixed with `group:`
|
||||
// Group is a special string which is always prefixed with `group:`.
|
||||
type Group string
|
||||
|
||||
func (g Group) Validate() error {
|
||||
@ -218,6 +219,7 @@ func (g *Group) UnmarshalJSON(b []byte) error {
|
||||
if err := g.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -264,7 +266,7 @@ func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.Nod
|
||||
return buildIPSetMultiErr(&ips, errs)
|
||||
}
|
||||
|
||||
// Tag is a special string which is always prefixed with `tag:`
|
||||
// Tag is a special string which is always prefixed with `tag:`.
|
||||
type Tag string
|
||||
|
||||
func (t Tag) Validate() error {
|
||||
@ -279,6 +281,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error {
|
||||
if err := t.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -347,6 +350,7 @@ func (h *Host) UnmarshalJSON(b []byte) error {
|
||||
if err := h.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -409,6 +413,7 @@ func (p *Prefix) parseString(addr string) error {
|
||||
}
|
||||
|
||||
*p = Prefix(addrPref)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -417,6 +422,7 @@ func (p *Prefix) parseString(addr string) error {
|
||||
return err
|
||||
}
|
||||
*p = Prefix(pref)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -428,6 +434,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -462,7 +469,7 @@ func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuild
|
||||
}
|
||||
}
|
||||
|
||||
// AutoGroup is a special string which is always prefixed with `autogroup:`
|
||||
// AutoGroup is a special string which is always prefixed with `autogroup:`.
|
||||
type AutoGroup string
|
||||
|
||||
const (
|
||||
@ -495,6 +502,7 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error {
|
||||
if err := ag.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -632,13 +640,14 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ve.Alias.Validate(); err != nil {
|
||||
if err := ve.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("type %T not supported", vs)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -713,6 +722,7 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
|
||||
return err
|
||||
}
|
||||
ve.Alias = ptr
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -729,6 +739,7 @@ func (a *Aliases) UnmarshalJSON(b []byte) error {
|
||||
for i, alias := range aliases {
|
||||
(*a)[i] = alias.Alias
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -784,7 +795,7 @@ func buildIPSetMultiErr(ipBuilder *netipx.IPSetBuilder, errs []error) (*netipx.I
|
||||
return ips, multierr.New(append(errs, err)...)
|
||||
}
|
||||
|
||||
// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer
|
||||
// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer.
|
||||
func unmarshalPointer[T any](
|
||||
b []byte,
|
||||
parseFunc func(string) (T, error),
|
||||
@ -818,6 +829,7 @@ func (aa *AutoApprovers) UnmarshalJSON(b []byte) error {
|
||||
for i, autoApprover := range autoApprovers {
|
||||
(*aa)[i] = autoApprover.AutoApprover
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -874,6 +886,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error {
|
||||
return err
|
||||
}
|
||||
ve.AutoApprover = ptr
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -894,6 +907,7 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error {
|
||||
return err
|
||||
}
|
||||
ve.Owner = ptr
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -910,6 +924,7 @@ func (o *Owners) UnmarshalJSON(b []byte) error {
|
||||
for i, owner := range owners {
|
||||
(*o)[i] = owner.Owner
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -941,6 +956,7 @@ func parseOwner(s string) (Owner, error) {
|
||||
case isGroup(s):
|
||||
return ptr.To(Group(s)), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types:
|
||||
- user (containing an "@")
|
||||
- group (starting with "group:")
|
||||
@ -1001,6 +1017,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
|
||||
|
||||
(*g)[group] = usernames
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1252,7 +1269,7 @@ type Policy struct {
|
||||
// We use the default JSON marshalling behavior provided by the Go runtime.
|
||||
|
||||
var (
|
||||
// TODO(kradalby): Add these checks for tagOwners and autoApprovers
|
||||
// TODO(kradalby): Add these checks for tagOwners and autoApprovers.
|
||||
autogroupForSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged}
|
||||
autogroupForDst = []AutoGroup{AutoGroupInternet, AutoGroupMember, AutoGroupTagged}
|
||||
autogroupForSSHSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged}
|
||||
@ -1279,7 +1296,7 @@ func validateAutogroupForSrc(src *AutoGroup) error {
|
||||
}
|
||||
|
||||
if src.Is(AutoGroupInternet) {
|
||||
return fmt.Errorf(`"autogroup:internet" used in source, it can only be used in ACL destinations`)
|
||||
return errors.New(`"autogroup:internet" used in source, it can only be used in ACL destinations`)
|
||||
}
|
||||
|
||||
if !slices.Contains(autogroupForSrc, *src) {
|
||||
@ -1307,7 +1324,7 @@ func validateAutogroupForSSHSrc(src *AutoGroup) error {
|
||||
}
|
||||
|
||||
if src.Is(AutoGroupInternet) {
|
||||
return fmt.Errorf(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`)
|
||||
return errors.New(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`)
|
||||
}
|
||||
|
||||
if !slices.Contains(autogroupForSSHSrc, *src) {
|
||||
@ -1323,7 +1340,7 @@ func validateAutogroupForSSHDst(dst *AutoGroup) error {
|
||||
}
|
||||
|
||||
if dst.Is(AutoGroupInternet) {
|
||||
return fmt.Errorf(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`)
|
||||
return errors.New(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`)
|
||||
}
|
||||
|
||||
if !slices.Contains(autogroupForSSHDst, *dst) {
|
||||
@ -1360,14 +1377,14 @@ func (p *Policy) validate() error {
|
||||
|
||||
for _, acl := range p.ACLs {
|
||||
for _, src := range acl.Sources {
|
||||
switch src.(type) {
|
||||
switch src := src.(type) {
|
||||
case *Host:
|
||||
h := src.(*Host)
|
||||
h := src
|
||||
if !p.Hosts.exist(*h) {
|
||||
errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h))
|
||||
}
|
||||
case *AutoGroup:
|
||||
ag := src.(*AutoGroup)
|
||||
ag := src
|
||||
|
||||
if err := validateAutogroupSupported(ag); err != nil {
|
||||
errs = append(errs, err)
|
||||
@ -1379,12 +1396,12 @@ func (p *Policy) validate() error {
|
||||
continue
|
||||
}
|
||||
case *Group:
|
||||
g := src.(*Group)
|
||||
g := src
|
||||
if err := p.Groups.Contains(g); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
case *Tag:
|
||||
tagOwner := src.(*Tag)
|
||||
tagOwner := src
|
||||
if err := p.TagOwners.Contains(tagOwner); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
@ -1440,9 +1457,9 @@ func (p *Policy) validate() error {
|
||||
}
|
||||
|
||||
for _, src := range ssh.Sources {
|
||||
switch src.(type) {
|
||||
switch src := src.(type) {
|
||||
case *AutoGroup:
|
||||
ag := src.(*AutoGroup)
|
||||
ag := src
|
||||
|
||||
if err := validateAutogroupSupported(ag); err != nil {
|
||||
errs = append(errs, err)
|
||||
@ -1454,21 +1471,21 @@ func (p *Policy) validate() error {
|
||||
continue
|
||||
}
|
||||
case *Group:
|
||||
g := src.(*Group)
|
||||
g := src
|
||||
if err := p.Groups.Contains(g); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
case *Tag:
|
||||
tagOwner := src.(*Tag)
|
||||
tagOwner := src
|
||||
if err := p.TagOwners.Contains(tagOwner); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, dst := range ssh.Destinations {
|
||||
switch dst.(type) {
|
||||
switch dst := dst.(type) {
|
||||
case *AutoGroup:
|
||||
ag := dst.(*AutoGroup)
|
||||
ag := dst
|
||||
if err := validateAutogroupSupported(ag); err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
@ -1479,7 +1496,7 @@ func (p *Policy) validate() error {
|
||||
continue
|
||||
}
|
||||
case *Tag:
|
||||
tagOwner := dst.(*Tag)
|
||||
tagOwner := dst
|
||||
if err := p.TagOwners.Contains(tagOwner); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
@ -1489,9 +1506,9 @@ func (p *Policy) validate() error {
|
||||
|
||||
for _, tagOwners := range p.TagOwners {
|
||||
for _, tagOwner := range tagOwners {
|
||||
switch tagOwner.(type) {
|
||||
switch tagOwner := tagOwner.(type) {
|
||||
case *Group:
|
||||
g := tagOwner.(*Group)
|
||||
g := tagOwner
|
||||
if err := p.Groups.Contains(g); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
@ -1501,14 +1518,14 @@ func (p *Policy) validate() error {
|
||||
|
||||
for _, approvers := range p.AutoApprovers.Routes {
|
||||
for _, approver := range approvers {
|
||||
switch approver.(type) {
|
||||
switch approver := approver.(type) {
|
||||
case *Group:
|
||||
g := approver.(*Group)
|
||||
g := approver
|
||||
if err := p.Groups.Contains(g); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
case *Tag:
|
||||
tagOwner := approver.(*Tag)
|
||||
tagOwner := approver
|
||||
if err := p.TagOwners.Contains(tagOwner); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
@ -1517,14 +1534,14 @@ func (p *Policy) validate() error {
|
||||
}
|
||||
|
||||
for _, approver := range p.AutoApprovers.ExitNode {
|
||||
switch approver.(type) {
|
||||
switch approver := approver.(type) {
|
||||
case *Group:
|
||||
g := approver.(*Group)
|
||||
g := approver
|
||||
if err := p.Groups.Contains(g); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
case *Tag:
|
||||
tagOwner := approver.(*Tag)
|
||||
tagOwner := approver
|
||||
if err := p.TagOwners.Contains(tagOwner); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
@ -1536,6 +1553,7 @@ func (p *Policy) validate() error {
|
||||
}
|
||||
|
||||
p.validated = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1589,6 +1607,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1618,6 +1637,7 @@ func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -5,13 +5,13 @@ import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/prometheus/common/model"
|
||||
"time"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go4.org/netipx"
|
||||
@ -68,7 +68,7 @@ func TestMarshalJSON(t *testing.T) {
|
||||
// Marshal the policy to JSON
|
||||
marshalled, err := json.MarshalIndent(policy, "", " ")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// Make sure all expected fields are present in the JSON
|
||||
jsonString := string(marshalled)
|
||||
assert.Contains(t, jsonString, "group:example")
|
||||
@ -79,21 +79,21 @@ func TestMarshalJSON(t *testing.T) {
|
||||
assert.Contains(t, jsonString, "accept")
|
||||
assert.Contains(t, jsonString, "tcp")
|
||||
assert.Contains(t, jsonString, "80")
|
||||
|
||||
|
||||
// Unmarshal back to verify round trip
|
||||
var roundTripped Policy
|
||||
err = json.Unmarshal(marshalled, &roundTripped)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// Compare the original and round-tripped policies
|
||||
cmps := append(util.Comparers,
|
||||
cmps := append(util.Comparers,
|
||||
cmp.Comparer(func(x, y Prefix) bool {
|
||||
return x == y
|
||||
}),
|
||||
cmpopts.IgnoreUnexported(Policy{}),
|
||||
cmpopts.EquateEmpty(),
|
||||
)
|
||||
|
||||
|
||||
if diff := cmp.Diff(policy, &roundTripped, cmps...); diff != "" {
|
||||
t.Fatalf("round trip policy (-original +roundtripped):\n%s", diff)
|
||||
}
|
||||
@ -958,13 +958,13 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
cmps := append(util.Comparers,
|
||||
cmps := append(util.Comparers,
|
||||
cmp.Comparer(func(x, y Prefix) bool {
|
||||
return x == y
|
||||
}),
|
||||
cmpopts.IgnoreUnexported(Policy{}),
|
||||
)
|
||||
|
||||
|
||||
// For round-trip testing, we'll normalize the policies before comparing
|
||||
|
||||
for _, tt := range tests {
|
||||
@ -981,6 +981,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
} else if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("unmarshalling: got err %v; want error %q", err, tt.wantErr)
|
||||
}
|
||||
|
||||
return // Skip the rest of the test if we expected an error
|
||||
}
|
||||
|
||||
@ -1001,9 +1002,9 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("round-trip unmarshalling: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// Add EquateEmpty to handle nil vs empty maps/slices
|
||||
roundTripCmps := append(cmps,
|
||||
roundTripCmps := append(cmps,
|
||||
cmpopts.EquateEmpty(),
|
||||
cmpopts.IgnoreUnexported(Policy{}),
|
||||
)
|
||||
@ -1584,6 +1585,7 @@ func mustIPSet(prefixes ...string) *netipx.IPSet {
|
||||
builder.AddPrefix(mp(p))
|
||||
}
|
||||
ipSet, _ := builder.IPSet()
|
||||
|
||||
return ipSet
|
||||
}
|
||||
|
||||
|
@ -73,10 +73,10 @@ func TestParsePortRange(t *testing.T) {
|
||||
expected []tailcfg.PortRange
|
||||
err string
|
||||
}{
|
||||
{"80", []tailcfg.PortRange{{80, 80}}, ""},
|
||||
{"80-90", []tailcfg.PortRange{{80, 90}}, ""},
|
||||
{"80,90", []tailcfg.PortRange{{80, 80}, {90, 90}}, ""},
|
||||
{"80-91,92,93-95", []tailcfg.PortRange{{80, 91}, {92, 92}, {93, 95}}, ""},
|
||||
{"80", []tailcfg.PortRange{{First: 80, Last: 80}}, ""},
|
||||
{"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, ""},
|
||||
{"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, ""},
|
||||
{"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, ""},
|
||||
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""},
|
||||
{"80-", nil, "invalid port range format"},
|
||||
{"-90", nil, "invalid port range format"},
|
||||
|
@ -158,6 +158,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix {
|
||||
}
|
||||
|
||||
tsaddr.SortPrefixes(routes)
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
|
@ -429,6 +429,7 @@ func (s *State) GetNodeViewByID(nodeID types.NodeID) (types.NodeView, error) {
|
||||
if err != nil {
|
||||
return types.NodeView{}, err
|
||||
}
|
||||
|
||||
return node.View(), nil
|
||||
}
|
||||
|
||||
@ -443,6 +444,7 @@ func (s *State) GetNodeViewByNodeKey(nodeKey key.NodePublic) (types.NodeView, er
|
||||
if err != nil {
|
||||
return types.NodeView{}, err
|
||||
}
|
||||
|
||||
return node.View(), nil
|
||||
}
|
||||
|
||||
@ -701,7 +703,7 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) {
|
||||
nodeToRegister.Expiry = ®Req.Expiry
|
||||
} else if !regReq.Expiry.IsZero() {
|
||||
// If client is sending an expired time (e.g., after logout),
|
||||
// If client is sending an expired time (e.g., after logout),
|
||||
// don't set expiry so the node won't be considered expired
|
||||
log.Debug().
|
||||
Time("requested_expiry", regReq.Expiry).
|
||||
|
@ -2,6 +2,7 @@ package hscontrol
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
@ -70,7 +71,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
|
||||
// When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443.
|
||||
certDomains := tsNode.CertDomains()
|
||||
if len(certDomains) == 0 {
|
||||
return fmt.Errorf("no cert domains available for HTTPS")
|
||||
return errors.New("no cert domains available for HTTPS")
|
||||
}
|
||||
base := "https://" + certDomains[0]
|
||||
go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@ -95,5 +96,6 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
|
||||
logf("TailSQL started")
|
||||
<-ctx.Done()
|
||||
logf("TailSQL shutting down...")
|
||||
|
||||
return tsNode.Close()
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ func Apple(url string) *elem.Element {
|
||||
),
|
||||
elem.Pre(nil,
|
||||
elem.Code(nil,
|
||||
elem.Text(fmt.Sprintf("tailscale login --login-server %s", url)),
|
||||
elem.Text("tailscale login --login-server "+url),
|
||||
),
|
||||
),
|
||||
headerTwo("GUI"),
|
||||
@ -143,10 +143,7 @@ func Apple(url string) *elem.Element {
|
||||
elem.Code(
|
||||
nil,
|
||||
elem.Text(
|
||||
fmt.Sprintf(
|
||||
`defaults write io.tailscale.ipn.macos ControlURL %s`,
|
||||
url,
|
||||
),
|
||||
"defaults write io.tailscale.ipn.macos ControlURL "+url,
|
||||
),
|
||||
),
|
||||
),
|
||||
@ -155,10 +152,7 @@ func Apple(url string) *elem.Element {
|
||||
elem.Code(
|
||||
nil,
|
||||
elem.Text(
|
||||
fmt.Sprintf(
|
||||
`defaults write io.tailscale.ipn.macsys ControlURL %s`,
|
||||
url,
|
||||
),
|
||||
"defaults write io.tailscale.ipn.macsys ControlURL "+url,
|
||||
),
|
||||
),
|
||||
),
|
||||
|
@ -1,8 +1,6 @@
|
||||
package templates
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/chasefleming/elem-go"
|
||||
"github.com/chasefleming/elem-go/attrs"
|
||||
)
|
||||
@ -31,7 +29,7 @@ func Windows(url string) *elem.Element {
|
||||
),
|
||||
elem.Pre(nil,
|
||||
elem.Code(nil,
|
||||
elem.Text(fmt.Sprintf(`tailscale login --login-server %s`, url)),
|
||||
elem.Text("tailscale login --login-server "+url),
|
||||
),
|
||||
),
|
||||
),
|
||||
|
@ -180,6 +180,7 @@ func MustRegistrationID() RegistrationID {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return rid
|
||||
}
|
||||
|
||||
|
@ -339,6 +339,7 @@ func LoadConfig(path string, isFile bool) error {
|
||||
log.Warn().Msg("No config file found, using defaults")
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("fatal error reading config file: %w", err)
|
||||
}
|
||||
|
||||
@ -843,7 +844,7 @@ func LoadServerConfig() (*Config, error) {
|
||||
}
|
||||
|
||||
if prefix4 == nil && prefix6 == nil {
|
||||
return nil, fmt.Errorf("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
|
||||
return nil, errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
|
||||
}
|
||||
|
||||
allocStr := viper.GetString("prefixes.allocation")
|
||||
@ -1020,7 +1021,7 @@ func isSafeServerURL(serverURL, baseDomain string) error {
|
||||
|
||||
s := len(serverDomainParts)
|
||||
b := len(baseDomainParts)
|
||||
for i := range len(baseDomainParts) {
|
||||
for i := range baseDomainParts {
|
||||
if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] {
|
||||
return nil
|
||||
}
|
||||
|
@ -282,6 +282,7 @@ func TestReadConfigFromEnv(t *testing.T) {
|
||||
assert.Equal(t, "trace", viper.GetString("log.level"))
|
||||
assert.Equal(t, "100.64.0.0/10", viper.GetString("prefixes.v4"))
|
||||
assert.False(t, viper.GetBool("database.sqlite.write_ahead_log"))
|
||||
|
||||
return nil, nil
|
||||
},
|
||||
want: nil,
|
||||
|
@ -28,8 +28,10 @@ var (
|
||||
ErrNodeUserHasNoName = errors.New("node user has no name")
|
||||
)
|
||||
|
||||
type NodeID uint64
|
||||
type NodeIDs []NodeID
|
||||
type (
|
||||
NodeID uint64
|
||||
NodeIDs []NodeID
|
||||
)
|
||||
|
||||
func (n NodeIDs) Len() int { return len(n) }
|
||||
func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] }
|
||||
@ -169,6 +171,7 @@ func (node *Node) HasIP(i netip.Addr) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@ -176,7 +179,7 @@ func (node *Node) HasIP(i netip.Addr) bool {
|
||||
// and therefore should not be treated as a
|
||||
// user owned device.
|
||||
// Currently, this function only handles tags set
|
||||
// via CLI ("forced tags" and preauthkeys)
|
||||
// via CLI ("forced tags" and preauthkeys).
|
||||
func (node *Node) IsTagged() bool {
|
||||
if len(node.ForcedTags) > 0 {
|
||||
return true
|
||||
@ -199,7 +202,7 @@ func (node *Node) IsTagged() bool {
|
||||
|
||||
// HasTag reports if a node has a given tag.
|
||||
// Currently, this function only handles tags set
|
||||
// via CLI ("forced tags" and preauthkeys)
|
||||
// via CLI ("forced tags" and preauthkeys).
|
||||
func (node *Node) HasTag(tag string) bool {
|
||||
return slices.Contains(node.Tags(), tag)
|
||||
}
|
||||
@ -577,6 +580,7 @@ func (nodes Nodes) DebugString() string {
|
||||
sb.WriteString(node.DebugString())
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
@ -590,6 +594,7 @@ func (node Node) DebugString() string {
|
||||
fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes())
|
||||
fmt.Fprintf(&sb, "\tSubnetRoutes: %v\n", node.SubnetRoutes())
|
||||
sb.WriteString("\n")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
@ -689,7 +694,7 @@ func (v NodeView) Tags() []string {
|
||||
// and therefore should not be treated as a
|
||||
// user owned device.
|
||||
// Currently, this function only handles tags set
|
||||
// via CLI ("forced tags" and preauthkeys)
|
||||
// via CLI ("forced tags" and preauthkeys).
|
||||
func (v NodeView) IsTagged() bool {
|
||||
if !v.Valid() {
|
||||
return false
|
||||
@ -727,7 +732,7 @@ func (v NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerC
|
||||
// GetFQDN returns the fully qualified domain name for the node.
|
||||
func (v NodeView) GetFQDN(baseDomain string) (string, error) {
|
||||
if !v.Valid() {
|
||||
return "", fmt.Errorf("failed to create valid FQDN: node view is invalid")
|
||||
return "", errors.New("failed to create valid FQDN: node view is invalid")
|
||||
}
|
||||
return v.ж.GetFQDN(baseDomain)
|
||||
}
|
||||
@ -773,4 +778,3 @@ func (v NodeView) IPsAsString() []string {
|
||||
}
|
||||
return v.ж.IPsAsString()
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,6 @@ package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -10,6 +9,7 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
type PAKError string
|
||||
|
||||
func (e PAKError) Error() string { return string(e) }
|
||||
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %s", e) }
|
||||
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %w", e) }
|
||||
|
||||
// PreAuthKey describes a pre-authorization key usable in a particular user.
|
||||
type PreAuthKey struct {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -109,7 +110,8 @@ func TestCanUsePreAuthKey(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
} else {
|
||||
httpErr, ok := err.(PAKError)
|
||||
var httpErr PAKError
|
||||
ok := errors.As(err, &httpErr)
|
||||
if !ok {
|
||||
t.Errorf("expected HTTPError but got %T", err)
|
||||
} else {
|
||||
|
@ -249,7 +249,7 @@ func (c *OIDCClaims) Identifier() string {
|
||||
// - Remove empty path segments
|
||||
// - For non-URL identifiers, it joins non-empty segments with a single slash
|
||||
// - Returns empty string for identifiers with only slashes
|
||||
// - Normalize URL schemes to lowercase
|
||||
// - Normalize URL schemes to lowercase.
|
||||
func CleanIdentifier(identifier string) string {
|
||||
if identifier == "" {
|
||||
return identifier
|
||||
@ -273,7 +273,7 @@ func CleanIdentifier(identifier string) string {
|
||||
cleanParts = append(cleanParts, part)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if len(cleanParts) == 0 {
|
||||
u.Path = ""
|
||||
} else {
|
||||
@ -281,6 +281,7 @@ func CleanIdentifier(identifier string) string {
|
||||
}
|
||||
// Ensure scheme is lowercase
|
||||
u.Scheme = strings.ToLower(u.Scheme)
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
@ -297,6 +298,7 @@ func CleanIdentifier(identifier string) string {
|
||||
if len(cleanParts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.Join(cleanParts, "/")
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
package types
|
||||
|
||||
var Version = "dev"
|
||||
var GitCommitHash = "dev"
|
||||
var (
|
||||
Version = "dev"
|
||||
GitCommitHash = "dev"
|
||||
)
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
@ -21,8 +22,10 @@ const (
|
||||
LabelHostnameLength = 63
|
||||
)
|
||||
|
||||
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||
var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||
var (
|
||||
invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||
invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||
)
|
||||
|
||||
var ErrInvalidUserName = errors.New("invalid user name")
|
||||
|
||||
@ -141,7 +144,7 @@ func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||
// here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.)
|
||||
rdnsSlice := []string{}
|
||||
for i := lastOctet - 1; i >= 0; i-- {
|
||||
rdnsSlice = append(rdnsSlice, fmt.Sprintf("%d", netRange.IP[i]))
|
||||
rdnsSlice = append(rdnsSlice, strconv.FormatUint(uint64(netRange.IP[i]), 10))
|
||||
}
|
||||
rdnsSlice = append(rdnsSlice, "in-addr.arpa.")
|
||||
rdnsBase := strings.Join(rdnsSlice, ".")
|
||||
@ -205,7 +208,7 @@ func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||
makeDomain := func(variablePrefix ...string) (dnsname.FQDN, error) {
|
||||
prefix := strings.Join(append(variablePrefix, prefixConstantParts...), ".")
|
||||
|
||||
return dnsname.ToFQDN(fmt.Sprintf("%s.ip6.arpa", prefix))
|
||||
return dnsname.ToFQDN(prefix + ".ip6.arpa")
|
||||
}
|
||||
|
||||
var fqdns []dnsname.FQDN
|
||||
|
@ -70,7 +70,7 @@ func (l *DBLogWrapper) Trace(ctx context.Context, begin time.Time, fc func() (sq
|
||||
"rowsAffected": rowsAffected,
|
||||
}
|
||||
|
||||
if err != nil && !(errors.Is(err, gorm.ErrRecordNotFound) && l.SkipErrRecordNotFound) {
|
||||
if err != nil && (!errors.Is(err, gorm.ErrRecordNotFound) || !l.SkipErrRecordNotFound) {
|
||||
l.Logger.Error().Err(err).Fields(fields).Msgf("")
|
||||
return
|
||||
}
|
||||
|
@ -58,5 +58,6 @@ var TheInternet = sync.OnceValue(func() *netipx.IPSet {
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16"))
|
||||
|
||||
theInternetSet, _ := internetBuilder.IPSet()
|
||||
|
||||
return theInternetSet
|
||||
})
|
||||
|
@ -53,37 +53,37 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) {
|
||||
}
|
||||
|
||||
type TraceroutePath struct {
|
||||
// Hop is the current jump in the total traceroute.
|
||||
Hop int
|
||||
// Hop is the current jump in the total traceroute.
|
||||
Hop int
|
||||
|
||||
// Hostname is the resolved hostname or IP address identifying the jump
|
||||
Hostname string
|
||||
// Hostname is the resolved hostname or IP address identifying the jump
|
||||
Hostname string
|
||||
|
||||
// IP is the IP address of the jump
|
||||
IP netip.Addr
|
||||
// IP is the IP address of the jump
|
||||
IP netip.Addr
|
||||
|
||||
// Latencies is a list of the latencies for this jump
|
||||
Latencies []time.Duration
|
||||
// Latencies is a list of the latencies for this jump
|
||||
Latencies []time.Duration
|
||||
}
|
||||
|
||||
type Traceroute struct {
|
||||
// Hostname is the resolved hostname or IP address identifying the target
|
||||
Hostname string
|
||||
// Hostname is the resolved hostname or IP address identifying the target
|
||||
Hostname string
|
||||
|
||||
// IP is the IP address of the target
|
||||
IP netip.Addr
|
||||
// IP is the IP address of the target
|
||||
IP netip.Addr
|
||||
|
||||
// Route is the path taken to reach the target if successful. The list is ordered by the path taken.
|
||||
Route []TraceroutePath
|
||||
// Route is the path taken to reach the target if successful. The list is ordered by the path taken.
|
||||
Route []TraceroutePath
|
||||
|
||||
// Success indicates if the traceroute was successful.
|
||||
Success bool
|
||||
// Success indicates if the traceroute was successful.
|
||||
Success bool
|
||||
|
||||
// Err contains an error if the traceroute was not successful.
|
||||
Err error
|
||||
// Err contains an error if the traceroute was not successful.
|
||||
Err error
|
||||
}
|
||||
|
||||
// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct
|
||||
// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct.
|
||||
func ParseTraceroute(output string) (Traceroute, error) {
|
||||
lines := strings.Split(strings.TrimSpace(output), "\n")
|
||||
if len(lines) < 1 {
|
||||
@ -112,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
|
||||
}
|
||||
|
||||
// Parse each hop line
|
||||
hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`)
|
||||
hopRegex := regexp.MustCompile("^\\s*(\\d+)\\s+(?:([^ ]+) \\(([^)]+)\\)|(\\*))(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?")
|
||||
|
||||
for i := 1; i < len(lines); i++ {
|
||||
matches := hopRegex.FindStringSubmatch(lines[i])
|
||||
|
@ -1077,7 +1077,6 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||
|
||||
func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: 1,
|
||||
@ -1213,7 +1212,6 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
|
||||
|
||||
func TestACLAutogroupMember(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
scenario := aclScenario(t,
|
||||
&policyv2.Policy{
|
||||
@ -1271,7 +1269,6 @@ func TestACLAutogroupMember(t *testing.T) {
|
||||
|
||||
func TestACLAutogroupTagged(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
scenario := aclScenario(t,
|
||||
&policyv2.Policy{
|
||||
|
@ -3,12 +3,11 @@ package integration
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"slices"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
@ -19,7 +18,6 @@ import (
|
||||
|
||||
func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
for _, https := range []bool{true, false} {
|
||||
t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) {
|
||||
@ -66,7 +64,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
|
||||
listNodes, err := headscale.ListNodes()
|
||||
assert.Equal(t, len(listNodes), len(allClients))
|
||||
assert.Len(t, allClients, len(listNodes))
|
||||
nodeCountBeforeLogout := len(listNodes)
|
||||
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
|
||||
|
||||
@ -87,7 +85,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
t.Logf("all clients logged out")
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
require.Equal(t, nodeCountBeforeLogout, len(listNodes))
|
||||
require.Len(t, listNodes, nodeCountBeforeLogout)
|
||||
|
||||
for _, node := range listNodes {
|
||||
assertLastSeenSet(t, node)
|
||||
@ -99,26 +97,48 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
// https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38
|
||||
// https://github.com/juanfont/headscale/issues/2164
|
||||
if !https {
|
||||
time.Sleep(5 * time.Minute)
|
||||
}
|
||||
userMap, err := headscale.MapUsers()
|
||||
assertNoErr(t, err)
|
||||
|
||||
userMap, err := headscale.MapUsers()
|
||||
assertNoErr(t, err)
|
||||
|
||||
for _, userName := range spec.Users {
|
||||
key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err)
|
||||
// Create auth keys once outside the retry loop
|
||||
userKeys := make(map[string]string)
|
||||
for _, userName := range spec.Users {
|
||||
key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false)
|
||||
assertNoErr(t, err)
|
||||
userKeys[userName] = key.GetKey()
|
||||
}
|
||||
|
||||
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to run tailscale up for user %s: %s", userName, err)
|
||||
// Wait for the 2-minute noise dial memory to expire
|
||||
// The Tailscale commit shows clients remember noise dials for 2 minutes
|
||||
t.Logf("Waiting 2.5 minutes for Tailscale noise dial memory to expire...")
|
||||
time.Sleep(2*time.Minute + 30*time.Second)
|
||||
|
||||
// Wait for clients to be ready to reconnect over HTTP after HTTPS
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
for _, userName := range spec.Users {
|
||||
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), userKeys[userName])
|
||||
assert.NoError(ct, err, "Client should be able to reconnect over HTTP")
|
||||
}
|
||||
}, 6*time.Minute, 30*time.Second)
|
||||
} else {
|
||||
userMap, err := headscale.MapUsers()
|
||||
assertNoErr(t, err)
|
||||
|
||||
for _, userName := range spec.Users {
|
||||
key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err)
|
||||
}
|
||||
|
||||
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to run tailscale up for user %s: %s", userName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
require.Equal(t, nodeCountBeforeLogout, len(listNodes))
|
||||
require.Len(t, listNodes, nodeCountBeforeLogout)
|
||||
|
||||
for _, node := range listNodes {
|
||||
assertLastSeenSet(t, node)
|
||||
@ -155,18 +175,17 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
}
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
require.Equal(t, nodeCountBeforeLogout, len(listNodes))
|
||||
require.Len(t, listNodes, nodeCountBeforeLogout)
|
||||
for _, node := range listNodes {
|
||||
assertLastSeenSet(t, node)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func assertLastSeenSet(t *testing.T, node *v1.Node) {
|
||||
assert.NotNil(t, node)
|
||||
assert.NotNil(t, node.LastSeen)
|
||||
assert.NotNil(t, node.GetLastSeen())
|
||||
}
|
||||
|
||||
// This test will first log in two sets of nodes to two sets of users, then
|
||||
@ -175,7 +194,6 @@ func assertLastSeenSet(t *testing.T, node *v1.Node) {
|
||||
// still has nodes, but they are not connected.
|
||||
func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: len(MustTestVersions),
|
||||
@ -204,7 +222,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
|
||||
listNodes, err := headscale.ListNodes()
|
||||
assert.Equal(t, len(listNodes), len(allClients))
|
||||
assert.Len(t, allClients, len(listNodes))
|
||||
nodeCountBeforeLogout := len(listNodes)
|
||||
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
|
||||
|
||||
@ -259,7 +277,6 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
|
||||
|
||||
func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
for _, https := range []bool{true, false} {
|
||||
t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) {
|
||||
@ -303,7 +320,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
|
||||
listNodes, err := headscale.ListNodes()
|
||||
assert.Equal(t, len(listNodes), len(allClients))
|
||||
assert.Len(t, allClients, len(listNodes))
|
||||
nodeCountBeforeLogout := len(listNodes)
|
||||
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
|
||||
|
||||
@ -325,32 +342,62 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
|
||||
// https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38
|
||||
// https://github.com/juanfont/headscale/issues/2164
|
||||
if !https {
|
||||
time.Sleep(5 * time.Minute)
|
||||
}
|
||||
|
||||
userMap, err := headscale.MapUsers()
|
||||
assertNoErr(t, err)
|
||||
|
||||
for _, userName := range spec.Users {
|
||||
key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err)
|
||||
}
|
||||
|
||||
// Expire the key so it can't be used
|
||||
_, err = headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"preauthkeys",
|
||||
"--user",
|
||||
strconv.FormatUint(userMap[userName].GetId(), 10),
|
||||
"expire",
|
||||
key.Key,
|
||||
})
|
||||
userMap, err := headscale.MapUsers()
|
||||
assertNoErr(t, err)
|
||||
|
||||
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey())
|
||||
assert.ErrorContains(t, err, "authkey expired")
|
||||
// Create and expire auth keys once outside the retry loop
|
||||
userExpiredKeys := make(map[string]string)
|
||||
for _, userName := range spec.Users {
|
||||
key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false)
|
||||
assertNoErr(t, err)
|
||||
|
||||
// Expire the key so it can't be used
|
||||
_, err = headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"preauthkeys",
|
||||
"--user",
|
||||
strconv.FormatUint(userMap[userName].GetId(), 10),
|
||||
"expire",
|
||||
key.GetKey(),
|
||||
})
|
||||
assertNoErr(t, err)
|
||||
userExpiredKeys[userName] = key.GetKey()
|
||||
}
|
||||
|
||||
// Wait for clients to be ready to reconnect over HTTP after HTTPS
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
for _, userName := range spec.Users {
|
||||
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), userExpiredKeys[userName])
|
||||
assert.Error(ct, err, "Should get error when using expired key")
|
||||
assert.Contains(ct, err.Error(), "authkey expired")
|
||||
}
|
||||
}, 6*time.Minute, 30*time.Second)
|
||||
} else {
|
||||
userMap, err := headscale.MapUsers()
|
||||
assertNoErr(t, err)
|
||||
|
||||
for _, userName := range spec.Users {
|
||||
key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err)
|
||||
}
|
||||
|
||||
// Expire the key so it can't be used
|
||||
_, err = headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"preauthkeys",
|
||||
"--user",
|
||||
strconv.FormatUint(userMap[userName].GetId(), 10),
|
||||
"expire",
|
||||
key.GetKey(),
|
||||
})
|
||||
assertNoErr(t, err)
|
||||
|
||||
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey())
|
||||
assert.ErrorContains(t, err, "authkey expired")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -1,14 +1,12 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"maps"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
@ -21,7 +19,6 @@ import (
|
||||
|
||||
func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
// Logins to MockOIDC is served by a queue with a strict order,
|
||||
// if we use more than one node per user, the order of the logins
|
||||
@ -119,7 +116,6 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||
// This test is really flaky.
|
||||
func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
shortAccessTTL := 5 * time.Minute
|
||||
|
||||
@ -174,9 +170,13 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
||||
// of safety reasons) before checking if the clients have logged out.
|
||||
// The Wait function can't do it itself as it has an upper bound of 1
|
||||
// min.
|
||||
time.Sleep(shortAccessTTL + 10*time.Second)
|
||||
|
||||
assertTailscaleNodesLogout(t, allClients)
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assert.NoError(ct, err)
|
||||
assert.Equal(ct, "NeedsLogin", status.BackendState)
|
||||
}
|
||||
}, shortAccessTTL+10*time.Second, 5*time.Second)
|
||||
}
|
||||
|
||||
func TestOIDC024UserCreation(t *testing.T) {
|
||||
@ -295,9 +295,7 @@ func TestOIDC024UserCreation(t *testing.T) {
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: 1,
|
||||
}
|
||||
for _, user := range tt.cliUsers {
|
||||
spec.Users = append(spec.Users, user)
|
||||
}
|
||||
spec.Users = append(spec.Users, tt.cliUsers...)
|
||||
|
||||
for _, user := range tt.oidcUsers {
|
||||
spec.OIDCUsers = append(spec.OIDCUsers, oidcMockUser(user, tt.emailVerified))
|
||||
@ -350,7 +348,6 @@ func TestOIDC024UserCreation(t *testing.T) {
|
||||
|
||||
func TestOIDCAuthenticationWithPKCE(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
// Single user with one node for testing PKCE flow
|
||||
spec := ScenarioSpec{
|
||||
@ -402,7 +399,6 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
|
||||
|
||||
func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
// Create no nodes and no users
|
||||
scenario, err := NewScenario(ScenarioSpec{
|
||||
@ -440,7 +436,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
|
||||
listUsers, err := headscale.ListUsers()
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, listUsers, 0)
|
||||
assert.Empty(t, listUsers)
|
||||
|
||||
ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]))
|
||||
assertNoErr(t, err)
|
||||
@ -482,7 +478,13 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
err = ts.Logout()
|
||||
assertNoErr(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for logout to complete and then do second logout
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
// Check that the first logout completed
|
||||
status, err := ts.Status()
|
||||
assert.NoError(ct, err)
|
||||
assert.Equal(ct, "NeedsLogin", status.BackendState)
|
||||
}, 5*time.Second, 1*time.Second)
|
||||
|
||||
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
|
||||
// logs in immediately after the first logout and I cannot reproduce it
|
||||
@ -530,16 +532,22 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
|
||||
// Machine key is the same as the "machine" has not changed,
|
||||
// but Node key is not as it is a new node
|
||||
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey)
|
||||
assert.Equal(t, listNodesAfterNewUserLogin[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey)
|
||||
assert.NotEqual(t, listNodesAfterNewUserLogin[0].NodeKey, listNodesAfterNewUserLogin[1].NodeKey)
|
||||
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
|
||||
assert.Equal(t, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
|
||||
assert.NotEqual(t, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey())
|
||||
|
||||
// Log out user2, and log into user1, no new node should be created,
|
||||
// the node should now "become" node1 again
|
||||
err = ts.Logout()
|
||||
assertNoErr(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for logout to complete and then do second logout
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
// Check that the first logout completed
|
||||
status, err := ts.Status()
|
||||
assert.NoError(ct, err)
|
||||
assert.Equal(ct, "NeedsLogin", status.BackendState)
|
||||
}, 5*time.Second, 1*time.Second)
|
||||
|
||||
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
|
||||
// logs in immediately after the first logout and I cannot reproduce it
|
||||
@ -588,24 +596,24 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||
// Validate that the machine we had when we logged in the first time, has the same
|
||||
// machine key, but a different ID than the newly logged in version of the same
|
||||
// machine.
|
||||
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey)
|
||||
assert.Equal(t, listNodes[0].NodeKey, listNodesAfterNewUserLogin[0].NodeKey)
|
||||
assert.Equal(t, listNodes[0].Id, listNodesAfterNewUserLogin[0].Id)
|
||||
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey)
|
||||
assert.NotEqual(t, listNodes[0].Id, listNodesAfterNewUserLogin[1].Id)
|
||||
assert.NotEqual(t, listNodes[0].User.Id, listNodesAfterNewUserLogin[1].User.Id)
|
||||
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey())
|
||||
assert.Equal(t, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey())
|
||||
assert.Equal(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId())
|
||||
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey())
|
||||
assert.NotEqual(t, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId())
|
||||
assert.NotEqual(t, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId())
|
||||
|
||||
// Even tho we are logging in again with the same user, the previous key has been expired
|
||||
// and a new one has been generated. The node entry in the database should be the same
|
||||
// as the user + machinekey still matches.
|
||||
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterLoggingBackIn[0].MachineKey)
|
||||
assert.NotEqual(t, listNodes[0].NodeKey, listNodesAfterLoggingBackIn[0].NodeKey)
|
||||
assert.Equal(t, listNodes[0].Id, listNodesAfterLoggingBackIn[0].Id)
|
||||
assert.Equal(t, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey())
|
||||
assert.NotEqual(t, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey())
|
||||
assert.Equal(t, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId())
|
||||
|
||||
// The "logged back in" machine should have the same machinekey but a different nodekey
|
||||
// than the version logged in with a different user.
|
||||
assert.Equal(t, listNodesAfterLoggingBackIn[0].MachineKey, listNodesAfterLoggingBackIn[1].MachineKey)
|
||||
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].NodeKey, listNodesAfterLoggingBackIn[1].NodeKey)
|
||||
assert.Equal(t, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey())
|
||||
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey())
|
||||
}
|
||||
|
||||
func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) {
|
||||
@ -623,7 +631,7 @@ func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser {
|
||||
return mockoidc.MockUser{
|
||||
Subject: username,
|
||||
PreferredUsername: username,
|
||||
Email: fmt.Sprintf("%s@headscale.net", username),
|
||||
Email: username + "@headscale.net",
|
||||
EmailVerified: emailVerified,
|
||||
}
|
||||
}
|
||||
|
@ -2,9 +2,8 @@ package integration
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/samber/lo"
|
||||
@ -55,7 +54,6 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
|
||||
|
||||
func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: len(MustTestVersions),
|
||||
@ -95,7 +93,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||
assertNoErrGetHeadscale(t, err)
|
||||
|
||||
listNodes, err := headscale.ListNodes()
|
||||
assert.Equal(t, len(listNodes), len(allClients))
|
||||
assert.Len(t, allClients, len(listNodes))
|
||||
nodeCountBeforeLogout := len(listNodes)
|
||||
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
|
||||
|
||||
@ -140,7 +138,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
require.Equal(t, nodeCountBeforeLogout, len(listNodes))
|
||||
require.Len(t, listNodes, nodeCountBeforeLogout)
|
||||
t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes))
|
||||
|
||||
for _, client := range allClients {
|
||||
|
@ -18,8 +18,8 @@ import (
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
"golang.org/x/exp/slices"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error {
|
||||
@ -30,7 +30,7 @@ func executeAndUnmarshal[T any](headscale ControlServer, command []string, resul
|
||||
|
||||
err = json.Unmarshal([]byte(str), result)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal: %s\n command err: %s", err, str)
|
||||
return fmt.Errorf("failed to unmarshal: %w\n command err: %s", err, str)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -48,7 +48,6 @@ func sortWithID[T GRPCSortable](a, b T) int {
|
||||
|
||||
func TestUserCommand(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
Users: []string{"user1", "user2"},
|
||||
@ -184,7 +183,7 @@ func TestUserCommand(t *testing.T) {
|
||||
"--identifier=1",
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, deleteResult, "User destroyed")
|
||||
|
||||
var listAfterIDDelete []*v1.User
|
||||
@ -222,7 +221,7 @@ func TestUserCommand(t *testing.T) {
|
||||
"--name=newname",
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, deleteResult, "User destroyed")
|
||||
|
||||
var listAfterNameDelete []v1.User
|
||||
@ -238,12 +237,11 @@ func TestUserCommand(t *testing.T) {
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
|
||||
require.Len(t, listAfterNameDelete, 0)
|
||||
require.Empty(t, listAfterNameDelete)
|
||||
}
|
||||
|
||||
func TestPreAuthKeyCommand(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
user := "preauthkeyspace"
|
||||
count := 3
|
||||
@ -347,7 +345,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
|
||||
assert.Equal(t, listedPreAuthKeys[index].GetAclTags(), []string{"tag:test1", "tag:test2"})
|
||||
assert.Equal(t, []string{"tag:test1", "tag:test2"}, listedPreAuthKeys[index].GetAclTags())
|
||||
}
|
||||
|
||||
// Test key expiry
|
||||
@ -386,7 +384,6 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||
|
||||
func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
user := "pre-auth-key-without-exp-user"
|
||||
spec := ScenarioSpec{
|
||||
@ -448,7 +445,6 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
|
||||
|
||||
func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
user := "pre-auth-key-reus-ephm-user"
|
||||
spec := ScenarioSpec{
|
||||
@ -524,7 +520,6 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
||||
|
||||
func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
user1 := "user1"
|
||||
user2 := "user2"
|
||||
@ -575,7 +570,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||
assertNoErr(t, err)
|
||||
|
||||
listNodes, err := headscale.ListNodes()
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, listNodes, 1)
|
||||
assert.Equal(t, user1, listNodes[0].GetUser().GetName())
|
||||
|
||||
@ -613,7 +608,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
require.Nil(t, err)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, listNodes, 2)
|
||||
assert.Equal(t, user1, listNodes[0].GetUser().GetName())
|
||||
assert.Equal(t, user2, listNodes[1].GetUser().GetName())
|
||||
@ -621,7 +616,6 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||
|
||||
func TestApiKeyCommand(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
count := 5
|
||||
|
||||
@ -653,7 +647,7 @@ func TestApiKeyCommand(t *testing.T) {
|
||||
"json",
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, apiResult)
|
||||
|
||||
keys[idx] = apiResult
|
||||
@ -672,7 +666,7 @@ func TestApiKeyCommand(t *testing.T) {
|
||||
},
|
||||
&listedAPIKeys,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, listedAPIKeys, 5)
|
||||
|
||||
@ -728,7 +722,7 @@ func TestApiKeyCommand(t *testing.T) {
|
||||
listedAPIKeys[idx].GetPrefix(),
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expiredPrefixes[listedAPIKeys[idx].GetPrefix()] = true
|
||||
}
|
||||
@ -744,7 +738,7 @@ func TestApiKeyCommand(t *testing.T) {
|
||||
},
|
||||
&listedAfterExpireAPIKeys,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for index := range listedAfterExpireAPIKeys {
|
||||
if _, ok := expiredPrefixes[listedAfterExpireAPIKeys[index].GetPrefix()]; ok {
|
||||
@ -770,7 +764,7 @@ func TestApiKeyCommand(t *testing.T) {
|
||||
"--prefix",
|
||||
listedAPIKeys[0].GetPrefix(),
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var listedAPIKeysAfterDelete []v1.ApiKey
|
||||
err = executeAndUnmarshal(headscale,
|
||||
@ -783,14 +777,13 @@ func TestApiKeyCommand(t *testing.T) {
|
||||
},
|
||||
&listedAPIKeysAfterDelete,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, listedAPIKeysAfterDelete, 4)
|
||||
}
|
||||
|
||||
func TestNodeTagCommand(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
Users: []string{"user1"},
|
||||
@ -811,7 +804,7 @@ func TestNodeTagCommand(t *testing.T) {
|
||||
types.MustRegistrationID().String(),
|
||||
}
|
||||
nodes := make([]*v1.Node, len(regIDs))
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for index, regID := range regIDs {
|
||||
_, err := headscale.Execute(
|
||||
@ -829,7 +822,7 @@ func TestNodeTagCommand(t *testing.T) {
|
||||
"json",
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var node v1.Node
|
||||
err = executeAndUnmarshal(
|
||||
@ -847,7 +840,7 @@ func TestNodeTagCommand(t *testing.T) {
|
||||
},
|
||||
&node,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes[index] = &node
|
||||
}
|
||||
@ -866,7 +859,7 @@ func TestNodeTagCommand(t *testing.T) {
|
||||
},
|
||||
&node,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []string{"tag:test"}, node.GetForcedTags())
|
||||
|
||||
@ -894,7 +887,7 @@ func TestNodeTagCommand(t *testing.T) {
|
||||
},
|
||||
&resultMachines,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
found := false
|
||||
for _, node := range resultMachines {
|
||||
if node.GetForcedTags() != nil {
|
||||
@ -905,19 +898,15 @@ func TestNodeTagCommand(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.Equal(
|
||||
assert.True(
|
||||
t,
|
||||
true,
|
||||
found,
|
||||
"should find a node with the tag 'tag:test' in the list of nodes",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
|
||||
func TestNodeAdvertiseTagCommand(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@ -1024,7 +1013,7 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
|
||||
},
|
||||
&resultMachines,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
found := false
|
||||
for _, node := range resultMachines {
|
||||
if tags := node.GetValidTags(); tags != nil {
|
||||
@ -1043,7 +1032,6 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
|
||||
|
||||
func TestNodeCommand(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
Users: []string{"node-user", "other-user"},
|
||||
@ -1067,7 +1055,7 @@ func TestNodeCommand(t *testing.T) {
|
||||
types.MustRegistrationID().String(),
|
||||
}
|
||||
nodes := make([]*v1.Node, len(regIDs))
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for index, regID := range regIDs {
|
||||
_, err := headscale.Execute(
|
||||
@ -1085,7 +1073,7 @@ func TestNodeCommand(t *testing.T) {
|
||||
"json",
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var node v1.Node
|
||||
err = executeAndUnmarshal(
|
||||
@ -1103,7 +1091,7 @@ func TestNodeCommand(t *testing.T) {
|
||||
},
|
||||
&node,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes[index] = &node
|
||||
}
|
||||
@ -1123,7 +1111,7 @@ func TestNodeCommand(t *testing.T) {
|
||||
},
|
||||
&listAll,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, listAll, 5)
|
||||
|
||||
@ -1144,7 +1132,7 @@ func TestNodeCommand(t *testing.T) {
|
||||
types.MustRegistrationID().String(),
|
||||
}
|
||||
otherUserMachines := make([]*v1.Node, len(otherUserRegIDs))
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for index, regID := range otherUserRegIDs {
|
||||
_, err := headscale.Execute(
|
||||
@ -1162,7 +1150,7 @@ func TestNodeCommand(t *testing.T) {
|
||||
"json",
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var node v1.Node
|
||||
err = executeAndUnmarshal(
|
||||
@ -1180,7 +1168,7 @@ func TestNodeCommand(t *testing.T) {
|
||||
},
|
||||
&node,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
otherUserMachines[index] = &node
|
||||
}
|
||||
@ -1200,7 +1188,7 @@ func TestNodeCommand(t *testing.T) {
|
||||
},
|
||||
&listAllWithotherUser,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// All nodes, nodes + otherUser
|
||||
assert.Len(t, listAllWithotherUser, 7)
|
||||
@ -1226,7 +1214,7 @@ func TestNodeCommand(t *testing.T) {
|
||||
},
|
||||
&listOnlyotherUserMachineUser,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, listOnlyotherUserMachineUser, 2)
|
||||
|
||||
@ -1258,7 +1246,7 @@ func TestNodeCommand(t *testing.T) {
|
||||
"--force",
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test: list main user after node is deleted
|
||||
var listOnlyMachineUserAfterDelete []v1.Node
|
||||
@ -1275,14 +1263,13 @@ func TestNodeCommand(t *testing.T) {
|
||||
},
|
||||
&listOnlyMachineUserAfterDelete,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, listOnlyMachineUserAfterDelete, 4)
|
||||
}
|
||||
|
||||
func TestNodeExpireCommand(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
Users: []string{"node-expire-user"},
|
||||
@ -1323,7 +1310,7 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||
"json",
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var node v1.Node
|
||||
err = executeAndUnmarshal(
|
||||
@ -1341,7 +1328,7 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||
},
|
||||
&node,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes[index] = &node
|
||||
}
|
||||
@ -1360,7 +1347,7 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||
},
|
||||
&listAll,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, listAll, 5)
|
||||
|
||||
@ -1377,10 +1364,10 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||
"nodes",
|
||||
"expire",
|
||||
"--identifier",
|
||||
fmt.Sprintf("%d", listAll[idx].GetId()),
|
||||
strconv.FormatUint(listAll[idx].GetId(), 10),
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
var listAllAfterExpiry []v1.Node
|
||||
@ -1395,7 +1382,7 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||
},
|
||||
&listAllAfterExpiry,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, listAllAfterExpiry, 5)
|
||||
|
||||
@ -1408,7 +1395,6 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||
|
||||
func TestNodeRenameCommand(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
Users: []string{"node-rename-command"},
|
||||
@ -1432,7 +1418,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||
types.MustRegistrationID().String(),
|
||||
}
|
||||
nodes := make([]*v1.Node, len(regIDs))
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for index, regID := range regIDs {
|
||||
_, err := headscale.Execute(
|
||||
@ -1487,7 +1473,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||
},
|
||||
&listAll,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, listAll, 5)
|
||||
|
||||
@ -1504,11 +1490,11 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||
"nodes",
|
||||
"rename",
|
||||
"--identifier",
|
||||
fmt.Sprintf("%d", listAll[idx].GetId()),
|
||||
strconv.FormatUint(listAll[idx].GetId(), 10),
|
||||
fmt.Sprintf("newnode-%d", idx+1),
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Contains(t, res, "Node renamed")
|
||||
}
|
||||
@ -1525,7 +1511,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||
},
|
||||
&listAllAfterRename,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, listAllAfterRename, 5)
|
||||
|
||||
@ -1542,7 +1528,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||
"nodes",
|
||||
"rename",
|
||||
"--identifier",
|
||||
fmt.Sprintf("%d", listAll[4].GetId()),
|
||||
strconv.FormatUint(listAll[4].GetId(), 10),
|
||||
strings.Repeat("t", 64),
|
||||
},
|
||||
)
|
||||
@ -1560,7 +1546,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||
},
|
||||
&listAllAfterRenameAttempt,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, listAllAfterRenameAttempt, 5)
|
||||
|
||||
@ -1573,7 +1559,6 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||
|
||||
func TestNodeMoveCommand(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
Users: []string{"old-user", "new-user"},
|
||||
@ -1610,7 +1595,7 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||
"json",
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var node v1.Node
|
||||
err = executeAndUnmarshal(
|
||||
@ -1628,13 +1613,13 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||
},
|
||||
&node,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, uint64(1), node.GetId())
|
||||
assert.Equal(t, "nomad-node", node.GetName())
|
||||
assert.Equal(t, node.GetUser().GetName(), "old-user")
|
||||
assert.Equal(t, "old-user", node.GetUser().GetName())
|
||||
|
||||
nodeID := fmt.Sprintf("%d", node.GetId())
|
||||
nodeID := strconv.FormatUint(node.GetId(), 10)
|
||||
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
@ -1651,9 +1636,9 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||
},
|
||||
&node,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, node.GetUser().GetName(), "new-user")
|
||||
assert.Equal(t, "new-user", node.GetUser().GetName())
|
||||
|
||||
var allNodes []v1.Node
|
||||
err = executeAndUnmarshal(
|
||||
@ -1667,13 +1652,13 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||
},
|
||||
&allNodes,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, allNodes, 1)
|
||||
|
||||
assert.Equal(t, allNodes[0].GetId(), node.GetId())
|
||||
assert.Equal(t, allNodes[0].GetUser(), node.GetUser())
|
||||
assert.Equal(t, allNodes[0].GetUser().GetName(), "new-user")
|
||||
assert.Equal(t, "new-user", allNodes[0].GetUser().GetName())
|
||||
|
||||
_, err = headscale.Execute(
|
||||
[]string{
|
||||
@ -1693,7 +1678,7 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||
err,
|
||||
"user not found",
|
||||
)
|
||||
assert.Equal(t, node.GetUser().GetName(), "new-user")
|
||||
assert.Equal(t, "new-user", node.GetUser().GetName())
|
||||
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
@ -1710,9 +1695,9 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||
},
|
||||
&node,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, node.GetUser().GetName(), "old-user")
|
||||
assert.Equal(t, "old-user", node.GetUser().GetName())
|
||||
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
@ -1729,14 +1714,13 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||
},
|
||||
&node,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, node.GetUser().GetName(), "old-user")
|
||||
assert.Equal(t, "old-user", node.GetUser().GetName())
|
||||
}
|
||||
|
||||
func TestPolicyCommand(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
Users: []string{"user1"},
|
||||
@ -1817,7 +1801,6 @@ func TestPolicyCommand(t *testing.T) {
|
||||
|
||||
func TestPolicyBrokenConfigCommand(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: 1,
|
||||
|
@ -1,7 +1,6 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
@ -104,7 +103,7 @@ func DERPVerify(
|
||||
defer c.Close()
|
||||
|
||||
var result error
|
||||
if err := c.Connect(context.Background()); err != nil {
|
||||
if err := c.Connect(t.Context()); err != nil {
|
||||
result = fmt.Errorf("client Connect: %w", err)
|
||||
}
|
||||
if m, err := c.Recv(); err != nil {
|
||||
|
@ -15,7 +15,6 @@ import (
|
||||
|
||||
func TestResolveMagicDNS(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: len(MustTestVersions),
|
||||
@ -49,7 +48,7 @@ func TestResolveMagicDNS(t *testing.T) {
|
||||
// It is safe to ignore this error as we handled it when caching it
|
||||
peerFQDN, _ := peer.FQDN()
|
||||
|
||||
assert.Equal(t, fmt.Sprintf("%s.headscale.net.", peer.Hostname()), peerFQDN)
|
||||
assert.Equal(t, peer.Hostname()+".headscale.net.", peerFQDN)
|
||||
|
||||
command := []string{
|
||||
"tailscale",
|
||||
@ -85,7 +84,6 @@ func TestResolveMagicDNS(t *testing.T) {
|
||||
|
||||
func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: 1,
|
||||
@ -222,12 +220,14 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
|
||||
_, err = hs.Execute([]string{"rm", erPath})
|
||||
assertNoErr(t, err)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// The same paths should still be available as it is not cleared on delete.
|
||||
for _, client := range allClients {
|
||||
assertCommandOutputContains(t, client, []string{"dig", "docker.myvpn.example.com"}, "9.9.9.9")
|
||||
}
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
for _, client := range allClients {
|
||||
result, _, err := client.Execute([]string{"dig", "docker.myvpn.example.com"})
|
||||
assert.NoError(ct, err)
|
||||
assert.Contains(ct, result, "9.9.9.9")
|
||||
}
|
||||
}, 10*time.Second, 1*time.Second)
|
||||
|
||||
// Write a new file, the backoff mechanism should make the filewatcher pick it up
|
||||
// again.
|
||||
|
@ -33,26 +33,27 @@ func DockerAddIntegrationLabels(opts *dockertest.RunOptions, testType string) {
|
||||
}
|
||||
|
||||
// GenerateRunID creates a unique run identifier with timestamp and random hash.
|
||||
// Format: YYYYMMDD-HHMMSS-HASH (e.g., 20250619-143052-a1b2c3)
|
||||
// Format: YYYYMMDD-HHMMSS-HASH (e.g., 20250619-143052-a1b2c3).
|
||||
func GenerateRunID() string {
|
||||
now := time.Now()
|
||||
timestamp := now.Format("20060102-150405")
|
||||
|
||||
|
||||
// Add a short random hash to ensure uniqueness
|
||||
randomHash := util.MustGenerateRandomStringDNSSafe(6)
|
||||
|
||||
return fmt.Sprintf("%s-%s", timestamp, randomHash)
|
||||
}
|
||||
|
||||
// ExtractRunIDFromContainerName extracts the run ID from container name.
|
||||
// Expects format: "prefix-YYYYMMDD-HHMMSS-HASH"
|
||||
// Expects format: "prefix-YYYYMMDD-HHMMSS-HASH".
|
||||
func ExtractRunIDFromContainerName(containerName string) string {
|
||||
parts := strings.Split(containerName, "-")
|
||||
if len(parts) >= 3 {
|
||||
// Return the last three parts as the run ID (YYYYMMDD-HHMMSS-HASH)
|
||||
return strings.Join(parts[len(parts)-3:], "-")
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("unexpected container name format: %s", containerName))
|
||||
|
||||
panic("unexpected container name format: " + containerName)
|
||||
}
|
||||
|
||||
// IsRunningInContainer checks if the current process is running inside a Docker container.
|
||||
@ -62,4 +63,4 @@ func IsRunningInContainer() bool {
|
||||
// This could be improved with more robust detection if needed
|
||||
_, err := os.Stat("/.dockerenv")
|
||||
return err == nil
|
||||
}
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ func ExecuteCommandTimeout(timeout time.Duration) ExecuteCommandOption {
|
||||
})
|
||||
}
|
||||
|
||||
// buffer is a goroutine safe bytes.buffer
|
||||
// buffer is a goroutine safe bytes.buffer.
|
||||
type buffer struct {
|
||||
store bytes.Buffer
|
||||
mutex sync.Mutex
|
||||
@ -58,8 +58,8 @@ func ExecuteCommand(
|
||||
env []string,
|
||||
options ...ExecuteCommandOption,
|
||||
) (string, string, error) {
|
||||
var stdout = buffer{}
|
||||
var stderr = buffer{}
|
||||
stdout := buffer{}
|
||||
stderr := buffer{}
|
||||
|
||||
execConfig := ExecuteCommandConfig{
|
||||
timeout: dockerExecuteTimeout,
|
||||
|
@ -159,7 +159,6 @@ func New(
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if dsic.workdir != "" {
|
||||
runOptions.WorkingDir = dsic.workdir
|
||||
}
|
||||
@ -192,7 +191,7 @@ func New(
|
||||
}
|
||||
// Add integration test labels if running under hi tool
|
||||
dockertestutil.DockerAddIntegrationLabels(runOptions, "derp")
|
||||
|
||||
|
||||
container, err = pool.BuildAndRunWithBuildOptions(
|
||||
buildOptions,
|
||||
runOptions,
|
||||
|
@ -2,13 +2,13 @@ package integration
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
type ClientsSpec struct {
|
||||
@ -71,9 +71,9 @@ func TestDERPServerWebsocketScenario(t *testing.T) {
|
||||
NodesPerUser: 1,
|
||||
Users: []string{"user1", "user2", "user3"},
|
||||
Networks: map[string][]string{
|
||||
"usernet1": []string{"user1"},
|
||||
"usernet2": []string{"user2"},
|
||||
"usernet3": []string{"user3"},
|
||||
"usernet1": {"user1"},
|
||||
"usernet2": {"user2"},
|
||||
"usernet3": {"user3"},
|
||||
},
|
||||
}
|
||||
|
||||
@ -106,7 +106,6 @@ func derpServerScenario(
|
||||
furtherAssertions ...func(*Scenario),
|
||||
) {
|
||||
IntegrationSkip(t)
|
||||
// t.Parallel()
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
assertNoErr(t, err)
|
||||
|
@ -26,7 +26,6 @@ import (
|
||||
|
||||
func TestPingAllByIP(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: len(MustTestVersions),
|
||||
@ -68,7 +67,6 @@ func TestPingAllByIP(t *testing.T) {
|
||||
|
||||
func TestPingAllByIPPublicDERP(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: len(MustTestVersions),
|
||||
@ -118,7 +116,6 @@ func TestEphemeralInAlternateTimezone(t *testing.T) {
|
||||
|
||||
func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: len(MustTestVersions),
|
||||
@ -191,7 +188,6 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
|
||||
// deleted by accident if they are still online and active.
|
||||
func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: len(MustTestVersions),
|
||||
@ -260,18 +256,21 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
|
||||
// Wait a bit and bring up the clients again before the expiry
|
||||
// time of the ephemeral nodes.
|
||||
// Nodes should be able to reconnect and work fine.
|
||||
time.Sleep(30 * time.Second)
|
||||
|
||||
for _, client := range allClients {
|
||||
err := client.Up()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to take down client %s: %s", client.Hostname(), err)
|
||||
}
|
||||
}
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
success = pingAllHelper(t, allClients, allAddrs)
|
||||
// Wait for clients to sync and be able to ping each other after reconnection
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assert.NoError(ct, err)
|
||||
|
||||
success = pingAllHelper(t, allClients, allAddrs)
|
||||
assert.Greater(ct, success, 0, "Ephemeral nodes should be able to reconnect and ping")
|
||||
}, 60*time.Second, 2*time.Second)
|
||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
|
||||
// Take down all clients, this should start an expiry timer for each.
|
||||
@ -284,7 +283,13 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
|
||||
|
||||
// This time wait for all of the nodes to expire and check that they are no longer
|
||||
// registered.
|
||||
time.Sleep(3 * time.Minute)
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
for _, userName := range spec.Users {
|
||||
nodes, err := headscale.ListNodes(userName)
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, nodes, 0, "Ephemeral nodes should be expired and removed for user %s", userName)
|
||||
}
|
||||
}, 4*time.Minute, 10*time.Second)
|
||||
|
||||
for _, userName := range spec.Users {
|
||||
nodes, err := headscale.ListNodes(userName)
|
||||
@ -305,7 +310,6 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
|
||||
|
||||
func TestPingAllByHostname(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: len(MustTestVersions),
|
||||
@ -341,20 +345,6 @@ func TestPingAllByHostname(t *testing.T) {
|
||||
// nolint:tparallel
|
||||
func TestTaildrop(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
retry := func(times int, sleepInterval time.Duration, doWork func() error) error {
|
||||
var err error
|
||||
for range times {
|
||||
err = doWork()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(sleepInterval)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: len(MustTestVersions),
|
||||
@ -396,40 +386,27 @@ func TestTaildrop(t *testing.T) {
|
||||
"/var/run/tailscale/tailscaled.sock",
|
||||
"http://local-tailscaled.sock/localapi/v0/file-targets",
|
||||
}
|
||||
err = retry(10, 1*time.Second, func() error {
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
result, _, err := client.Execute(curlCommand)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
assert.NoError(ct, err)
|
||||
|
||||
var fts []apitype.FileTarget
|
||||
err = json.Unmarshal([]byte(result), &fts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
assert.NoError(ct, err)
|
||||
|
||||
if len(fts) != len(allClients)-1 {
|
||||
ftStr := fmt.Sprintf("FileTargets for %s:\n", client.Hostname())
|
||||
for _, ft := range fts {
|
||||
ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name)
|
||||
}
|
||||
return fmt.Errorf(
|
||||
"client %s does not have all its peers as FileTargets, got %d, want: %d\n%s",
|
||||
client.Hostname(),
|
||||
assert.Failf(ct, "client %s does not have all its peers as FileTargets",
|
||||
"got %d, want: %d\n%s",
|
||||
len(fts),
|
||||
len(allClients)-1,
|
||||
ftStr,
|
||||
)
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf(
|
||||
"failed to query localapi for filetarget on %s, err: %s",
|
||||
client.Hostname(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
}, 10*time.Second, 1*time.Second)
|
||||
}
|
||||
|
||||
for _, client := range allClients {
|
||||
@ -454,24 +431,15 @@ func TestTaildrop(t *testing.T) {
|
||||
fmt.Sprintf("%s:", peerFQDN),
|
||||
}
|
||||
|
||||
err := retry(10, 1*time.Second, func() error {
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
t.Logf(
|
||||
"Sending file from %s to %s\n",
|
||||
client.Hostname(),
|
||||
peer.Hostname(),
|
||||
)
|
||||
_, _, err := client.Execute(command)
|
||||
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf(
|
||||
"failed to send taildrop file on %s with command %q, err: %s",
|
||||
client.Hostname(),
|
||||
strings.Join(command, " "),
|
||||
err,
|
||||
)
|
||||
}
|
||||
assert.NoError(ct, err)
|
||||
}, 10*time.Second, 1*time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -520,7 +488,6 @@ func TestTaildrop(t *testing.T) {
|
||||
|
||||
func TestUpdateHostnameFromClient(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
hostnames := map[string]string{
|
||||
"1": "user1-host",
|
||||
@ -603,9 +570,47 @@ func TestUpdateHostnameFromClient(t *testing.T) {
|
||||
assertNoErr(t, err)
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Verify that the server-side rename is reflected in DNSName while HostName remains unchanged
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
// Build a map of expected DNSNames by node ID
|
||||
expectedDNSNames := make(map[string]string)
|
||||
for _, node := range nodes {
|
||||
nodeID := strconv.FormatUint(node.GetId(), 10)
|
||||
expectedDNSNames[nodeID] = fmt.Sprintf("%d-givenname.headscale.net.", node.GetId())
|
||||
}
|
||||
|
||||
// Verify from each client's perspective
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assert.NoError(ct, err)
|
||||
|
||||
// Check self node
|
||||
selfID := string(status.Self.ID)
|
||||
expectedDNS := expectedDNSNames[selfID]
|
||||
assert.Equal(ct, expectedDNS, status.Self.DNSName,
|
||||
"Self DNSName should be renamed for client %s (ID: %s)", client.Hostname(), selfID)
|
||||
|
||||
// HostName should remain as the original client-reported hostname
|
||||
originalHostname := hostnames[selfID]
|
||||
assert.Equal(ct, originalHostname, status.Self.HostName,
|
||||
"Self HostName should remain unchanged for client %s (ID: %s)", client.Hostname(), selfID)
|
||||
|
||||
// Check peers
|
||||
for _, peer := range status.Peer {
|
||||
peerID := string(peer.ID)
|
||||
if expectedDNS, ok := expectedDNSNames[peerID]; ok {
|
||||
assert.Equal(ct, expectedDNS, peer.DNSName,
|
||||
"Peer DNSName should be renamed for peer ID %s as seen by client %s", peerID, client.Hostname())
|
||||
|
||||
// HostName should remain as the original client-reported hostname
|
||||
originalHostname := hostnames[peerID]
|
||||
assert.Equal(ct, originalHostname, peer.HostName,
|
||||
"Peer HostName should remain unchanged for peer ID %s as seen by client %s", peerID, client.Hostname())
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 60*time.Second, 2*time.Second)
|
||||
|
||||
// Verify that the clients can see the new hostname, but no givenName
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assertNoErr(t, err)
|
||||
@ -647,7 +652,6 @@ func TestUpdateHostnameFromClient(t *testing.T) {
|
||||
|
||||
func TestExpireNode(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: len(MustTestVersions),
|
||||
@ -707,7 +711,23 @@ func TestExpireNode(t *testing.T) {
|
||||
|
||||
t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String())
|
||||
|
||||
time.Sleep(2 * time.Minute)
|
||||
// Verify that the expired node has been marked in all peers list.
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assert.NoError(ct, err)
|
||||
|
||||
if client.Hostname() != node.GetName() {
|
||||
// Check if the expired node appears as expired in this client's peer list
|
||||
for key, peer := range status.Peer {
|
||||
if key == expiredNodeKey {
|
||||
assert.True(ct, peer.Expired, "Node should be marked as expired for client %s", client.Hostname())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 3*time.Minute, 10*time.Second)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
@ -774,7 +794,6 @@ func TestExpireNode(t *testing.T) {
|
||||
|
||||
func TestNodeOnlineStatus(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: len(MustTestVersions),
|
||||
@ -890,7 +909,6 @@ func TestNodeOnlineStatus(t *testing.T) {
|
||||
// five times ensuring they are able to restablish connectivity.
|
||||
func TestPingAllByIPManyUpDown(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: len(MustTestVersions),
|
||||
@ -944,8 +962,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
|
||||
t.Fatalf("failed to take down all nodes: %s", err)
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
for _, client := range allClients {
|
||||
c := client
|
||||
wg.Go(func() error {
|
||||
@ -958,10 +974,14 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
|
||||
t.Fatalf("failed to take down all nodes: %s", err)
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for sync and successful pings after nodes come back up
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assert.NoError(ct, err)
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
success := pingAllHelper(t, allClients, allAddrs)
|
||||
assert.Greater(ct, success, 0, "Nodes should be able to ping after coming back up")
|
||||
}, 30*time.Second, 2*time.Second)
|
||||
|
||||
success := pingAllHelper(t, allClients, allAddrs)
|
||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
@ -970,7 +990,6 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
|
||||
|
||||
func Test2118DeletingOnlineNodePanics(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: 1,
|
||||
@ -1042,10 +1061,24 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Ensure that the node has been deleted, this did not occur due to a panic.
|
||||
var nodeListAfter []v1.Node
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"nodes",
|
||||
"list",
|
||||
"--output",
|
||||
"json",
|
||||
},
|
||||
&nodeListAfter,
|
||||
)
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, nodeListAfter, 1, "Node should be deleted from list")
|
||||
}, 10*time.Second, 1*time.Second)
|
||||
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
|
@ -191,7 +191,7 @@ func WithPostgres() Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithPolicy sets the policy mode for headscale
|
||||
// WithPolicy sets the policy mode for headscale.
|
||||
func WithPolicyMode(mode types.PolicyMode) Option {
|
||||
return func(hsic *HeadscaleInContainer) {
|
||||
hsic.policyMode = mode
|
||||
@ -279,7 +279,7 @@ func New(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hostname := fmt.Sprintf("hs-%s", hash)
|
||||
hostname := "hs-" + hash
|
||||
|
||||
hsic := &HeadscaleInContainer{
|
||||
hostname: hostname,
|
||||
@ -308,14 +308,14 @@ func New(
|
||||
|
||||
if hsic.postgres {
|
||||
hsic.env["HEADSCALE_DATABASE_TYPE"] = "postgres"
|
||||
hsic.env["HEADSCALE_DATABASE_POSTGRES_HOST"] = fmt.Sprintf("postgres-%s", hash)
|
||||
hsic.env["HEADSCALE_DATABASE_POSTGRES_HOST"] = "postgres-" + hash
|
||||
hsic.env["HEADSCALE_DATABASE_POSTGRES_USER"] = "headscale"
|
||||
hsic.env["HEADSCALE_DATABASE_POSTGRES_PASS"] = "headscale"
|
||||
hsic.env["HEADSCALE_DATABASE_POSTGRES_NAME"] = "headscale"
|
||||
delete(hsic.env, "HEADSCALE_DATABASE_SQLITE_PATH")
|
||||
|
||||
pgRunOptions := &dockertest.RunOptions{
|
||||
Name: fmt.Sprintf("postgres-%s", hash),
|
||||
Name: "postgres-" + hash,
|
||||
Repository: "postgres",
|
||||
Tag: "latest",
|
||||
Networks: networks,
|
||||
@ -328,7 +328,7 @@ func New(
|
||||
|
||||
// Add integration test labels if running under hi tool
|
||||
dockertestutil.DockerAddIntegrationLabels(pgRunOptions, "postgres")
|
||||
|
||||
|
||||
pg, err := pool.RunWithOptions(pgRunOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("starting postgres container: %w", err)
|
||||
@ -373,7 +373,6 @@ func New(
|
||||
Env: env,
|
||||
}
|
||||
|
||||
|
||||
if len(hsic.hostPortBindings) > 0 {
|
||||
runOptions.PortBindings = map[docker.Port][]docker.PortBinding{}
|
||||
for port, hostPorts := range hsic.hostPortBindings {
|
||||
@ -396,7 +395,7 @@ func New(
|
||||
|
||||
// Add integration test labels if running under hi tool
|
||||
dockertestutil.DockerAddIntegrationLabels(runOptions, "headscale")
|
||||
|
||||
|
||||
container, err := pool.BuildAndRunWithBuildOptions(
|
||||
headscaleBuildOptions,
|
||||
runOptions,
|
||||
@ -566,7 +565,7 @@ func (t *HeadscaleInContainer) SaveMetrics(savePath string) error {
|
||||
|
||||
// extractTarToDirectory extracts a tar archive to a directory.
|
||||
func extractTarToDirectory(tarData []byte, targetDir string) error {
|
||||
if err := os.MkdirAll(targetDir, 0755); err != nil {
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
return fmt.Errorf("failed to create directory %s: %w", targetDir, err)
|
||||
}
|
||||
|
||||
@ -624,6 +623,7 @@ func (t *HeadscaleInContainer) SaveProfile(savePath string) error {
|
||||
}
|
||||
|
||||
targetDir := path.Join(savePath, t.hostname+"-pprof")
|
||||
|
||||
return extractTarToDirectory(tarFile, targetDir)
|
||||
}
|
||||
|
||||
@ -634,6 +634,7 @@ func (t *HeadscaleInContainer) SaveMapResponses(savePath string) error {
|
||||
}
|
||||
|
||||
targetDir := path.Join(savePath, t.hostname+"-mapresponses")
|
||||
|
||||
return extractTarToDirectory(tarFile, targetDir)
|
||||
}
|
||||
|
||||
@ -672,17 +673,16 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check database schema (sqlite3 command failed): %w", err)
|
||||
}
|
||||
|
||||
|
||||
if strings.TrimSpace(schemaCheck) == "" {
|
||||
return fmt.Errorf("database file exists but has no schema (empty database)")
|
||||
return errors.New("database file exists but has no schema (empty database)")
|
||||
}
|
||||
|
||||
|
||||
// Show a preview of the schema (first 500 chars)
|
||||
schemaPreview := schemaCheck
|
||||
if len(schemaPreview) > 500 {
|
||||
schemaPreview = schemaPreview[:500] + "..."
|
||||
}
|
||||
log.Printf("Database schema preview:\n%s", schemaPreview)
|
||||
|
||||
tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3")
|
||||
if err != nil {
|
||||
@ -727,7 +727,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("no regular file found in database tar archive")
|
||||
return errors.New("no regular file found in database tar archive")
|
||||
}
|
||||
|
||||
// Execute runs a command inside the Headscale container and returns the
|
||||
@ -756,13 +756,13 @@ func (t *HeadscaleInContainer) Execute(
|
||||
|
||||
// GetPort returns the docker container port as a string.
|
||||
func (t *HeadscaleInContainer) GetPort() string {
|
||||
return fmt.Sprintf("%d", t.port)
|
||||
return strconv.Itoa(t.port)
|
||||
}
|
||||
|
||||
// GetHealthEndpoint returns a health endpoint for the HeadscaleInContainer
|
||||
// instance.
|
||||
func (t *HeadscaleInContainer) GetHealthEndpoint() string {
|
||||
return fmt.Sprintf("%s/health", t.GetEndpoint())
|
||||
return t.GetEndpoint() + "/health"
|
||||
}
|
||||
|
||||
// GetEndpoint returns the Headscale endpoint for the HeadscaleInContainer.
|
||||
@ -772,10 +772,10 @@ func (t *HeadscaleInContainer) GetEndpoint() string {
|
||||
t.port)
|
||||
|
||||
if t.hasTLS() {
|
||||
return fmt.Sprintf("https://%s", hostEndpoint)
|
||||
return "https://" + hostEndpoint
|
||||
}
|
||||
|
||||
return fmt.Sprintf("http://%s", hostEndpoint)
|
||||
return "http://" + hostEndpoint
|
||||
}
|
||||
|
||||
// GetCert returns the public certificate of the HeadscaleInContainer.
|
||||
@ -910,6 +910,7 @@ func (t *HeadscaleInContainer) ListNodes(
|
||||
}
|
||||
|
||||
ret = append(ret, nodes...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -932,6 +933,7 @@ func (t *HeadscaleInContainer) ListNodes(
|
||||
sort.Slice(ret, func(i, j int) bool {
|
||||
return cmp.Compare(ret[i].GetId(), ret[j].GetId()) == -1
|
||||
})
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
@ -943,10 +945,10 @@ func (t *HeadscaleInContainer) NodesByUser() (map[string][]*v1.Node, error) {
|
||||
|
||||
var userMap map[string][]*v1.Node
|
||||
for _, node := range nodes {
|
||||
if _, ok := userMap[node.User.Name]; !ok {
|
||||
mak.Set(&userMap, node.User.Name, []*v1.Node{node})
|
||||
if _, ok := userMap[node.GetUser().GetName()]; !ok {
|
||||
mak.Set(&userMap, node.GetUser().GetName(), []*v1.Node{node})
|
||||
} else {
|
||||
userMap[node.User.Name] = append(userMap[node.User.Name], node)
|
||||
userMap[node.GetUser().GetName()] = append(userMap[node.GetUser().GetName()], node)
|
||||
}
|
||||
}
|
||||
|
||||
@ -999,7 +1001,7 @@ func (t *HeadscaleInContainer) MapUsers() (map[string]*v1.User, error) {
|
||||
|
||||
var userMap map[string]*v1.User
|
||||
for _, user := range users {
|
||||
mak.Set(&userMap, user.Name, user)
|
||||
mak.Set(&userMap, user.GetName(), user)
|
||||
}
|
||||
|
||||
return userMap, nil
|
||||
@ -1095,7 +1097,7 @@ func (h *HeadscaleInContainer) PID() (int, error) {
|
||||
case 1:
|
||||
return pids[0], nil
|
||||
default:
|
||||
return 0, fmt.Errorf("multiple headscale processes running")
|
||||
return 0, errors.New("multiple headscale processes running")
|
||||
}
|
||||
}
|
||||
|
||||
@ -1121,7 +1123,7 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) (
|
||||
"headscale", "nodes", "approve-routes",
|
||||
"--output", "json",
|
||||
"--identifier", strconv.FormatUint(id, 10),
|
||||
fmt.Sprintf("--routes=%s", strings.Join(util.PrefixesToString(routes), ",")),
|
||||
"--routes=" + strings.Join(util.PrefixesToString(routes), ","),
|
||||
}
|
||||
|
||||
result, _, err := dockertestutil.ExecuteCommand(
|
||||
|
@ -4,13 +4,12 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"slices"
|
||||
|
||||
cmpdiff "github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
@ -37,7 +36,6 @@ var allPorts = filter.PortRange{First: 0, Last: 0xffff}
|
||||
// routes.
|
||||
func TestEnablingRoutes(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: 3,
|
||||
@ -182,11 +180,12 @@ func TestEnablingRoutes(t *testing.T) {
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
if peerStatus.ID == "1" {
|
||||
switch peerStatus.ID {
|
||||
case "1":
|
||||
requirePeerSubnetRoutes(t, peerStatus, nil)
|
||||
} else if peerStatus.ID == "2" {
|
||||
case "2":
|
||||
requirePeerSubnetRoutes(t, peerStatus, nil)
|
||||
} else {
|
||||
default:
|
||||
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix("10.0.2.0/24")})
|
||||
}
|
||||
}
|
||||
@ -195,7 +194,6 @@ func TestEnablingRoutes(t *testing.T) {
|
||||
|
||||
func TestHASubnetRouterFailover(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: 3,
|
||||
@ -779,7 +777,6 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
// https://github.com/juanfont/headscale/issues/1604
|
||||
func TestSubnetRouteACL(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
user := "user4"
|
||||
|
||||
@ -1003,7 +1000,6 @@ func TestSubnetRouteACL(t *testing.T) {
|
||||
// set during login instead of set.
|
||||
func TestEnablingExitRoutes(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
user := "user2"
|
||||
|
||||
@ -1097,7 +1093,6 @@ func TestEnablingExitRoutes(t *testing.T) {
|
||||
// subnet router is working as expected.
|
||||
func TestSubnetRouterMultiNetwork(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: 1,
|
||||
@ -1177,7 +1172,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
|
||||
|
||||
// Enable route
|
||||
_, err = headscale.ApproveRoutes(
|
||||
nodes[0].Id,
|
||||
nodes[0].GetId(),
|
||||
[]netip.Prefix{*pref},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@ -1224,7 +1219,6 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
|
||||
|
||||
func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
spec := ScenarioSpec{
|
||||
NodesPerUser: 1,
|
||||
@ -1300,7 +1294,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
|
||||
}
|
||||
|
||||
// Enable route
|
||||
_, err = headscale.ApproveRoutes(nodes[0].Id, []netip.Prefix{tsaddr.AllIPv4()})
|
||||
_, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()})
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
@ -1719,7 +1713,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
pak, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false)
|
||||
assertNoErr(t, err)
|
||||
|
||||
err = routerUsernet1.Login(headscale.GetEndpoint(), pak.Key)
|
||||
err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey())
|
||||
assertNoErr(t, err)
|
||||
}
|
||||
// extra creation end.
|
||||
@ -2065,7 +2059,6 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub
|
||||
// that are explicitly allowed in the ACL.
|
||||
func TestSubnetRouteACLFiltering(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
// Use router and node users for better clarity
|
||||
routerUser := "router"
|
||||
@ -2090,7 +2083,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
// Set up the ACL policy that allows the node to access only one of the subnet routes (10.10.10.0/24)
|
||||
aclPolicyStr := fmt.Sprintf(`{
|
||||
aclPolicyStr := `{
|
||||
"hosts": {
|
||||
"router": "100.64.0.1/32",
|
||||
"node": "100.64.0.2/32"
|
||||
@ -2115,7 +2108,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
}`
|
||||
|
||||
route, err := scenario.SubnetOfNetwork("usernet1")
|
||||
require.NoError(t, err)
|
||||
|
@ -123,7 +123,7 @@ type ScenarioSpec struct {
|
||||
// NodesPerUser is how many nodes should be attached to each user.
|
||||
NodesPerUser int
|
||||
|
||||
// Networks, if set, is the seperate Docker networks that should be
|
||||
// Networks, if set, is the separate Docker networks that should be
|
||||
// created and a list of the users that should be placed in those networks.
|
||||
// If not set, a single network will be created and all users+nodes will be
|
||||
// added there.
|
||||
@ -1077,7 +1077,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
|
||||
|
||||
hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
|
||||
|
||||
hostname := fmt.Sprintf("hs-oidcmock-%s", hash)
|
||||
hostname := "hs-oidcmock-" + hash
|
||||
|
||||
usersJSON, err := json.Marshal(users)
|
||||
if err != nil {
|
||||
@ -1093,16 +1093,15 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
|
||||
},
|
||||
Networks: s.Networks(),
|
||||
Env: []string{
|
||||
fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname),
|
||||
"MOCKOIDC_ADDR=" + hostname,
|
||||
fmt.Sprintf("MOCKOIDC_PORT=%d", port),
|
||||
"MOCKOIDC_CLIENT_ID=superclient",
|
||||
"MOCKOIDC_CLIENT_SECRET=supersecret",
|
||||
fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()),
|
||||
fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)),
|
||||
"MOCKOIDC_ACCESS_TTL=" + accessTTL.String(),
|
||||
"MOCKOIDC_USERS=" + string(usersJSON),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
headscaleBuildOptions := &dockertest.BuildOptions{
|
||||
Dockerfile: hsic.IntegrationTestDockerFileName,
|
||||
ContextDir: dockerContextPath,
|
||||
@ -1117,7 +1116,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
|
||||
|
||||
// Add integration test labels if running under hi tool
|
||||
dockertestutil.DockerAddIntegrationLabels(mockOidcOptions, "oidc")
|
||||
|
||||
|
||||
if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions(
|
||||
headscaleBuildOptions,
|
||||
mockOidcOptions,
|
||||
@ -1184,7 +1183,7 @@ func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) {
|
||||
|
||||
hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
|
||||
|
||||
hostname := fmt.Sprintf("hs-webservice-%s", hash)
|
||||
hostname := "hs-webservice-" + hash
|
||||
|
||||
network, ok := s.networks[s.prefixedNetworkName(networkName)]
|
||||
if !ok {
|
||||
|
@ -28,7 +28,6 @@ func IntegrationSkip(t *testing.T) {
|
||||
// nolint:tparallel
|
||||
func TestHeadscale(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
var err error
|
||||
|
||||
@ -75,7 +74,6 @@ func TestHeadscale(t *testing.T) {
|
||||
// nolint:tparallel
|
||||
func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
var err error
|
||||
|
||||
|
@ -22,35 +22,6 @@ func isSSHNoAccessStdError(stderr string) bool {
|
||||
strings.Contains(stderr, "tailnet policy does not permit you to SSH to this node")
|
||||
}
|
||||
|
||||
var retry = func(times int, sleepInterval time.Duration,
|
||||
doWork func() (string, string, error),
|
||||
) (string, string, error) {
|
||||
var result string
|
||||
var stderr string
|
||||
var err error
|
||||
|
||||
for range times {
|
||||
tempResult, tempStderr, err := doWork()
|
||||
|
||||
result += tempResult
|
||||
stderr += tempStderr
|
||||
|
||||
if err == nil {
|
||||
return result, stderr, nil
|
||||
}
|
||||
|
||||
// If we get a permission denied error, we can fail immediately
|
||||
// since that is something we won-t recover from by retrying.
|
||||
if err != nil && isSSHNoAccessStdError(stderr) {
|
||||
return result, stderr, err
|
||||
}
|
||||
|
||||
time.Sleep(sleepInterval)
|
||||
}
|
||||
|
||||
return result, stderr, err
|
||||
}
|
||||
|
||||
func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Scenario {
|
||||
t.Helper()
|
||||
|
||||
@ -92,7 +63,6 @@ func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Sce
|
||||
|
||||
func TestSSHOneUserToAll(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
scenario := sshScenario(t,
|
||||
&policyv2.Policy{
|
||||
@ -160,7 +130,6 @@ func TestSSHOneUserToAll(t *testing.T) {
|
||||
|
||||
func TestSSHMultipleUsersAllToAll(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
scenario := sshScenario(t,
|
||||
&policyv2.Policy{
|
||||
@ -216,7 +185,6 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) {
|
||||
|
||||
func TestSSHNoSSHConfigured(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
scenario := sshScenario(t,
|
||||
&policyv2.Policy{
|
||||
@ -261,7 +229,6 @@ func TestSSHNoSSHConfigured(t *testing.T) {
|
||||
|
||||
func TestSSHIsBlockedInACL(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
scenario := sshScenario(t,
|
||||
&policyv2.Policy{
|
||||
@ -313,7 +280,6 @@ func TestSSHIsBlockedInACL(t *testing.T) {
|
||||
|
||||
func TestSSHUserOnlyIsolation(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
scenario := sshScenario(t,
|
||||
&policyv2.Policy{
|
||||
@ -404,6 +370,14 @@ func TestSSHUserOnlyIsolation(t *testing.T) {
|
||||
}
|
||||
|
||||
func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
|
||||
return doSSHWithRetry(t, client, peer, true)
|
||||
}
|
||||
|
||||
func doSSHWithoutRetry(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
|
||||
return doSSHWithRetry(t, client, peer, false)
|
||||
}
|
||||
|
||||
func doSSHWithRetry(t *testing.T, client TailscaleClient, peer TailscaleClient, retry bool) (string, string, error) {
|
||||
t.Helper()
|
||||
|
||||
peerFQDN, _ := peer.FQDN()
|
||||
@ -417,9 +391,29 @@ func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string,
|
||||
log.Printf("Running from %s to %s", client.Hostname(), peer.Hostname())
|
||||
log.Printf("Command: %s", strings.Join(command, " "))
|
||||
|
||||
return retry(10, 1*time.Second, func() (string, string, error) {
|
||||
return client.Execute(command)
|
||||
})
|
||||
var result, stderr string
|
||||
var err error
|
||||
|
||||
if retry {
|
||||
// Use assert.EventuallyWithT to retry SSH connections for success cases
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
result, stderr, err = client.Execute(command)
|
||||
|
||||
// If we get a permission denied error, we can fail immediately
|
||||
// since that is something we won't recover from by retrying.
|
||||
if err != nil && isSSHNoAccessStdError(stderr) {
|
||||
return // Don't retry permission denied errors
|
||||
}
|
||||
|
||||
// For all other errors, assert no error to trigger retry
|
||||
assert.NoError(ct, err)
|
||||
}, 10*time.Second, 1*time.Second)
|
||||
} else {
|
||||
// For failure cases, just execute once
|
||||
result, stderr, err = client.Execute(command)
|
||||
}
|
||||
|
||||
return result, stderr, err
|
||||
}
|
||||
|
||||
func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClient) {
|
||||
@ -434,7 +428,7 @@ func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClien
|
||||
func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer TailscaleClient) {
|
||||
t.Helper()
|
||||
|
||||
result, stderr, err := doSSH(t, client, peer)
|
||||
result, stderr, err := doSSHWithoutRetry(t, client, peer)
|
||||
|
||||
assert.Empty(t, result)
|
||||
|
||||
@ -444,7 +438,7 @@ func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer Tailsc
|
||||
func assertSSHTimeout(t *testing.T, client TailscaleClient, peer TailscaleClient) {
|
||||
t.Helper()
|
||||
|
||||
result, stderr, _ := doSSH(t, client, peer)
|
||||
result, stderr, _ := doSSHWithoutRetry(t, client, peer)
|
||||
|
||||
assert.Empty(t, result)
|
||||
|
||||
|
@ -251,7 +251,6 @@ func New(
|
||||
Env: []string{},
|
||||
}
|
||||
|
||||
|
||||
if tsic.withWebsocketDERP {
|
||||
if version != VersionHead {
|
||||
return tsic, errInvalidClientConfig
|
||||
@ -463,7 +462,7 @@ func (t *TailscaleInContainer) buildLoginCommand(
|
||||
|
||||
if len(t.withTags) > 0 {
|
||||
command = append(command,
|
||||
fmt.Sprintf(`--advertise-tags=%s`, strings.Join(t.withTags, ",")),
|
||||
"--advertise-tags="+strings.Join(t.withTags, ","),
|
||||
)
|
||||
}
|
||||
|
||||
@ -685,7 +684,7 @@ func (t *TailscaleInContainer) MustID() types.NodeID {
|
||||
// Panics if version is lower then minimum.
|
||||
func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
|
||||
if !util.TailscaleVersionNewerOrEqual("1.56", t.version) {
|
||||
panic(fmt.Sprintf("tsic.Netmap() called with unsupported version: %s", t.version))
|
||||
panic("tsic.Netmap() called with unsupported version: " + t.version)
|
||||
}
|
||||
|
||||
command := []string{
|
||||
@ -1026,7 +1025,7 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) err
|
||||
"tailscale", "ping",
|
||||
fmt.Sprintf("--timeout=%s", args.timeout),
|
||||
fmt.Sprintf("--c=%d", args.count),
|
||||
fmt.Sprintf("--until-direct=%s", strconv.FormatBool(args.direct)),
|
||||
"--until-direct=" + strconv.FormatBool(args.direct),
|
||||
}
|
||||
|
||||
command = append(command, hostnameOrIP)
|
||||
@ -1131,11 +1130,11 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err
|
||||
command := []string{
|
||||
"curl",
|
||||
"--silent",
|
||||
"--connect-timeout", fmt.Sprintf("%d", int(args.connectionTimeout.Seconds())),
|
||||
"--max-time", fmt.Sprintf("%d", int(args.maxTime.Seconds())),
|
||||
"--retry", fmt.Sprintf("%d", args.retry),
|
||||
"--retry-delay", fmt.Sprintf("%d", int(args.retryDelay.Seconds())),
|
||||
"--retry-max-time", fmt.Sprintf("%d", int(args.retryMaxTime.Seconds())),
|
||||
"--connect-timeout", strconv.Itoa(int(args.connectionTimeout.Seconds())),
|
||||
"--max-time", strconv.Itoa(int(args.maxTime.Seconds())),
|
||||
"--retry", strconv.Itoa(args.retry),
|
||||
"--retry-delay", strconv.Itoa(int(args.retryDelay.Seconds())),
|
||||
"--retry-max-time", strconv.Itoa(int(args.retryMaxTime.Seconds())),
|
||||
url,
|
||||
}
|
||||
|
||||
@ -1230,7 +1229,7 @@ func (t *TailscaleInContainer) ReadFile(path string) ([]byte, error) {
|
||||
}
|
||||
|
||||
if out.Len() == 0 {
|
||||
return nil, fmt.Errorf("file is empty")
|
||||
return nil, errors.New("file is empty")
|
||||
}
|
||||
|
||||
return out.Bytes(), nil
|
||||
@ -1259,5 +1258,6 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) {
|
||||
if err = json.Unmarshal(currentProfile, &p); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal current profile state: %w", err)
|
||||
}
|
||||
|
||||
return &p.Persist.PrivateNodeKey, nil
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package integration
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
@ -267,7 +266,7 @@ func assertValidStatus(t *testing.T, client TailscaleClient) {
|
||||
|
||||
// This isn't really relevant for Self as it won't be in its own socket/wireguard.
|
||||
// assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname())
|
||||
// assert.Truef(t, status.Self.InEngine, "%q is not in in wireguard engine", client.Hostname())
|
||||
// assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname())
|
||||
|
||||
for _, peer := range status.Peer {
|
||||
assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname())
|
||||
@ -311,7 +310,7 @@ func assertValidNetcheck(t *testing.T, client TailscaleClient) {
|
||||
func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) {
|
||||
t.Helper()
|
||||
|
||||
_, err := backoff.Retry(context.Background(), func() (struct{}, error) {
|
||||
_, err := backoff.Retry(t.Context(), func() (struct{}, error) {
|
||||
stdout, stderr, err := c.Execute(command)
|
||||
if err != nil {
|
||||
return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err)
|
||||
@ -492,6 +491,7 @@ func groupApprover(name string) policyv2.AutoApprover {
|
||||
func tagApprover(name string) policyv2.AutoApprover {
|
||||
return ptr.To(policyv2.Tag(name))
|
||||
}
|
||||
|
||||
//
|
||||
// // findPeerByHostname takes a hostname and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus
|
||||
// // if there is a peer with the given hostname. If no peer is found, nil is returned.
|
||||
|
Loading…
Reference in New Issue
Block a user