1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-08-05 13:49:57 +02:00

cleanup poll, remove silly notify ctx

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-05-26 16:10:39 +02:00
parent 72f473d0c9
commit 21d49ea5c0
No known key found for this signature in database
12 changed files with 174 additions and 1075 deletions

View File

@ -287,8 +287,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
if changed {
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
h.nodeNotifier.NotifyAll(ctx, update)
h.nodeNotifier.NotifyAll(update)
}
case <-derpTickerChan:
@ -299,8 +298,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
h.DERPMap.Regions[region.RegionID] = &region
}
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
h.nodeNotifier.NotifyAll(types.StateUpdate{
Type: types.StateDERPUpdated,
DERPMap: h.DERPMap,
})
@ -311,10 +309,9 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
}
h.cfg.TailcfgDNSConfig.ExtraRecords = records
ctx := types.NotifyCtx(context.Background(), "dns-extrarecord", "all")
// TODO(kradalby): We can probably do better than sending a full update here,
// but for now this will ensure that all of the nodes get the new records.
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
h.nodeNotifier.NotifyAll(types.UpdateFull())
}
}
}
@ -517,8 +514,7 @@ func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not
}
if changed {
ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
notif.NotifyAll(ctx, types.UpdateFull())
notif.NotifyAll(types.UpdateFull())
}
return nil
@ -544,8 +540,7 @@ func nodesChangedHook(
}
if filterChanged {
ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
notif.NotifyAll(ctx, types.UpdateFull())
notif.NotifyAll(types.UpdateFull())
return true, nil
}
@ -581,13 +576,9 @@ func (h *Headscale) Serve() error {
// Fetch an initial DERP Map before we start serving
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
mapp := mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.polMan, h.primaryRoutes)
h.mapBatcher = mapper.NewBatcher(mapp)
h.mapBatcher = mapper.NewBatcherAndMapper(h.db, h.cfg, h.DERPMap, h.polMan, h.primaryRoutes)
h.nodeNotifier = notifier.NewNotifier(h.cfg, h.mapBatcher)
// TODO(kradalby): I dont like this. Right now its done to access online status.
mapp.SetBatcher(h.mapBatcher)
h.mapBatcher.Start()
defer h.mapBatcher.Close()
@ -878,8 +869,7 @@ func (h *Headscale) Serve() error {
log.Error().Err(err).Msg("failed to approve routes after new policy")
}
ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
h.nodeNotifier.NotifyAll(types.UpdateFull())
}
default:
info := func(msg string) { log.Info().Msg(msg) }

View File

