From 8e25f7f9dd12421a805f82f09676b592a39c61b9 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 27 Aug 2025 17:09:13 +0200 Subject: [PATCH] bunch of qol (#2748) --- Dockerfile.tailscale-HEAD | 2 +- cmd/hi/tar_utils.go | 7 +- hscontrol/debug.go | 23 ++++++ hscontrol/mapper/batcher.go | 1 + hscontrol/mapper/batcher_lockfree.go | 4 + hscontrol/mapper/builder.go | 9 +- hscontrol/mapper/builder_test.go | 118 +++++++++++++-------------- hscontrol/mapper/mapper.go | 91 +++++++++++++++------ integration/control.go | 3 + integration/general_test.go | 80 +++++++++++++++++- integration/hsic/hsic.go | 64 ++++++++++++++- 11 files changed, 307 insertions(+), 95 deletions(-) diff --git a/Dockerfile.tailscale-HEAD b/Dockerfile.tailscale-HEAD index 0ee93eb4..43e68992 100644 --- a/Dockerfile.tailscale-HEAD +++ b/Dockerfile.tailscale-HEAD @@ -4,7 +4,7 @@ # This Dockerfile is more or less lifted from tailscale/tailscale # to ensure a similar build process when testing the HEAD of tailscale. -FROM golang:1.24-alpine AS build-env +FROM golang:1.25-alpine AS build-env WORKDIR /go/src diff --git a/cmd/hi/tar_utils.go b/cmd/hi/tar_utils.go index 060b3cf4..f0e1e86b 100644 --- a/cmd/hi/tar_utils.go +++ b/cmd/hi/tar_utils.go @@ -68,7 +68,7 @@ func extractDirectoryFromTar(tarReader io.Reader, targetDir string) error { continue // Skip potentially dangerous paths } - targetPath := filepath.Join(targetDir, filepath.Base(cleanName)) + targetPath := filepath.Join(targetDir, cleanName) switch header.Typeflag { case tar.TypeDir: @@ -77,6 +77,11 @@ func extractDirectoryFromTar(tarReader io.Reader, targetDir string) error { return fmt.Errorf("failed to create directory %s: %w", targetPath, err) } case tar.TypeReg: + // Ensure parent directories exist + if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil { + return fmt.Errorf("failed to create parent directories for %s: %w", targetPath, err) + } + // Create file outFile, err := os.Create(targetPath) if err != nil { diff --git a/hscontrol/debug.go b/hscontrol/debug.go index 481ce589..60676a1d 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -121,6 +121,29 @@ func (h *Headscale) debugHTTPServer() *http.Server { w.Write([]byte(h.state.PolicyDebugString())) })) + debug.Handle("mapresponses", "Map responses for all nodes", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + res, err := h.mapBatcher.DebugMapResponses() + if err != nil { + httpError(w, err) + return + } + + if res == nil { + w.WriteHeader(http.StatusOK) + w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set")) + return + } + + resJSON, err := json.MarshalIndent(res, "", " ") + if err != nil { + httpError(w, err) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(resJSON) + })) + err := statsviz.Register(debugMux) if err == nil { debug.URL("/debug/statsviz", "Statsviz (visualise go metrics)") diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index 21b2209f..bb69eac2 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -24,6 +24,7 @@ type Batcher interface { ConnectedMap() *xsync.Map[types.NodeID, bool] AddWork(c change.ChangeSet) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) + DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) } func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher { diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index aeafa001..e733e29a 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -489,3 +489,7 @@ func (nc *nodeConn) send(data *tailcfg.MapResponse) error { nc.updateCount.Add(1) return nil } + +func (b *LockFreeBatcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) { + return b.mapper.debugMapResponses() +} diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index 111724bc..dfe9d68d 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -237,7 +237,6 @@ func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) // WithPeersRemoved adds removed peer IDs func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder { - var tailscaleIDs []tailcfg.NodeID for _, id := range removedIDs { tailscaleIDs = append(tailscaleIDs, id.NodeID()) @@ -247,12 +246,16 @@ func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapRe } // Build finalizes the response and returns marshaled bytes -func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) { +func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) { if len(b.errs) > 0 { return nil, multierr.New(b.errs...) } if debugDumpMapResponsePath != "" { - writeDebugMapResponse(b.resp, b.nodeID) + node, err := b.mapper.state.GetNodeByID(b.nodeID) + if err != nil { + return nil, err + } + writeDebugMapResponse(b.resp, node) } return b.resp, nil diff --git a/hscontrol/mapper/builder_test.go b/hscontrol/mapper/builder_test.go index c8ff59ec..978b2c0e 100644 --- a/hscontrol/mapper/builder_test.go +++ b/hscontrol/mapper/builder_test.go @@ -18,17 +18,17 @@ func TestMapResponseBuilder_Basic(t *testing.T) { Enabled: true, }, } - + mockState := &state.State{} m := &mapper{ cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) - + builder := m.NewMapResponseBuilder(nodeID) - + // Test basic builder creation assert.NotNil(t, builder) assert.Equal(t, nodeID, builder.nodeID) @@ -45,13 +45,13 @@ func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) capVer := tailcfg.CapabilityVersion(42) - + builder := m.NewMapResponseBuilder(nodeID). WithCapabilityVersion(capVer) - + assert.Equal(t, capVer, builder.capVer) assert.False(t, builder.hasErrors()) } @@ -62,18 +62,18 @@ func TestMapResponseBuilder_WithDomain(t *testing.T) { ServerURL: "https://test.example.com", BaseDomain: domain, } - + mockState := &state.State{} m := &mapper{ cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) - + builder := m.NewMapResponseBuilder(nodeID). WithDomain() - + assert.Equal(t, domain, builder.resp.Domain) assert.False(t, builder.hasErrors()) } @@ -85,12 +85,12 @@ func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) - + builder := m.NewMapResponseBuilder(nodeID). WithCollectServicesDisabled() - + value, isSet := builder.resp.CollectServices.Get() assert.True(t, isSet) assert.False(t, value) @@ -99,22 +99,22 @@ func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) { func TestMapResponseBuilder_WithDebugConfig(t *testing.T) { tests := []struct { - name string + name string logTailEnabled bool - expected bool + expected bool }{ { - name: "LogTail enabled", + name: "LogTail enabled", logTailEnabled: true, - expected: false, // DisableLogTail should be false when LogTail is enabled + expected: false, // DisableLogTail should be false when LogTail is enabled }, { - name: "LogTail disabled", + name: "LogTail disabled", logTailEnabled: false, - expected: true, // DisableLogTail should be true when LogTail is disabled + expected: true, // DisableLogTail should be true when LogTail is disabled }, } - + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := &types.Config{ @@ -127,12 +127,12 @@ func TestMapResponseBuilder_WithDebugConfig(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) - + builder := m.NewMapResponseBuilder(nodeID). WithDebugConfig() - + require.NotNil(t, builder.resp.Debug) assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail) assert.False(t, builder.hasErrors()) @@ -147,22 +147,22 @@ func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) changes := []*tailcfg.PeerChange{ { - NodeID: 123, + NodeID: 123, DERPRegion: 1, }, { - NodeID: 456, + NodeID: 456, DERPRegion: 2, }, } - + builder := m.NewMapResponseBuilder(nodeID). WithPeerChangedPatch(changes) - + assert.Equal(t, changes, builder.resp.PeersChangedPatch) assert.False(t, builder.hasErrors()) } @@ -174,14 +174,14 @@ func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) removedID1 := types.NodeID(123) removedID2 := types.NodeID(456) - + builder := m.NewMapResponseBuilder(nodeID). WithPeersRemoved(removedID1, removedID2) - + expected := []tailcfg.NodeID{ removedID1.NodeID(), removedID2.NodeID(), @@ -197,25 +197,25 @@ func TestMapResponseBuilder_ErrorHandling(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) - + // Simulate an error in the builder builder := m.NewMapResponseBuilder(nodeID) builder.addError(assert.AnError) - + // All subsequent calls should continue to work and accumulate errors result := builder. WithDomain(). WithCollectServicesDisabled(). WithDebugConfig() - + assert.True(t, result.hasErrors()) assert.Len(t, result.errs, 1) assert.Equal(t, assert.AnError, result.errs[0]) - + // Build should return the error - data, err := result.Build("none") + data, err := result.Build() assert.Nil(t, data) assert.Error(t, err) } @@ -229,22 +229,22 @@ func TestMapResponseBuilder_ChainedCalls(t *testing.T) { Enabled: false, }, } - + mockState := &state.State{} m := &mapper{ cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) capVer := tailcfg.CapabilityVersion(99) - + builder := m.NewMapResponseBuilder(nodeID). WithCapabilityVersion(capVer). WithDomain(). WithCollectServicesDisabled(). WithDebugConfig() - + // Verify all fields are set correctly assert.Equal(t, capVer, builder.capVer) assert.Equal(t, domain, builder.resp.Domain) @@ -263,16 +263,16 @@ func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) removedID1 := types.NodeID(100) removedID2 := types.NodeID(200) - + // Test calling WithPeersRemoved multiple times builder := m.NewMapResponseBuilder(nodeID). WithPeersRemoved(removedID1). WithPeersRemoved(removedID2) - + // Second call should overwrite the first expected := []tailcfg.NodeID{removedID2.NodeID()} assert.Equal(t, expected, builder.resp.PeersRemoved) @@ -286,12 +286,12 @@ func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) - + builder := m.NewMapResponseBuilder(nodeID). WithPeerChangedPatch([]*tailcfg.PeerChange{}) - + assert.Empty(t, builder.resp.PeersChangedPatch) assert.False(t, builder.hasErrors()) } @@ -303,12 +303,12 @@ func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) - + builder := m.NewMapResponseBuilder(nodeID). WithPeerChangedPatch(nil) - + assert.Nil(t, builder.resp.PeersChangedPatch) assert.False(t, builder.hasErrors()) } @@ -320,28 +320,28 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) { cfg: cfg, state: mockState, } - + nodeID := types.NodeID(1) - + // Create a builder and add multiple errors builder := m.NewMapResponseBuilder(nodeID) builder.addError(assert.AnError) builder.addError(assert.AnError) builder.addError(nil) // This should be ignored - + // All subsequent calls should continue to work result := builder. WithDomain(). WithCollectServicesDisabled() - + assert.True(t, result.hasErrors()) assert.Len(t, result.errs, 2) // nil error should be ignored - + // Build should return a multierr - data, err := result.Build("none") + data, err := result.Build() assert.Nil(t, data) assert.Error(t, err) - + // The error should contain information about multiple errors assert.Contains(t, err.Error(), "multiple errors") -} \ No newline at end of file +} diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 43764457..59c92e24 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -9,6 +9,7 @@ import ( "os" "path" "slices" + "strconv" "strings" "time" @@ -154,7 +155,7 @@ func (m *mapper) fullMapResponse( WithUserProfiles(peers). WithPacketFilters(). WithPeers(peers). - Build(messages...) + Build() } func (m *mapper) derpMapResponse( @@ -207,36 +208,15 @@ func (m *mapper) peerRemovedResponse( func writeDebugMapResponse( resp *tailcfg.MapResponse, - nodeID types.NodeID, - messages ...string, + node *types.Node, ) { - data := map[string]any{ - "Messages": messages, - "MapResponse": resp, - } - - responseType := "keepalive" - - switch { - case len(resp.Peers) > 0: - responseType = "full" - case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive: - responseType = "self" - case len(resp.PeersChanged) > 0: - responseType = "changed" - case len(resp.PeersChangedPatch) > 0: - responseType = "patch" - case len(resp.PeersRemoved) > 0: - responseType = "removed" - } - - body, err := json.MarshalIndent(data, "", " ") + body, err := json.MarshalIndent(resp, "", " ") if err != nil { panic(err) } perms := fs.FileMode(debugMapResponsePerm) - mPath := path.Join(debugDumpMapResponsePath, nodeID.String()) + mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", node.ID)) err = os.MkdirAll(mPath, perms) if err != nil { panic(err) @@ -246,7 +226,7 @@ func writeDebugMapResponse( mapResponsePath := path.Join( mPath, - fmt.Sprintf("%s-%s.json", now, responseType), + fmt.Sprintf("%s.json", now), ) log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath) @@ -279,3 +259,62 @@ func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types. // netip.Prefixes that are allowed for that node. It is used to filter routes // from the primary route manager to the node. type routeFilterFunc func(id types.NodeID) []netip.Prefix + +func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) { + if debugDumpMapResponsePath == "" { + return nil, nil + } + + nodes, err := os.ReadDir(debugDumpMapResponsePath) + if err != nil { + return nil, err + } + + result := make(map[types.NodeID][]tailcfg.MapResponse) + for _, node := range nodes { + if !node.IsDir() { + continue + } + + nodeIDu, err := strconv.ParseUint(node.Name(), 10, 64) + if err != nil { + log.Error().Err(err).Msgf("Parsing node ID from dir %s", node.Name()) + continue + } + + nodeID := types.NodeID(nodeIDu) + + files, err := os.ReadDir(path.Join(debugDumpMapResponsePath, node.Name())) + if err != nil { + log.Error().Err(err).Msgf("Reading dir %s", node.Name()) + continue + } + + slices.SortStableFunc(files, func(a, b fs.DirEntry) int { + return strings.Compare(a.Name(), b.Name()) + }) + + for _, file := range files { + if file.IsDir() || !strings.HasSuffix(file.Name(), ".json") { + continue + } + + body, err := os.ReadFile(path.Join(debugDumpMapResponsePath, node.Name(), file.Name())) + if err != nil { + log.Error().Err(err).Msgf("Reading file %s", file.Name()) + continue + } + + var resp tailcfg.MapResponse + err = json.Unmarshal(body, &resp) + if err != nil { + log.Error().Err(err).Msgf("Unmarshalling file %s", file.Name()) + continue + } + + result[nodeID] = append(result[nodeID], resp) + } + } + + return result, nil +} diff --git a/integration/control.go b/integration/control.go index df1d5d13..e3cb17bd 100644 --- a/integration/control.go +++ b/integration/control.go @@ -5,7 +5,9 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/hscontrol/types" "github.com/ory/dockertest/v3" + "tailscale.com/tailcfg" ) type ControlServer interface { @@ -29,4 +31,5 @@ type ControlServer interface { GetCert() []byte GetHostname() string SetPolicy(*policyv2.Policy) error + GetAllMapReponses() (map[types.NodeID][]tailcfg.MapResponse, error) } diff --git a/integration/general_test.go b/integration/general_test.go index 4e250854..4bf36567 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -21,6 +21,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" "tailscale.com/client/tailscale/apitype" + "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -55,6 +56,17 @@ func TestPingAllByIP(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) + hs, err := scenario.Headscale() + require.NoError(t, err) + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + all, err := hs.GetAllMapReponses() + assert.NoError(ct, err) + + onlineMap := buildExpectedOnlineMap(all) + assertExpectedOnlineMapAllOnline(ct, len(allClients)-1, onlineMap) + }, 30*time.Second, 2*time.Second) + // assertClientsState(t, allClients) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { @@ -940,6 +952,9 @@ func TestPingAllByIPManyUpDown(t *testing.T) { ) assertNoErrHeadscaleEnv(t, err) + hs, err := scenario.Headscale() + require.NoError(t, err) + allClients, err := scenario.ListTailscaleClients() assertNoErrListClients(t, err) @@ -961,7 +976,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) { wg, _ := errgroup.WithContext(context.Background()) for run := range 3 { - t.Logf("Starting DownUpPing run %d", run+1) + t.Logf("Starting DownUpPing run %d at %s", run+1, time.Now().Format("2006-01-02T15-04-05.999999999")) for _, client := range allClients { c := client @@ -974,6 +989,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) { if err := wg.Wait(); err != nil { t.Fatalf("failed to take down all nodes: %s", err) } + t.Logf("All nodes taken down at %s", time.Now().Format("2006-01-02T15-04-05.999999999")) for _, client := range allClients { c := client @@ -984,13 +1000,24 @@ func TestPingAllByIPManyUpDown(t *testing.T) { } if err := wg.Wait(); err != nil { - t.Fatalf("failed to take down all nodes: %s", err) + t.Fatalf("failed to bring up all nodes: %s", err) } + t.Logf("All nodes brought up at %s", time.Now().Format("2006-01-02T15-04-05.999999999")) // Wait for sync and successful pings after nodes come back up err = scenario.WaitForTailscaleSync() assert.NoError(t, err) + t.Logf("All nodes synced up %s", time.Now().Format("2006-01-02T15-04-05.999999999")) + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + all, err := hs.GetAllMapReponses() + assert.NoError(ct, err) + + onlineMap := buildExpectedOnlineMap(all) + assertExpectedOnlineMapAllOnline(ct, len(allClients)-1, onlineMap) + }, 60*time.Second, 2*time.Second) + success := pingAllHelper(t, allClients, allAddrs) assert.Equalf(t, len(allClients)*len(allIps), success, "%d successful pings out of %d", success, len(allClients)*len(allIps)) } @@ -1103,3 +1130,52 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { assert.True(t, nodeListAfter[0].GetOnline()) assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId()) } + +func buildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[types.NodeID]map[types.NodeID]bool { + res := make(map[types.NodeID]map[types.NodeID]bool) + for nid, mrs := range all { + res[nid] = make(map[types.NodeID]bool) + for _, mr := range mrs { + for _, peer := range mr.Peers { + if peer.Online != nil { + res[nid][types.NodeID(peer.ID)] = *peer.Online + } + } + + for _, peer := range mr.PeersChanged { + if peer.Online != nil { + res[nid][types.NodeID(peer.ID)] = *peer.Online + } + } + + for _, peer := range mr.PeersChangedPatch { + if peer.Online != nil { + res[nid][types.NodeID(peer.NodeID)] = *peer.Online + } + } + } + } + return res +} + +func assertExpectedOnlineMapAllOnline(t *assert.CollectT, expectedPeerCount int, onlineMap map[types.NodeID]map[types.NodeID]bool) { + for nid, peers := range onlineMap { + onlineCount := 0 + for _, online := range peers { + if online { + onlineCount++ + } + } + assert.Equalf(t, expectedPeerCount, len(peers), "node:%d had an unexpected number of peers in online map", nid) + if expectedPeerCount != onlineCount { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Not all of node:%d peers where online:\n", nid)) + for pid, online := range peers { + sb.WriteString(fmt.Sprintf("\tPeer node:%d online: %t\n", pid, online)) + } + sb.WriteString("timestamp: " + time.Now().Format("2006-01-02T15-04-05.999999999") + "\n") + sb.WriteString("expected all peers to be online.") + t.Errorf("%s", sb.String()) + } + } +} diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index e77d2fbe..22250eb4 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -622,6 +622,27 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { } tarReader := tar.NewReader(bytes.NewReader(tarData)) + + // Find the top-level directory to strip + var topLevelDir string + firstPass := tar.NewReader(bytes.NewReader(tarData)) + for { + header, err := firstPass.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read tar header: %w", err) + } + + if header.Typeflag == tar.TypeDir && topLevelDir == "" { + topLevelDir = strings.TrimSuffix(header.Name, "/") + break + } + } + + // Second pass: extract files, stripping the top-level directory + tarReader = tar.NewReader(bytes.NewReader(tarData)) for { header, err := tarReader.Next() if err == io.EOF { @@ -637,7 +658,20 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { continue // Skip potentially dangerous paths } - targetPath := filepath.Join(targetDir, filepath.Base(cleanName)) + // Strip the top-level directory + if topLevelDir != "" && strings.HasPrefix(cleanName, topLevelDir+"/") { + cleanName = strings.TrimPrefix(cleanName, topLevelDir+"/") + } else if cleanName == topLevelDir { + // Skip the top-level directory itself + continue + } + + // Skip empty paths after stripping + if cleanName == "" { + continue + } + + targetPath := filepath.Join(targetDir, cleanName) switch header.Typeflag { case tar.TypeDir: @@ -646,6 +680,11 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { return fmt.Errorf("failed to create directory %s: %w", targetPath, err) } case tar.TypeReg: + // Ensure parent directories exist + if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil { + return fmt.Errorf("failed to create parent directories for %s: %w", targetPath, err) + } + // Create file outFile, err := os.Create(targetPath) if err != nil { @@ -674,7 +713,7 @@ func (t *HeadscaleInContainer) SaveProfile(savePath string) error { return err } - targetDir := path.Join(savePath, t.hostname+"-pprof") + targetDir := path.Join(savePath, "pprof") return extractTarToDirectory(tarFile, targetDir) } @@ -685,7 +724,7 @@ func (t *HeadscaleInContainer) SaveMapResponses(savePath string) error { return err } - targetDir := path.Join(savePath, t.hostname+"-mapresponses") + targetDir := path.Join(savePath, "mapresponses") return extractTarToDirectory(tarFile, targetDir) } @@ -1243,3 +1282,22 @@ func (t *HeadscaleInContainer) SendInterrupt() error { return nil } + +func (t *HeadscaleInContainer) GetAllMapReponses() (map[types.NodeID][]tailcfg.MapResponse, error) { + // Execute curl inside the container to access the debug endpoint locally + command := []string{ + "curl", "-s", "-H", "Accept: application/json", "http://localhost:9090/debug/mapresponses", + } + + result, err := t.Execute(command) + if err != nil { + return nil, fmt.Errorf("fetching mapresponses from debug endpoint: %w", err) + } + + var res map[types.NodeID][]tailcfg.MapResponse + if err := json.Unmarshal([]byte(result), &res); err != nil { + return nil, fmt.Errorf("decoding routes response: %w", err) + } + + return res, nil +}