1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-08-14 13:51:01 +02:00
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-08-06 08:46:12 +02:00
parent 7eef3cc38c
commit c24b988247
No known key found for this signature in database
27 changed files with 2073 additions and 755 deletions

View File

@ -136,9 +136,9 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
// Initialize ephemeral garbage collector
ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
node, err := app.state.GetNodeByID(ni)
if err != nil {
log.Err(err).Uint64("node.id", ni.Uint64()).Msgf("failed to get ephemeral node for deletion")
node, ok := app.state.GetNodeByID(ni)
if !ok {
log.Warn().Uint64("node.id", ni.Uint64()).Msgf("ephemeral node not found for deletion")
return
}
@ -371,7 +371,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
Str("client_address", req.RemoteAddr).
Msg("HTTP authentication invoked")
authHeader := req.Header.Get("authorization")
authHeader := req.Header.Get("Authorization")
if !strings.HasPrefix(authHeader, AuthPrefix) {
log.Error().
@ -487,11 +487,12 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
// Serve launches the HTTP and gRPC server service Headscale and the API.
func (h *Headscale) Serve() error {
var err error
capver.CanOldCodeBeCleanedUp()
if profilingEnabled {
if profilingPath != "" {
err := os.MkdirAll(profilingPath, os.ModePerm)
err = os.MkdirAll(profilingPath, os.ModePerm)
if err != nil {
log.Fatal().Err(err).Msg("failed to create profiling directory")
}
@ -543,10 +544,7 @@ func (h *Headscale) Serve() error {
// around between restarts, they will reconnect and the GC will
// be cancelled.
go h.ephemeralGC.Start()
ephmNodes, err := h.state.ListEphemeralNodes()
if err != nil {
return fmt.Errorf("failed to list ephemeral nodes: %w", err)
}
ephmNodes := h.state.ListEphemeralNodes()
for _, node := range ephmNodes.All() {
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
}
@ -778,23 +776,14 @@ func (h *Headscale) Serve() error {
continue
}
changed, err := h.state.ReloadPolicy()
changes, err := h.state.ReloadPolicy()
if err != nil {
log.Error().Err(err).Msgf("reloading policy")
continue
}
if changed {
log.Info().
Msg("ACL policy successfully reloaded, notifying nodes of change")
h.Change(changes...)
err = h.state.AutoApproveNodes()
if err != nil {
log.Error().Err(err).Msg("failed to approve routes after new policy")
}
h.Change(change.PolicySet)
}
default:
info := func(msg string) { log.Info().Msg(msg) }
log.Info().
@ -1004,6 +993,8 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
// Change is used to send changes to nodes.
// All change should be enqueued here and empty will be automatically
// ignored.
func (h *Headscale) Change(c change.ChangeSet) {
h.mapBatcher.AddWork(c)
func (h *Headscale) Change(cs ...change.ChangeSet) {
for _, c := range cs {
h.mapBatcher.AddWork(c)
}
}

View File

@ -11,7 +11,6 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
@ -28,27 +27,9 @@ func (h *Headscale) handleRegister(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
node, err := h.state.GetNodeByNodeKey(regReq.NodeKey)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("looking up node in database: %w", err)
}
if node.Valid() {
// If an existing node is trying to register with an auth key,
// we need to validate the auth key even for existing nodes
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
if err != nil {
// Preserve HTTPError types so they can be handled properly by the HTTP layer
var httpErr HTTPError
if errors.As(err, &httpErr) {
return nil, httpErr
}
return nil, fmt.Errorf("handling register with auth key for existing node: %w", err)
}
return resp, nil
}
node, ok := h.state.GetNodeByNodeKey(regReq.NodeKey)
if ok {
resp, err := h.handleExistingNode(node.AsStruct(), regReq, machineKey)
if err != nil {
return nil, fmt.Errorf("handling existing node: %w", err)
@ -69,6 +50,7 @@ func (h *Headscale) handleRegister(
if errors.As(err, &httpErr) {
return nil, httpErr
}
return nil, fmt.Errorf("handling register with auth key: %w", err)
}
@ -88,13 +70,22 @@ func (h *Headscale) handleExistingNode(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
if node.MachineKey != machineKey {
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
}
expired := node.IsExpired()
// If the node is expired and this is not a re-authentication attempt,
// force the client to re-authenticate
if expired && regReq.Auth == nil {
return &tailcfg.RegisterResponse{
NodeKeyExpired: true,
MachineAuthorized: false,
AuthURL: "", // Client will need to re-authenticate
}, nil
}
if !expired && !regReq.Expiry.IsZero() {
requestExpiry := regReq.Expiry
@ -117,12 +108,16 @@ func (h *Headscale) handleExistingNode(
}
}
_, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
updatedNode, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
if err != nil {
return nil, fmt.Errorf("setting node expiry: %w", err)
}
h.Change(c)
// CRITICAL: Use the updated node view for the response
// The original node object has stale expiry information
node = updatedNode.AsStruct()
}
return nodeToRegisterResponse(node), nil
@ -192,8 +187,8 @@ func (h *Headscale) handleRegisterWithAuthKey(
return nil, err
}
// If node is nil, it means an ephemeral node was deleted during logout
if node.Valid() {
// If node is not valid, it means an ephemeral node was deleted during logout
if !node.Valid() {
h.Change(changed)
return nil, nil
}
@ -212,6 +207,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
// TODO(kradalby): This needs to be ran as part of the batcher maybe?
// now since we dont update the node/pol here anymore
routeChange := h.state.AutoApproveRoutes(node)
if _, _, err := h.state.SaveNode(node); err != nil {
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
}
@ -229,6 +225,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
// }
user := node.User()
return &tailcfg.RegisterResponse{
MachineAuthorized: true,
NodeKeyExpired: node.IsExpired(),

View File

@ -936,7 +936,7 @@ AND auth_key_id NOT IN (
// - NEVER use gorm.AutoMigrate, write the exact migration steps needed
// - AutoMigrate depends on the struct staying exactly the same, which it won't over time.
// - Never write migrations that requires foreign keys to be disabled.
},
},
)
if err := runMigrations(cfg, dbConn, migrations); err != nil {

View File

@ -260,24 +260,18 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
}
// RenameNode takes a Node struct and a new GivenName for the nodes
// and renames it. If the name is not unique, it will return an error.
// and renames it. Validation should be done in the state layer before calling this function.
func RenameNode(tx *gorm.DB,
nodeID types.NodeID, newName string,
) error {
err := util.CheckForFQDNRules(
newName,
)
if err != nil {
return fmt.Errorf("renaming node: %w", err)
// Check if the new name is unique
var count int64
if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil {
return fmt.Errorf("failed to check name uniqueness: %w", err)
}
uniq, err := isUniqueName(tx, newName)
if err != nil {
return fmt.Errorf("checking if name is unique: %w", err)
}
if !uniq {
return fmt.Errorf("name is not unique: %s", newName)
if count > 0 {
return fmt.Errorf("name is not unique")
}
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
@ -409,9 +403,16 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
return nil, err
}
// CRITICAL: Reload the node to get the updated expiry
// Without this, we return stale node data to NodeStore
updatedNode, err := GetNodeByID(tx, node.ID)
if err != nil {
return nil, fmt.Errorf("failed to reload node after expiry update: %w", err)
}
nodeChange = change.KeyExpiry(node.ID)
return node, nil
return updatedNode, nil
}
}
@ -445,8 +446,13 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
node.ID = oldNode.ID
node.GivenName = oldNode.GivenName
node.ApprovedRoutes = oldNode.ApprovedRoutes
ipv4 = oldNode.IPv4
ipv6 = oldNode.IPv6
// Don't overwrite the provided IPs with old ones when they exist
if ipv4 == nil {
ipv4 = oldNode.IPv4
}
if ipv6 == nil {
ipv6 = oldNode.IPv6
}
}
// If the node exists and it already has IP(s), we just save it
@ -781,19 +787,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname .
node := hsdb.CreateNodeForTest(user, hostname...)
err := hsdb.DB.Transaction(func(tx *gorm.DB) error {
_, err := RegisterNode(tx, *node, nil, nil)
// Allocate IPs for the test node using the database's IP allocator
// This is a simplified allocation for testing - in production this would use State.ipAlloc
ipv4, ipv6, err := hsdb.allocateTestIPs(node.ID)
if err != nil {
panic(fmt.Sprintf("failed to allocate IPs for test node: %v", err))
}
var registeredNode *types.Node
err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
var err error
registeredNode, err = RegisterNode(tx, *node, ipv4, ipv6)
return err
})
if err != nil {
panic(fmt.Sprintf("failed to register test node: %v", err))
}
registeredNode, err := hsdb.GetNodeByID(node.ID)
if err != nil {
panic(fmt.Sprintf("failed to get registered test node: %v", err))
}
return registeredNode
}
@ -842,3 +852,23 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int
return nodes
}
// allocateTestIPs allocates sequential test IPs for nodes during testing.
func (hsdb *HSDatabase) allocateTestIPs(nodeID types.NodeID) (*netip.Addr, *netip.Addr, error) {
if !testing.Testing() {
panic("allocateTestIPs can only be called during tests")
}
// Use simple sequential allocation for tests
// IPv4: 100.64.0.x (where x is nodeID)
// IPv6: fd7a:115c:a1e0::x (where x is nodeID)
if nodeID > 254 {
return nil, nil, fmt.Errorf("test node ID %d too large for simple IP allocation", nodeID)
}
ipv4 := netip.AddrFrom4([4]byte{100, 64, 0, byte(nodeID)})
ipv6 := netip.AddrFrom16([16]byte{0xfd, 0x7a, 0x11, 0x5c, 0xa1, 0xe0, 0, 0, 0, 0, 0, 0, 0, 0, 0, byte(nodeID)})
return &ipv4, &ipv6, nil
}

View File

@ -292,12 +292,57 @@ func TestHeadscale_generateGivenName(t *testing.T) {
func TestAutoApproveRoutes(t *testing.T) {
tests := []struct {
name string
acl string
routes []netip.Prefix
want []netip.Prefix
want2 []netip.Prefix
name string
acl string
routes []netip.Prefix
want []netip.Prefix
want2 []netip.Prefix
expectChange bool // whether to expect route changes
}{
{
name: "no-auto-approvers-empty-policy",
acl: `
{
"groups": {
"group:admins": ["test@"]
},
"acls": [
{
"action": "accept",
"src": ["group:admins"],
"dst": ["group:admins:*"]
}
]
}`,
routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
want: []netip.Prefix{}, // Should be empty - no auto-approvers
want2: []netip.Prefix{}, // Should be empty - no auto-approvers
expectChange: false, // No changes expected
},
{
name: "no-auto-approvers-explicit-empty",
acl: `
{
"groups": {
"group:admins": ["test@"]
},
"acls": [
{
"action": "accept",
"src": ["group:admins"],
"dst": ["group:admins:*"]
}
],
"autoApprovers": {
"routes": {},
"exitNode": []
}
}`,
routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
want: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers
want2: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers
expectChange: false, // No changes expected
},
{
name: "2068-approve-issue-sub-kube",
acl: `
@ -316,8 +361,9 @@ func TestAutoApproveRoutes(t *testing.T) {
}
}
}`,
routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
expectChange: true, // Routes should be approved
},
{
name: "2068-approve-issue-sub-exit-tag",
@ -361,6 +407,7 @@ func TestAutoApproveRoutes(t *testing.T) {
tsaddr.AllIPv4(),
tsaddr.AllIPv6(),
},
expectChange: true, // Routes should be approved
},
}
@ -421,28 +468,40 @@ func TestAutoApproveRoutes(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, pm)
changed1 := policy.AutoApproveRoutes(pm, &node)
assert.True(t, changed1)
newRoutes1, changed1 := policy.ApproveRoutesWithPolicy(pm, node.View(), node.ApprovedRoutes, tt.routes)
assert.Equal(t, tt.expectChange, changed1)
err = adb.DB.Save(&node).Error
require.NoError(t, err)
if changed1 {
err = SetApprovedRoutes(adb.DB, node.ID, newRoutes1)
require.NoError(t, err)
}
_ = policy.AutoApproveRoutes(pm, &nodeTagged)
err = adb.DB.Save(&nodeTagged).Error
require.NoError(t, err)
newRoutes2, changed2 := policy.ApproveRoutesWithPolicy(pm, nodeTagged.View(), node.ApprovedRoutes, tt.routes)
if changed2 {
err = SetApprovedRoutes(adb.DB, nodeTagged.ID, newRoutes2)
require.NoError(t, err)
}
node1ByID, err := adb.GetNodeByID(1)
require.NoError(t, err)
if diff := cmp.Diff(tt.want, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" {
// For empty auto-approvers tests, handle nil vs empty slice comparison
expectedRoutes1 := tt.want
if len(expectedRoutes1) == 0 {
expectedRoutes1 = nil
}
if diff := cmp.Diff(expectedRoutes1, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" {
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
}
node2ByID, err := adb.GetNodeByID(2)
require.NoError(t, err)
if diff := cmp.Diff(tt.want2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" {
expectedRoutes2 := tt.want2
if len(expectedRoutes2) == 0 {
expectedRoutes2 = nil
}
if diff := cmp.Diff(expectedRoutes2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" {
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
}
})
@ -739,13 +798,13 @@ func TestListPeers(t *testing.T) {
// No parameter means no filter, should return all peers
nodes, err = db.ListPeers(1)
require.NoError(t, err)
assert.Equal(t, 1, len(nodes))
assert.Len(t, nodes, 1)
assert.Equal(t, "test2", nodes[0].Hostname)
// Empty node list should return all peers
nodes, err = db.ListPeers(1, types.NodeIDs{}...)
require.NoError(t, err)
assert.Equal(t, 1, len(nodes))
assert.Len(t, nodes, 1)
assert.Equal(t, "test2", nodes[0].Hostname)
// No match in IDs should return empty list and no error
@ -756,13 +815,13 @@ func TestListPeers(t *testing.T) {
// Partial match in IDs
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
require.NoError(t, err)
assert.Equal(t, 1, len(nodes))
assert.Len(t, nodes, 1)
assert.Equal(t, "test2", nodes[0].Hostname)
// Several matched IDs, but node ID is still filtered out
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
require.NoError(t, err)
assert.Equal(t, 1, len(nodes))
assert.Len(t, nodes, 1)
assert.Equal(t, "test2", nodes[0].Hostname)
}
@ -824,14 +883,14 @@ func TestListNodes(t *testing.T) {
// No parameter means no filter, should return all nodes
nodes, err = db.ListNodes()
require.NoError(t, err)
assert.Equal(t, 2, len(nodes))
assert.Len(t, nodes, 2)
assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname)
// Empty node list should return all nodes
nodes, err = db.ListNodes(types.NodeIDs{}...)
require.NoError(t, err)
assert.Equal(t, 2, len(nodes))
assert.Len(t, nodes, 2)
assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname)
@ -843,13 +902,13 @@ func TestListNodes(t *testing.T) {
// Partial match in IDs
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
require.NoError(t, err)
assert.Equal(t, 1, len(nodes))
assert.Len(t, nodes, 1)
assert.Equal(t, "test2", nodes[0].Hostname)
// Several matched IDs
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
require.NoError(t, err)
assert.Equal(t, 2, len(nodes))
assert.Len(t, nodes, 2)
assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname)
}

View File

@ -198,19 +198,20 @@ func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) {
}
// AssignNodeToUser assigns a Node to a user.
// Note: Validation should be done in the state layer before calling this function.
func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error {
node, err := GetNodeByID(tx, nodeID)
if err != nil {
return err
// Check if the user exists
var userExists bool
if err := tx.Model(&types.User{}).Select("count(*) > 0").Where("id = ?", uid).Find(&userExists).Error; err != nil {
return fmt.Errorf("failed to check if user exists: %w", err)
}
user, err := GetUserByID(tx, uid)
if err != nil {
return err
if !userExists {
return ErrUserNotFound
}
node.User = *user
node.UserID = user.ID
if result := tx.Save(&node); result.Error != nil {
return result.Error
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("user_id", uid).Error; err != nil {
return fmt.Errorf("failed to assign node to user: %w", err)
}
return nil

View File

@ -288,9 +288,9 @@ func (api headscaleV1APIServer) GetNode(
ctx context.Context,
request *v1.GetNodeRequest,
) (*v1.GetNodeResponse, error) {
node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil {
return nil, err
node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
if !ok {
return nil, status.Errorf(codes.NotFound, "node not found")
}
resp := node.Proto()
@ -334,7 +334,12 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
ctx context.Context,
request *v1.SetApprovedRoutesRequest,
) (*v1.SetApprovedRoutesResponse, error) {
var routes []netip.Prefix
log.Debug().
Uint64("node.id", request.GetNodeId()).
Strs("requestedRoutes", request.GetRoutes()).
Msg("gRPC SetApprovedRoutes called")
var newApproved []netip.Prefix
for _, route := range request.GetRoutes() {
prefix, err := netip.ParsePrefix(route)
if err != nil {
@ -344,31 +349,34 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
// If the prefix is an exit route, add both. The client expect both
// to annotate the node as an exit node.
if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() {
routes = append(routes, tsaddr.AllIPv4(), tsaddr.AllIPv6())
newApproved = append(newApproved, tsaddr.AllIPv4(), tsaddr.AllIPv6())
} else {
routes = append(routes, prefix)
newApproved = append(newApproved, prefix)
}
}
tsaddr.SortPrefixes(routes)
routes = slices.Compact(routes)
tsaddr.SortPrefixes(newApproved)
newApproved = slices.Compact(newApproved)
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved)
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
routeChange := api.h.state.SetNodeRoutes(node.ID(), node.SubnetRoutes()...)
// Always propagate node changes from SetApprovedRoutes
api.h.Change(nodeChange)
// If routes changed, propagate those changes too
if !routeChange.Empty() {
api.h.Change(routeChange)
}
proto := node.Proto()
proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID()))
// Populate SubnetRoutes with PrimaryRoutes to ensure it includes only the
// routes that are actively served from the node (per architectural requirement in types/node.go)
primaryRoutes := api.h.state.GetNodePrimaryRoutes(node.ID())
proto.SubnetRoutes = util.PrefixesToString(primaryRoutes)
log.Debug().
Uint64("node.id", node.ID().Uint64()).
Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())).
Strs("primaryRoutes", util.PrefixesToString(primaryRoutes)).
Strs("finalSubnetRoutes", proto.SubnetRoutes).
Msg("gRPC SetApprovedRoutes completed")
return &v1.SetApprovedRoutesResponse{Node: proto}, nil
}
@ -390,9 +398,9 @@ func (api headscaleV1APIServer) DeleteNode(
ctx context.Context,
request *v1.DeleteNodeRequest,
) (*v1.DeleteNodeResponse, error) {
node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil {
return nil, err
node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
if !ok {
return nil, status.Errorf(codes.NotFound, "node not found")
}
nodeChange, err := api.h.state.DeleteNode(node)
@ -463,19 +471,13 @@ func (api headscaleV1APIServer) ListNodes(
return nil, err
}
nodes, err := api.h.state.ListNodesByUser(types.UserID(user.ID))
if err != nil {
return nil, err
}
nodes := api.h.state.ListNodesByUser(types.UserID(user.ID))
response := nodesToProto(api.h.state, IsConnected, nodes)
return &v1.ListNodesResponse{Nodes: response}, nil
}
nodes, err := api.h.state.ListNodes()
if err != nil {
return nil, err
}
nodes := api.h.state.ListNodes()
response := nodesToProto(api.h.state, IsConnected, nodes)
return &v1.ListNodesResponse{Nodes: response}, nil
@ -499,6 +501,7 @@ func nodesToProto(state *state.State, isLikelyConnected *xsync.Map[types.NodeID,
}
}
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...))
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...))
response[index] = resp
}
@ -674,11 +677,8 @@ func (api headscaleV1APIServer) SetPolicy(
// a scenario where they might be allowed if the server has no nodes
// yet, but it should help for the general case and for hot reloading
// configurations.
nodes, err := api.h.state.ListNodes()
if err != nil {
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
}
changed, err := api.h.state.SetPolicy([]byte(p))
nodes := api.h.state.ListNodes()
_, err := api.h.state.SetPolicy([]byte(p))
if err != nil {
return nil, fmt.Errorf("setting policy: %w", err)
}
@ -695,16 +695,16 @@ func (api headscaleV1APIServer) SetPolicy(
return nil, err
}
// Only send update if the packet filter has changed.
if changed {
err = api.h.state.AutoApproveNodes()
if err != nil {
return nil, err
}
api.h.Change(change.PolicyChange())
// Always reload policy to ensure route re-evaluation, even if policy content hasn't changed.
// This ensures that routes are re-evaluated for auto-approval in cases where routes
// were manually disabled but could now be auto-approved with the current policy.
cs, err := api.h.state.ReloadPolicy()
if err != nil {
return nil, fmt.Errorf("reloading policy: %w", err)
}
api.h.Change(cs...)
response := &v1.SetPolicyResponse{
Policy: updated.Data,
UpdatedAt: timestamppb.New(updated.UpdatedAt),

View File

@ -94,10 +94,7 @@ func (h *Headscale) handleVerifyRequest(
return fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err)
}
nodes, err := h.state.ListNodes()
if err != nil {
return fmt.Errorf("cannot list nodes: %w", err)
}
nodes := h.state.ListNodes()
// Check if any node has the requested NodeKey
var nodeKeyFound bool

View File

@ -1,6 +1,7 @@
package mapper
import (
"errors"
"fmt"
"time"
@ -18,8 +19,8 @@ type batcherFunc func(cfg *types.Config, state *state.State) Batcher
type Batcher interface {
Start()
Close()
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool)
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
IsConnected(id types.NodeID) bool
ConnectedMap() *xsync.Map[types.NodeID, bool]
AddWork(c change.ChangeSet)
@ -119,7 +120,7 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
if nc == nil {
return fmt.Errorf("nodeConnection is nil")
return errors.New("nodeConnection is nil")
}
nodeID := nc.nodeID()

View File

@ -3,7 +3,6 @@ package mapper
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
@ -21,7 +20,6 @@ type LockFreeBatcher struct {
mapper *mapper
workers int
// Lock-free concurrent maps
nodes *xsync.Map[types.NodeID, *nodeConn]
connected *xsync.Map[types.NodeID, *time.Time]
@ -32,7 +30,6 @@ type LockFreeBatcher struct {
// Batching state
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
batchMutex sync.RWMutex
// Metrics
totalNodes atomic.Int64
@ -46,16 +43,13 @@ type LockFreeBatcher struct {
// It creates or updates the node's connection data, validates the initial map generation,
// and notifies other nodes that this node has come online.
// TODO(kradalby): See if we can move the isRouter argument somewhere else.
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error {
// First validate that we can generate initial map before doing anything else
fullSelfChange := change.FullSelf(id)
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
// TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange.
// This currently means that the goroutine for the node connection will do the processing
// which means that we might have uncontrolled concurrency.
// When we use MapResponseFromChange, it will be processed by the same worker pool, causing
// it to be processed in a more controlled manner.
initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange)
initialMap, err := generateMapResponse(id, version, b.mapper, change.FullSelf(id))
if err != nil {
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
}
@ -73,10 +67,9 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
conn = newConn
}
// Mark as connected only after validation succeeds
b.connected.Store(id, nil) // nil = connected
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher")
log.Info().Uint64("node.id", id.Uint64()).Msg("Node connected to batcher")
// Send the validated initial map
if initialMap != nil {
@ -86,9 +79,6 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
b.connected.Delete(id)
return fmt.Errorf("failed to send initial map to node %d: %w", id, err)
}
// Notify other nodes that this node came online
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter})
}
return nil
@ -97,12 +87,14 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
// It validates the connection channel matches the current one, closes the connection,
// and notifies other nodes that this node has gone offline.
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) {
// It reports if the node was actually closed. Returns false if the channel does not match the current connection,
// indicating that we are actually not disconnecting the node, but rather ignoring the request.
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
// Check if this is the current connection and mark it as closed
if existing, ok := b.nodes.Load(id); ok {
if !existing.matchesChannel(c) {
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring")
return // Not the current connection, not an error
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called on a different channel, ignoring")
return false // Not the current connection, not an error
}
// Mark the connection as closed to prevent further sends
@ -111,15 +103,14 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
}
}
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline")
log.Info().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher, marking as offline")
// Remove node and mark disconnected atomically
b.nodes.Delete(id)
b.connected.Store(id, ptr.To(time.Now()))
b.totalNodes.Add(-1)
// Notify other nodes that this node went offline
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter})
return true
}
// AddWork queues a change to be processed by the batcher.
@ -214,6 +205,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
Dur("duration", duration).
Msg("slow synchronous work processing")
}
continue
}
@ -228,6 +220,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("skipping work for closed connection")
continue
}
@ -240,12 +233,6 @@ func (b *LockFreeBatcher) worker(workerID int) {
Str("change", w.c.Change.String()).
Msg("failed to apply change")
}
} else {
log.Debug().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("node not found for asynchronous work - node may have disconnected")
}
duration := time.Since(startTime)
@ -276,8 +263,10 @@ func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
return true
}
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
return true
})
return
}
@ -285,7 +274,7 @@ func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
b.addToBatch(c)
}
// queueWork safely queues work
// queueWork safely queues work.
func (b *LockFreeBatcher) queueWork(w work) {
b.workQueuedCount.Add(1)
@ -298,7 +287,7 @@ func (b *LockFreeBatcher) queueWork(w work) {
}
}
// shouldProcessImmediately determines if a change should bypass batching
// shouldProcessImmediately determines if a change should bypass batching.
func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
// Process these changes immediately to avoid delaying critical functionality
switch c.Change {
@ -309,11 +298,8 @@ func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
}
}
// addToBatch adds a change to the pending batch
// addToBatch adds a change to the pending batch.
func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
b.batchMutex.Lock()
defer b.batchMutex.Unlock()
if c.SelfUpdateOnly {
changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{})
changes = append(changes, c)
@ -329,15 +315,13 @@ func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
changes = append(changes, c)
b.pendingChanges.Store(nodeID, changes)
return true
})
}
// processBatchedChanges processes all pending batched changes
// processBatchedChanges processes all pending batched changes.
func (b *LockFreeBatcher) processBatchedChanges() {
b.batchMutex.Lock()
defer b.batchMutex.Unlock()
if b.pendingChanges == nil {
return
}
@ -355,17 +339,27 @@ func (b *LockFreeBatcher) processBatchedChanges() {
// Clear the pending changes for this node
b.pendingChanges.Delete(nodeID)
return true
})
}
// IsConnected is lock-free read.
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
if val, ok := b.connected.Load(id); ok {
// nil means connected
return val == nil
val, ok := b.connected.Load(id)
if !ok {
return false
}
return false
// nil means connected
if val == nil {
return true
}
// During grace period, always return true to allow DNS resolution
// for logout HTTP requests to complete successfully
gracePeriod := 45 * time.Second
return time.Since(*val) < gracePeriod
}
// ConnectedMap returns a lock-free map of all connected nodes.
@ -487,5 +481,6 @@ func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
// the channel is still open.
connData.c <- data
nc.updateCount.Add(1)
return nil
}