@ -90,8 +90,7 @@ func (h *Headscale) handleExistingNode(
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
}
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
h.nodeNotifier.NotifyAll(types.UpdatePeerRemoved(node.ID))
}
expired = true
@ -102,8 +101,7 @@ func (h *Headscale) handleExistingNode(
return nil, fmt.Errorf("setting node expiry: %w", err)
}
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, requestExpiry), node.ID)
h.nodeNotifier.NotifyWithIgnore(types.UpdateExpire(node.ID, requestExpiry), node.ID)
}
return nodeToRegisterResponse(node), nil
@ -262,8 +260,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
}
if !updateSent || routesChanged {
ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname)
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
h.nodeNotifier.NotifyAll(types.UpdatePeerChanged(node.ID))
}
return &tailcfg.RegisterResponse{

View File

@ -286,8 +286,7 @@ func (api headscaleV1APIServer) RegisterNode(
}
if !updateSent || routesChanged {
ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
api.h.nodeNotifier.NotifyAll(types.UpdatePeerChanged(node.ID))
}
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
@ -336,8 +335,7 @@ func (api headscaleV1APIServer) SetTags(
}, status.Error(codes.InvalidArgument, err.Error())
}
ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
api.h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerChanged(node.ID), node.ID)
log.Trace().
Str("node", node.Hostname).
@ -382,11 +380,9 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
}
if api.h.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...) {
ctx := types.NotifyCtx(ctx, "poll-primary-change", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
api.h.nodeNotifier.NotifyAll(types.UpdateFull())
} else {
ctx = types.NotifyCtx(ctx, "cli-approveroutes", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
api.h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerChanged(node.ID), node.ID)
}
proto := node.Proto()
@ -422,8 +418,7 @@ func (api headscaleV1APIServer) DeleteNode(
return nil, err
}
ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
api.h.nodeNotifier.NotifyAll(types.UpdatePeerRemoved(node.ID))
return &v1.DeleteNodeResponse{}, nil
}
@ -447,14 +442,11 @@ func (api headscaleV1APIServer) ExpireNode(
return nil, err
}
ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByNodeID(
ctx,
types.UpdateSelf(node.ID),
node.ID)
ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, now), node.ID)
api.h.nodeNotifier.NotifyWithIgnore(types.UpdateExpire(node.ID, now), node.ID)
log.Trace().
Str("node", node.Hostname).
@ -484,8 +476,7 @@ func (api headscaleV1APIServer) RenameNode(
return nil, err
}
ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
api.h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerChanged(node.ID), node.ID)
log.Trace().
Str("node", node.Hostname).
@ -581,13 +572,10 @@ func (api headscaleV1APIServer) MoveNode(
return nil, err
}
ctx = types.NotifyCtx(ctx, "cli-movenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByNodeID(
ctx,
types.UpdateSelf(node.ID),
node.ID)
ctx = types.NotifyCtx(ctx, "cli-movenode", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
api.h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerChanged(node.ID), node.ID)
return &v1.MoveNodeResponse{Node: node.Proto()}, nil
}
@ -770,8 +758,7 @@ func (api headscaleV1APIServer) SetPolicy(
return nil, err
}
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
api.h.nodeNotifier.NotifyAll(types.UpdateFull())
}
response := &v1.SetPolicyResponse{

View File

@ -3,6 +3,9 @@ package mapper
import (
"time"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/puzpuzpuz/xsync/v3"
"github.com/rs/zerolog/log"
@ -39,7 +42,7 @@ type nodeConn struct {
type Batcher struct {
mu deadlock.RWMutex
mapper *Mapper
mapper *mapper
// connected is a map of NodeID to the time the closed a connection.
// This is used to track which nodes are currently connected.
@ -58,7 +61,21 @@ type Batcher struct {
workCh chan *ChangeWork
}
func NewBatcher(mapper *Mapper) *Batcher {
func NewBatcherAndMapper(
db *db.HSDatabase,
cfg *types.Config,
derpMap *tailcfg.DERPMap,
polMan policy.PolicyManager,
primary *routes.PrimaryRoutes,
) *Batcher {
mapper := newMapper(db, cfg, derpMap, polMan, primary)
b := NewBatcher(mapper)
mapper.batcher = b
return b
}
func NewBatcher(mapper *mapper) *Batcher {
return &Batcher{
mapper: mapper,
cancelCh: make(chan struct{}),
@ -164,7 +181,7 @@ func (b *Batcher) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
ret := xsync.NewMapOf[types.NodeID, bool]()
for id, _ := range b.connected {
for id := range b.connected {
ret.Store(id, b.isLikelyConnectedLocked(id))
}
@ -250,7 +267,7 @@ func (b *Batcher) resp(id types.NodeID, nc *nodeConn, work *ChangeWork) ([]byte,
switch work.Update.Type {
case types.StateFullUpdate:
data, err = b.mapper.FullMapResponse(req, node)
data, err = b.mapper.fullMapResponse(req, node)
case types.StatePeerChanged:
changed := make(map[types.NodeID]bool, len(work.Update.ChangeNodes))
@ -258,21 +275,21 @@ func (b *Batcher) resp(id types.NodeID, nc *nodeConn, work *ChangeWork) ([]byte,
changed[nodeID] = true
}
data, err = b.mapper.PeerChangedResponse(req, node, changed, work.Update.ChangePatches)
data, err = b.mapper.peerChangedResponse(req, node, changed, work.Update.ChangePatches)
case types.StatePeerChangedPatch:
data, err = b.mapper.PeerChangedPatchResponse(req, node, work.Update.ChangePatches)
data, err = b.mapper.peerChangedPatchResponse(req, node, work.Update.ChangePatches)
case types.StatePeerRemoved:
changed := make(map[types.NodeID]bool, len(work.Update.Removed))
for _, nodeID := range work.Update.Removed {
changed[nodeID] = false
}
data, err = b.mapper.PeerChangedResponse(req, node, changed, work.Update.ChangePatches)
data, err = b.mapper.peerChangedResponse(req, node, changed, work.Update.ChangePatches)
case types.StateSelfUpdate:
data, err = b.mapper.PeerChangedResponse(req, node, make(map[types.NodeID]bool), work.Update.ChangePatches)
// case types.StateDERPUpdated:
// data, err = b.mapper.DERPMapResponse(req, node, b.mapper.DERPMap)
data, err = b.mapper.peerChangedResponse(req, node, make(map[types.NodeID]bool), work.Update.ChangePatches)
case types.StateDERPUpdated:
data, err = b.mapper.derpMapResponse(req, node, work.Update.DERPMap)
}
return data, err

View File

@ -13,7 +13,6 @@ import (
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/db"
@ -49,7 +48,7 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
// - Create a "minifier" that removes info not needed for the node
// - some sort of batching, wait for 5 or 60 seconds before sending
type Mapper struct {
type mapper struct {
// Configuration
// TODO(kradalby): figure out if this is the format we want this in
db *db.HSDatabase
@ -59,9 +58,7 @@ type Mapper struct {
primary *routes.PrimaryRoutes
batcher *Batcher
uid string
created time.Time
seq uint64
}
type patch struct {
@ -69,36 +66,25 @@ type patch struct {
change *tailcfg.PeerChange
}
func NewMapper(
func newMapper(
db *db.HSDatabase,
cfg *types.Config,
derpMap *tailcfg.DERPMap,
polMan policy.PolicyManager,
primary *routes.PrimaryRoutes,
) *Mapper {
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
) *mapper {
return &Mapper{
return &mapper{
db: db,
cfg: cfg,
derpMap: derpMap,
polMan: polMan,
primary: primary,
uid: uid,
created: time.Now(),
seq: 0,
}
}
func (m *Mapper) SetBatcher(batcher *Batcher) {
m.batcher = batcher
}
func (m *Mapper) String() string {
return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
}
func generateUserProfiles(
node *types.Node,
peers types.Nodes,
@ -163,14 +149,18 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
}
}
// fullMapResponse creates a complete MapResponse for a node.
// It is a separate function to make testing easier.
func (m *Mapper) fullMapResponse(
// fullMapResponse returns a MapResponse for the given node.
func (m *mapper) fullMapResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
peers types.Nodes,
capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) {
resp, err := m.baseWithConfigMapResponse(node, capVer)
messages ...string,
) ([]byte, error) {
peers, err := m.listPeers(node.ID)
if err != nil {
return nil, err
}
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
if err != nil {
return nil, err
}
@ -181,7 +171,7 @@ func (m *Mapper) fullMapResponse(
m.polMan,
m.primary,
node,
capVer,
mapRequest.Version,
peers,
m.cfg,
)
@ -189,55 +179,10 @@ func (m *Mapper) fullMapResponse(
return nil, err
}
return resp, nil
return marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
}
// FullMapResponse returns a MapResponse for the given node.
func (m *Mapper) FullMapResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
messages ...string,
) ([]byte, error) {
peers, err := m.ListPeers(node.ID)
if err != nil {
return nil, err
}
resp, err := m.fullMapResponse(node, peers, mapRequest.Version)
if err != nil {
return nil, err
}
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
}
// ReadOnlyMapResponse returns a MapResponse for the given node.
// Lite means that the peers has been omitted, this is intended
// to be used to answer MapRequests with OmitPeers set to true.
func (m *Mapper) ReadOnlyMapResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
messages ...string,
) ([]byte, error) {
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
if err != nil {
return nil, err
}
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
}
func (m *Mapper) KeepAliveResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
) ([]byte, error) {
resp := m.baseMapResponse()
resp.KeepAlive = true
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
func (m *Mapper) DERPMapResponse(
func (m *mapper) derpMapResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
derpMap *tailcfg.DERPMap,
@ -247,10 +192,10 @@ func (m *Mapper) DERPMapResponse(
resp := m.baseMapResponse()
resp.DERPMap = derpMap
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
return marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
func (m *Mapper) PeerChangedResponse(
func (m *mapper) peerChangedResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
changed map[types.NodeID]bool,
@ -273,7 +218,7 @@ func (m *Mapper) PeerChangedResponse(
}
changedNodes := types.Nodes{}
if len(changedIDs) > 0 {
changedNodes, err = m.ListNodes(changedIDs...)
changedNodes, err = m.listNodes(changedIDs...)
if err != nil {
return nil, err
}
@ -324,12 +269,12 @@ func (m *Mapper) PeerChangedResponse(
}
resp.Node = tailnode
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
return marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
}
// PeerChangedPatchResponse creates a patch MapResponse with
// peerChangedPatchResponse creates a patch MapResponse with
// incoming update from a state change.
func (m *Mapper) PeerChangedPatchResponse(
func (m *mapper) peerChangedPatchResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
changed []*tailcfg.PeerChange,
@ -337,18 +282,16 @@ func (m *Mapper) PeerChangedPatchResponse(
resp := m.baseMapResponse()
resp.PeersChangedPatch = changed
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
return marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
func (m *Mapper) marshalMapResponse(
func marshalMapResponse(
mapRequest tailcfg.MapRequest,
resp *tailcfg.MapResponse,
node *types.Node,
compression string,
messages ...string,
) ([]byte, error) {
atomic.AddUint64(&m.seq, 1)
jsonBody, err := json.Marshal(resp)
if err != nil {
return nil, fmt.Errorf("marshalling map response: %w", err)
@ -392,7 +335,7 @@ func (m *Mapper) marshalMapResponse(
mapResponsePath := path.Join(
mPath,
fmt.Sprintf("%s-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
fmt.Sprintf("%s-%s.json", now, responseType),
)
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
@ -443,7 +386,7 @@ var zstdEncoderPool = &sync.Pool{
// baseMapResponse returns a tailcfg.MapResponse with
// KeepAlive false and ControlTime set to now.
func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
func (m *mapper) baseMapResponse() tailcfg.MapResponse {
now := time.Now()
resp := tailcfg.MapResponse{
@ -459,7 +402,7 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
// with the basic configuration from headscale set.
// It is used in for bigger updates, such as full and lite, not
// incremental.
func (m *Mapper) baseWithConfigMapResponse(
func (m *mapper) baseWithConfigMapResponse(
node *types.Node,
capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) {
@ -494,10 +437,10 @@ func (m *Mapper) baseWithConfigMapResponse(
return &resp, nil
}
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
// listPeers returns peers of node, regardless of any Policy or if the node is expired.
// If no peer IDs are given, all peers are returned.
// If at least one peer ID is given, only these peer nodes will be returned.
func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
peers, err := m.db.ListPeers(nodeID, peerIDs...)
if err != nil {
return nil, err
@ -513,9 +456,9 @@ func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.
return peers, nil
}
// ListNodes queries the database for either all nodes if no parameters are given
// listNodes queries the database for either all nodes if no parameters are given
// or for the given nodes if at least one node ID is given as parameter
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
func (m *mapper) listNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
nodes, err := m.db.ListNodes(nodeIDs...)
if err != nil {
return nil, err

View File

@ -1,450 +0,0 @@
package mapper
import (
"fmt"
"net/netip"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/key"
)
var iap = func(ipStr string) *netip.Addr {
ip := netip.MustParseAddr(ipStr)
return &ip
}
func TestDNSConfigMapResponse(t *testing.T) {
tests := []struct {
magicDNS bool
want *tailcfg.DNSConfig
}{
{
magicDNS: true,
want: &tailcfg.DNSConfig{
Routes: map[string][]*dnstype.Resolver{},
Domains: []string{
"foobar.headscale.net",
},
Proxied: true,
},
},
{
magicDNS: false,
want: &tailcfg.DNSConfig{
Domains: []string{"foobar.headscale.net"},
Proxied: false,
},
},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("with-magicdns-%v", tt.magicDNS), func(t *testing.T) {
mach := func(hostname, username string, userid uint) *types.Node {
return &types.Node{
Hostname: hostname,
UserID: userid,
User: types.User{
Name: username,
},
}
}
baseDomain := "foobar.headscale.net"
dnsConfigOrig := tailcfg.DNSConfig{
Routes: make(map[string][]*dnstype.Resolver),
Domains: []string{baseDomain},
Proxied: tt.magicDNS,
}
nodeInShared1 := mach("test_get_shared_nodes_1", "shared1", 1)
got := generateDNSConfig(
&types.Config{
TailcfgDNSConfig: &dnsConfigOrig,
},
nodeInShared1,
)
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("expandAlias() unexpected result (-want +got):\n%s", diff)
}
})
}
}
func Test_fullMapResponse(t *testing.T) {
mustNK := func(str string) key.NodePublic {
var k key.NodePublic
_ = k.UnmarshalText([]byte(str))
return k
}
mustDK := func(str string) key.DiscoPublic {
var k key.DiscoPublic
_ = k.UnmarshalText([]byte(str))
return k
}
mustMK := func(str string) key.MachinePublic {
var k key.MachinePublic
_ = k.UnmarshalText([]byte(str))
return k
}
hiview := func(hoin tailcfg.Hostinfo) tailcfg.HostinfoView {
return hoin.View()
}
created := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1"}
user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2"}
mini := &types.Node{
ID: 1,
MachineKey: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
NodeKey: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
DiscoKey: mustDK(
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
),
IPv4: iap("100.64.0.1"),
Hostname: "mini",
GivenName: "mini",
UserID: user1.ID,
User: user1,
ForcedTags: []string{},
AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen,
Expiry: &expire,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("172.0.0.0/10"),
},
},
ApprovedRoutes: []netip.Prefix{tsaddr.AllIPv4(), netip.MustParsePrefix("192.168.0.0/24")},
CreatedAt: created,
}
tailMini := &tailcfg.Node{
ID: 1,
StableID: "1",
Name: "mini",
User: tailcfg.UserID(user1.ID),
Key: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
KeyExpiry: expire,
Machine: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
DiscoKey: mustDK(
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
),
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")},
AllowedIPs: []netip.Prefix{
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("100.64.0.1/32"),
tsaddr.AllIPv6(),
},
PrimaryRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"),
},
HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("172.0.0.0/10"),
},
}),
Created: created,
Tags: []string{},
LastSeen: &lastSeen,
MachineAuthorized: true,
CapMap: tailcfg.NodeCapMap{
tailcfg.CapabilityFileSharing: []tailcfg.RawMessage{},
tailcfg.CapabilityAdmin: []tailcfg.RawMessage{},
tailcfg.CapabilitySSH: []tailcfg.RawMessage{},
},
}
peer1 := &types.Node{
ID: 2,
MachineKey: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
NodeKey: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
DiscoKey: mustDK(
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
),
IPv4: iap("100.64.0.2"),
Hostname: "peer1",
GivenName: "peer1",
UserID: user2.ID,
User: user2,
ForcedTags: []string{},
LastSeen: &lastSeen,
Expiry: &expire,
Hostinfo: &tailcfg.Hostinfo{},
CreatedAt: created,
}
tailPeer1 := &tailcfg.Node{
ID: 2,
StableID: "2",
Name: "peer1",
User: tailcfg.UserID(user2.ID),
Key: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
KeyExpiry: expire,
Machine: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
DiscoKey: mustDK(
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
),
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{}),
Created: created,
Tags: []string{},
LastSeen: &lastSeen,
MachineAuthorized: true,
CapMap: tailcfg.NodeCapMap{
tailcfg.CapabilityFileSharing: []tailcfg.RawMessage{},
tailcfg.CapabilityAdmin: []tailcfg.RawMessage{},
tailcfg.CapabilitySSH: []tailcfg.RawMessage{},
},
}
tests := []struct {
name string
pol []byte
node *types.Node
peers types.Nodes
derpMap *tailcfg.DERPMap
cfg *types.Config
want *tailcfg.MapResponse
wantErr bool
}{
// {
// name: "empty-node",
// node: types.Node{},
// pol: &policyv2.Policy{},
// dnsConfig: &tailcfg.DNSConfig{},
// baseDomain: "",
// want: nil,
// wantErr: true,
// },
{
name: "no-pol-no-peers-map-response",
node: mini,
peers: types.Nodes{},
derpMap: &tailcfg.DERPMap{},
cfg: &types.Config{
BaseDomain: "",
TailcfgDNSConfig: &tailcfg.DNSConfig{},
LogTail: types.LogTailConfig{Enabled: false},
RandomizeClientPort: false,
},
want: &tailcfg.MapResponse{
Node: tailMini,
KeepAlive: false,
DERPMap: &tailcfg.DERPMap{},
Peers: []*tailcfg.Node{},
DNSConfig: &tailcfg.DNSConfig{},
Domain: "",
CollectServices: "false",
UserProfiles: []tailcfg.UserProfile{
{
ID: tailcfg.UserID(user1.ID),
LoginName: "user1",
DisplayName: "user1",
},
},
ControlTime: &time.Time{},
PacketFilters: map[string][]tailcfg.FilterRule{"base": tailcfg.FilterAllowAll},
Debug: &tailcfg.Debug{
DisableLogTail: true,
},
},
wantErr: false,
},
{
name: "no-pol-with-peer-map-response",
node: mini,
peers: types.Nodes{
peer1,
},
derpMap: &tailcfg.DERPMap{},
cfg: &types.Config{
BaseDomain: "",
TailcfgDNSConfig: &tailcfg.DNSConfig{},
LogTail: types.LogTailConfig{Enabled: false},
RandomizeClientPort: false,
},
want: &tailcfg.MapResponse{
KeepAlive: false,
Node: tailMini,
DERPMap: &tailcfg.DERPMap{},
Peers: []*tailcfg.Node{
tailPeer1,
},
DNSConfig: &tailcfg.DNSConfig{},
Domain: "",
CollectServices: "false",
UserProfiles: []tailcfg.UserProfile{
{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"},
{ID: tailcfg.UserID(user2.ID), LoginName: "user2", DisplayName: "user2"},
},
ControlTime: &time.Time{},
PacketFilters: map[string][]tailcfg.FilterRule{"base": tailcfg.FilterAllowAll},
Debug: &tailcfg.Debug{
DisableLogTail: true,
},
},
wantErr: false,
},
{
name: "with-pol-map-response",
pol: []byte(`
{
"acls": [
{
"action": "accept",
"src": ["100.64.0.2"],
"dst": ["user1@:*"],
},
{
"action": "accept",
"src": ["100.64.0.1"],
"dst": ["192.168.0.0/24:*"],
},
],
}
`),
node: mini,
peers: types.Nodes{
peer1,
},
derpMap: &tailcfg.DERPMap{},
cfg: &types.Config{
BaseDomain: "",
TailcfgDNSConfig: &tailcfg.DNSConfig{},
LogTail: types.LogTailConfig{Enabled: false},
RandomizeClientPort: false,
},
want: &tailcfg.MapResponse{
KeepAlive: false,
Node: tailMini,
DERPMap: &tailcfg.DERPMap{},
Peers: []*tailcfg.Node{
tailPeer1,
},
DNSConfig: &tailcfg.DNSConfig{},
Domain: "",
CollectServices: "false",
PacketFilters: map[string][]tailcfg.FilterRule{
"base": {
{
SrcIPs: []string{"100.64.0.2/32"},
DstPorts: []tailcfg.NetPortRange{
{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny},
},
},
{
SrcIPs: []string{"100.64.0.1/32"},
DstPorts: []tailcfg.NetPortRange{{IP: "192.168.0.0/24", Ports: tailcfg.PortRangeAny}},
},
},
},
SSHPolicy: nil,
UserProfiles: []tailcfg.UserProfile{
{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"},
{ID: tailcfg.UserID(user2.ID), LoginName: "user2", DisplayName: "user2"},
},
ControlTime: &time.Time{},
Debug: &tailcfg.Debug{
DisableLogTail: true,
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
polMan, err := policy.NewPolicyManager(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
require.NoError(t, err)
primary := routes.New()
primary.SetRoutes(tt.node.ID, tt.node.SubnetRoutes()...)
for _, peer := range tt.peers {
primary.SetRoutes(peer.ID, peer.SubnetRoutes()...)
}
mappy := NewMapper(
nil,
tt.cfg,
tt.derpMap,
polMan,
primary,
)
got, err := mappy.fullMapResponse(
tt.node,
tt.peers,
0,
)
if (err != nil) != tt.wantErr {
t.Errorf("fullMapResponse() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(
tt.want,
got,
cmpopts.EquateEmpty(),
// Ignore ControlTime, it is set to now and we dont really need to mock it.
cmpopts.IgnoreFields(tailcfg.MapResponse{}, "ControlTime"),
); diff != "" {
t.Errorf("fullMapResponse() unexpected result (-want +got):\n%s", diff)
}
})
}
}

View File

@ -1,311 +0,0 @@
package mapper
import (
"encoding/json"
"net/netip"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func TestTailNode(t *testing.T) {
mustNK := func(str string) key.NodePublic {
var k key.NodePublic
_ = k.UnmarshalText([]byte(str))
return k
}
mustDK := func(str string) key.DiscoPublic {
var k key.DiscoPublic
_ = k.UnmarshalText([]byte(str))
return k
}
mustMK := func(str string) key.MachinePublic {
var k key.MachinePublic
_ = k.UnmarshalText([]byte(str))
return k
}
hiview := func(hoin tailcfg.Hostinfo) tailcfg.HostinfoView {
return hoin.View()
}
created := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
tests := []struct {
name string
node *types.Node
pol []byte
dnsConfig *tailcfg.DNSConfig
baseDomain string
want *tailcfg.Node
wantErr bool
}{
{
name: "empty-node",
node: &types.Node{
GivenName: "empty",
Hostinfo: &tailcfg.Hostinfo{},
},
dnsConfig: &tailcfg.DNSConfig{},
baseDomain: "",
want: &tailcfg.Node{
Name: "empty",
StableID: "0",
HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{}),
Tags: []string{},
MachineAuthorized: true,
CapMap: tailcfg.NodeCapMap{
tailcfg.CapabilityFileSharing: []tailcfg.RawMessage{},
tailcfg.CapabilityAdmin: []tailcfg.RawMessage{},
tailcfg.CapabilitySSH: []tailcfg.RawMessage{},
},
},
wantErr: false,
},
{
name: "minimal-node",
node: &types.Node{
ID: 0,
MachineKey: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
NodeKey: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
DiscoKey: mustDK(
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
),
IPv4: iap("100.64.0.1"),
Hostname: "mini",
GivenName: "mini",
UserID: 0,
User: types.User{
Name: "mini",
},
ForcedTags: []string{},
AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen,
Expiry: &expire,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("172.0.0.0/10"),
},
},
ApprovedRoutes: []netip.Prefix{tsaddr.AllIPv4(), netip.MustParsePrefix("192.168.0.0/24")},
CreatedAt: created,
},
dnsConfig: &tailcfg.DNSConfig{},
baseDomain: "",
want: &tailcfg.Node{
ID: 0,
StableID: "0",
Name: "mini",
User: 0,
Key: mustNK(
"nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
),
KeyExpiry: expire,
Machine: mustMK(
"mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507",
),
DiscoKey: mustDK(
"discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
),
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")},
AllowedIPs: []netip.Prefix{
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("100.64.0.1/32"),
tsaddr.AllIPv6(),
},
PrimaryRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"),
},
HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("172.0.0.0/10"),
},
}),
Created: created,
Tags: []string{},
LastSeen: &lastSeen,
MachineAuthorized: true,
CapMap: tailcfg.NodeCapMap{
tailcfg.CapabilityFileSharing: []tailcfg.RawMessage{},
tailcfg.CapabilityAdmin: []tailcfg.RawMessage{},
tailcfg.CapabilitySSH: []tailcfg.RawMessage{},
},
},
wantErr: false,
},
{
name: "check-dot-suffix-on-node-name",
node: &types.Node{
GivenName: "minimal",
Hostinfo: &tailcfg.Hostinfo{},
},
dnsConfig: &tailcfg.DNSConfig{},
baseDomain: "example.com",
want: &tailcfg.Node{
// a node name should have a dot appended
Name: "minimal.example.com.",
StableID: "0",
HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{}),
Tags: []string{},
MachineAuthorized: true,
CapMap: tailcfg.NodeCapMap{
tailcfg.CapabilityFileSharing: []tailcfg.RawMessage{},
tailcfg.CapabilityAdmin: []tailcfg.RawMessage{},
tailcfg.CapabilitySSH: []tailcfg.RawMessage{},
},
},
wantErr: false,
},
// TODO: Add tests to check other aspects of the node conversion:
// - With tags and policy
// - dnsconfig and basedomain
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
polMan, err := policy.NewPolicyManager(tt.pol, []types.User{}, types.Nodes{tt.node})
require.NoError(t, err)
primary := routes.New()
cfg := &types.Config{
BaseDomain: tt.baseDomain,
TailcfgDNSConfig: tt.dnsConfig,
RandomizeClientPort: false,
}
_ = primary.SetRoutes(tt.node.ID, tt.node.SubnetRoutes()...)
// This is a hack to avoid having a second node to test the primary route.
// This should be baked into the test case proper if it is extended in the future.
_ = primary.SetRoutes(2, netip.MustParsePrefix("192.168.0.0/24"))
got, err := tailNode(
tt.node,
0,
polMan,
func(id types.NodeID) []netip.Prefix {
return primary.PrimaryRoutes(id)
},
cfg,
)
if (err != nil) != tt.wantErr {
t.Errorf("tailNode() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("tailNode() unexpected result (-want +got):\n%s", diff)
}
})
}
}
func TestNodeExpiry(t *testing.T) {
tp := func(t time.Time) *time.Time {
return &t
}
tests := []struct {
name string
exp *time.Time
wantTime time.Time
wantTimeZero bool
}{
{
name: "no-expiry",
exp: nil,
wantTimeZero: true,
},
{
name: "zero-expiry",
exp: &time.Time{},
wantTimeZero: true,
},
{
name: "localtime",
exp: tp(time.Time{}.Local()),
wantTimeZero: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
node := &types.Node{
ID: 0,
GivenName: "test",
Expiry: tt.exp,
}
polMan, err := policy.NewPolicyManager(nil, nil, nil)
require.NoError(t, err)
tn, err := tailNode(
node,
0,
polMan,
func(id types.NodeID) []netip.Prefix {
return []netip.Prefix{}
},
&types.Config{},
)
if err != nil {
t.Fatalf("nodeExpiry() error = %v", err)
}
// Round trip the node through JSON to ensure the time is serialized correctly
seri, err := json.Marshal(tn)
if err != nil {
t.Fatalf("nodeExpiry() error = %v", err)
}
var deseri tailcfg.Node
err = json.Unmarshal(seri, &deseri)
if err != nil {
t.Fatalf("nodeExpiry() error = %v", err)
}
if tt.wantTimeZero {
if !deseri.KeyExpiry.IsZero() {
t.Errorf("nodeExpiry() = %v, want zero", deseri.KeyExpiry)
}
} else if deseri.KeyExpiry != tt.wantTime {
t.Errorf("nodeExpiry() = %v, want %v", deseri.KeyExpiry, tt.wantTime)
}
})
}
}

View File

@ -1,12 +1,12 @@
package notifier
import (
"context"
"sort"
"strings"
"sync"
"time"
"slices"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/puzpuzpuz/xsync/v3"
@ -97,12 +97,11 @@ func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
return n.mbatcher.LikelyConnectedMap()
}
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
n.NotifyWithIgnore(ctx, update)
func (n *Notifier) NotifyAll(update types.StateUpdate) {
n.NotifyWithIgnore(update)
}
func (n *Notifier) NotifyWithIgnore(
ctx context.Context,
update types.StateUpdate,
ignoreNodeIDs ...types.NodeID,
) {
@ -110,12 +109,10 @@ func (n *Notifier) NotifyWithIgnore(
return
}
notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
n.b.addOrPassthrough(update)
}
func (n *Notifier) NotifyByNodeID(
ctx context.Context,
update types.StateUpdate,
nodeID types.NodeID,
) {
@ -140,9 +137,7 @@ func (n *Notifier) String() string {
var b strings.Builder
var keys []types.NodeID
sort.Slice(keys, func(i, j int) bool {
return keys[i] < keys[j]
})
slices.Sort(keys)
return b.String()
}
@ -227,9 +222,7 @@ func (b *batcher) flush() {
}
changedNodes := b.changedNodeIDs.Slice().AsSlice()
sort.Slice(changedNodes, func(i, j int) bool {
return changedNodes[i] < changedNodes[j]
})
slices.Sort(changedNodes)
if b.changedNodeIDs.Slice().Len() > 0 {
update := types.UpdatePeerChanged(changedNodes...)

View File

@ -545,15 +545,12 @@ func (a *AuthProviderOIDC) handleRegistration(
}
if !updateSent || routesChanged {
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID(
ctx,
types.UpdateSelf(node.ID),
node.ID,
)
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
a.notifier.NotifyWithIgnore(types.UpdatePeerChanged(node.ID), node.ID)
}
return newNode, nil

View File

@ -2,12 +2,12 @@ package hscontrol
import (
"context"
"encoding/json"
"math/rand/v2"
"net/http"
"net/netip"
"time"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
@ -15,6 +15,8 @@ import (
xslices "golang.org/x/exp/slices"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/util/must"
"tailscale.com/util/zstdframe"
)
const (
@ -30,7 +32,6 @@ type mapSession struct {
req tailcfg.MapRequest
ctx context.Context
capVer tailcfg.CapabilityVersion
mapper *mapper.Mapper
cancelChMu deadlock.Mutex
@ -58,17 +59,6 @@ func (h *Headscale) newMapSession(
) *mapSession {
warnf, infof, tracef, errf := logPollFunc(req, node)
// TODO(kradalby): This needs to happen in the batcher now, give a full
// var updateChan chan []byte
// map on start.
// if req.Stream {
// // Use a buffered channel in case a node is not fully ready
// // to receive a message to make sure we dont block the entire
// // notifier.
// updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize)
// updateChan <- types.UpdateFull()
// }
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
return &mapSession{
@ -115,15 +105,11 @@ func (m *mapSession) close() {
}
func (m *mapSession) isStreaming() bool {
return m.req.Stream && !m.req.ReadOnly
return m.req.Stream
}
func (m *mapSession) isEndpointUpdate() bool {
return !m.req.Stream && !m.req.ReadOnly && m.req.OmitPeers
}
func (m *mapSession) isReadOnlyUpdate() bool {
return !m.req.Stream && m.req.OmitPeers && m.req.ReadOnly
return !m.req.Stream && m.req.OmitPeers
}
func (m *mapSession) resetKeepAlive() {
@ -144,13 +130,10 @@ func (m *mapSession) afterServeLongPoll() {
// serve handles non-streaming requests.
func (m *mapSession) serve() {
// TODO(kradalby): A set todos to harden:
// - func to tell the stream to die, readonly -> false, !stream && omitpeers -> false, true
// This is the mechanism where the node gives us information about its
// current configuration.
//
// If OmitPeers is true, Stream is false, and ReadOnly is false,
// If OmitPeers is true and Stream is false
// then the server will let clients update their endpoints without
// breaking existing long-polling (Stream == true) connections.
// In this case, the server can omit the entire response; the client
@ -158,27 +141,12 @@ func (m *mapSession) serve() {
//
// This is what Tailscale calls a Lite update, the client ignores
// the response and just wants a 200.
// !req.stream && !req.ReadOnly && req.OmitPeers
//
// TODO(kradalby): remove ReadOnly when we only support capVer 68+
// !req.stream && req.OmitPeers
if m.isEndpointUpdate() {
m.handleEndpointUpdate()
return
}
// ReadOnly is whether the client just wants to fetch the
// MapResponse, without updating their Endpoints. The
// Endpoints field will be ignored and LastSeen will not be
// updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at
// start-up before their first real endpoint update.
if m.isReadOnlyUpdate() {
m.handleReadOnlyRequest()
return
}
}
// serveLongPoll ensures the node gets the appropriate updates from either
@ -200,6 +168,9 @@ func (m *mapSession) serveLongPoll() {
// reconnects, the channel might be of another connection.
// In that case, it is not closed and the node is still online.
if m.h.mapBatcher.RemoveNode(m.node.ID, m.ch) {
// TODO(kradalby): All of this handling should be moved out of here
// to the mapBatcher(?), where there is more state (with the goal of removing it from here).
// Failover the node's routes if any.
m.h.updateNodeOnlineStatus(false, m.node)
@ -208,8 +179,7 @@ func (m *mapSession) serveLongPoll() {
// TODO(kradalby): This can likely be made more effective, but likely most
// nodes has access to the same routes, so it might not be a big deal.
if m.h.primaryRoutes.SetRoutes(m.node.ID) {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
m.h.nodeNotifier.NotifyAll(types.UpdateFull())
}
}
@ -221,11 +191,13 @@ func (m *mapSession) serveLongPoll() {
m.h.pollNetMapStreamWG.Add(1)
defer m.h.pollNetMapStreamWG.Done()
// TODO(kradalby): All of this handling should be moved out of here
// to the mapBatcher(?), where there is more state (with the goal of removing it from here).
if m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
m.h.nodeNotifier.NotifyAll(types.UpdateFull())
}
// TODO(kradalby): I think this didnt really work and can be reverted back to a normal write thing.
// Upgrade the writer to a ResponseController
rc := http.NewResponseController(m.w)
@ -239,6 +211,9 @@ func (m *mapSession) serveLongPoll() {
m.keepAliveTicker = time.NewTicker(m.keepAlive)
m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.req.Compress, m.req.Version)
// TODO(kradalby): All of this handling should be moved out of here
// to the mapBatcher(?), where there is more state (with the goal of removing it from here).
go m.h.updateNodeOnlineStatus(true, m.node)
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
@ -266,54 +241,63 @@ func (m *mapSession) serveLongPoll() {
return
}
startWrite := time.Now()
_, err := m.w.Write(update)
if err != nil {
m.errf(err, "could not write the map response, for mapSession: %p", m)
if err := m.write(rc, update); err != nil {
m.errf(err, "cannot write update to client")
return
}
err = rc.Flush()
if err != nil {
m.errf(err, "flushing the map response to client, for mapSession: %p", m)
return
}
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
m.tracef("update sent")
m.resetKeepAlive()
// TODO(kradalby): This needs to be rehinked now that we do not have a mapper,
// maybe keepalive can be a static function? I do not think it needs state?
// case <-m.keepAliveTicker.C:
// data, err := m.mapper.KeepAliveResponse(m.req, m.node)
// if err != nil {
// m.errf(err, "Error generating the keep alive msg")
// mapResponseSent.WithLabelValues("error", "keepalive").Inc()
// return
// }
// _, err = m.w.Write(data)
// if err != nil {
// m.errf(err, "Cannot write keep alive message")
// mapResponseSent.WithLabelValues("error", "keepalive").Inc()
// return
// }
// err = rc.Flush()
// if err != nil {
// m.errf(err, "flushing keep alive to client, for mapSession: %p", m)
// mapResponseSent.WithLabelValues("error", "keepalive").Inc()
// return
// }
case <-m.keepAliveTicker.C:
var err error
switch m.req.Compress {
case "zstd":
err = m.write(rc, keepAliveZstd)
default:
err = m.write(rc, keepAlivePlain)
}
// if debugHighCardinalityMetrics {
// mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
// }
// mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
if err != nil {
m.errf(err, "cannot write keep alive")
return
}
if debugHighCardinalityMetrics {
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
}
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
}
}
}
func (m *mapSession) write(rc *http.ResponseController, data []byte) error {
startWrite := time.Now()
_, err := m.w.Write(data)
if err != nil {
return err
}
err = rc.Flush()
if err != nil {
return err
}
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
return nil
}
var keepAlivePlain = must.Get(json.Marshal(tailcfg.MapResponse{
KeepAlive: true,
}))
var keepAliveZstd = (func() []byte {
msg := must.Get(json.Marshal(tailcfg.MapResponse{
KeepAlive: true,
}))
return zstdframe.AppendEncode(nil, msg, zstdframe.FastestCompression)
})()
// updateNodeOnlineStatus records the last seen status of a node and notifies peers
// about change in their online/offline status.
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
@ -335,8 +319,7 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
h.db.SetLastSeen(node.ID, *node.LastSeen)
}
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname)
h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerPatch(change), node.ID)
h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerPatch(change), node.ID)
}
func (m *mapSession) handleEndpointUpdate() {
@ -393,19 +376,15 @@ func (m *mapSession) handleEndpointUpdate() {
// Update the routes of the given node in the route manager to
// see if an update needs to be sent.
if m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) {
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", m.node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
m.h.nodeNotifier.NotifyAll(types.UpdateFull())
} else {
ctx := types.NotifyCtx(m.ctx, "cli-approveroutes", m.node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(m.node.ID), m.node.ID)
m.h.nodeNotifier.NotifyWithIgnore(types.UpdatePeerChanged(m.node.ID), m.node.ID)
// TODO(kradalby): I am not sure if we need this?
// Send an update to the node itself with to ensure it
// has an updated packetfilter allowing the new route
// if it is defined in the ACL.
ctx = types.NotifyCtx(m.ctx, "poll-nodeupdate-self-hostinfochange", m.node.Hostname)
m.h.nodeNotifier.NotifyByNodeID(
ctx,
types.UpdateSelf(m.node.ID),
m.node.ID)
}
@ -425,9 +404,7 @@ func (m *mapSession) handleEndpointUpdate() {
return
}
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", m.node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore(
ctx,
types.UpdatePeerChanged(m.node.ID),
m.node.ID,
)
@ -436,30 +413,6 @@ func (m *mapSession) handleEndpointUpdate() {
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
}
func (m *mapSession) handleReadOnlyRequest() {
m.tracef("Client asked for a lite update, responding without peers")
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node)
if err != nil {
m.errf(err, "Failed to create MapResponse")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseReadOnly.WithLabelValues("error").Inc()
return
}
m.w.Header().Set("Content-Type", "application/json; charset=utf-8")
m.w.WriteHeader(http.StatusOK)
_, err = m.w.Write(mapResp)
if err != nil {
m.errf(err, "Failed to write response")
mapResponseReadOnly.WithLabelValues("error").Inc()
return
}
m.w.WriteHeader(http.StatusOK)
mapResponseReadOnly.WithLabelValues("ok").Inc()
}
func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) {
trace := log.Trace().Uint64("node.id", uint64(change.NodeID)).Str("hostname", hostname)
@ -512,7 +465,6 @@ func logPollFunc(
return func(msg string, a ...any) {
log.Warn().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
@ -522,7 +474,6 @@ func logPollFunc(
func(msg string, a ...any) {
log.Info().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
@ -532,7 +483,6 @@ func logPollFunc(
func(msg string, a ...any) {
log.Trace().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
@ -542,7 +492,6 @@ func logPollFunc(
func(err error, msg string, a ...any) {
log.Error().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).

21
hscontrol/types/change.go Normal file
View File

@ -0,0 +1,21 @@
package types
type Change struct {
NodeChange NodeChange
UserChange UserChange
}
type NodeChangeWhat string
const (
NodeChangeCameOnline NodeChangeWhat = "node-online"
)
type NodeChange struct {
ID NodeID
What NodeChangeWhat
}
type UserChange struct {
ID UserID
}

View File

@ -1,14 +1,12 @@
package types
import (
"context"
"errors"
"fmt"
"time"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
"tailscale.com/util/ctxkey"
)
const (
@ -148,18 +146,6 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
}
}
var (
NotifyOriginKey = ctxkey.New("notify.origin", "")
NotifyHostnameKey = ctxkey.New("notify.hostname", "")
)
func NotifyCtx(ctx context.Context, origin, hostname string) context.Context {
ctx2, _ := context.WithTimeout(ctx, 3*time.Second)
ctx2 = NotifyOriginKey.WithValue(ctx2, origin)
ctx2 = NotifyHostnameKey.WithValue(ctx2, hostname)
return ctx2
}
const RegistrationIDLength = 24
type RegistrationID string
@ -196,23 +182,3 @@ type RegisterNode struct {
Node Node
Registered chan *Node
}
type Change struct {
NodeChange NodeChange
UserChange UserChange
}
type NodeChangeWhat string
const (
NodeChangeCameOnline NodeChangeWhat = "node-online"
)
type NodeChange struct {
ID NodeID
What NodeChangeWhat
}
type UserChange struct {
ID UserID
}