mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-24 13:46:53 +02:00
state/nodestore: in memory representation of nodes
Initial work on a nodestore which stores all of the nodes and their relations in memory with relationship for peers precalculated. It is a copy-on-write structure, replacing the "snapshot" when a change to the structure occurs. It is optimised for reads, and while batches are not fast, they are grouped together to do less of the expensive peer calculation if there are many changes rapidly. Writes will block until commited, while reads are never blocked. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
1f7ed7e4e5
commit
0c539d7993
@ -548,8 +548,8 @@ func (h *Headscale) Serve() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list ephemeral nodes: %w", err)
|
||||
}
|
||||
for _, node := range ephmNodes {
|
||||
h.ephemeralGC.Schedule(node.ID, h.cfg.EphemeralNodeInactivityTimeout)
|
||||
for _, node := range ephmNodes.All() {
|
||||
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
|
||||
}
|
||||
|
||||
if h.cfg.DNSConfig.ExtraRecordsPath != "" {
|
||||
|
@ -34,7 +34,7 @@ func (h *Headscale) handleRegister(
|
||||
return nil, fmt.Errorf("looking up node in database: %w", err)
|
||||
}
|
||||
|
||||
if node != nil {
|
||||
if node.Valid() {
|
||||
// If an existing node is trying to register with an auth key,
|
||||
// we need to validate the auth key even for existing nodes
|
||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||
@ -50,7 +50,7 @@ func (h *Headscale) handleRegister(
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
resp, err := h.handleExistingNode(node, regReq, machineKey)
|
||||
resp, err := h.handleExistingNode(node.AsStruct(), regReq, machineKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("handling existing node: %w", err)
|
||||
}
|
||||
@ -107,7 +107,7 @@ func (h *Headscale) handleExistingNode(
|
||||
// If the request expiry is in the past, we consider it a logout.
|
||||
if requestExpiry.Before(time.Now()) {
|
||||
if node.IsEphemeral() {
|
||||
c, err := h.state.DeleteNode(node)
|
||||
c, err := h.state.DeleteNode(node.View())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
|
||||
}
|
||||
@ -124,9 +124,9 @@ func (h *Headscale) handleExistingNode(
|
||||
}
|
||||
|
||||
h.Change(c)
|
||||
}
|
||||
}
|
||||
|
||||
return nodeToRegisterResponse(node), nil
|
||||
return nodeToRegisterResponse(node), nil
|
||||
}
|
||||
|
||||
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
||||
@ -177,7 +177,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
node, changed, policyChanged, err := h.state.HandleNodeFromPreAuthKey(
|
||||
node, changed, err := h.state.HandleNodeFromPreAuthKey(
|
||||
regReq,
|
||||
machineKey,
|
||||
)
|
||||
@ -194,7 +194,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
}
|
||||
|
||||
// If node is nil, it means an ephemeral node was deleted during logout
|
||||
if node == nil {
|
||||
if node.Valid() {
|
||||
h.Change(changed)
|
||||
return nil, nil
|
||||
}
|
||||
@ -218,21 +218,23 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
}
|
||||
|
||||
if routeChange && changed.Empty() {
|
||||
changed = change.NodeAdded(node.ID)
|
||||
changed = change.NodeAdded(node.ID())
|
||||
}
|
||||
h.Change(changed)
|
||||
|
||||
// If policy changed due to node registration, send a separate policy change
|
||||
if policyChanged {
|
||||
policyChange := change.PolicyChange()
|
||||
h.Change(policyChange)
|
||||
}
|
||||
// TODO(kradalby): I think this is covered above, but we need to validate that.
|
||||
// // If policy changed due to node registration, send a separate policy change
|
||||
// if policyChanged {
|
||||
// policyChange := change.PolicyChange()
|
||||
// h.Change(policyChange)
|
||||
// }
|
||||
|
||||
user := node.User()
|
||||
return &tailcfg.RegisterResponse{
|
||||
MachineAuthorized: true,
|
||||
NodeKeyExpired: node.IsExpired(),
|
||||
User: *node.User.TailscaleUser(),
|
||||
Login: *node.User.TailscaleLogin(),
|
||||
User: *user.TailscaleUser(),
|
||||
Login: *user.TailscaleLogin(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -73,14 +73,14 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
}
|
||||
|
||||
sshPol := make(map[string]*tailcfg.SSHPolicy)
|
||||
for _, node := range nodes {
|
||||
pol, err := h.state.SSHPolicy(node.View())
|
||||
for _, node := range nodes.All() {
|
||||
pol, err := h.state.SSHPolicy(node)
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
sshPol[fmt.Sprintf("id:%d hostname:%s givenname:%s", node.ID, node.Hostname, node.GivenName)] = pol
|
||||
sshPol[fmt.Sprintf("id:%d hostname:%s givenname:%s", node.ID(), node.Hostname(), node.GivenName())] = pol
|
||||
}
|
||||
|
||||
sshJSON, err := json.MarshalIndent(sshPol, "", " ")
|
||||
|
@ -25,6 +25,7 @@ import (
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/views"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
@ -59,9 +60,10 @@ func (api headscaleV1APIServer) CreateUser(
|
||||
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
|
||||
}
|
||||
|
||||
|
||||
c := change.UserAdded(types.UserID(user.ID))
|
||||
if policyChanged {
|
||||
|
||||
// TODO(kradalby): Both of these might be policy changes, find a better way to merge.
|
||||
if !policyChanged.Empty() {
|
||||
c.Change = change.Policy
|
||||
}
|
||||
|
||||
@ -79,15 +81,13 @@ func (api headscaleV1APIServer) RenameUser(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, policyChanged, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
|
||||
_, c, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
api.h.Change(change.PolicyChange())
|
||||
}
|
||||
api.h.Change(c)
|
||||
|
||||
newUser, err := api.h.state.GetUserByName(request.GetNewName())
|
||||
if err != nil {
|
||||
@ -297,7 +297,7 @@ func (api headscaleV1APIServer) GetNode(
|
||||
|
||||
// Populate the online field based on
|
||||
// currently connected nodes.
|
||||
resp.Online = api.h.mapBatcher.IsConnected(node.ID)
|
||||
resp.Online = api.h.mapBatcher.IsConnected(node.ID())
|
||||
|
||||
return &v1.GetNodeResponse{Node: resp}, nil
|
||||
}
|
||||
@ -323,7 +323,7 @@ func (api headscaleV1APIServer) SetTags(
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
Str("node", node.Hostname()).
|
||||
Strs("tags", request.GetTags()).
|
||||
Msg("Changing tags of node")
|
||||
|
||||
@ -357,7 +357,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
routeChange := api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
|
||||
routeChange := api.h.state.SetNodeRoutes(node.ID(), node.SubnetRoutes()...)
|
||||
|
||||
// Always propagate node changes from SetApprovedRoutes
|
||||
api.h.Change(nodeChange)
|
||||
@ -368,7 +368,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
||||
}
|
||||
|
||||
proto := node.Proto()
|
||||
proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID))
|
||||
proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID()))
|
||||
|
||||
return &v1.SetApprovedRoutesResponse{Node: proto}, nil
|
||||
}
|
||||
@ -420,8 +420,8 @@ func (api headscaleV1APIServer) ExpireNode(
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
Time("expiry", *node.Expiry).
|
||||
Str("node", node.Hostname()).
|
||||
Time("expiry", *node.AsStruct().Expiry).
|
||||
Msg("node expired")
|
||||
|
||||
return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
|
||||
@ -440,7 +440,7 @@ func (api headscaleV1APIServer) RenameNode(
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
Str("node", node.Hostname()).
|
||||
Str("new_name", request.GetNewName()).
|
||||
Msg("node renamed")
|
||||
|
||||
@ -477,36 +477,36 @@ func (api headscaleV1APIServer) ListNodes(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sort.Slice(nodes, func(i, j int) bool {
|
||||
return nodes[i].ID < nodes[j].ID
|
||||
})
|
||||
|
||||
response := nodesToProto(api.h.state, IsConnected, nodes)
|
||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||
}
|
||||
|
||||
func nodesToProto(state *state.State, IsConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
|
||||
response := make([]*v1.Node, len(nodes))
|
||||
for index, node := range nodes {
|
||||
func nodesToProto(state *state.State, isLikelyConnected *xsync.Map[types.NodeID, bool], nodes views.Slice[types.NodeView]) []*v1.Node {
|
||||
response := make([]*v1.Node, nodes.Len())
|
||||
for index, node := range nodes.All() {
|
||||
resp := node.Proto()
|
||||
|
||||
// Populate the online field based on
|
||||
// currently connected nodes.
|
||||
if val, ok := IsConnected.Load(node.ID); ok && val {
|
||||
if val, ok := isLikelyConnected.Load(node.ID()); ok && val {
|
||||
resp.Online = true
|
||||
}
|
||||
|
||||
var tags []string
|
||||
for _, tag := range node.RequestTags() {
|
||||
if state.NodeCanHaveTag(node.View(), tag) {
|
||||
if state.NodeCanHaveTag(node, tag) {
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
}
|
||||
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags...))
|
||||
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID), node.ExitRoutes()...))
|
||||
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...))
|
||||
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...))
|
||||
response[index] = resp
|
||||
}
|
||||
|
||||
sort.Slice(response, func(i, j int) bool {
|
||||
return response[i].Id < response[j].Id
|
||||
})
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
@ -683,8 +683,8 @@ func (api headscaleV1APIServer) SetPolicy(
|
||||
return nil, fmt.Errorf("setting policy: %w", err)
|
||||
}
|
||||
|
||||
if len(nodes) > 0 {
|
||||
_, err = api.h.state.SSHPolicy(nodes[0].View())
|
||||
if nodes.Len() > 0 {
|
||||
_, err = api.h.state.SSHPolicy(nodes.At(0))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("verifying SSH rules: %w", err)
|
||||
}
|
||||
|
@ -99,8 +99,17 @@ func (h *Headscale) handleVerifyRequest(
|
||||
return fmt.Errorf("cannot list nodes: %w", err)
|
||||
}
|
||||
|
||||
// Check if any node has the requested NodeKey
|
||||
var nodeKeyFound bool
|
||||
for _, node := range nodes.All() {
|
||||
if node.NodeKey() == derpAdmitClientRequest.NodePublic {
|
||||
nodeKeyFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
resp := &tailcfg.DERPAdmitClientResponse{
|
||||
Allow: nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic),
|
||||
Allow: nodeKeyFound,
|
||||
}
|
||||
|
||||
return json.NewEncoder(writer).Encode(resp)
|
||||
|
@ -1940,7 +1940,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("Error listing peers for node %d: %v", i, err)
|
||||
} else {
|
||||
t.Logf("Node %d should see %d peers from state", i, len(peers))
|
||||
t.Logf("Node %d should see %d peers from state", i, peers.Len())
|
||||
}
|
||||
}
|
||||
|
||||
@ -2106,9 +2106,9 @@ func TestBatcherWorkQueueTracing(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("Error getting peers from state: %v", err)
|
||||
} else {
|
||||
t.Logf("State shows %d peers available for this node", len(peers))
|
||||
if len(peers) > 0 && len(data.Peers) == 0 {
|
||||
t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", len(peers))
|
||||
t.Logf("State shows %d peers available for this node", peers.Len())
|
||||
if peers.Len() > 0 && len(data.Peers) == 0 {
|
||||
t.Errorf("CRITICAL: State has %d peers but response has 0 peers!", peers.Len())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -63,9 +63,9 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||
|
||||
_, matchers := b.mapper.state.Filter()
|
||||
tailnode, err := tailNode(
|
||||
node.View(), b.capVer, b.mapper.state,
|
||||
node, b.capVer, b.mapper.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
b.mapper.cfg)
|
||||
if err != nil {
|
||||
@ -112,7 +112,7 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
sshPolicy, err := b.mapper.state.SSHPolicy(node.View())
|
||||
sshPolicy, err := b.mapper.state.SSHPolicy(node)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
@ -135,7 +135,7 @@ func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||
}
|
||||
|
||||
// WithUserProfiles adds user profiles for the requesting node and given peers
|
||||
func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder {
|
||||
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
@ -161,14 +161,14 @@ func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||
// new PacketFilters field and "base" allows us to send a full update when we
|
||||
// have to send an empty list, avoiding the hack in the else block.
|
||||
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
||||
"base": policy.ReduceFilterRules(node.View(), filter),
|
||||
"base": policy.ReduceFilterRules(node, filter),
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeers adds full peer list with policy filtering (for full map response)
|
||||
func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
|
||||
func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||
|
||||
tailPeers, err := b.buildTailPeers(peers)
|
||||
if err != nil {
|
||||
@ -181,7 +181,7 @@ func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
|
||||
}
|
||||
|
||||
// WithPeerChanges adds changed peers with policy filtering (for incremental updates)
|
||||
func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder {
|
||||
func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||
|
||||
tailPeers, err := b.buildTailPeers(peers)
|
||||
if err != nil {
|
||||
@ -193,8 +193,8 @@ func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuil
|
||||
return b
|
||||
}
|
||||
|
||||
// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting
|
||||
func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) {
|
||||
// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting
|
||||
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -206,15 +206,15 @@ func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node,
|
||||
// access each-other at all and remove them from the peers.
|
||||
var changedViews views.Slice[types.NodeView]
|
||||
if len(filter) > 0 {
|
||||
changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers)
|
||||
changedViews = policy.ReduceNodes(node, peers, matchers)
|
||||
} else {
|
||||
changedViews = peers.ViewSlice()
|
||||
changedViews = peers
|
||||
}
|
||||
|
||||
tailPeers, err := tailNodes(
|
||||
changedViews, b.capVer, b.mapper.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
b.mapper.cfg)
|
||||
if err != nil {
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -68,16 +69,18 @@ func newMapper(
|
||||
}
|
||||
|
||||
func generateUserProfiles(
|
||||
node *types.Node,
|
||||
peers types.Nodes,
|
||||
node types.NodeView,
|
||||
peers views.Slice[types.NodeView],
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[uint]*types.User)
|
||||
ids := make([]uint, 0, len(userMap))
|
||||
userMap[node.User.ID] = &node.User
|
||||
ids = append(ids, node.User.ID)
|
||||
for _, peer := range peers {
|
||||
userMap[peer.User.ID] = &peer.User
|
||||
ids = append(ids, peer.User.ID)
|
||||
user := node.User()
|
||||
userMap[user.ID] = &user
|
||||
ids = append(ids, user.ID)
|
||||
for _, peer := range peers.All() {
|
||||
peerUser := peer.User()
|
||||
userMap[peerUser.ID] = &peerUser
|
||||
ids = append(ids, peerUser.ID)
|
||||
}
|
||||
|
||||
slices.Sort(ids)
|
||||
@ -94,7 +97,7 @@ func generateUserProfiles(
|
||||
|
||||
func generateDNSConfig(
|
||||
cfg *types.Config,
|
||||
node *types.Node,
|
||||
node types.NodeView,
|
||||
) *tailcfg.DNSConfig {
|
||||
if cfg.TailcfgDNSConfig == nil {
|
||||
return nil
|
||||
@ -114,12 +117,12 @@ func generateDNSConfig(
|
||||
//
|
||||
// This will produce a resolver like:
|
||||
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
for _, resolver := range resolvers {
|
||||
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
||||
attrs := url.Values{
|
||||
"device_name": []string{node.Hostname},
|
||||
"device_model": []string{node.Hostinfo.OS},
|
||||
"device_name": []string{node.Hostname()},
|
||||
"device_model": []string{node.Hostinfo().OS()},
|
||||
}
|
||||
|
||||
if len(node.IPs()) > 0 {
|
||||
@ -137,7 +140,7 @@ func (m *mapper) fullMapResponse(
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
messages ...string,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers, err := m.listPeers(nodeID)
|
||||
peers, err := m.state.ListPeers(nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -182,7 +185,7 @@ func (m *mapper) peerChangeResponse(
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
changedNodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers, err := m.listPeers(nodeID, changedNodeID)
|
||||
peers, err := m.state.ListPeers(nodeID, changedNodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -256,25 +259,6 @@ func writeDebugMapResponse(
|
||||
}
|
||||
}
|
||||
|
||||
// listPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
// If no peer IDs are given, all peers are returned.
|
||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||
func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
peers, err := m.state.ListPeers(nodeID, peerIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(kradalby): Add back online via batcher. This was removed
|
||||
// to avoid a circular dependency between the mapper and the notification.
|
||||
for _, peer := range peers {
|
||||
online := m.batcher.IsConnected(peer.ID)
|
||||
peer.IsOnline = &online
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// routeFilterFunc is a function that takes a node ID and returns a list of
|
||||
// netip.Prefixes that are allowed for that node. It is used to filter routes
|
||||
// from the primary route manager to the node.
|
||||
|
@ -71,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
||||
&types.Config{
|
||||
TailcfgDNSConfig: &dnsConfigOrig,
|
||||
},
|
||||
nodeInShared1,
|
||||
nodeInShared1.View(),
|
||||
)
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
|
||||
|
@ -296,7 +296,7 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||
// getAndValidateNode retrieves the node from the database using the NodeKey
|
||||
// and validates that it matches the MachineKey from the Noise session.
|
||||
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
|
||||
node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
|
||||
nv, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
|
||||
@ -304,8 +304,6 @@ func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.
|
||||
return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil)
|
||||
}
|
||||
|
||||
nv := node.View()
|
||||
|
||||
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
|
||||
if ns.machineKey != nv.MachineKey() {
|
||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil)
|
||||
|
@ -298,7 +298,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
return
|
||||
}
|
||||
|
||||
user, policyChanged, err := a.createOrUpdateUserFromClaim(&claims)
|
||||
user, c, err := a.createOrUpdateUserFromClaim(&claims)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
@ -318,9 +318,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
a.h.Change(change.PolicyChange())
|
||||
}
|
||||
a.h.Change(c)
|
||||
|
||||
// TODO(kradalby): Is this comment right?
|
||||
// If the node exists, then the node should be reauthenticated,
|
||||
@ -483,14 +481,14 @@ func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.Regis
|
||||
|
||||
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
claims *types.OIDCClaims,
|
||||
) (*types.User, bool, error) {
|
||||
) (*types.User, change.ChangeSet, error) {
|
||||
var user *types.User
|
||||
var err error
|
||||
var newUser bool
|
||||
var policyChanged bool
|
||||
var c change.ChangeSet
|
||||
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
|
||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||
return nil, false, fmt.Errorf("creating or updating user: %w", err)
|
||||
return nil, change.EmptySet, fmt.Errorf("creating or updating user: %w", err)
|
||||
}
|
||||
|
||||
// if the user is still not found, create a new empty user.
|
||||
@ -504,21 +502,21 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
user.FromClaim(claims)
|
||||
|
||||
if newUser {
|
||||
user, policyChanged, err = a.h.state.CreateUser(*user)
|
||||
user, c, err = a.h.state.CreateUser(*user)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("creating user: %w", err)
|
||||
return nil, change.EmptySet, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
} else {
|
||||
_, policyChanged, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
|
||||
_, c, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
|
||||
*u = *user
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("updating user: %w", err)
|
||||
return nil, change.EmptySet, fmt.Errorf("updating user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return user, policyChanged, nil
|
||||
return user, c, nil
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) handleRegistration(
|
||||
|
@ -235,6 +235,7 @@ func (m *mapSession) serveLongPoll() {
|
||||
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
|
||||
}
|
||||
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
|
||||
m.resetKeepAlive()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
260
hscontrol/state/node_store.go
Normal file
260
hscontrol/state/node_store.go
Normal file
@ -0,0 +1,260 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
const (
|
||||
batchSize = 10
|
||||
batchTimeout = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
const (
|
||||
put = 1
|
||||
del = 2
|
||||
update = 3
|
||||
)
|
||||
|
||||
// NodeStore is a thread-safe store for nodes.
|
||||
// It is a copy-on-write structure, replacing the "snapshot"
|
||||
// when a change to the structure occurs. It is optimised for reads,
|
||||
// and while batches are not fast, they are grouped together
|
||||
// to do less of the expensive peer calculation if there are many
|
||||
// changes rapidly.
|
||||
//
|
||||
// Writes will block until committed, while reads are never
|
||||
// blocked.
|
||||
type NodeStore struct {
|
||||
data atomic.Pointer[Snapshot]
|
||||
|
||||
peersFunc PeersFunc
|
||||
writeQueue chan work
|
||||
// TODO: metrics
|
||||
}
|
||||
|
||||
func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc) *NodeStore {
|
||||
nodes := make(map[types.NodeID]types.Node, len(allNodes))
|
||||
for _, n := range allNodes {
|
||||
nodes[n.ID] = *n
|
||||
}
|
||||
snap := snapshotFromNodes(nodes, peersFunc)
|
||||
|
||||
store := &NodeStore{
|
||||
peersFunc: peersFunc,
|
||||
}
|
||||
store.data.Store(&snap)
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
type Snapshot struct {
|
||||
// nodesByID is the main source of truth for nodes.
|
||||
nodesByID map[types.NodeID]types.Node
|
||||
|
||||
// calculated from nodesByID
|
||||
nodesByNodeKey map[key.NodePublic]types.NodeView
|
||||
peersByNode map[types.NodeID][]types.NodeView
|
||||
nodesByUser map[types.UserID][]types.NodeView
|
||||
allNodes []types.NodeView
|
||||
}
|
||||
|
||||
type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
||||
|
||||
type work struct {
|
||||
op int
|
||||
nodeID types.NodeID
|
||||
node types.Node
|
||||
updateFn UpdateNodeFunc
|
||||
result chan struct{}
|
||||
immediate bool // For operations that need immediate processing
|
||||
}
|
||||
|
||||
// PutNode adds or updates a node in the store.
|
||||
// If the node already exists, it will be replaced.
|
||||
// If the node does not exist, it will be added.
|
||||
// This is a blocking operation that waits for the write to complete.
|
||||
func (s *NodeStore) PutNode(n types.Node) {
|
||||
work := work{
|
||||
op: put,
|
||||
nodeID: n.ID,
|
||||
node: n,
|
||||
result: make(chan struct{}),
|
||||
}
|
||||
|
||||
s.writeQueue <- work
|
||||
<-work.result
|
||||
}
|
||||
|
||||
// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it.
|
||||
type UpdateNodeFunc func(n *types.Node)
|
||||
|
||||
// UpdateNode applies a function to modify a specific node in the store.
|
||||
// This is a blocking operation that waits for the write to complete.
|
||||
func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) {
|
||||
work := work{
|
||||
op: update,
|
||||
nodeID: nodeID,
|
||||
updateFn: updateFn,
|
||||
result: make(chan struct{}),
|
||||
}
|
||||
|
||||
s.writeQueue <- work
|
||||
<-work.result
|
||||
}
|
||||
|
||||
// UpdateNodeImmediate applies a function to modify a specific node in the store
|
||||
// with immediate processing (bypassing normal batching delays).
|
||||
// Use this for time-sensitive updates like online status changes.
|
||||
func (s *NodeStore) UpdateNodeImmediate(nodeID types.NodeID, updateFn func(n *types.Node)) {
|
||||
work := work{
|
||||
op: update,
|
||||
nodeID: nodeID,
|
||||
updateFn: updateFn,
|
||||
result: make(chan struct{}),
|
||||
immediate: true,
|
||||
}
|
||||
|
||||
s.writeQueue <- work
|
||||
<-work.result
|
||||
}
|
||||
|
||||
// DeleteNode removes a node from the store by its ID.
|
||||
// This is a blocking operation that waits for the write to complete.
|
||||
func (s *NodeStore) DeleteNode(id types.NodeID) {
|
||||
work := work{
|
||||
op: del,
|
||||
nodeID: id,
|
||||
result: make(chan struct{}),
|
||||
}
|
||||
|
||||
s.writeQueue <- work
|
||||
<-work.result
|
||||
}
|
||||
|
||||
func (s *NodeStore) Start() {
|
||||
s.writeQueue = make(chan work)
|
||||
go s.processWrite()
|
||||
}
|
||||
|
||||
func (s *NodeStore) Stop() {
|
||||
close(s.writeQueue)
|
||||
}
|
||||
|
||||
func (s *NodeStore) processWrite() {
|
||||
c := time.NewTicker(batchTimeout)
|
||||
batch := make([]work, 0, batchSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case w, ok := <-s.writeQueue:
|
||||
if !ok {
|
||||
c.Stop()
|
||||
return
|
||||
}
|
||||
|
||||
// Handle immediate operations right away
|
||||
if w.immediate {
|
||||
s.applyBatch([]work{w})
|
||||
continue
|
||||
}
|
||||
|
||||
batch = append(batch, w)
|
||||
if len(batch) >= batchSize {
|
||||
s.applyBatch(batch)
|
||||
batch = batch[:0]
|
||||
c.Reset(batchTimeout)
|
||||
}
|
||||
|
||||
case <-c.C:
|
||||
if len(batch) != 0 {
|
||||
s.applyBatch(batch)
|
||||
batch = batch[:0]
|
||||
}
|
||||
c.Reset(batchTimeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeStore) applyBatch(batch []work) {
|
||||
nodes := make(map[types.NodeID]types.Node)
|
||||
maps.Copy(nodes, s.data.Load().nodesByID)
|
||||
|
||||
for _, w := range batch {
|
||||
switch w.op {
|
||||
case put:
|
||||
nodes[w.nodeID] = w.node
|
||||
case update:
|
||||
// Update the specific node identified by nodeID
|
||||
if n, exists := nodes[w.nodeID]; exists {
|
||||
w.updateFn(&n)
|
||||
nodes[w.nodeID] = n
|
||||
}
|
||||
case del:
|
||||
delete(nodes, w.nodeID)
|
||||
}
|
||||
}
|
||||
|
||||
newSnap := snapshotFromNodes(nodes, s.peersFunc)
|
||||
|
||||
s.data.Store(&newSnap)
|
||||
|
||||
for _, w := range batch {
|
||||
close(w.result)
|
||||
}
|
||||
}
|
||||
|
||||
func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot {
|
||||
allNodes := make([]types.NodeView, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
allNodes = append(allNodes, n.View())
|
||||
}
|
||||
|
||||
newSnap := Snapshot{
|
||||
nodesByID: nodes,
|
||||
allNodes: allNodes,
|
||||
nodesByNodeKey: make(map[key.NodePublic]types.NodeView),
|
||||
peersByNode: peersFunc(allNodes),
|
||||
nodesByUser: make(map[types.UserID][]types.NodeView),
|
||||
}
|
||||
|
||||
// Build nodesByUser and nodesByNodeKey maps
|
||||
for _, n := range nodes {
|
||||
nodeView := n.View()
|
||||
newSnap.nodesByUser[types.UserID(n.UserID)] = append(newSnap.nodesByUser[types.UserID(n.UserID)], nodeView)
|
||||
newSnap.nodesByNodeKey[n.NodeKey] = nodeView
|
||||
}
|
||||
|
||||
return newSnap
|
||||
}
|
||||
|
||||
// GetNode retrieves a node by its ID.
|
||||
func (s *NodeStore) GetNode(id types.NodeID) types.NodeView {
|
||||
n := s.data.Load().nodesByID[id]
|
||||
return n.View()
|
||||
}
|
||||
|
||||
// GetNodeByNodeKey retrieves a node by its NodeKey.
|
||||
func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) types.NodeView {
|
||||
return s.data.Load().nodesByNodeKey[nodeKey]
|
||||
}
|
||||
|
||||
// ListNodes returns a slice of all nodes in the store.
|
||||
func (s *NodeStore) ListNodes() views.Slice[types.NodeView] {
|
||||
return views.SliceOf(s.data.Load().allNodes)
|
||||
}
|
||||
|
||||
// ListPeers returns a slice of all peers for a given node ID.
|
||||
func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] {
|
||||
return views.SliceOf(s.data.Load().peersByNode[id])
|
||||
}
|
||||
|
||||
// ListNodesByUser returns a slice of all nodes for a given user ID.
|
||||
func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] {
|
||||
return views.SliceOf(s.data.Load().nodesByUser[uid])
|
||||
}
|
494
hscontrol/state/node_store_test.go
Normal file
494
hscontrol/state/node_store_test.go
Normal file
@ -0,0 +1,494 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func TestSnapshotFromNodes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func() (map[types.NodeID]types.Node, PeersFunc)
|
||||
validate func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot)
|
||||
}{
|
||||
{
|
||||
name: "empty nodes",
|
||||
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||
nodes := make(map[types.NodeID]types.Node)
|
||||
peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
|
||||
return make(map[types.NodeID][]types.NodeView)
|
||||
}
|
||||
return nodes, peersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
assert.Empty(t, snapshot.nodesByID)
|
||||
assert.Empty(t, snapshot.allNodes)
|
||||
assert.Empty(t, snapshot.peersByNode)
|
||||
assert.Empty(t, snapshot.nodesByUser)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single node",
|
||||
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||
nodes := map[types.NodeID]types.Node{
|
||||
1: createTestNode(1, 1, "user1", "node1"),
|
||||
}
|
||||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
assert.Len(t, snapshot.nodesByID, 1)
|
||||
assert.Len(t, snapshot.allNodes, 1)
|
||||
assert.Len(t, snapshot.peersByNode, 1)
|
||||
assert.Len(t, snapshot.nodesByUser, 1)
|
||||
|
||||
require.Contains(t, snapshot.nodesByID, types.NodeID(1))
|
||||
assert.Equal(t, nodes[1].ID, snapshot.nodesByID[1].ID)
|
||||
assert.Empty(t, snapshot.peersByNode[1]) // no other nodes, so no peers
|
||||
assert.Len(t, snapshot.nodesByUser[1], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.nodesByUser[1][0].ID())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple nodes same user",
|
||||
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||
nodes := map[types.NodeID]types.Node{
|
||||
1: createTestNode(1, 1, "user1", "node1"),
|
||||
2: createTestNode(2, 1, "user1", "node2"),
|
||||
}
|
||||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
assert.Len(t, snapshot.nodesByID, 2)
|
||||
assert.Len(t, snapshot.allNodes, 2)
|
||||
assert.Len(t, snapshot.peersByNode, 2)
|
||||
assert.Len(t, snapshot.nodesByUser, 1)
|
||||
|
||||
// Each node sees the other as peer (but not itself)
|
||||
assert.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
|
||||
assert.Len(t, snapshot.peersByNode[2], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
|
||||
assert.Len(t, snapshot.nodesByUser[1], 2)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple nodes different users",
|
||||
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||
nodes := map[types.NodeID]types.Node{
|
||||
1: createTestNode(1, 1, "user1", "node1"),
|
||||
2: createTestNode(2, 2, "user2", "node2"),
|
||||
3: createTestNode(3, 1, "user1", "node3"),
|
||||
}
|
||||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
assert.Len(t, snapshot.allNodes, 3)
|
||||
assert.Len(t, snapshot.peersByNode, 3)
|
||||
assert.Len(t, snapshot.nodesByUser, 2)
|
||||
|
||||
// Each node should have 2 peers (all others, but not itself)
|
||||
assert.Len(t, snapshot.peersByNode[1], 2)
|
||||
assert.Len(t, snapshot.peersByNode[2], 2)
|
||||
assert.Len(t, snapshot.peersByNode[3], 2)
|
||||
|
||||
// User groupings
|
||||
assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,3
|
||||
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 2
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "odd-even peers filtering",
|
||||
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
|
||||
nodes := map[types.NodeID]types.Node{
|
||||
1: createTestNode(1, 1, "user1", "node1"),
|
||||
2: createTestNode(2, 2, "user2", "node2"),
|
||||
3: createTestNode(3, 3, "user3", "node3"),
|
||||
4: createTestNode(4, 4, "user4", "node4"),
|
||||
}
|
||||
peersFunc := oddEvenPeersFunc
|
||||
return nodes, peersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
assert.Len(t, snapshot.nodesByID, 4)
|
||||
assert.Len(t, snapshot.allNodes, 4)
|
||||
assert.Len(t, snapshot.peersByNode, 4)
|
||||
assert.Len(t, snapshot.nodesByUser, 4)
|
||||
|
||||
// Odd nodes should only see other odd nodes as peers
|
||||
require.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
|
||||
|
||||
require.Len(t, snapshot.peersByNode[3], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
|
||||
|
||||
// Even nodes should only see other even nodes as peers
|
||||
require.Len(t, snapshot.peersByNode[2], 1)
|
||||
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
|
||||
|
||||
require.Len(t, snapshot.peersByNode[4], 1)
|
||||
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
nodes, peersFunc := tt.setupFunc()
|
||||
snapshot := snapshotFromNodes(nodes, peersFunc)
|
||||
tt.validate(t, nodes, snapshot)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func createTestNode(nodeID types.NodeID, userID uint, username, hostname string) types.Node {
|
||||
now := time.Now()
|
||||
machineKey := key.NewMachine()
|
||||
nodeKey := key.NewNode()
|
||||
discoKey := key.NewDisco()
|
||||
|
||||
ipv4 := netip.MustParseAddr("100.64.0.1")
|
||||
ipv6 := netip.MustParseAddr("fd7a:115c:a1e0::1")
|
||||
|
||||
return types.Node{
|
||||
ID: nodeID,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
DiscoKey: discoKey.Public(),
|
||||
Hostname: hostname,
|
||||
GivenName: hostname,
|
||||
UserID: userID,
|
||||
User: types.User{
|
||||
Name: username,
|
||||
DisplayName: username,
|
||||
},
|
||||
RegisterMethod: "test",
|
||||
IPv4: &ipv4,
|
||||
IPv6: &ipv6,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
// Peer functions
|
||||
|
||||
func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
|
||||
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
|
||||
for _, node := range nodes {
|
||||
var peers []types.NodeView
|
||||
for _, n := range nodes {
|
||||
if n.ID() != node.ID() {
|
||||
peers = append(peers, n)
|
||||
}
|
||||
}
|
||||
ret[node.ID()] = peers
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
|
||||
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
|
||||
for _, node := range nodes {
|
||||
var peers []types.NodeView
|
||||
nodeIsOdd := node.ID()%2 == 1
|
||||
|
||||
for _, n := range nodes {
|
||||
if n.ID() == node.ID() {
|
||||
continue
|
||||
}
|
||||
|
||||
peerIsOdd := n.ID()%2 == 1
|
||||
|
||||
// Only add peer if both are odd or both are even
|
||||
if nodeIsOdd == peerIsOdd {
|
||||
peers = append(peers, n)
|
||||
}
|
||||
}
|
||||
ret[node.ID()] = peers
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func TestNodeStoreOperations(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func(t *testing.T) *NodeStore
|
||||
steps []testStep
|
||||
}{
|
||||
{
|
||||
name: "create empty store and add single node",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
return NewNodeStore(nil, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "verify empty store",
|
||||
action: func(store *NodeStore) {
|
||||
snapshot := store.data.Load()
|
||||
assert.Empty(t, snapshot.nodesByID)
|
||||
assert.Empty(t, snapshot.allNodes)
|
||||
assert.Empty(t, snapshot.peersByNode)
|
||||
assert.Empty(t, snapshot.nodesByUser)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add first node",
|
||||
action: func(store *NodeStore) {
|
||||
node := createTestNode(1, 1, "user1", "node1")
|
||||
store.PutNode(node)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 1)
|
||||
assert.Len(t, snapshot.allNodes, 1)
|
||||
assert.Len(t, snapshot.peersByNode, 1)
|
||||
assert.Len(t, snapshot.nodesByUser, 1)
|
||||
|
||||
require.Contains(t, snapshot.nodesByID, types.NodeID(1))
|
||||
assert.Equal(t, node.ID, snapshot.nodesByID[1].ID)
|
||||
assert.Empty(t, snapshot.peersByNode[1]) // no peers yet
|
||||
assert.Len(t, snapshot.nodesByUser[1], 1)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create store with initial node and add more",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
initialNodes := types.Nodes{&node1}
|
||||
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "verify initial state",
|
||||
action: func(store *NodeStore) {
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 1)
|
||||
assert.Len(t, snapshot.allNodes, 1)
|
||||
assert.Len(t, snapshot.peersByNode, 1)
|
||||
assert.Len(t, snapshot.nodesByUser, 1)
|
||||
assert.Empty(t, snapshot.peersByNode[1])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add second node same user",
|
||||
action: func(store *NodeStore) {
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
store.PutNode(node2)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 2)
|
||||
assert.Len(t, snapshot.allNodes, 2)
|
||||
assert.Len(t, snapshot.peersByNode, 2)
|
||||
assert.Len(t, snapshot.nodesByUser, 1)
|
||||
|
||||
// Now both nodes should see each other as peers
|
||||
assert.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
|
||||
assert.Len(t, snapshot.peersByNode[2], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
|
||||
assert.Len(t, snapshot.nodesByUser[1], 2)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add third node different user",
|
||||
action: func(store *NodeStore) {
|
||||
node3 := createTestNode(3, 2, "user2", "node3")
|
||||
store.PutNode(node3)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
assert.Len(t, snapshot.allNodes, 3)
|
||||
assert.Len(t, snapshot.peersByNode, 3)
|
||||
assert.Len(t, snapshot.nodesByUser, 2)
|
||||
|
||||
// All nodes should see the other 2 as peers
|
||||
assert.Len(t, snapshot.peersByNode[1], 2)
|
||||
assert.Len(t, snapshot.peersByNode[2], 2)
|
||||
assert.Len(t, snapshot.peersByNode[3], 2)
|
||||
|
||||
// User groupings
|
||||
assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,2
|
||||
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 3
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "test node deletion",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
node3 := createTestNode(3, 2, "user2", "node3")
|
||||
initialNodes := types.Nodes{&node1, &node2, &node3}
|
||||
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "verify initial 3 nodes",
|
||||
action: func(store *NodeStore) {
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
assert.Len(t, snapshot.allNodes, 3)
|
||||
assert.Len(t, snapshot.peersByNode, 3)
|
||||
assert.Len(t, snapshot.nodesByUser, 2)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete middle node",
|
||||
action: func(store *NodeStore) {
|
||||
store.DeleteNode(2)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 2)
|
||||
assert.Len(t, snapshot.allNodes, 2)
|
||||
assert.Len(t, snapshot.peersByNode, 2)
|
||||
assert.Len(t, snapshot.nodesByUser, 2)
|
||||
|
||||
// Node 2 should be gone
|
||||
assert.NotContains(t, snapshot.nodesByID, types.NodeID(2))
|
||||
|
||||
// Remaining nodes should see each other as peers
|
||||
assert.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
|
||||
assert.Len(t, snapshot.peersByNode[3], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
|
||||
|
||||
// User groupings updated
|
||||
assert.Len(t, snapshot.nodesByUser[1], 1) // user1 now has only node 1
|
||||
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 still has node 3
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete all remaining nodes",
|
||||
action: func(store *NodeStore) {
|
||||
store.DeleteNode(1)
|
||||
store.DeleteNode(3)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Empty(t, snapshot.nodesByID)
|
||||
assert.Empty(t, snapshot.allNodes)
|
||||
assert.Empty(t, snapshot.peersByNode)
|
||||
assert.Empty(t, snapshot.nodesByUser)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "test node updates",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
node1 := createTestNode(1, 1, "user1", "node1")
|
||||
node2 := createTestNode(2, 1, "user1", "node2")
|
||||
initialNodes := types.Nodes{&node1, &node2}
|
||||
return NewNodeStore(initialNodes, allowAllPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "verify initial hostnames",
|
||||
action: func(store *NodeStore) {
|
||||
snapshot := store.data.Load()
|
||||
assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname)
|
||||
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update node hostname",
|
||||
action: func(store *NodeStore) {
|
||||
store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "updated-node1"
|
||||
n.GivenName = "updated-node1"
|
||||
})
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname)
|
||||
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].GivenName)
|
||||
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) // unchanged
|
||||
|
||||
// Peers should still work correctly
|
||||
assert.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Len(t, snapshot.peersByNode[2], 1)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "test with odd-even peers filtering",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
return NewNodeStore(nil, oddEvenPeersFunc)
|
||||
},
|
||||
steps: []testStep{
|
||||
{
|
||||
name: "add nodes with odd-even filtering",
|
||||
action: func(store *NodeStore) {
|
||||
// Add nodes in sequence
|
||||
store.PutNode(createTestNode(1, 1, "user1", "node1"))
|
||||
store.PutNode(createTestNode(2, 2, "user2", "node2"))
|
||||
store.PutNode(createTestNode(3, 3, "user3", "node3"))
|
||||
store.PutNode(createTestNode(4, 4, "user4", "node4"))
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 4)
|
||||
|
||||
// Verify odd-even peer relationships
|
||||
require.Len(t, snapshot.peersByNode[1], 1)
|
||||
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
|
||||
|
||||
require.Len(t, snapshot.peersByNode[2], 1)
|
||||
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
|
||||
|
||||
require.Len(t, snapshot.peersByNode[3], 1)
|
||||
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
|
||||
|
||||
require.Len(t, snapshot.peersByNode[4], 1)
|
||||
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete odd node and verify even nodes unaffected",
|
||||
action: func(store *NodeStore) {
|
||||
store.DeleteNode(1)
|
||||
|
||||
snapshot := store.data.Load()
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
|
||||
// Node 3 (odd) should now have no peers
|
||||
assert.Empty(t, snapshot.peersByNode[3])
|
||||
|
||||
// Even nodes should still see each other
|
||||
require.Len(t, snapshot.peersByNode[2], 1)
|
||||
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
|
||||
require.Len(t, snapshot.peersByNode[4], 1)
|
||||
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store := tt.setupFunc(t)
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
for _, step := range tt.steps {
|
||||
t.Run(step.name, func(t *testing.T) {
|
||||
step.action(store)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testStep struct {
|
||||
name string
|
||||
action func(store *NodeStore)
|
||||
}
|
@ -27,6 +27,7 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/types/views"
|
||||
zcache "zgo.at/zcache/v2"
|
||||
)
|
||||
|
||||
@ -49,6 +50,9 @@ type State struct {
|
||||
// cfg holds the current Headscale configuration
|
||||
cfg *types.Config
|
||||
|
||||
// nodeStore provides an in-memory cache for nodes.
|
||||
nodeStore *NodeStore
|
||||
|
||||
// subsystem keeping state
|
||||
// db provides persistent storage and database operations
|
||||
db *hsdb.HSDatabase
|
||||
@ -107,6 +111,12 @@ func NewState(cfg *types.Config) (*State, error) {
|
||||
return nil, fmt.Errorf("init policy manager: %w", err)
|
||||
}
|
||||
|
||||
nodeStore := NewNodeStore(nodes, func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
|
||||
_, matchers := polMan.Filter()
|
||||
return policy.BuildPeerMap(views.SliceOf(nodes), matchers)
|
||||
})
|
||||
nodeStore.Start()
|
||||
|
||||
return &State{
|
||||
cfg: cfg,
|
||||
|
||||
@ -117,11 +127,14 @@ func NewState(cfg *types.Config) (*State, error) {
|
||||
polMan: polMan,
|
||||
registrationCache: registrationCache,
|
||||
primaryRoutes: routes.New(),
|
||||
nodeStore: nodeStore,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close gracefully shuts down the State instance and releases all resources.
|
||||
func (s *State) Close() error {
|
||||
s.nodeStore.Stop()
|
||||
|
||||
if err := s.db.Close(); err != nil {
|
||||
return fmt.Errorf("closing database: %w", err)
|
||||
}
|
||||
@ -204,43 +217,42 @@ func (s *State) AutoApproveNodes() error {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user and updates the policy manager.
|
||||
// Returns the created user, whether policies changed, and any error.
|
||||
func (s *State) CreateUser(user types.User) (*types.User, bool, error) {
|
||||
// Returns the created user, change set, and any error.
|
||||
func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
|
||||
if err := s.db.DB.Save(&user).Error; err != nil {
|
||||
return nil, false, fmt.Errorf("creating user: %w", err)
|
||||
return nil, change.EmptySet, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
policyChanged, err := s.updatePolicyManagerUsers()
|
||||
c, err := s.updatePolicyManagerUsers()
|
||||
if err != nil {
|
||||
// Log the error but don't fail the user creation
|
||||
return &user, false, fmt.Errorf("failed to update policy manager after user creation: %w", err)
|
||||
return &user, change.EmptySet, fmt.Errorf("failed to update policy manager after user creation: %w", err)
|
||||
}
|
||||
|
||||
// Even if the policy manager doesn't detect a filter change, SSH policies
|
||||
// might now be resolvable when they weren't before. If there are existing
|
||||
// nodes, we should send a policy change to ensure they get updated SSH policies.
|
||||
if !policyChanged {
|
||||
if c.Empty() {
|
||||
nodes, err := s.ListNodes()
|
||||
if err == nil && len(nodes) > 0 {
|
||||
policyChanged = true
|
||||
if err == nil && nodes.Len() > 0 {
|
||||
c = change.PolicyChange()
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Str("user", user.Name).Bool("policyChanged", policyChanged).Msg("User created, policy manager updated")
|
||||
log.Info().Str("user", user.Name).Bool("policyChanged", !c.Empty()).Msg("User created, policy manager updated")
|
||||
|
||||
// TODO(kradalby): implement the user in-memory cache
|
||||
|
||||
return &user, policyChanged, nil
|
||||
return &user, c, nil
|
||||
}
|
||||
|
||||
// UpdateUser modifies an existing user using the provided update function within a transaction.
|
||||
// Returns the updated user, whether policies changed, and any error.
|
||||
func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, bool, error) {
|
||||
// Returns the updated user, change set, and any error.
|
||||
func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, change.ChangeSet, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@ -261,18 +273,18 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error
|
||||
return user, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, change.EmptySet, err
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
policyChanged, err := s.updatePolicyManagerUsers()
|
||||
c, err := s.updatePolicyManagerUsers()
|
||||
if err != nil {
|
||||
return user, false, fmt.Errorf("failed to update policy manager after user update: %w", err)
|
||||
return user, change.EmptySet, fmt.Errorf("failed to update policy manager after user update: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): implement the user in-memory cache
|
||||
|
||||
return user, policyChanged, nil
|
||||
return user, c, nil
|
||||
}
|
||||
|
||||
// DeleteUser permanently removes a user and all associated data (nodes, API keys, etc).
|
||||
@ -282,7 +294,7 @@ func (s *State) DeleteUser(userID types.UserID) error {
|
||||
}
|
||||
|
||||
// RenameUser changes a user's name. The new name must be unique.
|
||||
func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, bool, error) {
|
||||
func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, change.ChangeSet, error) {
|
||||
return s.UpdateUser(userID, func(user *types.User) error {
|
||||
user.Name = newName
|
||||
return nil
|
||||
@ -315,28 +327,31 @@ func (s *State) ListAllUsers() ([]types.User, error) {
|
||||
}
|
||||
|
||||
// CreateNode creates a new node and updates the policy manager.
|
||||
// Returns the created node, whether policies changed, and any error.
|
||||
func (s *State) CreateNode(node *types.Node) (*types.Node, bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
// Returns the created node, change set, and any error.
|
||||
func (s *State) CreateNode(node *types.Node) (types.NodeView, change.ChangeSet, error) {
|
||||
s.nodeStore.PutNode(*node)
|
||||
|
||||
if err := s.db.DB.Save(node).Error; err != nil {
|
||||
return nil, false, fmt.Errorf("creating node: %w", err)
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("creating node: %w", err)
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
policyChanged, err := s.updatePolicyManagerNodes()
|
||||
c, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return node, false, fmt.Errorf("failed to update policy manager after node creation: %w", err)
|
||||
return node.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node creation: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): implement the node in-memory cache
|
||||
|
||||
return node, policyChanged, nil
|
||||
if c.Empty() {
|
||||
c = change.NodeAdded(node.ID)
|
||||
}
|
||||
|
||||
return node.View(), c, nil
|
||||
}
|
||||
|
||||
// updateNodeTx performs a database transaction to update a node and refresh the policy manager.
|
||||
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, change.ChangeSet, error) {
|
||||
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (types.NodeView, change.ChangeSet, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@ -357,70 +372,72 @@ func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) err
|
||||
return node, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, err
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
policyChanged, err := s.updatePolicyManagerNodes()
|
||||
c, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err)
|
||||
return node.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): implement the node in-memory cache
|
||||
|
||||
var c change.ChangeSet
|
||||
if policyChanged {
|
||||
c = change.PolicyChange()
|
||||
} else {
|
||||
if c.Empty() {
|
||||
// Basic node change without specific details since this is a generic update
|
||||
c = change.NodeAdded(node.ID)
|
||||
}
|
||||
|
||||
return node.View(), c, nil
|
||||
}
|
||||
|
||||
// SaveNode persists an existing node to the database and updates the policy manager.
|
||||
func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
nodePtr := node.AsStruct()
|
||||
s.nodeStore.PutNode(*nodePtr)
|
||||
|
||||
if err := s.db.DB.Save(nodePtr).Error; err != nil {
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("saving node: %w", err)
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
c, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return nodePtr.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): implement the node in-memory cache
|
||||
|
||||
if c.Empty() {
|
||||
c = change.NodeAdded(node.ID())
|
||||
}
|
||||
|
||||
return node, c, nil
|
||||
}
|
||||
|
||||
// SaveNode persists an existing node to the database and updates the policy manager.
|
||||
func (s *State) SaveNode(node *types.Node) (*types.Node, change.ChangeSet, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if err := s.db.DB.Save(node).Error; err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("saving node: %w", err)
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
policyChanged, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): implement the node in-memory cache
|
||||
|
||||
if policyChanged {
|
||||
return node, change.PolicyChange(), nil
|
||||
}
|
||||
|
||||
return node, change.EmptySet, nil
|
||||
}
|
||||
|
||||
// DeleteNode permanently removes a node and cleans up associated resources.
|
||||
// Returns whether policies changed and any error. This operation is irreversible.
|
||||
func (s *State) DeleteNode(node *types.Node) (change.ChangeSet, error) {
|
||||
err := s.db.DeleteNode(node)
|
||||
func (s *State) DeleteNode(node types.NodeView) (change.ChangeSet, error) {
|
||||
s.nodeStore.DeleteNode(node.ID())
|
||||
|
||||
err := s.db.DeleteNode(node.AsStruct())
|
||||
if err != nil {
|
||||
return change.EmptySet, err
|
||||
}
|
||||
|
||||
c := change.NodeRemoved(node.ID)
|
||||
c := change.NodeRemoved(node.ID())
|
||||
|
||||
// Check if policy manager needs updating after node deletion
|
||||
policyChanged, err := s.updatePolicyManagerNodes()
|
||||
policyChange, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return change.EmptySet, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
|
||||
}
|
||||
|
||||
if policyChanged {
|
||||
c = change.PolicyChange()
|
||||
if !policyChange.Empty() {
|
||||
c = policyChange
|
||||
}
|
||||
|
||||
return c, nil
|
||||
@ -434,12 +451,26 @@ func (s *State) Connect(node *types.Node) change.ChangeSet {
|
||||
c = change.NodeAdded(node.ID)
|
||||
}
|
||||
|
||||
// Update nodestore with online status - node is connecting so it's online
|
||||
// Use immediate update to ensure online status changes are not delayed by batching
|
||||
s.nodeStore.UpdateNodeImmediate(node.ID, func(n *types.Node) {
|
||||
// Set the online status in the node's ephemeral field
|
||||
n.IsOnline = ptr.To(true)
|
||||
})
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (s *State) Disconnect(node *types.Node) (change.ChangeSet, error) {
|
||||
c := change.NodeOffline(node.ID)
|
||||
|
||||
// Update nodestore with offline status
|
||||
// Use immediate update to ensure online status changes are not delayed by batching
|
||||
s.nodeStore.UpdateNodeImmediate(node.ID, func(n *types.Node) {
|
||||
// Set the online status to false in the node's ephemeral field
|
||||
n.IsOnline = ptr.To(false)
|
||||
})
|
||||
|
||||
_, _, err := s.SetLastSeen(node.ID, time.Now())
|
||||
if err != nil {
|
||||
return c, fmt.Errorf("disconnecting node: %w", err)
|
||||
@ -454,70 +485,95 @@ func (s *State) Disconnect(node *types.Node) (change.ChangeSet, error) {
|
||||
}
|
||||
|
||||
// GetNodeByID retrieves a node by ID.
|
||||
func (s *State) GetNodeByID(nodeID types.NodeID) (*types.Node, error) {
|
||||
return s.db.GetNodeByID(nodeID)
|
||||
}
|
||||
|
||||
// GetNodeViewByID retrieves a node view by ID.
|
||||
func (s *State) GetNodeViewByID(nodeID types.NodeID) (types.NodeView, error) {
|
||||
node, err := s.db.GetNodeByID(nodeID)
|
||||
if err != nil {
|
||||
return types.NodeView{}, err
|
||||
}
|
||||
|
||||
return node.View(), nil
|
||||
func (s *State) GetNodeByID(nodeID types.NodeID) (types.NodeView, error) {
|
||||
return s.nodeStore.GetNode(nodeID), nil
|
||||
}
|
||||
|
||||
// GetNodeByNodeKey retrieves a node by its Tailscale public key.
|
||||
func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) {
|
||||
return s.db.GetNodeByNodeKey(nodeKey)
|
||||
}
|
||||
|
||||
// GetNodeViewByNodeKey retrieves a node view by its Tailscale public key.
|
||||
func (s *State) GetNodeViewByNodeKey(nodeKey key.NodePublic) (types.NodeView, error) {
|
||||
node, err := s.db.GetNodeByNodeKey(nodeKey)
|
||||
if err != nil {
|
||||
return types.NodeView{}, err
|
||||
}
|
||||
|
||||
return node.View(), nil
|
||||
func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, error) {
|
||||
return s.nodeStore.GetNodeByNodeKey(nodeKey), nil
|
||||
}
|
||||
|
||||
// ListNodes retrieves specific nodes by ID, or all nodes if no IDs provided.
|
||||
func (s *State) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
func (s *State) ListNodes(nodeIDs ...types.NodeID) (views.Slice[types.NodeView], error) {
|
||||
if len(nodeIDs) == 0 {
|
||||
return s.db.ListNodes()
|
||||
return s.nodeStore.ListNodes(), nil
|
||||
}
|
||||
|
||||
return s.db.ListNodes(nodeIDs...)
|
||||
// Filter nodes by the requested IDs
|
||||
allNodes := s.nodeStore.ListNodes()
|
||||
nodeIDSet := make(map[types.NodeID]struct{}, len(nodeIDs))
|
||||
for _, id := range nodeIDs {
|
||||
nodeIDSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
var filteredNodes []types.NodeView
|
||||
for _, node := range allNodes.All() {
|
||||
if _, exists := nodeIDSet[node.ID()]; exists {
|
||||
filteredNodes = append(filteredNodes, node)
|
||||
}
|
||||
}
|
||||
|
||||
return views.SliceOf(filteredNodes), nil
|
||||
}
|
||||
|
||||
// ListNodesByUser retrieves all nodes belonging to a specific user.
|
||||
func (s *State) ListNodesByUser(userID types.UserID) (types.Nodes, error) {
|
||||
return hsdb.Read(s.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
return hsdb.ListNodesByUser(rx, userID)
|
||||
})
|
||||
func (s *State) ListNodesByUser(userID types.UserID) (views.Slice[types.NodeView], error) {
|
||||
return s.nodeStore.ListNodesByUser(userID), nil
|
||||
}
|
||||
|
||||
// ListPeers retrieves nodes that can communicate with the specified node based on policy.
|
||||
func (s *State) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
return s.db.ListPeers(nodeID, peerIDs...)
|
||||
func (s *State) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (views.Slice[types.NodeView], error) {
|
||||
if len(peerIDs) == 0 {
|
||||
return s.nodeStore.ListPeers(nodeID), nil
|
||||
}
|
||||
|
||||
// For specific peerIDs, filter from all nodes
|
||||
allNodes := s.nodeStore.ListNodes()
|
||||
nodeIDSet := make(map[types.NodeID]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
nodeIDSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
var filteredNodes []types.NodeView
|
||||
for _, node := range allNodes.All() {
|
||||
if _, exists := nodeIDSet[node.ID()]; exists {
|
||||
filteredNodes = append(filteredNodes, node)
|
||||
}
|
||||
}
|
||||
|
||||
return views.SliceOf(filteredNodes), nil
|
||||
}
|
||||
|
||||
// ListEphemeralNodes retrieves all ephemeral (temporary) nodes in the system.
|
||||
func (s *State) ListEphemeralNodes() (types.Nodes, error) {
|
||||
return s.db.ListEphemeralNodes()
|
||||
func (s *State) ListEphemeralNodes() (views.Slice[types.NodeView], error) {
|
||||
allNodes := s.nodeStore.ListNodes()
|
||||
var ephemeralNodes []types.NodeView
|
||||
|
||||
for _, node := range allNodes.All() {
|
||||
// Check if node is ephemeral by checking its AuthKey
|
||||
if node.AuthKey().Valid() && node.AuthKey().Ephemeral() {
|
||||
ephemeralNodes = append(ephemeralNodes, node)
|
||||
}
|
||||
}
|
||||
|
||||
return views.SliceOf(ephemeralNodes), nil
|
||||
}
|
||||
|
||||
// SetNodeExpiry updates the expiration time for a node.
|
||||
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, change.ChangeSet, error) {
|
||||
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.NodeSetExpiry(tx, nodeID, expiry)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("setting node expiry: %w", err)
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node expiry: %w", err)
|
||||
}
|
||||
|
||||
// Update nodestore with the same change
|
||||
s.nodeStore.UpdateNode(nodeID, func(n *types.Node) {
|
||||
n.Expiry = &expiry
|
||||
})
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.KeyExpiry(nodeID)
|
||||
}
|
||||
@ -526,14 +582,19 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Nod
|
||||
}
|
||||
|
||||
// SetNodeTags assigns tags to a node for use in access control policies.
|
||||
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, change.ChangeSet, error) {
|
||||
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.SetTags(tx, nodeID, tags)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("setting node tags: %w", err)
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("setting node tags: %w", err)
|
||||
}
|
||||
|
||||
// Update nodestore with the same change
|
||||
s.nodeStore.UpdateNode(nodeID, func(n *types.Node) {
|
||||
n.ForcedTags = tags
|
||||
})
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.NodeAdded(nodeID)
|
||||
}
|
||||
@ -542,16 +603,21 @@ func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, ch
|
||||
}
|
||||
|
||||
// SetApprovedRoutes sets the network routes that a node is approved to advertise.
|
||||
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, change.ChangeSet, error) {
|
||||
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (types.NodeView, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.SetApprovedRoutes(tx, nodeID, routes)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("setting approved routes: %w", err)
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("setting approved routes: %w", err)
|
||||
}
|
||||
|
||||
// Update nodestore with the same change
|
||||
s.nodeStore.UpdateNode(nodeID, func(n *types.Node) {
|
||||
n.ApprovedRoutes = routes
|
||||
})
|
||||
|
||||
// Update primary routes after changing approved routes
|
||||
routeChange := s.primaryRoutes.SetRoutes(nodeID, n.SubnetRoutes()...)
|
||||
routeChange := s.primaryRoutes.SetRoutes(nodeID, n.AsStruct().SubnetRoutes()...)
|
||||
|
||||
if routeChange || !c.IsFull() {
|
||||
c = change.PolicyChange()
|
||||
@ -561,14 +627,19 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*
|
||||
}
|
||||
|
||||
// RenameNode changes the display name of a node.
|
||||
func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, change.ChangeSet, error) {
|
||||
func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.RenameNode(tx, nodeID, newName)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("renaming node: %w", err)
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err)
|
||||
}
|
||||
|
||||
// Update nodestore with the same change
|
||||
s.nodeStore.UpdateNode(nodeID, func(n *types.Node) {
|
||||
n.GivenName = newName
|
||||
})
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.NodeAdded(nodeID)
|
||||
}
|
||||
@ -577,21 +648,19 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, ch
|
||||
}
|
||||
|
||||
// SetLastSeen updates when a node was last seen, used for connectivity monitoring.
|
||||
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, change.ChangeSet, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (types.NodeView, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.SetLastSeen(tx, nodeID, lastSeen)
|
||||
})
|
||||
}
|
||||
|
||||
// AssignNodeToUser transfers a node to a different user.
|
||||
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.AssignNodeToUser(tx, nodeID, userID)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("assigning node to user: %w", err)
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("setting last seen: %w", err)
|
||||
}
|
||||
|
||||
// Update nodestore with the same change
|
||||
s.nodeStore.UpdateNode(nodeID, func(n *types.Node) {
|
||||
n.LastSeen = &lastSeen
|
||||
})
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.NodeAdded(nodeID)
|
||||
}
|
||||
@ -599,6 +668,32 @@ func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*typ
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// AssignNodeToUser transfers a node to a different user.
|
||||
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (types.NodeView, change.ChangeSet, error) {
|
||||
node, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.AssignNodeToUser(tx, nodeID, userID)
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
}
|
||||
|
||||
// Update nodestore with the same change
|
||||
// Get the updated user information from the database
|
||||
user, err := s.GetUserByID(userID)
|
||||
if err == nil {
|
||||
s.nodeStore.UpdateNode(nodeID, func(n *types.Node) {
|
||||
n.UserID = uint(userID)
|
||||
n.User = *user
|
||||
})
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.NodeAdded(nodeID)
|
||||
}
|
||||
|
||||
return node, c, nil
|
||||
}
|
||||
|
||||
// BackfillNodeIPs assigns IP addresses to nodes that don't have them.
|
||||
func (s *State) BackfillNodeIPs() ([]string, error) {
|
||||
return s.db.BackfillNodeIPs(s.ipAlloc)
|
||||
@ -631,8 +726,16 @@ func (s *State) SetPolicy(pol []byte) (bool, error) {
|
||||
}
|
||||
|
||||
// AutoApproveRoutes checks if a node's routes should be auto-approved.
|
||||
func (s *State) AutoApproveRoutes(node *types.Node) bool {
|
||||
return policy.AutoApproveRoutes(s.polMan, node)
|
||||
func (s *State) AutoApproveRoutes(node types.NodeView) bool {
|
||||
nodePtr := node.AsStruct()
|
||||
changed := policy.AutoApproveRoutes(s.polMan, nodePtr)
|
||||
if changed {
|
||||
s.nodeStore.PutNode(*nodePtr)
|
||||
// Update primaryRoutes manager with the newly approved routes
|
||||
// This is essential for actual packet forwarding to work
|
||||
s.primaryRoutes.SetRoutes(nodePtr.ID, nodePtr.SubnetRoutes()...)
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// PolicyDebugString returns a debug representation of the current policy.
|
||||
@ -742,34 +845,42 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
userID types.UserID,
|
||||
expiry *time.Time,
|
||||
registrationMethod string,
|
||||
) (*types.Node, change.ChangeSet, error) {
|
||||
) (types.NodeView, change.ChangeSet, error) {
|
||||
ipv4, ipv6, err := s.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, err
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
}
|
||||
|
||||
return s.db.HandleNodeFromAuthPath(
|
||||
node, nodeChange, err := s.db.HandleNodeFromAuthPath(
|
||||
registrationID,
|
||||
userID,
|
||||
expiry,
|
||||
util.RegisterMethodOIDC,
|
||||
ipv4, ipv6,
|
||||
)
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
}
|
||||
|
||||
// Update nodestore with the newly registered/updated node
|
||||
s.nodeStore.PutNode(*node)
|
||||
|
||||
return node.View(), nodeChange, nil
|
||||
}
|
||||
|
||||
// HandleNodeFromPreAuthKey handles node registration using a pre-authentication key.
|
||||
func (s *State) HandleNodeFromPreAuthKey(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*types.Node, change.ChangeSet, bool, error) {
|
||||
) (types.NodeView, change.ChangeSet, error) {
|
||||
pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey)
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, false, err
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
}
|
||||
|
||||
err = pak.Validate()
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, false, err
|
||||
return types.NodeView{}, change.EmptySet, err
|
||||
}
|
||||
|
||||
nodeToRegister := types.Node{
|
||||
@ -796,7 +907,7 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
|
||||
ipv4, ipv6, err := s.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, false, fmt.Errorf("allocating IPs: %w", err)
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("allocating IPs: %w", err)
|
||||
}
|
||||
|
||||
node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
@ -818,38 +929,44 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
return node, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, false, fmt.Errorf("writing node to database: %w", err)
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err)
|
||||
}
|
||||
|
||||
// Check if this is a logout request for an ephemeral node
|
||||
if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral {
|
||||
// This is a logout request for an ephemeral node, delete it immediately
|
||||
c, err := s.DeleteNode(node)
|
||||
c, err := s.DeleteNode(node.View())
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, false, fmt.Errorf("deleting ephemeral node during logout: %w", err)
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("deleting ephemeral node during logout: %w", err)
|
||||
}
|
||||
return nil, c, false, nil
|
||||
return types.NodeView{}, c, nil
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
// This is necessary because we just created a new node.
|
||||
// We need to ensure that the policy manager is aware of this new node.
|
||||
// Also update users to ensure all users are known when evaluating policies.
|
||||
usersChanged, err := s.updatePolicyManagerUsers()
|
||||
usersChange, err := s.updatePolicyManagerUsers()
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager users after node registration: %w", err)
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update policy manager users after node registration: %w", err)
|
||||
}
|
||||
|
||||
nodesChanged, err := s.updatePolicyManagerNodes()
|
||||
nodesChange, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err)
|
||||
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err)
|
||||
}
|
||||
|
||||
policyChanged := usersChanged || nodesChanged
|
||||
var c change.ChangeSet
|
||||
if !usersChange.Empty() || !nodesChange.Empty() {
|
||||
c = change.PolicyChange()
|
||||
} else {
|
||||
c = change.NodeAdded(node.ID)
|
||||
}
|
||||
|
||||
c := change.NodeAdded(node.ID)
|
||||
// Update nodestore with the newly registered node
|
||||
s.nodeStore.PutNode(*node)
|
||||
|
||||
return node, c, policyChanged, nil
|
||||
return node.View(), c, nil
|
||||
}
|
||||
|
||||
// AllocateNextIPs allocates the next available IPv4 and IPv6 addresses.
|
||||
@ -863,22 +980,25 @@ func (s *State) AllocateNextIPs() (*netip.Addr, *netip.Addr, error) {
|
||||
// have the list already available so it could go much quicker. Alternatively
|
||||
// the policy manager could have a remove or add list for users.
|
||||
// updatePolicyManagerUsers refreshes the policy manager with current user data.
|
||||
func (s *State) updatePolicyManagerUsers() (bool, error) {
|
||||
func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) {
|
||||
users, err := s.ListAllUsers()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("listing users for policy update: %w", err)
|
||||
return change.EmptySet, fmt.Errorf("listing users for policy update: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users")
|
||||
|
||||
changed, err := s.polMan.SetUsers(users)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("updating policy manager users: %w", err)
|
||||
return change.EmptySet, fmt.Errorf("updating policy manager users: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Bool("changed", changed).Msg("Policy manager users updated")
|
||||
|
||||
return changed, nil
|
||||
if changed {
|
||||
return change.PolicyChange(), nil
|
||||
}
|
||||
return change.EmptySet, nil
|
||||
}
|
||||
|
||||
// updatePolicyManagerNodes updates the policy manager with current nodes.
|
||||
@ -887,18 +1007,21 @@ func (s *State) updatePolicyManagerUsers() (bool, error) {
|
||||
// have the list already available so it could go much quicker. Alternatively
|
||||
// the policy manager could have a remove or add list for nodes.
|
||||
// updatePolicyManagerNodes refreshes the policy manager with current node data.
|
||||
func (s *State) updatePolicyManagerNodes() (bool, error) {
|
||||
func (s *State) updatePolicyManagerNodes() (change.ChangeSet, error) {
|
||||
nodes, err := s.ListNodes()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("listing nodes for policy update: %w", err)
|
||||
return change.EmptySet, fmt.Errorf("listing nodes for policy update: %w", err)
|
||||
}
|
||||
|
||||
changed, err := s.polMan.SetNodes(nodes.ViewSlice())
|
||||
changed, err := s.polMan.SetNodes(nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("updating policy manager nodes: %w", err)
|
||||
return change.EmptySet, fmt.Errorf("updating policy manager nodes: %w", err)
|
||||
}
|
||||
|
||||
return changed, nil
|
||||
if changed {
|
||||
return change.PolicyChange(), nil
|
||||
}
|
||||
return change.EmptySet, nil
|
||||
}
|
||||
|
||||
// PingDB checks if the database connection is healthy.
|
||||
@ -923,6 +1046,9 @@ func (s *State) autoApproveNodes() error {
|
||||
// TODO(kradalby): This change should probably be sent to the rest of the system.
|
||||
changed := policy.AutoApproveRoutes(s.polMan, node)
|
||||
if changed {
|
||||
// Update nodestore first if available
|
||||
s.nodeStore.PutNode(*node)
|
||||
|
||||
err = tx.Save(node).Error
|
||||
if err != nil {
|
||||
return err
|
||||
@ -985,7 +1111,7 @@ func (s *State) UpdateNodeFromMapRequest(node *types.Node, req tailcfg.MapReques
|
||||
if routesChanged {
|
||||
// Auto approve any routes that have been defined in policy as
|
||||
// auto approved. Check if this actually changed the node.
|
||||
_ = s.AutoApproveRoutes(node)
|
||||
_ = s.AutoApproveRoutes(node.View())
|
||||
|
||||
// Update the routes of the given node in the route manager to
|
||||
// see if an update needs to be sent.
|
||||
@ -998,7 +1124,7 @@ func (s *State) UpdateNodeFromMapRequest(node *types.Node, req tailcfg.MapReques
|
||||
// the hostname change.
|
||||
node.ApplyHostnameFromHostInfo(req.Hostinfo)
|
||||
|
||||
_, policyChange, err := s.SaveNode(node)
|
||||
_, policyChange, err := s.SaveNode(node.View())
|
||||
if err != nil {
|
||||
return change.EmptySet, err
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go4.org/netipx"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"tailscale.com/net/tsaddr"
|
||||
@ -511,11 +512,25 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) {
|
||||
}
|
||||
|
||||
if node.Hostname != hostInfo.Hostname {
|
||||
log.Trace().
|
||||
Str("node_id", node.ID.String()).
|
||||
Str("old_hostname", node.Hostname).
|
||||
Str("new_hostname", hostInfo.Hostname).
|
||||
Str("old_given_name", node.GivenName).
|
||||
Bool("given_name_changed", node.GivenNameHasBeenChanged()).
|
||||
Msg("Updating hostname from hostinfo")
|
||||
|
||||
if node.GivenNameHasBeenChanged() {
|
||||
node.GivenName = util.ConvertWithFQDNRules(hostInfo.Hostname)
|
||||
}
|
||||
|
||||
node.Hostname = hostInfo.Hostname
|
||||
|
||||
log.Trace().
|
||||
Str("node_id", node.ID.String()).
|
||||
Str("new_hostname", node.Hostname).
|
||||
Str("new_given_name", node.GivenName).
|
||||
Msg("Hostname updated")
|
||||
}
|
||||
}
|
||||
|
||||
@ -759,6 +774,22 @@ func (v NodeView) ExitRoutes() []netip.Prefix {
|
||||
return v.ж.ExitRoutes()
|
||||
}
|
||||
|
||||
// RequestTags returns the ACL tags that the node is requesting.
|
||||
func (v NodeView) RequestTags() []string {
|
||||
if !v.Valid() || !v.Hostinfo().Valid() {
|
||||
return []string{}
|
||||
}
|
||||
return v.Hostinfo().RequestTags().AsSlice()
|
||||
}
|
||||
|
||||
// Proto converts the NodeView to a protobuf representation.
|
||||
func (v NodeView) Proto() *v1.Node {
|
||||
if !v.Valid() {
|
||||
return nil
|
||||
}
|
||||
return v.ж.Proto()
|
||||
}
|
||||
|
||||
// HasIP reports if a node has a given IP address.
|
||||
func (v NodeView) HasIP(i netip.Addr) bool {
|
||||
if !v.Valid() {
|
||||
|
@ -112,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
|
||||
}
|
||||
|
||||
// Parse each hop line
|
||||
hopRegex := regexp.MustCompile("^\\s*(\\d+)\\s+(?:([^ ]+) \\(([^)]+)\\)|(\\*))(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?(?:\\s+(\\d+\\.\\d+) ms)?")
|
||||
hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`)
|
||||
|
||||
for i := 1; i < len(lines); i++ {
|
||||
matches := hopRegex.FindStringSubmatch(lines[i])
|
||||
|
@ -176,6 +176,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
||||
assert.NoError(ct, err)
|
||||
assert.Equal(ct, "NeedsLogin", status.BackendState)
|
||||
}
|
||||
assertTailscaleNodesLogout(t, allClients)
|
||||
}, shortAccessTTL+10*time.Second, 5*time.Second)
|
||||
}
|
||||
|
||||
|
@ -535,6 +535,8 @@ func TestUpdateHostnameFromClient(t *testing.T) {
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
// Wait for nodestore batch processing to complete
|
||||
// NodeStore batching timeout is 500ms, so we wait up to 1 second
|
||||
var nodes []*v1.Node
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
err := executeAndUnmarshal(
|
||||
@ -630,27 +632,34 @@ func TestUpdateHostnameFromClient(t *testing.T) {
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"node",
|
||||
"list",
|
||||
"--output",
|
||||
"json",
|
||||
},
|
||||
&nodes,
|
||||
)
|
||||
// Wait for nodestore batch processing to complete
|
||||
// NodeStore batching timeout is 500ms, so we wait up to 1 second
|
||||
assert.Eventually(t, func() bool {
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"node",
|
||||
"list",
|
||||
"--output",
|
||||
"json",
|
||||
},
|
||||
&nodes,
|
||||
)
|
||||
|
||||
assertNoErr(t, err)
|
||||
assert.Len(t, nodes, 3)
|
||||
if err != nil || len(nodes) != 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
hostname := hostnames[strconv.FormatUint(node.GetId(), 10)]
|
||||
givenName := fmt.Sprintf("%d-givenname", node.GetId())
|
||||
assert.Equal(t, hostname+"NEW", node.GetName())
|
||||
assert.Equal(t, givenName, node.GetGivenName())
|
||||
}
|
||||
for _, node := range nodes {
|
||||
hostname := hostnames[strconv.FormatUint(node.GetId(), 10)]
|
||||
givenName := fmt.Sprintf("%d-givenname", node.GetId())
|
||||
if node.GetName() != hostname+"NEW" || node.GetGivenName() != givenName {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}, time.Second, 50*time.Millisecond, "hostname updates should be reflected in node list with NEW suffix")
|
||||
}
|
||||
|
||||
func TestExpireNode(t *testing.T) {
|
||||
|
@ -122,22 +122,22 @@ func TestEnablingRoutes(t *testing.T) {
|
||||
assert.Len(t, node.GetSubnetRoutes(), 1)
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for route state changes to propagate to clients
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// Verify that the clients can see the new routes
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
// Verify that the clients can see the new routes
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
require.NoError(t, err)
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
assert.NotNil(t, peerStatus.PrimaryRoutes)
|
||||
|
||||
assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 3)
|
||||
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])})
|
||||
assert.NotNil(c, peerStatus.PrimaryRoutes)
|
||||
assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 3)
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])})
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 10*time.Second, 500*time.Millisecond, "clients should see new routes")
|
||||
|
||||
_, err = headscale.ApproveRoutes(
|
||||
1,
|
||||
@ -151,26 +151,27 @@ func TestEnablingRoutes(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for route state changes to propagate to nodes
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.GetId() == 1 {
|
||||
assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.0.0/24
|
||||
assert.Len(t, node.GetApprovedRoutes(), 1) // 10.0.1.0/24
|
||||
assert.Empty(t, node.GetSubnetRoutes())
|
||||
} else if node.GetId() == 2 {
|
||||
assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.1.0/24
|
||||
assert.Empty(t, node.GetApprovedRoutes())
|
||||
assert.Empty(t, node.GetSubnetRoutes())
|
||||
} else {
|
||||
assert.Len(t, node.GetAvailableRoutes(), 1) // 10.0.2.0/24
|
||||
assert.Len(t, node.GetApprovedRoutes(), 1) // 10.0.2.0/24
|
||||
assert.Len(t, node.GetSubnetRoutes(), 1) // 10.0.2.0/24
|
||||
for _, node := range nodes {
|
||||
if node.GetId() == 1 {
|
||||
assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.0.0/24
|
||||
assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.1.0/24
|
||||
assert.Empty(c, node.GetSubnetRoutes())
|
||||
} else if node.GetId() == 2 {
|
||||
assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.1.0/24
|
||||
assert.Empty(c, node.GetApprovedRoutes())
|
||||
assert.Empty(c, node.GetSubnetRoutes())
|
||||
} else {
|
||||
assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.2.0/24
|
||||
assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.2.0/24
|
||||
assert.Len(c, node.GetSubnetRoutes(), 1) // 10.0.2.0/24
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes")
|
||||
|
||||
// Verify that the clients can see the new routes
|
||||
for _, client := range allClients {
|
||||
@ -283,15 +284,17 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assertNoErrSync(t, err)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
// Wait for route configuration changes after advertising routes
|
||||
var nodes []*v1.Node
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 6)
|
||||
|
||||
nodes, err := headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, nodes[0], 1, 0, 0)
|
||||
requireNodeRouteCount(t, nodes[1], 1, 0, 0)
|
||||
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[0], 1, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0)
|
||||
}, 3*time.Second, 200*time.Millisecond, "all routes should be available but not yet approved")
|
||||
|
||||
// Verify that no routes has been sent to the client,
|
||||
// they are not yet enabled.
|
||||
@ -315,15 +318,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
// Wait for route approval on first subnet router
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 6)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCount(t, nodes[1], 1, 0, 0)
|
||||
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0)
|
||||
}, 3*time.Second, 200*time.Millisecond, "first subnet router should have approved route")
|
||||
|
||||
// Verify that the client has routes from the primary machine and can access
|
||||
// the webservice.
|
||||
@ -371,15 +375,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
// Wait for route approval on second subnet router
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 6)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
|
||||
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0)
|
||||
}, 3*time.Second, 200*time.Millisecond, "second subnet router should have approved route")
|
||||
|
||||
// Verify that the client has routes from the primary machine
|
||||
srs1 = subRouter1.MustStatus()
|
||||
@ -427,15 +432,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
// Wait for route approval on third subnet router
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 6)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
|
||||
requireNodeRouteCount(t, nodes[2], 1, 1, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[2], 1, 1, 0)
|
||||
}, 3*time.Second, 200*time.Millisecond, "third subnet router should have approved route")
|
||||
|
||||
// Verify that the client has routes from the primary machine
|
||||
srs1 = subRouter1.MustStatus()
|
||||
@ -469,9 +475,27 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 13)
|
||||
|
||||
tr, err = client.Traceroute(webip)
|
||||
require.NoError(t, err)
|
||||
assertTracerouteViaIP(t, tr, subRouter1.MustIPv4())
|
||||
// Wait for traceroute to work correctly through the expected router
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
tr, err := client.Traceroute(webip)
|
||||
assert.NoError(c, err)
|
||||
|
||||
// Get the expected router IP - use a more robust approach to handle temporary disconnections
|
||||
ips, err := subRouter1.IPs()
|
||||
assert.NoError(c, err)
|
||||
assert.NotEmpty(c, ips, "subRouter1 should have IP addresses")
|
||||
|
||||
var expectedIP netip.Addr
|
||||
for _, ip := range ips {
|
||||
if ip.Is4() {
|
||||
expectedIP = ip
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(c, expectedIP.IsValid(), "subRouter1 should have a valid IPv4 address")
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, expectedIP)
|
||||
}, 10*time.Second, 500*time.Millisecond, "traceroute should go through subRouter1")
|
||||
|
||||
// Take down the current primary
|
||||
t.Logf("taking down subnet router r1 (%s)", subRouter1.Hostname())
|
||||
@ -479,18 +503,19 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
err = subRouter1.Down()
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for router status changes after r1 goes down
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
srs2 = subRouter2.MustStatus()
|
||||
clientStatus = client.MustStatus()
|
||||
|
||||
srs2 = subRouter2.MustStatus()
|
||||
clientStatus = client.MustStatus()
|
||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||
|
||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||
|
||||
assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down")
|
||||
assert.True(t, srs2PeerStatus.Online, "r1 down, r2 up")
|
||||
assert.True(t, srs3PeerStatus.Online, "r1 down, r2 up")
|
||||
assert.False(c, srs1PeerStatus.Online, "r1 should be offline")
|
||||
assert.True(c, srs2PeerStatus.Online, "r2 should be online")
|
||||
assert.True(c, srs3PeerStatus.Online, "r3 should be online")
|
||||
}, 5*time.Second, 200*time.Millisecond, "router status should update after r1 goes down")
|
||||
|
||||
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
||||
require.NotNil(t, srs2PeerStatus.PrimaryRoutes)
|
||||
@ -520,22 +545,19 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
err = subRouter2.Down()
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for router status changes after r2 goes down
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
clientStatus, err = client.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
// TODO(kradalby): Check client status
|
||||
// Both are expected to be down
|
||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||
|
||||
// Verify that the route is not presented from either router
|
||||
clientStatus, err = client.Status()
|
||||
require.NoError(t, err)
|
||||
|
||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||
|
||||
assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down")
|
||||
assert.False(t, srs2PeerStatus.Online, "r1 down, r2 down")
|
||||
assert.True(t, srs3PeerStatus.Online, "r1 down, r2 down")
|
||||
assert.False(c, srs1PeerStatus.Online, "r1 should be offline")
|
||||
assert.False(c, srs2PeerStatus.Online, "r2 should be offline")
|
||||
assert.True(c, srs3PeerStatus.Online, "r3 should be online")
|
||||
}, 5*time.Second, 200*time.Millisecond, "router status should update after r2 goes down")
|
||||
|
||||
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
||||
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
||||
@ -559,19 +581,19 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
err = subRouter1.Up()
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for router status changes after r1 comes back up
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
clientStatus, err = client.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
// Verify that the route is announced from subnet router 1
|
||||
clientStatus, err = client.Status()
|
||||
require.NoError(t, err)
|
||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||
|
||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||
|
||||
assert.True(t, srs1PeerStatus.Online, "r1 is back up, r2 down")
|
||||
assert.False(t, srs2PeerStatus.Online, "r1 is back up, r2 down")
|
||||
assert.True(t, srs3PeerStatus.Online, "r1 is back up, r3 available")
|
||||
assert.True(c, srs1PeerStatus.Online, "r1 should be back online")
|
||||
assert.False(c, srs2PeerStatus.Online, "r2 should still be offline")
|
||||
assert.True(c, srs3PeerStatus.Online, "r3 should still be online")
|
||||
}, 5*time.Second, 200*time.Millisecond, "router status should update after r1 comes back up")
|
||||
|
||||
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
||||
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
||||
@ -601,19 +623,20 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
err = subRouter2.Up()
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for nodestore batch processing to complete and online status to be updated
|
||||
// NodeStore batching timeout is 500ms, so we wait up to 10 seconds for all routers to be online
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
clientStatus, err = client.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
// Verify that the route is announced from subnet router 1
|
||||
clientStatus, err = client.Status()
|
||||
require.NoError(t, err)
|
||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||
|
||||
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
|
||||
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
|
||||
srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey]
|
||||
|
||||
assert.True(t, srs1PeerStatus.Online, "r1 up, r2 up")
|
||||
assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up")
|
||||
assert.True(t, srs3PeerStatus.Online, "r1 up, r2 up")
|
||||
assert.True(c, srs1PeerStatus.Online, "r1 should be online")
|
||||
assert.True(c, srs2PeerStatus.Online, "r2 should be online")
|
||||
assert.True(c, srs3PeerStatus.Online, "r3 should be online")
|
||||
}, 10*time.Second, 500*time.Millisecond, "all routers should be online after bringing up r2")
|
||||
|
||||
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
|
||||
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
|
||||
@ -641,15 +664,18 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname())
|
||||
_, err = headscale.ApproveRoutes(MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{})
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for nodestore batch processing and route state changes to complete
|
||||
// NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 6)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||
// After disabling route on r3, r1 should become primary with 1 subnet route
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route should failover to r1 after disabling r3")
|
||||
|
||||
// Verify that the route is announced from subnet router 1
|
||||
clientStatus, err = client.Status()
|
||||
@ -686,15 +712,18 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname())
|
||||
_, err = headscale.ApproveRoutes(MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{})
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for nodestore batch processing and route state changes to complete
|
||||
// NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 6)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||
// After disabling route on r1, r2 should become primary with 1 subnet route
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route should failover to r2 after disabling r1")
|
||||
|
||||
// Verify that the route is announced from subnet router 1
|
||||
clientStatus, err = client.Status()
|
||||
@ -735,15 +764,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
util.MustStringsToPrefixes(r1Node.GetAvailableRoutes()),
|
||||
)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for route state changes after re-enabling r1
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 6)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||
}, 5*time.Second, 200*time.Millisecond, "route state should stabilize after re-enabling r1, expecting r2 to still be primary to avoid flapping")
|
||||
|
||||
// Verify that the route is announced from subnet router 1
|
||||
clientStatus, err = client.Status()
|
||||
@ -894,14 +924,15 @@ func TestSubnetRouteACL(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for route state changes to propagate to nodes
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 2)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodes, 2)
|
||||
|
||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCount(t, nodes[1], 0, 0, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, nodes[1], 0, 0, 0)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes")
|
||||
|
||||
// Verify that the client has routes from the primary machine
|
||||
srs1, _ := subRouter1.Status()
|
||||
@ -1070,22 +1101,23 @@ func TestEnablingExitRoutes(t *testing.T) {
|
||||
requireNodeRouteCount(t, nodes[0], 2, 2, 2)
|
||||
requireNodeRouteCount(t, nodes[1], 2, 2, 2)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for route state changes to propagate to clients
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// Verify that the clients can see the new routes
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
// Verify that the clients can see the new routes
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
assertNoErr(t, err)
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
require.NotNil(t, peerStatus.AllowedIPs)
|
||||
assert.Len(t, peerStatus.AllowedIPs.AsSlice(), 4)
|
||||
assert.Contains(t, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4())
|
||||
assert.Contains(t, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6())
|
||||
assert.NotNil(c, peerStatus.AllowedIPs)
|
||||
assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 4)
|
||||
assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4())
|
||||
assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6())
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 10*time.Second, 500*time.Millisecond, "clients should see new routes")
|
||||
}
|
||||
|
||||
// TestSubnetRouterMultiNetwork is an evolution of the subnet router test.
|
||||
@ -1178,23 +1210,24 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for route state changes to propagate to nodes and clients
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 2)
|
||||
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 2)
|
||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
||||
// Verify that the routes have been sent to the client
|
||||
status, err = user2c.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
// Verify that the routes have been sent to the client.
|
||||
status, err = user2c.Status()
|
||||
require.NoError(t, err)
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
assert.Contains(t, peerStatus.PrimaryRoutes.AsSlice(), *pref)
|
||||
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{*pref})
|
||||
}
|
||||
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *pref)
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*pref})
|
||||
}
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes and clients")
|
||||
|
||||
usernet1, err := scenario.Network("usernet1")
|
||||
require.NoError(t, err)
|
||||
@ -1298,22 +1331,23 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
|
||||
_, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()})
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for route state changes to propagate to nodes and clients
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 2)
|
||||
requireNodeRouteCountWithCollect(c, nodes[0], 2, 2, 2)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 2)
|
||||
requireNodeRouteCount(t, nodes[0], 2, 2, 2)
|
||||
// Verify that the routes have been sent to the client
|
||||
status, err = user2c.Status()
|
||||
assert.NoError(c, err)
|
||||
|
||||
// Verify that the routes have been sent to the client.
|
||||
status, err = user2c.Status()
|
||||
require.NoError(t, err)
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
requirePeerSubnetRoutes(t, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()})
|
||||
}
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()})
|
||||
}
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes and clients")
|
||||
|
||||
// Tell user2c to use user1c as an exit node.
|
||||
command = []string{
|
||||
@ -1621,6 +1655,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
require.NoErrorf(t, err, "failed to create scenario: %s", err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
var nodes []*v1.Node
|
||||
opts := []hsic.Option{
|
||||
hsic.WithTestName("autoapprovemulti"),
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
@ -1753,13 +1788,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
require.NoErrorf(t, err, "failed to advertise route: %s", err)
|
||||
}
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err := headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
// Wait for route state changes to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err := headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||
|
||||
// Verify that the routes have been sent to the client.
|
||||
status, err := client.Status()
|
||||
@ -1793,13 +1829,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
err = headscale.SetPolicy(tt.pol)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
// Wait for route state changes to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||
|
||||
// Verify that the routes have been sent to the client.
|
||||
status, err = client.Status()
|
||||
@ -1834,13 +1871,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0)
|
||||
// Wait for route state changes to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||
|
||||
// Verify that the routes have been sent to the client.
|
||||
status, err = client.Status()
|
||||
@ -1870,13 +1908,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
err = headscale.SetPolicy(tt.pol)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
// Wait for route state changes to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||
|
||||
// Verify that the routes have been sent to the client.
|
||||
status, err = client.Status()
|
||||
@ -1915,13 +1954,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
_, _, err = routerSubRoute.Execute(command)
|
||||
require.NoErrorf(t, err, "failed to advertise route: %s", err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
// Wait for route state changes to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||
requireNodeRouteCount(t, nodes[1], 1, 1, 1)
|
||||
|
||||
// Verify that the routes have been sent to the client.
|
||||
@ -1951,13 +1991,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
_, _, err = routerSubRoute.Execute(command)
|
||||
require.NoErrorf(t, err, "failed to advertise route: %s", err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
// Wait for route state changes to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// These route should auto approve, so the node is expected to have a route
|
||||
// for all counts.
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
|
||||
requireNodeRouteCount(t, nodes[2], 0, 0, 0)
|
||||
|
||||
@ -1985,13 +2026,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
||||
_, _, err = routerExitNode.Execute(command)
|
||||
require.NoErrorf(t, err, "failed to advertise route: %s", err)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
require.NoError(t, err)
|
||||
requireNodeRouteCount(t, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
|
||||
requireNodeRouteCount(t, nodes[2], 2, 2, 2)
|
||||
// Wait for route state changes to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
|
||||
requireNodeRouteCountWithCollect(c, nodes[2], 2, 2, 2)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||
|
||||
// Verify that the routes have been sent to the client.
|
||||
status, err = client.Status()
|
||||
@ -2025,6 +2067,15 @@ func assertTracerouteViaIP(t *testing.T, tr util.Traceroute, ip netip.Addr) {
|
||||
require.Equal(t, tr.Route[0].IP, ip)
|
||||
}
|
||||
|
||||
// assertTracerouteViaIPWithCollect is a version of assertTracerouteViaIP that works with assert.CollectT
|
||||
func assertTracerouteViaIPWithCollect(c *assert.CollectT, tr util.Traceroute, ip netip.Addr) {
|
||||
assert.NotNil(c, tr)
|
||||
assert.True(c, tr.Success)
|
||||
assert.NoError(c, tr.Err)
|
||||
assert.NotEmpty(c, tr.Route)
|
||||
assert.Equal(c, tr.Route[0].IP, ip)
|
||||
}
|
||||
|
||||
// requirePeerSubnetRoutes asserts that the peer has the expected subnet routes.
|
||||
func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected []netip.Prefix) {
|
||||
t.Helper()
|
||||
@ -2049,6 +2100,28 @@ func requirePeerSubnetRoutes(t *testing.T, status *ipnstate.PeerStatus, expected
|
||||
}
|
||||
}
|
||||
|
||||
func requirePeerSubnetRoutesWithCollect(c *assert.CollectT, status *ipnstate.PeerStatus, expected []netip.Prefix) {
|
||||
if status.AllowedIPs.Len() <= 2 && len(expected) != 0 {
|
||||
assert.Fail(c, fmt.Sprintf("peer %s (%s) has no subnet routes, expected %v", status.HostName, status.ID, expected))
|
||||
return
|
||||
}
|
||||
|
||||
if len(expected) == 0 {
|
||||
expected = []netip.Prefix{}
|
||||
}
|
||||
|
||||
got := slicesx.Filter(nil, status.AllowedIPs.AsSlice(), func(p netip.Prefix) bool {
|
||||
if tsaddr.IsExitRoute(p) {
|
||||
return true
|
||||
}
|
||||
return !slices.ContainsFunc(status.TailscaleIPs, p.Contains)
|
||||
})
|
||||
|
||||
if diff := cmpdiff.Diff(expected, got, util.PrefixComparer, cmpopts.EquateEmpty()); diff != "" {
|
||||
assert.Fail(c, fmt.Sprintf("peer %s (%s) subnet routes, unexpected result (-want +got):\n%s", status.HostName, status.ID, diff))
|
||||
}
|
||||
}
|
||||
|
||||
func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, subnet int) {
|
||||
t.Helper()
|
||||
require.Lenf(t, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes()))
|
||||
@ -2056,6 +2129,12 @@ func requireNodeRouteCount(t *testing.T, node *v1.Node, announced, approved, sub
|
||||
require.Lenf(t, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes()))
|
||||
}
|
||||
|
||||
func requireNodeRouteCountWithCollect(c *assert.CollectT, node *v1.Node, announced, approved, subnet int) {
|
||||
assert.Lenf(c, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes()))
|
||||
assert.Lenf(c, node.GetApprovedRoutes(), approved, "expected %q approved routes(%v) to have %d route, had %d", node.GetName(), node.GetApprovedRoutes(), approved, len(node.GetApprovedRoutes()))
|
||||
assert.Lenf(c, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes()))
|
||||
}
|
||||
|
||||
// TestSubnetRouteACLFiltering tests that a node can only access subnet routes
|
||||
// that are explicitly allowed in the ACL.
|
||||
func TestSubnetRouteACLFiltering(t *testing.T) {
|
||||
@ -2208,19 +2287,19 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Give some time for the routes to propagate
|
||||
time.Sleep(5 * time.Second)
|
||||
// Wait for route state changes to propagate
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// List nodes and verify the router has 3 available routes
|
||||
nodes, err = headscale.NodesByUser()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 2)
|
||||
|
||||
// List nodes and verify the router has 3 available routes
|
||||
nodes, err = headscale.NodesByUser()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodes, 2)
|
||||
// Find the router node
|
||||
routerNode = nodes[routerUser][0]
|
||||
|
||||
// Find the router node
|
||||
routerNode = nodes[routerUser][0]
|
||||
|
||||
// Check that the router has 3 routes now approved and available
|
||||
requireNodeRouteCount(t, routerNode, 3, 3, 3)
|
||||
// Check that the router has 3 routes now approved and available
|
||||
requireNodeRouteCountWithCollect(c, routerNode, 3, 3, 3)
|
||||
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
|
||||
|
||||
// Now check the client node status
|
||||
nodeStatus, err := nodeClient.Status()
|
||||
|
@ -14,7 +14,6 @@ import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -279,16 +278,16 @@ func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) {
|
||||
return nil, fmt.Errorf("no network named: %s", name)
|
||||
}
|
||||
|
||||
for _, ipam := range net.Network.IPAM.Config {
|
||||
pref, err := netip.ParsePrefix(ipam.Subnet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pref, nil
|
||||
if len(net.Network.IPAM.Config) == 0 {
|
||||
return nil, fmt.Errorf("no IPAM config found in network: %s", name)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no prefix found in network: %s", name)
|
||||
pref, err := netip.ParsePrefix(net.Network.IPAM.Config[0].Subnet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pref, nil
|
||||
}
|
||||
|
||||
func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) {
|
||||
@ -696,7 +695,6 @@ func (s *Scenario) createHeadscaleEnv(
|
||||
return err
|
||||
}
|
||||
|
||||
sort.Strings(s.spec.Users)
|
||||
for _, user := range s.spec.Users {
|
||||
u, err := s.CreateUser(user)
|
||||
if err != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user