1
0
mirror of https://github.com/juanfont/headscale.git synced 2024-12-20 19:09:07 +01:00

Update packetfilter when peers change

Previously we did not update the packet filter
when nodes changed, which would cause new nodes
to be missing from packet filters of old nodes.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-08-09 20:37:41 +02:00 committed by Kristoffer Dalby
parent a8079a2096
commit 3b0749a320
2 changed files with 28 additions and 13 deletions

View File

@ -382,28 +382,31 @@ func (m *Mapper) DERPMapResponse(
func (m *Mapper) PeerChangedResponse( func (m *Mapper) PeerChangedResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machine *types.Machine, machine *types.Machine,
machineKeys []uint64, machineIDs []uint64,
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
) ([]byte, error) { ) ([]byte, error) {
var err error var err error
changed := make(types.Machines, len(machineKeys)) changed := make(types.Machines, len(machineIDs))
lastSeen := make(map[tailcfg.NodeID]bool) lastSeen := make(map[tailcfg.NodeID]bool)
for idx, machineKey := range machineKeys {
peer, err := m.db.GetMachineByID(machineKey)
if err != nil {
return nil, err
}
changed[idx] = *peer peersList, err := m.db.ListPeers(machine)
if err != nil {
// We have just seen the node, let the peers update their list. return nil, err
lastSeen[tailcfg.NodeID(peer.ID)] = true
} }
rules, _, err := policy.GenerateFilterAndSSHRules( peers := peersList.IDMap()
for idx, machineID := range machineIDs {
changed[idx] = peers[machineID]
// We have just seen the node, let the peers update their list.
lastSeen[tailcfg.NodeID(machineID)] = true
}
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
pol, pol,
machine, machine,
changed, peersList,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -434,6 +437,8 @@ func (m *Mapper) PeerChangedResponse(
resp := m.baseMapResponse(machine) resp := m.baseMapResponse(machine)
resp.PeersChanged = tailPeers resp.PeersChanged = tailPeers
resp.PacketFilter = policy.ReduceFilterRules(machine, rules)
resp.SSHPolicy = sshPolicy
// resp.PeerSeenChange = lastSeen // resp.PeerSeenChange = lastSeen
return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress) return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)

View File

@ -353,3 +353,13 @@ func (machines MachinesP) String() string {
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
} }
func (machines Machines) IDMap() map[uint64]Machine {
ret := map[uint64]Machine{}
for _, machine := range machines {
ret[machine.ID] = machine
}
return ret
}