mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	Split up MapResponse
This commits extends the mapper with functions for creating "delta" MapResponses for different purposes (peer changed, peer removed, derp). This wires up the new state management with a new StateUpdate struct letting the poll worker know what kind of update to send to the connected nodes. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									66ff1fcd40
								
							
						
					
					
						commit
						4b65cf48d0
					
				| @ -257,7 +257,10 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) { | ||||
| 				h.DERPMap.Regions[region.RegionID] = ®ion | ||||
| 			} | ||||
| 
 | ||||
| 			h.nodeNotifier.NotifyAll() | ||||
| 			h.nodeNotifier.NotifyAll(types.StateUpdate{ | ||||
| 				Type:    types.StateDERPUpdated, | ||||
| 				DERPMap: *h.DERPMap, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @ -721,7 +724,9 @@ func (h *Headscale) Serve() error { | ||||
| 						Str("path", aclPath). | ||||
| 						Msg("ACL policy successfully reloaded, notifying nodes of change") | ||||
| 
 | ||||
| 					h.nodeNotifier.NotifyAll() | ||||
| 					h.nodeNotifier.NotifyAll(types.StateUpdate{ | ||||
| 						Type: types.StateFullUpdate, | ||||
| 					}) | ||||
| 				} | ||||
| 
 | ||||
| 			default: | ||||
|  | ||||
| @ -13,6 +13,7 @@ import ( | ||||
| 	"github.com/patrickmn/go-cache" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| 	"gorm.io/gorm" | ||||
| 	"tailscale.com/tailcfg" | ||||
| 	"tailscale.com/types/key" | ||||
| ) | ||||
| 
 | ||||
| @ -218,7 +219,10 @@ func (hsdb *HSDatabase) SetTags( | ||||
| 	} | ||||
| 	machine.ForcedTags = newTags | ||||
| 
 | ||||
| 	hsdb.notifier.NotifyWithIgnore(machine.MachineKey) | ||||
| 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||
| 		Type:    types.StatePeerChanged, | ||||
| 		Changed: []uint64{machine.ID}, | ||||
| 	}, machine.MachineKey) | ||||
| 
 | ||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | ||||
| 		return fmt.Errorf("failed to update tags for machine in the database: %w", err) | ||||
| @ -232,7 +236,10 @@ func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error { | ||||
| 	now := time.Now() | ||||
| 	machine.Expiry = &now | ||||
| 
 | ||||
| 	hsdb.notifier.NotifyWithIgnore(machine.MachineKey) | ||||
| 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||
| 		Type:    types.StatePeerChanged, | ||||
| 		Changed: []uint64{machine.ID}, | ||||
| 	}, machine.MachineKey) | ||||
| 
 | ||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | ||||
| 		return fmt.Errorf("failed to expire machine in the database: %w", err) | ||||
| @ -259,7 +266,10 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er | ||||
| 	} | ||||
| 	machine.GivenName = newName | ||||
| 
 | ||||
| 	hsdb.notifier.NotifyWithIgnore(machine.MachineKey) | ||||
| 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||
| 		Type:    types.StatePeerChanged, | ||||
| 		Changed: []uint64{machine.ID}, | ||||
| 	}, machine.MachineKey) | ||||
| 
 | ||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | ||||
| 		return fmt.Errorf("failed to rename machine in the database: %w", err) | ||||
| @ -275,7 +285,10 @@ func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time) | ||||
| 	machine.LastSuccessfulUpdate = &now | ||||
| 	machine.Expiry = &expiry | ||||
| 
 | ||||
| 	hsdb.notifier.NotifyWithIgnore(machine.MachineKey) | ||||
| 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||
| 		Type:    types.StatePeerChanged, | ||||
| 		Changed: []uint64{machine.ID}, | ||||
| 	}, machine.MachineKey) | ||||
| 
 | ||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | ||||