View File

@ -26,6 +26,43 @@ type batcherTestCase struct {
fn batcherFunc
}
// testBatcherWrapper wraps a real batcher to add online/offline notifications
// that would normally be sent by poll.go in production
type testBatcherWrapper struct {
Batcher
}
func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
// First add the node to the real batcher
err := t.Batcher.AddNode(id, c, version)
if err != nil {
return err
}
// Then send the online notification that poll.go would normally send
t.Batcher.AddWork(change.NodeOnline(id))
return nil
}
func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
// First remove from the real batcher
removed := t.Batcher.RemoveNode(id, c)
if !removed {
return false
}
// Then send the offline notification that poll.go would normally send
t.Batcher.AddWork(change.NodeOffline(id))
return true
}
// wrapBatcherForTest wraps a batcher with test-specific behavior
func wrapBatcherForTest(b Batcher) Batcher {
return &testBatcherWrapper{Batcher: b}
}
// allBatcherFunctions contains all batcher implementations to test.
var allBatcherFunctions = []batcherTestCase{
{"LockFree", NewBatcherAndMapper},
@ -176,8 +213,8 @@ func setupBatcherWithTestData(
"acls": [
{
"action": "accept",
"users": ["*"],
"ports": ["*:*"]
"src": ["*"],
"dst": ["*:*"]
}
]
}`
@ -187,8 +224,8 @@ func setupBatcherWithTestData(
t.Fatalf("Failed to set allow-all policy: %v", err)
}
// Create batcher with the state
batcher := bf(cfg, state)
// Create batcher with the state and wrap it for testing
batcher := wrapBatcherForTest(bf(cfg, state))
batcher.Start()
testData := &TestData{
@ -455,7 +492,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
testNode.start()
// Connect the node to the batcher
batcher.AddNode(testNode.n.ID, testNode.ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
time.Sleep(100 * time.Millisecond) // Let connection settle
// Generate some work
@ -558,7 +595,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
for i := range allNodes {
node := &allNodes[i]
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
// Issue full update after each join to ensure connectivity
batcher.AddWork(change.FullSet)
@ -606,7 +643,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
// Disconnect all nodes
for i := range allNodes {
node := &allNodes[i]
batcher.RemoveNode(node.n.ID, node.ch, false)
batcher.RemoveNode(node.n.ID, node.ch)
}
// Give time for final updates to process
@ -724,7 +761,7 @@ func TestBatcherBasicOperations(t *testing.T) {
tn2 := testData.Nodes[1]
// Test AddNode with real node ID
batcher.AddNode(tn.n.ID, tn.ch, false, 100)
batcher.AddNode(tn.n.ID, tn.ch, 100)
if !batcher.IsConnected(tn.n.ID) {
t.Error("Node should be connected after AddNode")
}
@ -744,14 +781,14 @@ func TestBatcherBasicOperations(t *testing.T) {
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
// Add the second node and verify update message
batcher.AddNode(tn2.n.ID, tn2.ch, false, 100)
batcher.AddNode(tn2.n.ID, tn2.ch, 100)
assert.True(t, batcher.IsConnected(tn2.n.ID))
// First node should get an update that second node has connected.
select {
case data := <-tn.ch:
assertOnlineMapResponse(t, data, true)
case <-time.After(200 * time.Millisecond):
case <-time.After(500 * time.Millisecond):
t.Error("Did not receive expected Online response update")
}
@ -765,19 +802,19 @@ func TestBatcherBasicOperations(t *testing.T) {
len(data.Peers) >= 1 || data.Node != nil,
"Should receive initial full map",
)
case <-time.After(200 * time.Millisecond):
case <-time.After(500 * time.Millisecond):
t.Error("Second node should receive its initial full map")
}
// Disconnect the second node
batcher.RemoveNode(tn2.n.ID, tn2.ch, false)
batcher.RemoveNode(tn2.n.ID, tn2.ch)
assert.False(t, batcher.IsConnected(tn2.n.ID))
// First node should get update that second has disconnected.
select {
case data := <-tn.ch:
assertOnlineMapResponse(t, data, false)
case <-time.After(200 * time.Millisecond):
case <-time.After(500 * time.Millisecond):
t.Error("Did not receive expected Online response update")
}
@ -803,7 +840,7 @@ func TestBatcherBasicOperations(t *testing.T) {
// }
// Test RemoveNode
batcher.RemoveNode(tn.n.ID, tn.ch, false)
batcher.RemoveNode(tn.n.ID, tn.ch)
if batcher.IsConnected(tn.n.ID) {
t.Error("Node should be disconnected after RemoveNode")
}
@ -949,7 +986,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
testNodes := testData.Nodes
ch := make(chan *tailcfg.MapResponse, 10)
batcher.AddNode(testNodes[0].n.ID, ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
// Track update content for validation
var receivedUpdates []*tailcfg.MapResponse
@ -1045,7 +1082,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
batcher.AddNode(testNode.n.ID, ch1, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
}()
// Add real work during connection chaos
@ -1059,7 +1096,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
go func() {
defer wg.Done()
time.Sleep(1 * time.Microsecond)
batcher.AddNode(testNode.n.ID, ch2, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
}()
// Remove second connection
@ -1067,7 +1104,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
go func() {
defer wg.Done()
time.Sleep(2 * time.Microsecond)
batcher.RemoveNode(testNode.n.ID, ch2, false)
batcher.RemoveNode(testNode.n.ID, ch2)
}()
wg.Wait()
@ -1142,7 +1179,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, 5)
// Add node and immediately queue real work
batcher.AddNode(testNode.n.ID, ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
batcher.AddWork(change.DERPSet)
// Consumer goroutine to validate data and detect channel issues
@ -1184,7 +1221,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
// Rapid removal creates race between worker and removal
time.Sleep(time.Duration(i%3) * 100 * time.Microsecond)
batcher.RemoveNode(testNode.n.ID, ch, false)
batcher.RemoveNode(testNode.n.ID, ch)
// Give workers time to process and close channels
time.Sleep(5 * time.Millisecond)
@ -1254,7 +1291,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
for _, node := range stableNodes {
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
stableChannels[node.n.ID] = ch
batcher.AddNode(node.n.ID, ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
// Monitor updates for each stable client
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
@ -1312,7 +1349,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
churningChannelsMutex.Lock()
churningChannels[nodeID] = ch
churningChannelsMutex.Unlock()
batcher.AddNode(nodeID, ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
// Consume updates to prevent blocking
go func() {
@ -1349,7 +1386,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
ch, exists := churningChannels[nodeID]
churningChannelsMutex.Unlock()
if exists {
batcher.RemoveNode(nodeID, ch, false)
batcher.RemoveNode(nodeID, ch)
}
}(node.n.ID)
}
@ -1599,7 +1636,7 @@ func XTestBatcherScalability(t *testing.T) {
var connectedNodesMutex sync.RWMutex
for i := range testNodes {
node := &testNodes[i]
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
connectedNodesMutex.Lock()
connectedNodes[node.n.ID] = true
connectedNodesMutex.Unlock()
@ -1666,7 +1703,7 @@ func XTestBatcherScalability(t *testing.T) {
connectedNodesMutex.RUnlock()
if isConnected {
batcher.RemoveNode(nodeID, channel, false)
batcher.RemoveNode(nodeID, channel)
connectedNodesMutex.Lock()
connectedNodes[nodeID] = false
connectedNodesMutex.Unlock()
@ -1690,7 +1727,6 @@ func XTestBatcherScalability(t *testing.T) {
batcher.AddNode(
nodeID,
channel,
false,
tailcfg.CapabilityVersion(100),
)
connectedNodesMutex.Lock()
@ -1792,7 +1828,7 @@ func XTestBatcherScalability(t *testing.T) {
// Now disconnect all nodes from batcher to stop new updates
for i := range testNodes {
node := &testNodes[i]
batcher.RemoveNode(node.n.ID, node.ch, false)
batcher.RemoveNode(node.n.ID, node.ch)
}
// Give time for enhanced tracking goroutines to process any remaining data in channels
@ -1924,7 +1960,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
// Connect nodes one at a time to avoid overwhelming the work queue
for i, node := range allNodes {
batcher.AddNode(node.n.ID, node.ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
// Small delay between connections to allow NodeCameOnline processing
time.Sleep(50 * time.Millisecond)
@ -1936,12 +1972,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
// Check how many peers each node should see
for i, node := range allNodes {
peers, err := testData.State.ListPeers(node.n.ID)
if err != nil {
t.Errorf("Error listing peers for node %d: %v", i, err)
} else {
t.Logf("Node %d should see %d peers from state", i, peers.Len())
}
peers := testData.State.ListPeers(node.n.ID)
t.Logf("Node %d should see %d peers from state", i, peers.Len())
}
// Send a full update - this should generate full peer lists
@ -1957,7 +1989,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
foundFullUpdate := false
// Read all available updates for each node
for i := range len(allNodes) {
for i := range allNodes {
nodeUpdates := 0
t.Logf("Reading updates for node %d:", i)
@ -2047,7 +2079,7 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
t.Logf("=== WORK QUEUE TRACING TEST ===")
// Connect first node
batcher.AddNode(nodes[0].n.ID, nodes[0].ch, false, tailcfg.CapabilityVersion(100))
batcher.AddNode(nodes[0].n.ID, nodes[0].ch, tailcfg.CapabilityVersion(100))
t.Logf("Connected node %d", nodes[0].n.ID)
// Wait for initial NodeCameOnline to be processed
@ -2102,14 +2134,10 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
}
// Check if there should be peers available
peers, err := testData.State.ListPeers(nodes[0].n.ID)
if err != nil {
t.Errorf("Error getting peers from state: %v", err)
} else {
t.Logf("State shows %d peers available for this node", peers.Len())
if peers.Len() > 0 && len(data.Peers) == 0 {
t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", peers.Len())
}
peers := testData.State.ListPeers(nodes[0].n.ID)
t.Logf("State shows %d peers available for this node", peers.Len())
if peers.Len() > 0 && len(data.Peers) == 0 {
t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", peers.Len())
}
} else {
t.Errorf("Response data is nil")

View File

@ -1,6 +1,7 @@
package mapper
import (
"errors"
"net/netip"
"sort"
"time"
@ -8,11 +9,12 @@ import (
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
"tailscale.com/types/views"
"tailscale.com/util/multierr"
)
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse.
type MapResponseBuilder struct {
resp *tailcfg.MapResponse
mapper *mapper
@ -21,7 +23,7 @@ type MapResponseBuilder struct {
errs []error
}
// NewMapResponseBuilder creates a new builder with basic fields set
// NewMapResponseBuilder creates a new builder with basic fields set.
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
now := time.Now()
return &MapResponseBuilder{
@ -35,37 +37,44 @@ func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder
}
}
// addError adds an error to the builder's error list
// addError adds an error to the builder's error list.
func (b *MapResponseBuilder) addError(err error) {
if err != nil {
b.errs = append(b.errs, err)
}
}
// hasErrors returns true if the builder has accumulated any errors
// hasErrors returns true if the builder has accumulated any errors.
func (b *MapResponseBuilder) hasErrors() bool {
return len(b.errs) > 0
}
// WithCapabilityVersion sets the capability version for the response
// WithCapabilityVersion sets the capability version for the response.
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
b.capVer = capVer
return b
}
// WithSelfNode adds the requesting node to the response
// WithSelfNode adds the requesting node to the response.
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
nodeView, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
b.addError(errors.New("node not found"))
return b
}
// Always use batcher's view of online status for self node
// The batcher respects grace periods for logout scenarios
node := nodeView.AsStruct()
if b.mapper.batcher != nil {
node.IsOnline = ptr.To(b.mapper.batcher.IsConnected(b.nodeID))
}
_, matchers := b.mapper.state.Filter()
tailnode, err := tailNode(
node, b.capVer, b.mapper.state,
node.View(), b.capVer, b.mapper.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers)
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
},
b.mapper.cfg)
if err != nil {
@ -74,29 +83,30 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
}
b.resp.Node = tailnode
return b
}
// WithDERPMap adds the DERP map to the response
// WithDERPMap adds the DERP map to the response.
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
b.resp.DERPMap = b.mapper.state.DERPMap()
return b
}
// WithDomain adds the domain configuration
// WithDomain adds the domain configuration.
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
b.resp.Domain = b.mapper.cfg.Domain()
return b
}
// WithCollectServicesDisabled sets the collect services flag to false
// WithCollectServicesDisabled sets the collect services flag to false.
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
b.resp.CollectServices.Set(false)
return b
}
// WithDebugConfig adds debug configuration
// It disables log tailing if the mapper's LogTail is not enabled
// It disables log tailing if the mapper's LogTail is not enabled.
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
b.resp.Debug = &tailcfg.Debug{
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
@ -104,11 +114,11 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
return b
}
// WithSSHPolicy adds SSH policy configuration for the requesting node
// WithSSHPolicy adds SSH policy configuration for the requesting node.
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
b.addError(errors.New("node not found"))
return b
}
@ -119,38 +129,41 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
}
b.resp.SSHPolicy = sshPolicy
return b
}
// WithDNSConfig adds DNS configuration for the requesting node
// WithDNSConfig adds DNS configuration for the requesting node.
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
b.addError(errors.New("node not found"))
return b
}
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
return b
}
// WithUserProfiles adds user profiles for the requesting node and given peers
// WithUserProfiles adds user profiles for the requesting node and given peers.
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
b.addError(errors.New("node not found"))
return b
}
b.resp.UserProfiles = generateUserProfiles(node, peers)
return b
}
// WithPacketFilters adds packet filter rules based on policy
// WithPacketFilters adds packet filter rules based on policy.
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
b.addError(errors.New("node not found"))
return b
}
@ -167,9 +180,8 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
return b
}
// WithPeers adds full peer list with policy filtering (for full map response)
// WithPeers adds full peer list with policy filtering (for full map response).
func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapResponseBuilder {
tailPeers, err := b.buildTailPeers(peers)
if err != nil {
b.addError(err)
@ -177,12 +189,12 @@ func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapRe
}
b.resp.Peers = tailPeers
return b
}
// WithPeerChanges adds changed peers with policy filtering (for incremental updates)
// WithPeerChanges adds changed peers with policy filtering (for incremental updates).
func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) *MapResponseBuilder {
tailPeers, err := b.buildTailPeers(peers)
if err != nil {
b.addError(err)
@ -190,14 +202,15 @@ func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView])
}
b.resp.PeersChanged = tailPeers
return b
}
// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting
// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting.
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
return nil, err
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
return nil, errors.New("node not found")
}
filter, matchers := b.mapper.state.Filter()
@ -229,24 +242,24 @@ func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) (
return tailPeers, nil
}
// WithPeerChangedPatch adds peer change patches
// WithPeerChangedPatch adds peer change patches.
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
b.resp.PeersChangedPatch = changes
return b
}
// WithPeersRemoved adds removed peer IDs
// WithPeersRemoved adds removed peer IDs.
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
var tailscaleIDs []tailcfg.NodeID
for _, id := range removedIDs {
tailscaleIDs = append(tailscaleIDs, id.NodeID())
}
b.resp.PeersRemoved = tailscaleIDs
return b
}
// Build finalizes the response and returns marshaled bytes
// Build finalizes the response and returns marshaled bytes.
func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) {
if len(b.errs) > 0 {
return nil, multierr.New(b.errs...)

View File

@ -18,17 +18,17 @@ func TestMapResponseBuilder_Basic(t *testing.T) {
Enabled: true,
},
}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID)
// Test basic builder creation
assert.NotNil(t, builder)
assert.Equal(t, nodeID, builder.nodeID)
@ -45,13 +45,13 @@ func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) {
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
capVer := tailcfg.CapabilityVersion(42)
builder := m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer)
assert.Equal(t, capVer, builder.capVer)
assert.False(t, builder.hasErrors())
}
@ -62,18 +62,18 @@ func TestMapResponseBuilder_WithDomain(t *testing.T) {
ServerURL: "https://test.example.com",
BaseDomain: domain,
}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithDomain()
assert.Equal(t, domain, builder.resp.Domain)
assert.False(t, builder.hasErrors())
}
@ -85,12 +85,12 @@ func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithCollectServicesDisabled()
value, isSet := builder.resp.CollectServices.Get()
assert.True(t, isSet)
assert.False(t, value)
@ -99,22 +99,22 @@ func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
tests := []struct {
name string
name string
logTailEnabled bool
expected bool
expected bool
}{
{
name: "LogTail enabled",
name: "LogTail enabled",
logTailEnabled: true,
expected: false, // DisableLogTail should be false when LogTail is enabled
expected: false, // DisableLogTail should be false when LogTail is enabled
},
{
name: "LogTail disabled",
name: "LogTail disabled",
logTailEnabled: false,
expected: true, // DisableLogTail should be true when LogTail is disabled
expected: true, // DisableLogTail should be true when LogTail is disabled
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &types.Config{
@ -127,12 +127,12 @@ func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithDebugConfig()
require.NotNil(t, builder.resp.Debug)
assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail)
assert.False(t, builder.hasErrors())
@ -147,22 +147,22 @@ func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) {
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
changes := []*tailcfg.PeerChange{
{
NodeID: 123,
NodeID: 123,
DERPRegion: 1,
},
{
NodeID: 456,
NodeID: 456,
DERPRegion: 2,
},
}
builder := m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch(changes)
assert.Equal(t, changes, builder.resp.PeersChangedPatch)
assert.False(t, builder.hasErrors())
}
@ -174,14 +174,14 @@ func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) {
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
removedID1 := types.NodeID(123)
removedID2 := types.NodeID(456)
builder := m.NewMapResponseBuilder(nodeID).
WithPeersRemoved(removedID1, removedID2)
expected := []tailcfg.NodeID{
removedID1.NodeID(),
removedID2.NodeID(),
@ -197,23 +197,23 @@ func TestMapResponseBuilder_ErrorHandling(t *testing.T) {
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
// Simulate an error in the builder
builder := m.NewMapResponseBuilder(nodeID)
builder.addError(assert.AnError)
// All subsequent calls should continue to work and accumulate errors
result := builder.
WithDomain().
WithCollectServicesDisabled().
WithDebugConfig()
assert.True(t, result.hasErrors())
assert.Len(t, result.errs, 1)
assert.Equal(t, assert.AnError, result.errs[0])
// Build should return the error
data, err := result.Build("none")
assert.Nil(t, data)
@ -229,22 +229,22 @@ func TestMapResponseBuilder_ChainedCalls(t *testing.T) {
Enabled: false,
},
}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
capVer := tailcfg.CapabilityVersion(99)
builder := m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer).
WithDomain().
WithCollectServicesDisabled().
WithDebugConfig()
// Verify all fields are set correctly
assert.Equal(t, capVer, builder.capVer)
assert.Equal(t, domain, builder.resp.Domain)
@ -263,16 +263,16 @@ func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) {
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
removedID1 := types.NodeID(100)
removedID2 := types.NodeID(200)
// Test calling WithPeersRemoved multiple times
builder := m.NewMapResponseBuilder(nodeID).
WithPeersRemoved(removedID1).
WithPeersRemoved(removedID2)
// Second call should overwrite the first
expected := []tailcfg.NodeID{removedID2.NodeID()}
assert.Equal(t, expected, builder.resp.PeersRemoved)
@ -286,12 +286,12 @@ func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) {
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch([]*tailcfg.PeerChange{})
assert.Empty(t, builder.resp.PeersChangedPatch)
assert.False(t, builder.hasErrors())
}
@ -303,12 +303,12 @@ func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) {
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch(nil)
assert.Nil(t, builder.resp.PeersChangedPatch)
assert.False(t, builder.hasErrors())
}
@ -320,28 +320,28 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
// Create a builder and add multiple errors
builder := m.NewMapResponseBuilder(nodeID)
builder.addError(assert.AnError)
builder.addError(assert.AnError)
builder.addError(nil) // This should be ignored
// All subsequent calls should continue to work
result := builder.
WithDomain().
WithCollectServicesDisabled()
assert.True(t, result.hasErrors())
assert.Len(t, result.errs, 2) // nil error should be ignored
// Build should return a multierr
data, err := result.Build("none")
assert.Nil(t, data)
assert.Error(t, err)
// The error should contain information about multiple errors
assert.Contains(t, err.Error(), "multiple errors")
}
}

View File

@ -18,6 +18,7 @@ import (
"tailscale.com/envknob"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/ptr"
"tailscale.com/types/views"
)
@ -49,6 +50,37 @@ type mapper struct {
created time.Time
}
// addOnlineStatusToPeers adds fresh online status from batcher to peer nodes.
//
// We do a last-minute copy-and-write on the NodeView to inject current online status
// from the batcher's connection map. Online status is not populated upstream in NodeStore
// for consistency reasons - it's runtime connection state that should come from the
// connection manager (batcher) to ensure map responses have the freshest data.
func (m *mapper) addOnlineStatusToPeers(peers views.Slice[types.NodeView]) views.Slice[types.NodeView] {
if peers.Len() == 0 || m.batcher == nil {
return peers
}
result := make([]types.NodeView, 0, peers.Len())
for _, peer := range peers.All() {
if !peer.Valid() {
result = append(result, peer)
continue
}
// Get online status from batcher connection map
// The batcher respects grace periods for logout scenarios
isOnline := m.batcher.IsConnected(peer.ID())
// Create a mutable copy and set online status
peerCopy := peer.AsStruct()
peerCopy.IsOnline = ptr.To(isOnline)
result = append(result, peerCopy.View())
}
return views.SliceOf(result)
}
type patch struct {
timestamp time.Time
change *tailcfg.PeerChange
@ -140,10 +172,10 @@ func (m *mapper) fullMapResponse(
capVer tailcfg.CapabilityVersion,
messages ...string,
) (*tailcfg.MapResponse, error) {
peers, err := m.state.ListPeers(nodeID)
if err != nil {
return nil, err
}
peers := m.state.ListPeers(nodeID)
// Add fresh online status to peers from batcher connection state
peersWithOnlineStatus := m.addOnlineStatusToPeers(peers)
return m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer).
@ -154,9 +186,9 @@ func (m *mapper) fullMapResponse(
WithDebugConfig().
WithSSHPolicy().
WithDNSConfig().
WithUserProfiles(peers).
WithUserProfiles(peersWithOnlineStatus).
WithPacketFilters().
WithPeers(peers).
WithPeers(peersWithOnlineStatus).
Build(messages...)
}
@ -185,16 +217,16 @@ func (m *mapper) peerChangeResponse(
capVer tailcfg.CapabilityVersion,
changedNodeID types.NodeID,
) (*tailcfg.MapResponse, error) {
peers, err := m.state.ListPeers(nodeID, changedNodeID)
if err != nil {
return nil, err
}
peers := m.state.ListPeers(nodeID, changedNodeID)
// Add fresh online status to peers from batcher connection state
peersWithOnlineStatus := m.addOnlineStatusToPeers(peers)
return m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer).
WithSelfNode().
WithUserProfiles(peers).
WithPeerChanges(peers).
WithUserProfiles(peersWithOnlineStatus).
WithPeerChanges(peersWithOnlineStatus).
Build()
}

View File

@ -133,11 +133,15 @@ func tailNode(
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
}
if !node.IsOnline().Valid() || !node.IsOnline().Get() {
// LastSeen is only set when node is
// not connected to the control server.
if node.LastSeen().Valid() {
lastSeen := node.LastSeen().Get()
// Always set LastSeen if it's valid, regardless of online status
// This ensures that during logout grace periods (when IsOnline might be true
// for DNS preservation), other nodes can still see when this node disconnected
if node.LastSeen().Valid() {
lastSeen := node.LastSeen().Get()
// Only set LastSeen if the node is offline OR if LastSeen is recent
// (indicating it disconnected recently but might be in grace period)
if !node.IsOnline().Valid() || !node.IsOnline().Get() ||
time.Since(lastSeen) < 60*time.Second {
tNode.LastSeen = &lastSeen
}
}

View File

@ -13,7 +13,6 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"golang.org/x/net/http2"
"gorm.io/gorm"
"tailscale.com/control/controlbase"
"tailscale.com/control/controlhttp/controlhttpserver"
"tailscale.com/tailcfg"
@ -296,12 +295,9 @@ func (ns *noiseServer) NoiseRegistrationHandler(
// getAndValidateNode retrieves the node from the database using the NodeKey
// and validates that it matches the MachineKey from the Noise session.
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
nv, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
}
return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil)
nv, ok := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
if !ok {
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
}
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.

View File

@ -7,6 +7,7 @@ import (
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
@ -138,39 +139,61 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
return ret
}
// AutoApproveRoutes approves any route that can be autoapproved from
// the nodes perspective according to the given policy.
// It reports true if any routes were approved.
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes.
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
// ApproveRoutesWithPolicy checks if the node can approve the announced routes
// and returns the new list of approved routes.
// The approved routes will include:
// 1. ALL previously approved routes (regardless of whether they're still advertised)
// 2. New routes from announcedRoutes that can be auto-approved by policy
// This ensures that:
// - Previously approved routes are ALWAYS preserved (auto-approval never removes routes)
// - New routes can be auto-approved according to policy
// - Routes can only be removed by explicit admin action (not by auto-approval)
func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApproved, announcedRoutes []netip.Prefix) ([]netip.Prefix, bool) {
if pm == nil {
return false
return currentApproved, false
}
nodeView := node.View()
var newApproved []netip.Prefix
for _, route := range nodeView.AnnouncedRoutes() {
if pm.NodeCanApproveRoute(nodeView, route) {
// Start with ALL currently approved routes - we never remove approved routes
newApproved := make([]netip.Prefix, len(currentApproved))
copy(newApproved, currentApproved)
// Then, check for new routes that can be auto-approved
for _, route := range announcedRoutes {
// Skip if already approved
if slices.Contains(newApproved, route) {
continue
}
// Check if this new route can be auto-approved by policy
canApprove := pm.NodeCanApproveRoute(nv, route)
if canApprove {
newApproved = append(newApproved, route)
}
log.Trace().
Uint64("node.id", nv.ID().Uint64()).
Str("node.name", nv.Hostname()).
Str("route", route.String()).
Bool("can_approve", canApprove).
Msg("Evaluating route for auto-approval")
}
// Only modify ApprovedRoutes if we have new routes to approve.
// This prevents clearing existing approved routes when nodes
// temporarily don't have announced routes during policy changes.
if len(newApproved) > 0 {
combined := append(newApproved, node.ApprovedRoutes...)
tsaddr.SortPrefixes(combined)
combined = slices.Compact(combined)
combined = lo.Filter(combined, func(route netip.Prefix, index int) bool {
return route.IsValid()
})
// Sort and deduplicate
tsaddr.SortPrefixes(newApproved)
newApproved = slices.Compact(newApproved)
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
return route.IsValid()
})
// Only update if the routes actually changed
if !slices.Equal(node.ApprovedRoutes, combined) {
node.ApprovedRoutes = combined
return true
}
// Sort the current approved for comparison
sortedCurrent := make([]netip.Prefix, len(currentApproved))
copy(sortedCurrent, currentApproved)
tsaddr.SortPrefixes(sortedCurrent)
// Only update if the routes actually changed
if !slices.Equal(sortedCurrent, newApproved) {
return newApproved, true
}
return false
return newApproved, false
}

View File

@ -0,0 +1,339 @@
package policy
import (
"fmt"
"net/netip"
"testing"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
"tailscale.com/types/views"
)
func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
user1 := types.User{
Model: gorm.Model{ID: 1},
Name: "testuser@",
}
user2 := types.User{
Model: gorm.Model{ID: 2},
Name: "otheruser@",
}
users := []types.User{user1, user2}
node1 := &types.Node{
ID: 1,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test-node",
UserID: user1.ID,
User: user1,
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
ForcedTags: []string{"tag:test"},
}
node2 := &types.Node{
ID: 2,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "other-node",
UserID: user2.ID,
User: user2,
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
}
// Create a policy that auto-approves specific routes
policyJSON := `{
"groups": {
"group:test": ["testuser@"]
},
"tagOwners": {
"tag:test": ["testuser@"]
},
"acls": [
{
"action": "accept",
"src": ["*"],
"dst": ["*:*"]
}
],
"autoApprovers": {
"routes": {
"10.0.0.0/8": ["testuser@", "tag:test"],
"10.1.0.0/24": ["testuser@"],
"10.2.0.0/24": ["testuser@"],
"192.168.0.0/24": ["tag:test"]
}
}
}`
pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()}))
assert.NoError(t, err)
tests := []struct {
name string
node *types.Node
currentApproved []netip.Prefix
announcedRoutes []netip.Prefix
wantApproved []netip.Prefix
wantChanged bool
description string
}{
{
name: "previously_approved_route_no_longer_advertised_should_remain",
node: node1,
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.0.0/24"),
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // Only this one is still advertised
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.0.0/24"), // Should still be here!
},
wantChanged: false,
description: "Previously approved routes should never be removed even when no longer advertised",
},
{
name: "add_new_auto_approved_route_keeps_old_approved",
node: node1,
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.5.0.0/24"), // This was manually approved
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.1.0.0/24"), // New route that should be auto-approved
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.1.0.0/24"), // New auto-approved route (subset of 10.0.0.0/8)
netip.MustParsePrefix("10.5.0.0/24"), // Old approved route kept
},
wantChanged: true,
description: "New auto-approved routes should be added while keeping old approved routes",
},
{
name: "no_announced_routes_keeps_all_approved",
node: node1,
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("172.16.0.0/16"),
},
announcedRoutes: []netip.Prefix{}, // No routes announced
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/24"),
},
wantChanged: false,
description: "All approved routes should remain when no routes are announced",
},
{
name: "no_changes_when_announced_equals_approved",
node: node1,
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantChanged: false,
description: "No changes should occur when announced routes match approved routes",
},
{
name: "auto_approve_multiple_new_routes",
node: node1,
currentApproved: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/24"), // This was manually approved
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.2.0.0/24"), // Should be auto-approved (subset of 10.0.0.0/8)
netip.MustParsePrefix("192.168.0.0/24"), // Should be auto-approved for tag:test
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.2.0.0/24"), // New auto-approved
netip.MustParsePrefix("172.16.0.0/24"), // Original kept
netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved
},
wantChanged: true,
description: "Multiple new routes should be auto-approved while keeping existing approved routes",
},
{
name: "node_without_permission_no_auto_approval",
node: node2, // Different node without the tag
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), // This requires tag:test
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // Only the original approved route
},
wantChanged: false,
description: "Routes should not be auto-approved for nodes without proper permissions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, tt.node.View(), tt.currentApproved, tt.announcedRoutes)
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description)
// Sort for comparison since ApproveRoutesWithPolicy sorts the results
tsaddr.SortPrefixes(tt.wantApproved)
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description)
// Verify that all previously approved routes are still present
for _, prevRoute := range tt.currentApproved {
assert.Contains(t, gotApproved, prevRoute,
"previously approved route %s was removed - this should never happen", prevRoute)
}
})
}
}
func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
// Create a basic policy for edge case testing
aclPolicy := `
{
"acls": [
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
],
"autoApprovers": {
"routes": {
"10.1.0.0/24": ["test@"],
},
},
}`
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
tests := []struct {
name string
currentApproved []netip.Prefix
announcedRoutes []netip.Prefix
wantApproved []netip.Prefix
wantChanged bool
}{
{
name: "nil_policy_manager",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"),
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantChanged: false,
},
{
name: "nil_current_approved",
currentApproved: nil,
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.1.0.0/24"),
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.1.0.0/24"),
},
wantChanged: true,
},
{
name: "nil_announced_routes",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
announcedRoutes: nil,
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantChanged: false,
},
{
name: "duplicate_approved_routes",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.1.0.0/24"),
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("10.1.0.0/24"),
},
wantChanged: true,
},
{
name: "empty_slices",
currentApproved: []netip.Prefix{},
announcedRoutes: []netip.Prefix{},
wantApproved: []netip.Prefix{},
wantChanged: false,
},
}
for _, tt := range tests {
for i, pmf := range pmfs {
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
// Create test user
user := types.User{
Model: gorm.Model{ID: 1},
Name: "test",
}
users := []types.User{user}
// Create test node
node := types.Node{
ID: 1,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
User: user,
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved,
}
nodes := types.Nodes{&node}
// Create policy manager or use nil if specified
var pm PolicyManager
var err error
if tt.name != "nil_policy_manager" {
pm, err = pmf(users, nodes.ViewSlice())
assert.NoError(t, err)
} else {
pm = nil
}
gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, node.View(), tt.currentApproved, tt.announcedRoutes)
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch")
// Handle nil vs empty slice comparison
if tt.wantApproved == nil {
assert.Nil(t, gotApproved, "expected nil approved routes")
} else {
tsaddr.SortPrefixes(tt.wantApproved)
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch")
}
})
}
}
}

View File

@ -0,0 +1,361 @@
package policy
import (
"fmt"
"net/netip"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
// Test policy that allows specific routes to be auto-approved
aclPolicy := `
{
"groups": {
"group:admins": ["test@"],
},
"acls": [
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
],
"autoApprovers": {
"routes": {
"10.0.0.0/24": ["test@"],
"192.168.0.0/24": ["group:admins"],
"172.16.0.0/16": ["tag:approved"],
},
},
"tagOwners": {
"tag:approved": ["test@"],
},
}`
tests := []struct {
name string
currentApproved []netip.Prefix
announcedRoutes []netip.Prefix
nodeHostname string
nodeUser string
nodeTags []string
wantApproved []netip.Prefix
wantChanged bool
wantRemovedRoutes []netip.Prefix // Routes that should NOT be in the result
}{
{
name: "previously_approved_route_no_longer_advertised_remains",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.0.0/24"),
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), // Only this one still advertised
},
nodeUser: "test",
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // Should remain!
netip.MustParsePrefix("192.168.0.0/24"),
},
wantChanged: false,
wantRemovedRoutes: []netip.Prefix{}, // Nothing should be removed
},
{
name: "add_new_auto_approved_route_keeps_existing",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // Still advertised
netip.MustParsePrefix("192.168.0.0/24"), // New route
},
nodeUser: "test",
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.0.0/24"), // Auto-approved via group
},
wantChanged: true,
},
{
name: "no_announced_routes_keeps_all_approved",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("172.16.0.0/16"),
},
announcedRoutes: []netip.Prefix{}, // No routes announced anymore
nodeUser: "test",
wantApproved: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.0.0/24"),
},
wantChanged: false,
},
{
name: "manually_approved_route_not_in_policy_remains",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("203.0.113.0/24"), // Not in auto-approvers
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // Can be auto-approved
},
nodeUser: "test",
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // New auto-approved
netip.MustParsePrefix("203.0.113.0/24"), // Manual approval preserved
},
wantChanged: true,
},
{
name: "tagged_node_gets_tag_approved_routes",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"), // Tag-approved route
},
nodeUser: "test",
nodeTags: []string{"tag:approved"},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved
},
wantChanged: true,
},
{
name: "complex_scenario_multiple_changes",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // Will not be advertised
netip.MustParsePrefix("203.0.113.0/24"), // Manual, not advertised
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), // New, auto-approvable
netip.MustParsePrefix("172.16.0.0/16"), // New, not approvable (no tag)
netip.MustParsePrefix("198.51.100.0/24"), // New, not in policy
},
nodeUser: "test",
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), // Kept despite not advertised
netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved
netip.MustParsePrefix("203.0.113.0/24"), // Kept despite not advertised
},
wantChanged: true,
},
}
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
for _, tt := range tests {
for i, pmf := range pmfs {
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
// Create test user
user := types.User{
Model: gorm.Model{ID: 1},
Name: tt.nodeUser,
}
users := []types.User{user}
// Create test node
node := types.Node{
ID: 1,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: tt.nodeHostname,
UserID: user.ID,
User: user,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.announcedRoutes,
},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved,
ForcedTags: tt.nodeTags,
}
nodes := types.Nodes{&node}
// Create policy manager
pm, err := pmf(users, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, pm)
// Test ApproveRoutesWithPolicy
gotApproved, gotChanged := ApproveRoutesWithPolicy(
pm,
node.View(),
tt.currentApproved,
tt.announcedRoutes,
)
// Check change flag
assert.Equal(t, tt.wantChanged, gotChanged, "change flag mismatch")
// Check approved routes match expected
if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" {
t.Logf("Want: %v", tt.wantApproved)
t.Logf("Got: %v", gotApproved)
t.Errorf("unexpected approved routes (-want +got):\n%s", diff)
}
// Verify all previously approved routes are still present
for _, prevRoute := range tt.currentApproved {
assert.Contains(t, gotApproved, prevRoute,
"previously approved route %s was removed - this should NEVER happen", prevRoute)
}
// Verify no routes were incorrectly removed
for _, removedRoute := range tt.wantRemovedRoutes {
assert.NotContains(t, gotApproved, removedRoute,
"route %s should have been removed but wasn't", removedRoute)
}
})
}
}
}
func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) {
aclPolicy := `
{
"acls": [
{"action": "accept", "src": ["*"], "dst": ["*:*"]},
],
"autoApprovers": {
"routes": {
"10.0.0.0/8": ["test@"],
},
},
}`
tests := []struct {
name string
currentApproved []netip.Prefix
announcedRoutes []netip.Prefix
wantApproved []netip.Prefix
wantChanged bool
}{
{
name: "nil_current_approved",
currentApproved: nil,
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantChanged: true,
},
{
name: "empty_current_approved",
currentApproved: []netip.Prefix{},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantChanged: true,
},
{
name: "duplicate_routes_handled",
currentApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
},
announcedRoutes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantApproved: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
},
wantChanged: true, // Duplicates are removed, so it's a change
},
}
pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy))
for _, tt := range tests {
for i, pmf := range pmfs {
t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) {
// Create test user
user := types.User{
Model: gorm.Model{ID: 1},
Name: "test",
}
users := []types.User{user}
node := types.Node{
ID: 1,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
User: user,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.announcedRoutes,
},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved,
}
nodes := types.Nodes{&node}
pm, err := pmf(users, nodes.ViewSlice())
require.NoError(t, err)
gotApproved, gotChanged := ApproveRoutesWithPolicy(
pm,
node.View(),
tt.currentApproved,
tt.announcedRoutes,
)
assert.Equal(t, tt.wantChanged, gotChanged)
if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" {
t.Errorf("unexpected approved routes (-want +got):\n%s", diff)
}
})
}
}
}
func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
user := types.User{
Model: gorm.Model{ID: 1},
Name: "test",
}
currentApproved := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
}
announcedRoutes := []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"),
}
node := types.Node{
ID: 1,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
User: user,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: announcedRoutes,
},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: currentApproved,
}
// With nil policy manager, should return current approved unchanged
gotApproved, gotChanged := ApproveRoutesWithPolicy(nil, node.View(), currentApproved, announcedRoutes)
assert.False(t, gotChanged)
assert.Equal(t, currentApproved, gotApproved)
}

View File

@ -771,6 +771,29 @@ func TestNodeCanApproveRoute(t *testing.T) {
policy: `{"acls":[{"action":"accept","src":["*"],"dst":["*:*"]}]}`,
canApprove: false,
},
{
name: "policy-without-autoApprovers-section",
node: normalNode,
route: p("10.33.0.0/16"),
policy: `{
"groups": {
"group:admin": ["user1@"]
},
"acls": [
{
"action": "accept",
"src": ["group:admin"],
"dst": ["group:admin:*"]
},
{
"action": "accept",
"src": ["group:admin"],
"dst": ["10.33.0.0/16:*"]
}
]
}`,
canApprove: false,
},
}
for _, tt := range tests {

View File

@ -239,8 +239,9 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
// The fast path is that a node requests to approve a prefix
// where there is an exact entry, e.g. 10.0.0.0/8, then
// check and return quickly
if _, ok := pm.autoApproveMap[route]; ok {
if slices.ContainsFunc(node.IPs(), pm.autoApproveMap[route].Contains) {
if approvers, ok := pm.autoApproveMap[route]; ok {
canApprove := slices.ContainsFunc(node.IPs(), approvers.Contains)
if canApprove {
return true
}
}
@ -253,7 +254,8 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
// Check if prefix is larger (so containing) and then overlaps
// the route to see if the node can approve a subset of an autoapprover
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {
if slices.ContainsFunc(node.IPs(), approveAddrs.Contains) {
canApprove := slices.ContainsFunc(node.IPs(), approveAddrs.Contains)
if canApprove {
return true
}
}

View File

@ -10,7 +10,6 @@ import (
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
@ -112,6 +111,15 @@ func (m *mapSession) serve() {
// This is the mechanism where the node gives us information about its
// current configuration.
//
// Process the MapRequest to update node state (endpoints, hostinfo, etc.)
c, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req)
if err != nil {
httpError(m.w, err)
return
}
m.h.Change(c)
// If OmitPeers is true and Stream is false
// then the server will let clients update their endpoints without
// breaking existing long-polling (Stream == true) connections.
@ -122,14 +130,6 @@ func (m *mapSession) serve() {
// the response and just wants a 200.
// !req.stream && req.OmitPeers
if m.isEndpointUpdate() {
c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req)
if err != nil {
httpError(m.w, err)
return
}
m.h.Change(c)
m.w.WriteHeader(http.StatusOK)
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
}
@ -142,6 +142,8 @@ func (m *mapSession) serve() {
func (m *mapSession) serveLongPoll() {
m.beforeServeLongPoll()
log.Trace().Str("node", m.node.Hostname).Uint64("node.id", m.node.ID.Uint64()).Msgf("starting long poll session chan(%p)", m.ch)
// Clean up the session when the client disconnects
defer func() {
m.cancelChMu.Lock()
@ -149,18 +151,26 @@ func (m *mapSession) serveLongPoll() {
close(m.cancelCh)
m.cancelChMu.Unlock()
// TODO(kradalby): This can likely be made more effective, but likely most
// nodes has access to the same routes, so it might not be a big deal.
disconnectChange, err := m.h.state.Disconnect(m.node)
if err != nil {
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
log.Trace().Str("node", m.node.Hostname).Uint64("node.id", m.node.ID.Uint64()).Msgf("removing session from batcher chan(%p)", m.ch)
// Validate if we are actually closing the current session or
// if the connection has been replaced. If the connection has been replaced,
// do not run the rest of the disconnect logic.
if m.h.mapBatcher.RemoveNode(m.node.ID, m.ch) {
log.Trace().Str("node", m.node.Hostname).Uint64("node.id", m.node.ID.Uint64()).Msgf("removed from batcher chan(%p)", m.ch)
// First update NodeStore to mark the node as offline
// This ensures the state is consistent before notifying the batcher
disconnectChange, err := m.h.state.Disconnect(m.node.ID)
if err != nil {
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
}
// Send the disconnect change notification
m.h.Change(disconnectChange)
m.afterServeLongPoll()
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
}
m.h.Change(disconnectChange)
m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter())
m.afterServeLongPoll()
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
}()
// Set up the client stream
@ -172,25 +182,37 @@ func (m *mapSession) serveLongPoll() {
m.keepAliveTicker = time.NewTicker(m.keepAlive)
// Add node to batcher BEFORE sending Connect change to prevent race condition
// where the change is sent before the node is in the batcher's node map
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil {
// Add node to batcher so it can receive updates,
// adding this before connecting it to the state ensure that
// it does not miss any updates that might be sent in the split
// time between the node connecting and the batcher being ready.
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil {
m.errf(err, "failed to add node to batcher")
// Send empty response to client to fail fast for invalid/non-existent nodes
select {
case m.ch <- &tailcfg.MapResponse{}:
default:
// Channel might be closed
}
return
}
// Now send the Connect change - the batcher handles NodeCameOnline internally
// but we still need to update routes and other state-level changes
connectChange := m.h.state.Connect(m.node)
if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline {
m.h.Change(connectChange)
// Process the initial MapRequest to update node state (endpoints, hostinfo, etc.)
// CRITICAL: This must be done BEFORE calling Connect() to ensure routes are properly
// synchronized. When nodes reconnect, they send their hostinfo with announced routes
// in the MapRequest. We need this data in NodeStore before Connect() sets up the
// primary routes, otherwise SubnetRoutes() returns empty and the node is removed
// from AvailableRoutes.
mapReqChange, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req)
if err != nil {
m.errf(err, "failed to update node from initial MapRequest")
return
}
m.h.Change(mapReqChange)
// Connect the node after its state has been updated.
// We send two separate change notifications because these are distinct operations:
// 1. UpdateNodeFromMapRequest: processes the client's reported state (routes, endpoints, hostinfo)
// 2. Connect: marks the node online and recalculates primary routes based on the updated state
// While this results in two notifications, it ensures route data is synchronized before
// primary route selection occurs, which is critical for proper HA subnet router failover.
connectChange := m.h.state.Connect(m.node.ID)
m.h.Change(connectChange)
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)

View File

@ -1,11 +1,15 @@
package state
import (
"fmt"
"maps"
"strings"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"tailscale.com/types/key"
"tailscale.com/types/views"
)
@ -21,6 +25,56 @@ const (
update = 3
)
const prometheusNamespace = "headscale"
var (
nodeStoreOperations = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "nodestore_operations_total",
Help: "Total number of NodeStore operations",
}, []string{"operation"})
nodeStoreOperationDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "nodestore_operation_duration_seconds",
Help: "Duration of NodeStore operations",
Buckets: prometheus.DefBuckets,
}, []string{"operation"})
nodeStoreBatchSize = promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "nodestore_batch_size",
Help: "Size of NodeStore write batches",
Buckets: []float64{1, 2, 5, 10, 20, 50, 100},
})
nodeStoreBatchDuration = promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "nodestore_batch_duration_seconds",
Help: "Duration of NodeStore batch processing",
Buckets: prometheus.DefBuckets,
})
nodeStoreSnapshotBuildDuration = promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "nodestore_snapshot_build_duration_seconds",
Help: "Duration of NodeStore snapshot building from nodes",
Buckets: prometheus.DefBuckets,
})
nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "nodestore_nodes_total",
Help: "Total number of nodes in the NodeStore",
})
nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "nodestore_peers_calculation_duration_seconds",
Help: "Duration of peers calculation in NodeStore",
Buckets: prometheus.DefBuckets,
})
nodeStoreQueueDepth = promauto.NewGauge(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "nodestore_queue_depth",
Help: "Current depth of NodeStore write queue",
})
)
// NodeStore is a thread-safe store for nodes.
// It is a copy-on-write structure, replacing the "snapshot"
// when a change to the structure occurs. It is optimised for reads,
@ -29,13 +83,14 @@ const (
// changes rapidly.
//
// Writes will block until committed, while reads are never
// blocked.
// blocked. This means that the caller of a write operation
// is responsible for ensuring an update depending on a write
// is not issued before the write is complete.
type NodeStore struct {
data atomic.Pointer[Snapshot]
peersFunc PeersFunc
writeQueue chan work
// TODO: metrics
}
func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore {
@ -50,9 +105,17 @@ func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore {
}
store.data.Store(&snap)
// Initialize node count gauge
nodeStoreNodesCount.Set(float64(len(nodes)))
return store
}
// Snapshot is the representation of the current state of the NodeStore.
// It contains all nodes and their relationships.
// It is a copy-on-write structure, meaning that when a write occurs,
// a new Snapshot is created with the updated state,
// and replaces the old one atomically.
type Snapshot struct {
// nodesByID is the main source of truth for nodes.
nodesByID map[types.NodeID]types.Node
@ -64,15 +127,19 @@ type Snapshot struct {
allNodes []types.NodeView
}
// PeersFunc is a function that takes a list of nodes and returns a map
// with the relationships between nodes and their peers.
// This will typically be used to calculate which nodes can see each other
// based on the current policy.
type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView
// work represents a single operation to be performed on the NodeStore.
type work struct {
op int
nodeID types.NodeID
node types.Node
updateFn UpdateNodeFunc
result chan struct{}
immediate bool // For operations that need immediate processing
op int
nodeID types.NodeID
node types.Node
updateFn UpdateNodeFunc
result chan struct{}
}
// PutNode adds or updates a node in the store.
@ -80,6 +147,9 @@ type work struct {
// If the node does not exist, it will be added.
// This is a blocking operation that waits for the write to complete.
func (s *NodeStore) PutNode(n types.Node) {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put"))
defer timer.ObserveDuration()
work := work{
op: put,
nodeID: n.ID,
@ -87,8 +157,12 @@ func (s *NodeStore) PutNode(n types.Node) {
result: make(chan struct{}),
}
nodeStoreQueueDepth.Inc()
s.writeQueue <- work
<-work.result
nodeStoreQueueDepth.Dec()
nodeStoreOperations.WithLabelValues("put").Inc()
}
// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it.
@ -96,7 +170,21 @@ type UpdateNodeFunc func(n *types.Node)
// UpdateNode applies a function to modify a specific node in the store.
// This is a blocking operation that waits for the write to complete.
// This is analogous to a database "transaction", or, the caller should
// rather collect all data they want to change, and then call this function.
// Fewer calls are better.
//
// TODO(kradalby): Technically we could have a version of this that modifies the node
// in the current snapshot if _we know_ that the change will not affect the peer relationships.
// This is because the main nodesByID map contains the struct, and every other map is using a
// pointer to the underlying struct. The gotcha with this is that we will need to introduce
// a lock around the nodesByID map to ensure that no other writes are happening
// while we are modifying the node. Which mean we would need to implement read-write locks
// on all read operations.
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update"))
defer timer.ObserveDuration()
work := work{
op: update,
nodeID: nodeID,
@ -104,48 +192,47 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)
result: make(chan struct{}),
}
nodeStoreQueueDepth.Inc()
s.writeQueue <- work
<-work.result
}
nodeStoreQueueDepth.Dec()
// UpdateNodeImmediate applies a function to modify a specific node in the store
// with immediate processing (bypassing normal batching delays).
// Use this for time-sensitive updates like online status changes.
func (s *NodeStore) UpdateNodeImmediate(nodeID types.NodeID, updateFn func(n *types.Node)) {
work := work{
op: update,
nodeID: nodeID,
updateFn: updateFn,
result: make(chan struct{}),
immediate: true,
}
s.writeQueue <- work
<-work.result
nodeStoreOperations.WithLabelValues("update").Inc()
}
// DeleteNode removes a node from the store by its ID.
// This is a blocking operation that waits for the write to complete.
func (s *NodeStore) DeleteNode(id types.NodeID) {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("delete"))
defer timer.ObserveDuration()
work := work{
op: del,
nodeID: id,
result: make(chan struct{}),
}
nodeStoreQueueDepth.Inc()
s.writeQueue <- work
<-work.result
nodeStoreQueueDepth.Dec()
nodeStoreOperations.WithLabelValues("delete").Inc()
}
// Start initializes the NodeStore and starts processing the write queue.
func (s *NodeStore) Start() {
s.writeQueue = make(chan work)
go s.processWrite()
}
// Stop stops the NodeStore and closes the write queue.
func (s *NodeStore) Stop() {
close(s.writeQueue)
}
// processWrite processes the write queue in batches.
// It collects writes into batches and applies them periodically.
func (s *NodeStore) processWrite() {
c := time.NewTicker(batchTimeout)
batch := make([]work, 0, batchSize)
@ -157,13 +244,7 @@ func (s *NodeStore) processWrite() {
c.Stop()
return
}
// Handle immediate operations right away
if w.immediate {
s.applyBatch([]work{w})
continue
}
batch = append(batch, w)
if len(batch) >= batchSize {
s.applyBatch(batch)
@ -181,7 +262,22 @@ func (s *NodeStore) processWrite() {
}
}
// applyBatch applies a batch of work to the node store.
// This means that it takes a copy of the current nodes,
// then applies the batch of operations to that copy,
// runs any precomputation needed (like calculating peers),
// and finally replaces the snapshot in the store with the new one.
// The replacement of the snapshot is atomic, ensuring that reads
// are never blocked by writes.
// Each write item is blocked until the batch is applied to ensure
// the caller knows the operation is complete and do not send any
// updates that are dependent on a read that is yet to be written.
func (s *NodeStore) applyBatch(batch []work) {
timer := prometheus.NewTimer(nodeStoreBatchDuration)
defer timer.ObserveDuration()
nodeStoreBatchSize.Observe(float64(len(batch)))
nodes := make(map[types.NodeID]types.Node)
maps.Copy(nodes, s.data.Load().nodesByID)
@ -201,15 +297,25 @@ func (s *NodeStore) applyBatch(batch []work) {
}
newSnap := snapshotFromNodes(nodes, s.peersFunc)
s.data.Store(&newSnap)
// Update node count gauge
nodeStoreNodesCount.Set(float64(len(nodes)))
for _, w := range batch {
close(w.result)
}
}
// snapshotFromNodes creates a new Snapshot from the provided nodes.
// It builds a lot of "indexes" to make lookups fast for datasets we
// that is used frequently, like nodesByNodeKey, peersByNode, and nodesByUser.
// This is not a fast operation, it is the "slow" part of our copy-on-write
// structure, but it allows us to have fast reads and efficient lookups.
func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot {
timer := prometheus.NewTimer(nodeStoreSnapshotBuildDuration)
defer timer.ObserveDuration()
allNodes := make([]types.NodeView, 0, len(nodes))
for _, n := range nodes {
allNodes = append(allNodes, n.View())
@ -219,8 +325,17 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
nodesByID: nodes,
allNodes: allNodes,
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
peersByNode: peersFunc(allNodes),
nodesByUser: make(map[types.UserID][]types.NodeView),
// peersByNode is most likely the most expensive operation,
// it will use the list of all nodes, combined with the
// current policy to precalculate which nodes are peers and
// can see each other.
peersByNode: func() map[types.NodeID][]types.NodeView {
peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration)
defer peersTimer.ObserveDuration()
return peersFunc(allNodes)
}(),
nodesByUser: make(map[types.UserID][]types.NodeView),
}
// Build nodesByUser and nodesByNodeKey maps
@ -234,9 +349,21 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
}
// GetNode retrieves a node by its ID.
func (s *NodeStore) GetNode(id types.NodeID) types.NodeView {
n := s.data.Load().nodesByID[id]
return n.View()
// The bool indicates if the node exists or is available (like "err not found").
// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure
// it isn't an invalid node (this is more of a node error or node is broken).
func (s *NodeStore) GetNode(id types.NodeID) (types.NodeView, bool) {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get"))
defer timer.ObserveDuration()
nodeStoreOperations.WithLabelValues("get").Inc()
n, exists := s.data.Load().nodesByID[id]
if !exists {
return types.NodeView{}, false
}
return n.View(), true
}
// GetNodeByNodeKey retrieves a node by its NodeKey.
@ -306,15 +433,30 @@ func (s *NodeStore) DebugString() string {
// ListNodes returns a slice of all nodes in the store.
func (s *NodeStore) ListNodes() views.Slice[types.NodeView] {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list"))
defer timer.ObserveDuration()
nodeStoreOperations.WithLabelValues("list").Inc()
return views.SliceOf(s.data.Load().allNodes)
}
// ListPeers returns a slice of all peers for a given node ID.
func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_peers"))
defer timer.ObserveDuration()
nodeStoreOperations.WithLabelValues("list_peers").Inc()
return views.SliceOf(s.data.Load().peersByNode[id])
}
// ListNodesByUser returns a slice of all nodes for a given user ID.
func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] {
timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_by_user"))
defer timer.ObserveDuration()
nodeStoreOperations.WithLabelValues("list_by_user").Inc()
return views.SliceOf(s.data.Load().nodesByUser[uid])
}

View File

@ -24,6 +24,7 @@ func TestSnapshotFromNodes(t *testing.T) {
peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
return make(map[types.NodeID][]types.NodeView)
}
return nodes, peersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
@ -61,6 +62,7 @@ func TestSnapshotFromNodes(t *testing.T) {
1: createTestNode(1, 1, "user1", "node1"),
2: createTestNode(2, 1, "user1", "node2"),
}
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
@ -85,6 +87,7 @@ func TestSnapshotFromNodes(t *testing.T) {
2: createTestNode(2, 2, "user2", "node2"),
3: createTestNode(3, 1, "user1", "node3"),
}
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
@ -113,6 +116,7 @@ func TestSnapshotFromNodes(t *testing.T) {
4: createTestNode(4, 4, "user4", "node4"),
}
peersFunc := oddEvenPeersFunc
return nodes, peersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
@ -191,6 +195,7 @@ func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
}
ret[node.ID()] = peers
}
return ret
}
@ -214,6 +219,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
}
ret[node.ID()] = peers
}
return ret
}
@ -329,6 +335,7 @@ func TestNodeStoreOperations(t *testing.T) {
node2 := createTestNode(2, 1, "user1", "node2")
node3 := createTestNode(3, 2, "user2", "node3")
initialNodes := types.Nodes{&node1, &node2, &node3}
return NewNodeStore(initialNodes, allowAllPeersFunc)
},
steps: []testStep{

File diff suppressed because it is too large Load Diff

View File

@ -45,6 +45,7 @@ func (c Change) AlsoSelf() bool {
case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate:
return true
}
return false
}

View File

@ -104,6 +104,7 @@ type Node struct {
// headscale. It is best effort and not persisted.
LastSeen *time.Time `gorm:"column:last_seen"`
// ApprovedRoutes is a list of routes that the node is allowed to announce
// as a subnet router. They are not necessarily the routes that the node
// announces at the moment.
@ -420,6 +421,11 @@ func (node *Node) AnnouncedRoutes() []netip.Prefix {
}
// SubnetRoutes returns the list of routes that the node announces and are approved.
//
// IMPORTANT: This method is used for internal data structures and should NOT be used
// for the gRPC Proto conversion. For Proto, SubnetRoutes must be populated manually
// with PrimaryRoutes to ensure it includes only routes actively served by the node.
// See the comment in Proto() method and the implementation in grpcv1.go/nodesToProto.
func (node *Node) SubnetRoutes() []netip.Prefix {
var routes []netip.Prefix
@ -525,7 +531,7 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) {
}
node.Hostname = hostInfo.Hostname
log.Trace().
Str("node_id", node.ID.String()).
Str("new_hostname", node.Hostname).