| 		return fmt.Errorf( | ||||
| @ -549,6 +562,27 @@ func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string) | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func OnlineMachineMap(peers types.Machines) map[tailcfg.NodeID]bool { | ||||
| 	ret := make(map[tailcfg.NodeID]bool) | ||||
| 
 | ||||
| 	for _, peer := range peers { | ||||
| 		ret[tailcfg.NodeID(peer.ID)] = peer.IsOnline() | ||||
| 	} | ||||
| 
 | ||||
| 	return ret | ||||
| } | ||||
| 
 | ||||
| func (hsdb *HSDatabase) ListOnlineMachines( | ||||
| 	machine *types.Machine, | ||||
| ) (map[tailcfg.NodeID]bool, error) { | ||||
| 	peers, err := hsdb.ListPeers(machine) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return OnlineMachineMap(peers), nil | ||||
| } | ||||
| 
 | ||||
| // enableRoutes enables new routes based on a list of new routes.
 | ||||
| func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string) error { | ||||
| 	newRoutes := make([]netip.Prefix, len(routeStrs)) | ||||
| @ -600,7 +634,10 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	hsdb.notifier.NotifyWithIgnore(machine.MachineKey) | ||||
| 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||
| 		Type:    types.StatePeerChanged, | ||||
| 		Changed: []uint64{machine.ID}, | ||||
| 	}, machine.MachineKey) | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| @ -676,12 +713,13 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		expiredFound := false | ||||
| 		expired := make([]tailcfg.NodeID, 0) | ||||
| 		for idx, machine := range machines { | ||||
| 			if machine.IsEphemeral() && machine.LastSeen != nil && | ||||
| 				time.Now(). | ||||
| 					After(machine.LastSeen.Add(inactivityThreshhold)) { | ||||
| 				expiredFound = true | ||||
| 				expired = append(expired, tailcfg.NodeID(machine.ID)) | ||||
| 
 | ||||
| 				log.Info(). | ||||
| 					Str("machine", machine.Hostname). | ||||
| 					Msg("Ephemeral client removed from database") | ||||
| @ -696,8 +734,11 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if expiredFound { | ||||
| 			hsdb.notifier.NotifyAll() | ||||
| 		if len(expired) > 0 { | ||||
| 			hsdb.notifier.NotifyAll(types.StateUpdate{ | ||||
| 				Type:    types.StatePeerRemoved, | ||||
| 				Removed: expired, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @ -726,11 +767,11 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { | ||||
| 			return time.Unix(0, 0) | ||||
| 		} | ||||
| 
 | ||||
| 		expiredFound := false | ||||
| 		expired := make([]tailcfg.NodeID, 0) | ||||
| 		for index, machine := range machines { | ||||
| 			if machine.IsExpired() && | ||||
| 				machine.Expiry.After(lastCheck) { | ||||
| 				expiredFound = true | ||||
| 				expired = append(expired, tailcfg.NodeID(machine.ID)) | ||||
| 
 | ||||
| 				err := hsdb.ExpireMachine(&machines[index]) | ||||
| 				if err != nil { | ||||
| @ -748,8 +789,11 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if expiredFound { | ||||
| 			hsdb.notifier.NotifyAll() | ||||
| 		if len(expired) > 0 { | ||||
| 			hsdb.notifier.NotifyAll(types.StateUpdate{ | ||||
| 				Type:    types.StatePeerRemoved, | ||||
| 				Removed: expired, | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
|  | ||||
| @ -274,7 +274,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { | ||||
| 		log.Error().Err(err).Msg("error getting routes") | ||||
| 	} | ||||
| 
 | ||||
| 	routesChanged := false | ||||
| 	changedMachines := make([]uint64, 0) | ||||
| 	for pos, route := range routes { | ||||
| 		if route.IsExitRoute() { | ||||
| 			continue | ||||
| @ -295,7 +295,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { | ||||
| 					return err | ||||
| 				} | ||||
| 
 | ||||
| 				routesChanged = true | ||||
| 				changedMachines = append(changedMachines, route.MachineID) | ||||
| 
 | ||||
| 				continue | ||||
| 			} | ||||
| @ -369,12 +369,15 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { | ||||
| 				return err | ||||
| 			} | ||||
| 
 | ||||
| 			routesChanged = true | ||||
| 			changedMachines = append(changedMachines, route.MachineID) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if routesChanged { | ||||
| 		hsdb.notifier.NotifyAll() | ||||
| 	if len(changedMachines) > 0 { | ||||
| 		hsdb.notifier.NotifyAll(types.StateUpdate{ | ||||
| 			Type:    types.StatePeerChanged, | ||||
| 			Changed: changedMachines, | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
|  | ||||
| @ -5,6 +5,7 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/url" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| @ -129,45 +130,35 @@ func fullMapResponse( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	// Peers is always returned sorted by Node.ID.
 | ||||
| 	sort.SliceStable(tailPeers, func(x, y int) bool { | ||||
| 		return tailPeers[x].ID < tailPeers[y].ID | ||||
| 	}) | ||||
| 
 | ||||
| 	now := time.Now() | ||||
| 
 | ||||
| 	resp := tailcfg.MapResponse{ | ||||
| 		KeepAlive: false, | ||||
| 		Node:  tailnode, | ||||
| 
 | ||||
| 		// TODO: Only send if updated
 | ||||
| 		DERPMap: derpMap, | ||||
| 
 | ||||
| 		// TODO: Only send if updated
 | ||||
| 		Peers: tailPeers, | ||||
| 
 | ||||
| 		// TODO(kradalby): Implement:
 | ||||
| 		// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374
 | ||||
| 		// PeersChanged
 | ||||
| 		// PeersRemoved
 | ||||
| 		// PeersChangedPatch
 | ||||
| 		// PeerSeenChange
 | ||||
| 		// OnlineChange
 | ||||
| 		DERPMap: derpMap, | ||||
| 
 | ||||
| 		// TODO: Only send if updated
 | ||||
| 		DNSConfig: dnsConfig, | ||||
| 
 | ||||
| 		// TODO: Only send if updated
 | ||||
| 		Domain:    baseDomain, | ||||
| 
 | ||||
| 		// Do not instruct clients to collect services, we do not
 | ||||
| 		// Do not instruct clients to collect services we do not
 | ||||
| 		// support or do anything with them
 | ||||
| 		CollectServices: "false", | ||||
| 
 | ||||
| 		// TODO: Only send if updated
 | ||||
| 		PacketFilter: policy.ReduceFilterRules(machine, rules), | ||||
| 
 | ||||
| 		UserProfiles: profiles, | ||||
| 
 | ||||
| 		// TODO: Only send if updated
 | ||||
| 		SSHPolicy: sshPolicy, | ||||
| 
 | ||||
| 		ControlTime:  &now, | ||||
| 		KeepAlive:    false, | ||||
| 		OnlineChange: db.OnlineMachineMap(peers), | ||||
| 
 | ||||
| 		Debug: &tailcfg.Debug{ | ||||
| 			DisableLogTail:      !logtail, | ||||
| @ -271,8 +262,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // CreateMapResponse returns a MapResponse for the given machine.
 | ||||
| func (m Mapper) CreateMapResponse( | ||||
| // FullMapResponse returns a MapResponse for the given machine.
 | ||||
| func (m Mapper) FullMapResponse( | ||||
| 	mapRequest tailcfg.MapRequest, | ||||
| 	machine *types.Machine, | ||||
| 	pol *policy.ACLPolicy, | ||||
| @ -302,39 +293,107 @@ func (m Mapper) CreateMapResponse( | ||||
| 	} | ||||
| 
 | ||||
| 	if m.isNoise { | ||||
| 		return m.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress) | ||||
| 		return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress) | ||||
| 	} | ||||
| 
 | ||||
| 	var machineKey key.MachinePublic | ||||
| 	err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
| 			Err(err). | ||||
| 			Msg("Cannot parse client key") | ||||
| 
 | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return m.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress) | ||||
| 	return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress) | ||||
| } | ||||
| 
 | ||||
| func (m Mapper) CreateKeepAliveResponse( | ||||
| func (m Mapper) KeepAliveResponse( | ||||
| 	mapRequest tailcfg.MapRequest, | ||||
| 	machine *types.Machine, | ||||
| ) ([]byte, error) { | ||||
| 	keepAliveResponse := tailcfg.MapResponse{ | ||||
| 		KeepAlive: true, | ||||
| 	resp := m.baseMapResponse(machine) | ||||
| 	resp.KeepAlive = true | ||||
| 
 | ||||
| 	return m.marshalMapResponse(&resp, machine, mapRequest.Compress) | ||||
| } | ||||
| 
 | ||||
| func (m Mapper) DERPMapResponse( | ||||
| 	mapRequest tailcfg.MapRequest, | ||||
| 	machine *types.Machine, | ||||
| 	derpMap tailcfg.DERPMap, | ||||
| ) ([]byte, error) { | ||||
| 	resp := m.baseMapResponse(machine) | ||||
| 	resp.DERPMap = &derpMap | ||||
| 
 | ||||
| 	return m.marshalMapResponse(&resp, machine, mapRequest.Compress) | ||||
| } | ||||
| 
 | ||||
| func (m Mapper) PeerChangedResponse( | ||||
| 	mapRequest tailcfg.MapRequest, | ||||
| 	machine *types.Machine, | ||||
| 	machineKeys []uint64, | ||||
| 	pol *policy.ACLPolicy, | ||||
| ) ([]byte, error) { | ||||
| 	var err error | ||||
| 	changed := make(types.Machines, len(machineKeys)) | ||||
| 	lastSeen := make(map[tailcfg.NodeID]bool) | ||||
| 	for idx, machineKey := range machineKeys { | ||||
| 		peer, err := m.db.GetMachineByID(machineKey) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 	if m.isNoise { | ||||
| 		return m.marshalMapResponse( | ||||
| 			keepAliveResponse, | ||||
| 			key.MachinePublic{}, | ||||
| 			mapRequest.Compress, | ||||
| 		changed[idx] = *peer | ||||
| 
 | ||||
| 		// We have just seen the node, let the peers update their list.
 | ||||
| 		lastSeen[tailcfg.NodeID(peer.ID)] = true | ||||
| 	} | ||||
| 
 | ||||
| 	rules, _, err := policy.GenerateFilterAndSSHRules( | ||||
| 		pol, | ||||
| 		machine, | ||||
| 		changed, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	// Filter out peers that have expired.
 | ||||
| 	changed = lo.Filter(changed, func(item types.Machine, index int) bool { | ||||
| 		return !item.IsExpired() | ||||
| 	}) | ||||
| 
 | ||||
| 	// If there are filter rules present, see if there are any machines that cannot
 | ||||
| 	// access eachother at all and remove them from the changed.
 | ||||
| 	if len(rules) > 0 { | ||||
| 		changed = policy.FilterMachinesByACL(machine, changed, rules) | ||||
| 	} | ||||
| 
 | ||||
| 	tailPeers, err := tailNodes(changed, pol, m.dnsCfg, m.baseDomain) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	// Peers is always returned sorted by Node.ID.
 | ||||
| 	sort.SliceStable(tailPeers, func(x, y int) bool { | ||||
| 		return tailPeers[x].ID < tailPeers[y].ID | ||||
| 	}) | ||||
| 
 | ||||
| 	resp := m.baseMapResponse(machine) | ||||
| 	resp.PeersChanged = tailPeers | ||||
| 	resp.PeerSeenChange = lastSeen | ||||
| 
 | ||||
| 	return m.marshalMapResponse(&resp, machine, mapRequest.Compress) | ||||
| } | ||||
| 
 | ||||
| func (m Mapper) PeerRemovedResponse( | ||||
| 	mapRequest tailcfg.MapRequest, | ||||
| 	machine *types.Machine, | ||||
| 	removed []tailcfg.NodeID, | ||||
| ) ([]byte, error) { | ||||
| 	resp := m.baseMapResponse(machine) | ||||
| 	resp.PeersRemoved = removed | ||||
| 
 | ||||
| 	return m.marshalMapResponse(&resp, machine, mapRequest.Compress) | ||||
| } | ||||
| 
 | ||||
| func (m Mapper) marshalMapResponse( | ||||
| 	resp *tailcfg.MapResponse, | ||||
| 	machine *types.Machine, | ||||
| 	compression string, | ||||
| ) ([]byte, error) { | ||||
| 	var machineKey key.MachinePublic | ||||
| 	err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) | ||||
| 	if err != nil { | ||||
| @ -346,40 +405,6 @@ func (m Mapper) CreateKeepAliveResponse( | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return m.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress) | ||||
| } | ||||
| 
 | ||||
| // MarshalResponse takes an Tailscale Response, marhsal it to JSON.
 | ||||
| // If isNoise is set, then the JSON body will be returned
 | ||||
| // If !isNoise and privateKey2019 is set, the JSON body will be sealed in a Nacl box.
 | ||||
| func MarshalResponse( | ||||
| 	resp interface{}, | ||||
| 	isNoise bool, | ||||
| 	privateKey2019 *key.MachinePrivate, | ||||
| 	machineKey key.MachinePublic, | ||||
| ) ([]byte, error) { | ||||
| 	jsonBody, err := json.Marshal(resp) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
| 			Err(err). | ||||
| 			Msg("Cannot marshal response") | ||||
| 
 | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	if !isNoise && privateKey2019 != nil { | ||||
| 		return privateKey2019.SealTo(machineKey, jsonBody), nil | ||||
| 	} | ||||
| 
 | ||||
| 	return jsonBody, nil | ||||
| } | ||||
| 
 | ||||
| func (m Mapper) marshalMapResponse( | ||||
| 	resp interface{}, | ||||
| 	machineKey key.MachinePublic, | ||||
| 	compression string, | ||||
| ) ([]byte, error) { | ||||
| 	jsonBody, err := json.Marshal(resp) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| @ -409,6 +434,32 @@ func (m Mapper) marshalMapResponse( | ||||
| 	return data, nil | ||||
| } | ||||
| 
 | ||||
| // MarshalResponse takes an Tailscale Response, marhsal it to JSON.
 | ||||
| // If isNoise is set, then the JSON body will be returned
 | ||||
| // If !isNoise and privateKey2019 is set, the JSON body will be sealed in a Nacl box.
 | ||||
| func MarshalResponse( | ||||
| 	resp interface{}, | ||||
| 	isNoise bool, | ||||
| 	privateKey2019 *key.MachinePrivate, | ||||
| 	machineKey key.MachinePublic, | ||||
| ) ([]byte, error) { | ||||
| 	jsonBody, err := json.Marshal(resp) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
| 			Err(err). | ||||
| 			Msg("Cannot marshal response") | ||||
| 
 | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	if !isNoise && privateKey2019 != nil { | ||||
| 		return privateKey2019.SealTo(machineKey, jsonBody), nil | ||||
| 	} | ||||
| 
 | ||||
| 	return jsonBody, nil | ||||
| } | ||||
| 
 | ||||
| func zstdEncode(in []byte) []byte { | ||||
| 	encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder) | ||||
| 	if !ok { | ||||
| @ -433,3 +484,19 @@ var zstdEncoderPool = &sync.Pool{ | ||||
| 		return encoder | ||||
| 	}, | ||||
| } | ||||
| 
 | ||||
| func (m *Mapper) baseMapResponse(machine *types.Machine) tailcfg.MapResponse { | ||||
| 	now := time.Now() | ||||
| 
 | ||||
| 	resp := tailcfg.MapResponse{ | ||||
| 		KeepAlive:   false, | ||||
| 		ControlTime: &now, | ||||
| 	} | ||||
| 
 | ||||
| 	online, err := m.db.ListOnlineMachines(machine) | ||||
| 	if err == nil { | ||||
| 		resp.OnlineChange = online | ||||
| 	} | ||||
| 
 | ||||
| 	return resp | ||||
| } | ||||
|  | ||||
| @ -387,6 +387,7 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 				DNSConfig:       &tailcfg.DNSConfig{}, | ||||
| 				Domain:          "", | ||||
| 				CollectServices: "false", | ||||
| 				OnlineChange:    map[tailcfg.NodeID]bool{tailPeer1.ID: false}, | ||||
| 				PacketFilter:    []tailcfg.FilterRule{}, | ||||
| 				UserProfiles:    []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, | ||||
| 				SSHPolicy:       &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, | ||||
| @ -428,6 +429,7 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 				DNSConfig:       &tailcfg.DNSConfig{}, | ||||
| 				Domain:          "", | ||||
| 				CollectServices: "false", | ||||
| 				OnlineChange:    map[tailcfg.NodeID]bool{tailPeer1.ID: false}, | ||||
| 				PacketFilter: []tailcfg.FilterRule{ | ||||
| 					{ | ||||
| 						SrcIPs: []string{"100.64.0.2/32"}, | ||||
|  | ||||
| @ -3,24 +3,25 @@ package notifier | ||||
| import ( | ||||
| 	"sync" | ||||
| 
 | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| ) | ||||
| 
 | ||||
| type Notifier struct { | ||||
| 	l     sync.RWMutex | ||||
| 	nodes map[string]chan<- struct{} | ||||
| 	nodes map[string]chan<- types.StateUpdate | ||||
| } | ||||
| 
 | ||||
| func NewNotifier() *Notifier { | ||||
| 	return &Notifier{} | ||||
| } | ||||
| 
 | ||||
| func (n *Notifier) AddNode(machineKey string, c chan<- struct{}) { | ||||
| func (n *Notifier) AddNode(machineKey string, c chan<- types.StateUpdate) { | ||||
| 	n.l.Lock() | ||||
| 	defer n.l.Unlock() | ||||
| 
 | ||||
| 	if n.nodes == nil { | ||||
| 		n.nodes = make(map[string]chan<- struct{}) | ||||
| 		n.nodes = make(map[string]chan<- types.StateUpdate) | ||||
| 	} | ||||
| 
 | ||||
| 	n.nodes[machineKey] = c | ||||
| @ -37,11 +38,11 @@ func (n *Notifier) RemoveNode(machineKey string) { | ||||
| 	delete(n.nodes, machineKey) | ||||
| } | ||||
| 
 | ||||
| func (n *Notifier) NotifyAll() { | ||||
| 	n.NotifyWithIgnore() | ||||
| func (n *Notifier) NotifyAll(update types.StateUpdate) { | ||||
| 	n.NotifyWithIgnore(update) | ||||
| } | ||||
| 
 | ||||
| func (n *Notifier) NotifyWithIgnore(ignore ...string) { | ||||
| func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) { | ||||
| 	n.l.RLock() | ||||
| 	defer n.l.RUnlock() | ||||
| 
 | ||||
| @ -50,6 +51,6 @@ func (n *Notifier) NotifyWithIgnore(ignore ...string) { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		c <- struct{}{} | ||||
| 		c <- update | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -116,7 +116,7 @@ func (h *Headscale) handlePoll( | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy) | ||||
| 	mapResp, err := mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy) | ||||
| 	if err != nil { | ||||
| 		logErr(err, "Failed to create MapResponse") | ||||
| 		http.Error(writer, "", http.StatusInternalServerError) | ||||
| @ -163,7 +163,12 @@ func (h *Headscale) handlePoll( | ||||
| 			Inc() | ||||
| 
 | ||||
| 		// Tell all the other nodes about the new endpoint, but dont update ourselves.
 | ||||
| 		h.nodeNotifier.NotifyWithIgnore(machine.MachineKey) | ||||
| 		h.nodeNotifier.NotifyWithIgnore( | ||||
| 			types.StateUpdate{ | ||||
| 				Type:    types.StatePeerChanged, | ||||
| 				Changed: []uint64{machine.ID}, | ||||
| 			}, | ||||
| 			machine.MachineKey) | ||||
| 
 | ||||
| 		return | ||||
| 	} else if mapRequest.OmitPeers && mapRequest.Stream { | ||||
| @ -220,7 +225,7 @@ func (h *Headscale) pollNetMapStream( | ||||
| 	keepAliveTicker := time.NewTicker(keepAliveInterval) | ||||
| 
 | ||||
| 	const chanSize = 8 | ||||
| 	updateChan := make(chan struct{}, chanSize) | ||||
| 	updateChan := make(chan types.StateUpdate, chanSize) | ||||
| 
 | ||||
| 	h.pollNetMapStreamWG.Add(1) | ||||
| 	defer h.pollNetMapStreamWG.Done() | ||||
| @ -238,7 +243,7 @@ func (h *Headscale) pollNetMapStream( | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-keepAliveTicker.C: | ||||
| 			data, err := mapp.CreateKeepAliveResponse(mapRequest, machine) | ||||
| 			data, err := mapp.KeepAliveResponse(mapRequest, machine) | ||||
| 			if err != nil { | ||||
| 				logErr(err, "Error generating the keep alive msg") | ||||
| 
 | ||||
| @ -263,10 +268,23 @@ func (h *Headscale) pollNetMapStream( | ||||
| 				return | ||||
| 			} | ||||
| 
 | ||||
| 		case <-updateChan: | ||||
| 			data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy) | ||||
| 		case update := <-updateChan: | ||||
| 			var data []byte | ||||
| 			var err error | ||||
| 
 | ||||
| 			switch update.Type { | ||||
| 			case types.StateFullUpdate: | ||||
| 				data, err = mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy) | ||||
| 			case types.StatePeerChanged: | ||||
| 				data, err = mapp.PeerChangedResponse(mapRequest, machine, update.Changed, h.ACLPolicy) | ||||
| 			case types.StatePeerRemoved: | ||||
| 				data, err = mapp.PeerRemovedResponse(mapRequest, machine, update.Removed) | ||||
| 			case types.StateDERPUpdated: | ||||
| 				data, err = mapp.DERPMapResponse(mapRequest, machine, update.DERPMap) | ||||
| 			} | ||||
| 
 | ||||
| 			if err != nil { | ||||
| 				logErr(err, "Could not get the map update") | ||||
| 				logErr(err, "Could not get the create map update") | ||||
| 
 | ||||
| 				return | ||||
| 			} | ||||
| @ -317,7 +335,7 @@ func (h *Headscale) pollNetMapStream( | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) { | ||||
| func closeChanWithLog[C chan []byte | chan struct{} | chan types.StateUpdate](channel C, machine, name string) { | ||||
| 	log.Trace(). | ||||
| 		Str("handler", "PollNetMap"). | ||||
| 		Str("machine", machine). | ||||
|  | ||||
| @ -106,3 +106,32 @@ func (i StringList) Value() (driver.Value, error) { | ||||
| 
 | ||||
| 	return string(bytes), err | ||||
| } | ||||
| 
 | ||||
| type StateUpdateType int | ||||
| 
 | ||||
| const ( | ||||
| 	StateFullUpdate StateUpdateType = iota | ||||
| 	StatePeerChanged | ||||
| 	StatePeerRemoved | ||||
| 	StateDERPUpdated | ||||
| ) | ||||
| 
 | ||||
| // StateUpdate is an internal message containing information about
 | ||||
| // a state change that has happened to the network.
 | ||||
| type StateUpdate struct { | ||||
| 	// The type of update
 | ||||
| 	Type StateUpdateType | ||||
| 
 | ||||
| 	// Changed must be set when Type is StatePeerChanged and
 | ||||
| 	// contain the Machine IDs of machines that has changed.
 | ||||
| 	Changed []uint64 | ||||
| 
 | ||||
| 	// Removed must be set when Type is StatePeerRemoved and
 | ||||
| 	// contain a list of the nodes that has been removed from
 | ||||
| 	// the network.
 | ||||
| 	Removed []tailcfg.NodeID | ||||
| 
 | ||||
| 	// DERPMap must be set when Type is StateDERPUpdated and
 | ||||
| 	// contain the new DERP Map.
 | ||||
| 	DERPMap tailcfg.DERPMap | ||||
| } | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user