mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	Split up MapResponse
This commits extends the mapper with functions for creating "delta" MapResponses for different purposes (peer changed, peer removed, derp). This wires up the new state management with a new StateUpdate struct letting the poll worker know what kind of update to send to the connected nodes. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									66ff1fcd40
								
							
						
					
					
						commit
						4b65cf48d0
					
				| @ -257,7 +257,10 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) { | |||||||
| 				h.DERPMap.Regions[region.RegionID] = ®ion | 				h.DERPMap.Regions[region.RegionID] = ®ion | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			h.nodeNotifier.NotifyAll() | 			h.nodeNotifier.NotifyAll(types.StateUpdate{ | ||||||
|  | 				Type:    types.StateDERPUpdated, | ||||||
|  | 				DERPMap: *h.DERPMap, | ||||||
|  | 			}) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @ -721,7 +724,9 @@ func (h *Headscale) Serve() error { | |||||||
| 						Str("path", aclPath). | 						Str("path", aclPath). | ||||||
| 						Msg("ACL policy successfully reloaded, notifying nodes of change") | 						Msg("ACL policy successfully reloaded, notifying nodes of change") | ||||||
| 
 | 
 | ||||||
| 					h.nodeNotifier.NotifyAll() | 					h.nodeNotifier.NotifyAll(types.StateUpdate{ | ||||||
|  | 						Type: types.StateFullUpdate, | ||||||
|  | 					}) | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 			default: | 			default: | ||||||
|  | |||||||
| @ -13,6 +13,7 @@ import ( | |||||||
| 	"github.com/patrickmn/go-cache" | 	"github.com/patrickmn/go-cache" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
|  | 	"tailscale.com/tailcfg" | ||||||
| 	"tailscale.com/types/key" | 	"tailscale.com/types/key" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| @ -218,7 +219,10 @@ func (hsdb *HSDatabase) SetTags( | |||||||
| 	} | 	} | ||||||
| 	machine.ForcedTags = newTags | 	machine.ForcedTags = newTags | ||||||
| 
 | 
 | ||||||
| 	hsdb.notifier.NotifyWithIgnore(machine.MachineKey) | 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||||
|  | 		Type:    types.StatePeerChanged, | ||||||
|  | 		Changed: []uint64{machine.ID}, | ||||||
|  | 	}, machine.MachineKey) | ||||||
| 
 | 
 | ||||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | 	if err := hsdb.db.Save(machine).Error; err != nil { | ||||||
| 		return fmt.Errorf("failed to update tags for machine in the database: %w", err) | 		return fmt.Errorf("failed to update tags for machine in the database: %w", err) | ||||||
| @ -232,7 +236,10 @@ func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error { | |||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
| 	machine.Expiry = &now | 	machine.Expiry = &now | ||||||
| 
 | 
 | ||||||
| 	hsdb.notifier.NotifyWithIgnore(machine.MachineKey) | 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||||
|  | 		Type:    types.StatePeerChanged, | ||||||
|  | 		Changed: []uint64{machine.ID}, | ||||||
|  | 	}, machine.MachineKey) | ||||||
| 
 | 
 | ||||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | 	if err := hsdb.db.Save(machine).Error; err != nil { | ||||||
| 		return fmt.Errorf("failed to expire machine in the database: %w", err) | 		return fmt.Errorf("failed to expire machine in the database: %w", err) | ||||||
| @ -259,7 +266,10 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er | |||||||
| 	} | 	} | ||||||
| 	machine.GivenName = newName | 	machine.GivenName = newName | ||||||
| 
 | 
 | ||||||
| 	hsdb.notifier.NotifyWithIgnore(machine.MachineKey) | 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||||
|  | 		Type:    types.StatePeerChanged, | ||||||
|  | 		Changed: []uint64{machine.ID}, | ||||||
|  | 	}, machine.MachineKey) | ||||||
| 
 | 
 | ||||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | 	if err := hsdb.db.Save(machine).Error; err != nil { | ||||||
| 		return fmt.Errorf("failed to rename machine in the database: %w", err) | 		return fmt.Errorf("failed to rename machine in the database: %w", err) | ||||||
| @ -275,7 +285,10 @@ func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time) | |||||||
| 	machine.LastSuccessfulUpdate = &now | 	machine.LastSuccessfulUpdate = &now | ||||||
| 	machine.Expiry = &expiry | 	machine.Expiry = &expiry | ||||||
| 
 | 
 | ||||||
| 	hsdb.notifier.NotifyWithIgnore(machine.MachineKey) | 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||||
|  | 		Type:    types.StatePeerChanged, | ||||||
|  | 		Changed: []uint64{machine.ID}, | ||||||
|  | 	}, machine.MachineKey) | ||||||
| 
 | 
 | ||||||
| 	if err := hsdb.db.Save(machine).Error; err != nil { | 	if err := hsdb.db.Save(machine).Error; err != nil { | ||||||
| 		return fmt.Errorf( | 		return fmt.Errorf( | ||||||
| @ -549,6 +562,27 @@ func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string) | |||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func OnlineMachineMap(peers types.Machines) map[tailcfg.NodeID]bool { | ||||||
|  | 	ret := make(map[tailcfg.NodeID]bool) | ||||||
|  | 
 | ||||||
|  | 	for _, peer := range peers { | ||||||
|  | 		ret[tailcfg.NodeID(peer.ID)] = peer.IsOnline() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return ret | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (hsdb *HSDatabase) ListOnlineMachines( | ||||||
|  | 	machine *types.Machine, | ||||||
|  | ) (map[tailcfg.NodeID]bool, error) { | ||||||
|  | 	peers, err := hsdb.ListPeers(machine) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return OnlineMachineMap(peers), nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // enableRoutes enables new routes based on a list of new routes.
 | // enableRoutes enables new routes based on a list of new routes.
 | ||||||
| func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string) error { | func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string) error { | ||||||
| 	newRoutes := make([]netip.Prefix, len(routeStrs)) | 	newRoutes := make([]netip.Prefix, len(routeStrs)) | ||||||
| @ -600,7 +634,10 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	hsdb.notifier.NotifyWithIgnore(machine.MachineKey) | 	hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ | ||||||
|  | 		Type:    types.StatePeerChanged, | ||||||
|  | 		Changed: []uint64{machine.ID}, | ||||||
|  | 	}, machine.MachineKey) | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| @ -676,12 +713,13 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		expiredFound := false | 		expired := make([]tailcfg.NodeID, 0) | ||||||
| 		for idx, machine := range machines { | 		for idx, machine := range machines { | ||||||
| 			if machine.IsEphemeral() && machine.LastSeen != nil && | 			if machine.IsEphemeral() && machine.LastSeen != nil && | ||||||
| 				time.Now(). | 				time.Now(). | ||||||
| 					After(machine.LastSeen.Add(inactivityThreshhold)) { | 					After(machine.LastSeen.Add(inactivityThreshhold)) { | ||||||
| 				expiredFound = true | 				expired = append(expired, tailcfg.NodeID(machine.ID)) | ||||||
|  | 
 | ||||||
| 				log.Info(). | 				log.Info(). | ||||||
| 					Str("machine", machine.Hostname). | 					Str("machine", machine.Hostname). | ||||||
| 					Msg("Ephemeral client removed from database") | 					Msg("Ephemeral client removed from database") | ||||||
| @ -696,8 +734,11 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if expiredFound { | 		if len(expired) > 0 { | ||||||
| 			hsdb.notifier.NotifyAll() | 			hsdb.notifier.NotifyAll(types.StateUpdate{ | ||||||
|  | 				Type:    types.StatePeerRemoved, | ||||||
|  | 				Removed: expired, | ||||||
|  | 			}) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @ -726,11 +767,11 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { | |||||||
| 			return time.Unix(0, 0) | 			return time.Unix(0, 0) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		expiredFound := false | 		expired := make([]tailcfg.NodeID, 0) | ||||||
| 		for index, machine := range machines { | 		for index, machine := range machines { | ||||||
| 			if machine.IsExpired() && | 			if machine.IsExpired() && | ||||||
| 				machine.Expiry.After(lastCheck) { | 				machine.Expiry.After(lastCheck) { | ||||||
| 				expiredFound = true | 				expired = append(expired, tailcfg.NodeID(machine.ID)) | ||||||
| 
 | 
 | ||||||
| 				err := hsdb.ExpireMachine(&machines[index]) | 				err := hsdb.ExpireMachine(&machines[index]) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| @ -748,8 +789,11 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if expiredFound { | 		if len(expired) > 0 { | ||||||
| 			hsdb.notifier.NotifyAll() | 			hsdb.notifier.NotifyAll(types.StateUpdate{ | ||||||
|  | 				Type:    types.StatePeerRemoved, | ||||||
|  | 				Removed: expired, | ||||||
|  | 			}) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -274,7 +274,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { | |||||||
| 		log.Error().Err(err).Msg("error getting routes") | 		log.Error().Err(err).Msg("error getting routes") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	routesChanged := false | 	changedMachines := make([]uint64, 0) | ||||||
| 	for pos, route := range routes { | 	for pos, route := range routes { | ||||||
| 		if route.IsExitRoute() { | 		if route.IsExitRoute() { | ||||||
| 			continue | 			continue | ||||||
| @ -295,7 +295,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { | |||||||
| 					return err | 					return err | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				routesChanged = true | 				changedMachines = append(changedMachines, route.MachineID) | ||||||
| 
 | 
 | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| @ -369,12 +369,15 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { | |||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			routesChanged = true | 			changedMachines = append(changedMachines, route.MachineID) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if routesChanged { | 	if len(changedMachines) > 0 { | ||||||
| 		hsdb.notifier.NotifyAll() | 		hsdb.notifier.NotifyAll(types.StateUpdate{ | ||||||
|  | 			Type:    types.StatePeerChanged, | ||||||
|  | 			Changed: changedMachines, | ||||||
|  | 		}) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | |||||||
| @ -5,6 +5,7 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | 	"sort" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| @ -129,45 +130,35 @@ func fullMapResponse( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Peers is always returned sorted by Node.ID.
 | ||||||
|  | 	sort.SliceStable(tailPeers, func(x, y int) bool { | ||||||
|  | 		return tailPeers[x].ID < tailPeers[y].ID | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
| 	now := time.Now() | 	now := time.Now() | ||||||
| 
 | 
 | ||||||
| 	resp := tailcfg.MapResponse{ | 	resp := tailcfg.MapResponse{ | ||||||
| 		KeepAlive: false, | 		Node:  tailnode, | ||||||
| 		Node:      tailnode, |  | ||||||
| 
 |  | ||||||
| 		// TODO: Only send if updated
 |  | ||||||
| 		DERPMap: derpMap, |  | ||||||
| 
 |  | ||||||
| 		// TODO: Only send if updated
 |  | ||||||
| 		Peers: tailPeers, | 		Peers: tailPeers, | ||||||
| 
 | 
 | ||||||
| 		// TODO(kradalby): Implement:
 | 		DERPMap: derpMap, | ||||||
| 		// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374
 |  | ||||||
| 		// PeersChanged
 |  | ||||||
| 		// PeersRemoved
 |  | ||||||
| 		// PeersChangedPatch
 |  | ||||||
| 		// PeerSeenChange
 |  | ||||||
| 		// OnlineChange
 |  | ||||||
| 
 | 
 | ||||||
| 		// TODO: Only send if updated
 |  | ||||||
| 		DNSConfig: dnsConfig, | 		DNSConfig: dnsConfig, | ||||||
|  | 		Domain:    baseDomain, | ||||||
| 
 | 
 | ||||||
| 		// TODO: Only send if updated
 | 		// Do not instruct clients to collect services we do not
 | ||||||
| 		Domain: baseDomain, |  | ||||||
| 
 |  | ||||||
| 		// Do not instruct clients to collect services, we do not
 |  | ||||||
| 		// support or do anything with them
 | 		// support or do anything with them
 | ||||||
| 		CollectServices: "false", | 		CollectServices: "false", | ||||||
| 
 | 
 | ||||||
| 		// TODO: Only send if updated
 |  | ||||||
| 		PacketFilter: policy.ReduceFilterRules(machine, rules), | 		PacketFilter: policy.ReduceFilterRules(machine, rules), | ||||||
| 
 | 
 | ||||||
| 		UserProfiles: profiles, | 		UserProfiles: profiles, | ||||||
| 
 | 
 | ||||||
| 		// TODO: Only send if updated
 |  | ||||||
| 		SSHPolicy: sshPolicy, | 		SSHPolicy: sshPolicy, | ||||||
| 
 | 
 | ||||||
| 		ControlTime: &now, | 		ControlTime:  &now, | ||||||
|  | 		KeepAlive:    false, | ||||||
|  | 		OnlineChange: db.OnlineMachineMap(peers), | ||||||
| 
 | 
 | ||||||
| 		Debug: &tailcfg.Debug{ | 		Debug: &tailcfg.Debug{ | ||||||
| 			DisableLogTail:      !logtail, | 			DisableLogTail:      !logtail, | ||||||
| @ -271,8 +262,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // CreateMapResponse returns a MapResponse for the given machine.
 | // FullMapResponse returns a MapResponse for the given machine.
 | ||||||
| func (m Mapper) CreateMapResponse( | func (m Mapper) FullMapResponse( | ||||||
| 	mapRequest tailcfg.MapRequest, | 	mapRequest tailcfg.MapRequest, | ||||||
| 	machine *types.Machine, | 	machine *types.Machine, | ||||||
| 	pol *policy.ACLPolicy, | 	pol *policy.ACLPolicy, | ||||||
| @ -302,39 +293,107 @@ func (m Mapper) CreateMapResponse( | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if m.isNoise { | 	if m.isNoise { | ||||||
| 		return m.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress) | 		return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var machineKey key.MachinePublic | 	return m.marshalMapResponse(mapResponse, machine, mapRequest.Compress) | ||||||
| 	err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Error(). |  | ||||||
| 			Caller(). |  | ||||||
| 			Err(err). |  | ||||||
| 			Msg("Cannot parse client key") |  | ||||||
| 
 |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return m.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress) |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (m Mapper) CreateKeepAliveResponse( | func (m Mapper) KeepAliveResponse( | ||||||
| 	mapRequest tailcfg.MapRequest, | 	mapRequest tailcfg.MapRequest, | ||||||
| 	machine *types.Machine, | 	machine *types.Machine, | ||||||
| ) ([]byte, error) { | ) ([]byte, error) { | ||||||
| 	keepAliveResponse := tailcfg.MapResponse{ | 	resp := m.baseMapResponse(machine) | ||||||
| 		KeepAlive: true, | 	resp.KeepAlive = true | ||||||
|  | 
 | ||||||
|  | 	return m.marshalMapResponse(&resp, machine, mapRequest.Compress) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Mapper) DERPMapResponse( | ||||||
|  | 	mapRequest tailcfg.MapRequest, | ||||||
|  | 	machine *types.Machine, | ||||||
|  | 	derpMap tailcfg.DERPMap, | ||||||
|  | ) ([]byte, error) { | ||||||
|  | 	resp := m.baseMapResponse(machine) | ||||||
|  | 	resp.DERPMap = &derpMap | ||||||
|  | 
 | ||||||
|  | 	return m.marshalMapResponse(&resp, machine, mapRequest.Compress) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Mapper) PeerChangedResponse( | ||||||
|  | 	mapRequest tailcfg.MapRequest, | ||||||
|  | 	machine *types.Machine, | ||||||
|  | 	machineKeys []uint64, | ||||||
|  | 	pol *policy.ACLPolicy, | ||||||
|  | ) ([]byte, error) { | ||||||
|  | 	var err error | ||||||
|  | 	changed := make(types.Machines, len(machineKeys)) | ||||||
|  | 	lastSeen := make(map[tailcfg.NodeID]bool) | ||||||
|  | 	for idx, machineKey := range machineKeys { | ||||||
|  | 		peer, err := m.db.GetMachineByID(machineKey) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		changed[idx] = *peer | ||||||
|  | 
 | ||||||
|  | 		// We have just seen the node, let the peers update their list.
 | ||||||
|  | 		lastSeen[tailcfg.NodeID(peer.ID)] = true | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if m.isNoise { | 	rules, _, err := policy.GenerateFilterAndSSHRules( | ||||||
| 		return m.marshalMapResponse( | 		pol, | ||||||
| 			keepAliveResponse, | 		machine, | ||||||
| 			key.MachinePublic{}, | 		changed, | ||||||
| 			mapRequest.Compress, | 	) | ||||||
| 		) | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Filter out peers that have expired.
 | ||||||
|  | 	changed = lo.Filter(changed, func(item types.Machine, index int) bool { | ||||||
|  | 		return !item.IsExpired() | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	// If there are filter rules present, see if there are any machines that cannot
 | ||||||
|  | 	// access eachother at all and remove them from the changed.
 | ||||||
|  | 	if len(rules) > 0 { | ||||||
|  | 		changed = policy.FilterMachinesByACL(machine, changed, rules) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	tailPeers, err := tailNodes(changed, pol, m.dnsCfg, m.baseDomain) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Peers is always returned sorted by Node.ID.
 | ||||||
|  | 	sort.SliceStable(tailPeers, func(x, y int) bool { | ||||||
|  | 		return tailPeers[x].ID < tailPeers[y].ID | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	resp := m.baseMapResponse(machine) | ||||||
|  | 	resp.PeersChanged = tailPeers | ||||||
|  | 	resp.PeerSeenChange = lastSeen | ||||||
|  | 
 | ||||||
|  | 	return m.marshalMapResponse(&resp, machine, mapRequest.Compress) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Mapper) PeerRemovedResponse( | ||||||
|  | 	mapRequest tailcfg.MapRequest, | ||||||
|  | 	machine *types.Machine, | ||||||
|  | 	removed []tailcfg.NodeID, | ||||||
|  | ) ([]byte, error) { | ||||||
|  | 	resp := m.baseMapResponse(machine) | ||||||
|  | 	resp.PeersRemoved = removed | ||||||
|  | 
 | ||||||
|  | 	return m.marshalMapResponse(&resp, machine, mapRequest.Compress) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m Mapper) marshalMapResponse( | ||||||
|  | 	resp *tailcfg.MapResponse, | ||||||
|  | 	machine *types.Machine, | ||||||
|  | 	compression string, | ||||||
|  | ) ([]byte, error) { | ||||||
| 	var machineKey key.MachinePublic | 	var machineKey key.MachinePublic | ||||||
| 	err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) | 	err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @ -346,40 +405,6 @@ func (m Mapper) CreateKeepAliveResponse( | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return m.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // MarshalResponse takes an Tailscale Response, marhsal it to JSON.
 |  | ||||||
| // If isNoise is set, then the JSON body will be returned
 |  | ||||||
| // If !isNoise and privateKey2019 is set, the JSON body will be sealed in a Nacl box.
 |  | ||||||
| func MarshalResponse( |  | ||||||
| 	resp interface{}, |  | ||||||
| 	isNoise bool, |  | ||||||
| 	privateKey2019 *key.MachinePrivate, |  | ||||||
| 	machineKey key.MachinePublic, |  | ||||||
| ) ([]byte, error) { |  | ||||||
| 	jsonBody, err := json.Marshal(resp) |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Error(). |  | ||||||
| 			Caller(). |  | ||||||
| 			Err(err). |  | ||||||
| 			Msg("Cannot marshal response") |  | ||||||
| 
 |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if !isNoise && privateKey2019 != nil { |  | ||||||
| 		return privateKey2019.SealTo(machineKey, jsonBody), nil |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return jsonBody, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m Mapper) marshalMapResponse( |  | ||||||
| 	resp interface{}, |  | ||||||
| 	machineKey key.MachinePublic, |  | ||||||
| 	compression string, |  | ||||||
| ) ([]byte, error) { |  | ||||||
| 	jsonBody, err := json.Marshal(resp) | 	jsonBody, err := json.Marshal(resp) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| @ -409,6 +434,32 @@ func (m Mapper) marshalMapResponse( | |||||||
| 	return data, nil | 	return data, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // MarshalResponse takes an Tailscale Response, marhsal it to JSON.
 | ||||||
|  | // If isNoise is set, then the JSON body will be returned
 | ||||||
|  | // If !isNoise and privateKey2019 is set, the JSON body will be sealed in a Nacl box.
 | ||||||
|  | func MarshalResponse( | ||||||
|  | 	resp interface{}, | ||||||
|  | 	isNoise bool, | ||||||
|  | 	privateKey2019 *key.MachinePrivate, | ||||||
|  | 	machineKey key.MachinePublic, | ||||||
|  | ) ([]byte, error) { | ||||||
|  | 	jsonBody, err := json.Marshal(resp) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Error(). | ||||||
|  | 			Caller(). | ||||||
|  | 			Err(err). | ||||||
|  | 			Msg("Cannot marshal response") | ||||||
|  | 
 | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !isNoise && privateKey2019 != nil { | ||||||
|  | 		return privateKey2019.SealTo(machineKey, jsonBody), nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return jsonBody, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func zstdEncode(in []byte) []byte { | func zstdEncode(in []byte) []byte { | ||||||
| 	encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder) | 	encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| @ -433,3 +484,19 @@ var zstdEncoderPool = &sync.Pool{ | |||||||
| 		return encoder | 		return encoder | ||||||
| 	}, | 	}, | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func (m *Mapper) baseMapResponse(machine *types.Machine) tailcfg.MapResponse { | ||||||
|  | 	now := time.Now() | ||||||
|  | 
 | ||||||
|  | 	resp := tailcfg.MapResponse{ | ||||||
|  | 		KeepAlive:   false, | ||||||
|  | 		ControlTime: &now, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	online, err := m.db.ListOnlineMachines(machine) | ||||||
|  | 	if err == nil { | ||||||
|  | 		resp.OnlineChange = online | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return resp | ||||||
|  | } | ||||||
|  | |||||||
| @ -387,6 +387,7 @@ func Test_fullMapResponse(t *testing.T) { | |||||||
| 				DNSConfig:       &tailcfg.DNSConfig{}, | 				DNSConfig:       &tailcfg.DNSConfig{}, | ||||||
| 				Domain:          "", | 				Domain:          "", | ||||||
| 				CollectServices: "false", | 				CollectServices: "false", | ||||||
|  | 				OnlineChange:    map[tailcfg.NodeID]bool{tailPeer1.ID: false}, | ||||||
| 				PacketFilter:    []tailcfg.FilterRule{}, | 				PacketFilter:    []tailcfg.FilterRule{}, | ||||||
| 				UserProfiles:    []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, | 				UserProfiles:    []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, | ||||||
| 				SSHPolicy:       &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, | 				SSHPolicy:       &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, | ||||||
| @ -428,6 +429,7 @@ func Test_fullMapResponse(t *testing.T) { | |||||||
| 				DNSConfig:       &tailcfg.DNSConfig{}, | 				DNSConfig:       &tailcfg.DNSConfig{}, | ||||||
| 				Domain:          "", | 				Domain:          "", | ||||||
| 				CollectServices: "false", | 				CollectServices: "false", | ||||||
|  | 				OnlineChange:    map[tailcfg.NodeID]bool{tailPeer1.ID: false}, | ||||||
| 				PacketFilter: []tailcfg.FilterRule{ | 				PacketFilter: []tailcfg.FilterRule{ | ||||||
| 					{ | 					{ | ||||||
| 						SrcIPs: []string{"100.64.0.2/32"}, | 						SrcIPs: []string{"100.64.0.2/32"}, | ||||||
|  | |||||||
| @ -3,24 +3,25 @@ package notifier | |||||||
| import ( | import ( | ||||||
| 	"sync" | 	"sync" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Notifier struct { | type Notifier struct { | ||||||
| 	l     sync.RWMutex | 	l     sync.RWMutex | ||||||
| 	nodes map[string]chan<- struct{} | 	nodes map[string]chan<- types.StateUpdate | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewNotifier() *Notifier { | func NewNotifier() *Notifier { | ||||||
| 	return &Notifier{} | 	return &Notifier{} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (n *Notifier) AddNode(machineKey string, c chan<- struct{}) { | func (n *Notifier) AddNode(machineKey string, c chan<- types.StateUpdate) { | ||||||
| 	n.l.Lock() | 	n.l.Lock() | ||||||
| 	defer n.l.Unlock() | 	defer n.l.Unlock() | ||||||
| 
 | 
 | ||||||
| 	if n.nodes == nil { | 	if n.nodes == nil { | ||||||
| 		n.nodes = make(map[string]chan<- struct{}) | 		n.nodes = make(map[string]chan<- types.StateUpdate) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	n.nodes[machineKey] = c | 	n.nodes[machineKey] = c | ||||||
| @ -37,11 +38,11 @@ func (n *Notifier) RemoveNode(machineKey string) { | |||||||
| 	delete(n.nodes, machineKey) | 	delete(n.nodes, machineKey) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (n *Notifier) NotifyAll() { | func (n *Notifier) NotifyAll(update types.StateUpdate) { | ||||||
| 	n.NotifyWithIgnore() | 	n.NotifyWithIgnore(update) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (n *Notifier) NotifyWithIgnore(ignore ...string) { | func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) { | ||||||
| 	n.l.RLock() | 	n.l.RLock() | ||||||
| 	defer n.l.RUnlock() | 	defer n.l.RUnlock() | ||||||
| 
 | 
 | ||||||
| @ -50,6 +51,6 @@ func (n *Notifier) NotifyWithIgnore(ignore ...string) { | |||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		c <- struct{}{} | 		c <- update | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -116,7 +116,7 @@ func (h *Headscale) handlePoll( | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy) | 	mapResp, err := mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logErr(err, "Failed to create MapResponse") | 		logErr(err, "Failed to create MapResponse") | ||||||
| 		http.Error(writer, "", http.StatusInternalServerError) | 		http.Error(writer, "", http.StatusInternalServerError) | ||||||
| @ -163,7 +163,12 @@ func (h *Headscale) handlePoll( | |||||||
| 			Inc() | 			Inc() | ||||||
| 
 | 
 | ||||||
| 		// Tell all the other nodes about the new endpoint, but dont update ourselves.
 | 		// Tell all the other nodes about the new endpoint, but dont update ourselves.
 | ||||||
| 		h.nodeNotifier.NotifyWithIgnore(machine.MachineKey) | 		h.nodeNotifier.NotifyWithIgnore( | ||||||
|  | 			types.StateUpdate{ | ||||||
|  | 				Type:    types.StatePeerChanged, | ||||||
|  | 				Changed: []uint64{machine.ID}, | ||||||
|  | 			}, | ||||||
|  | 			machine.MachineKey) | ||||||
| 
 | 
 | ||||||
| 		return | 		return | ||||||
| 	} else if mapRequest.OmitPeers && mapRequest.Stream { | 	} else if mapRequest.OmitPeers && mapRequest.Stream { | ||||||
| @ -220,7 +225,7 @@ func (h *Headscale) pollNetMapStream( | |||||||
| 	keepAliveTicker := time.NewTicker(keepAliveInterval) | 	keepAliveTicker := time.NewTicker(keepAliveInterval) | ||||||
| 
 | 
 | ||||||
| 	const chanSize = 8 | 	const chanSize = 8 | ||||||
| 	updateChan := make(chan struct{}, chanSize) | 	updateChan := make(chan types.StateUpdate, chanSize) | ||||||
| 
 | 
 | ||||||
| 	h.pollNetMapStreamWG.Add(1) | 	h.pollNetMapStreamWG.Add(1) | ||||||
| 	defer h.pollNetMapStreamWG.Done() | 	defer h.pollNetMapStreamWG.Done() | ||||||
| @ -238,7 +243,7 @@ func (h *Headscale) pollNetMapStream( | |||||||
| 	for { | 	for { | ||||||
| 		select { | 		select { | ||||||
| 		case <-keepAliveTicker.C: | 		case <-keepAliveTicker.C: | ||||||
| 			data, err := mapp.CreateKeepAliveResponse(mapRequest, machine) | 			data, err := mapp.KeepAliveResponse(mapRequest, machine) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				logErr(err, "Error generating the keep alive msg") | 				logErr(err, "Error generating the keep alive msg") | ||||||
| 
 | 
 | ||||||
| @ -263,10 +268,23 @@ func (h *Headscale) pollNetMapStream( | |||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 		case <-updateChan: | 		case update := <-updateChan: | ||||||
| 			data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy) | 			var data []byte | ||||||
|  | 			var err error | ||||||
|  | 
 | ||||||
|  | 			switch update.Type { | ||||||
|  | 			case types.StateFullUpdate: | ||||||
|  | 				data, err = mapp.FullMapResponse(mapRequest, machine, h.ACLPolicy) | ||||||
|  | 			case types.StatePeerChanged: | ||||||
|  | 				data, err = mapp.PeerChangedResponse(mapRequest, machine, update.Changed, h.ACLPolicy) | ||||||
|  | 			case types.StatePeerRemoved: | ||||||
|  | 				data, err = mapp.PeerRemovedResponse(mapRequest, machine, update.Removed) | ||||||
|  | 			case types.StateDERPUpdated: | ||||||
|  | 				data, err = mapp.DERPMapResponse(mapRequest, machine, update.DERPMap) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				logErr(err, "Could not get the map update") | 				logErr(err, "Could not get the create map update") | ||||||
| 
 | 
 | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| @ -317,7 +335,7 @@ func (h *Headscale) pollNetMapStream( | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) { | func closeChanWithLog[C chan []byte | chan struct{} | chan types.StateUpdate](channel C, machine, name string) { | ||||||
| 	log.Trace(). | 	log.Trace(). | ||||||
| 		Str("handler", "PollNetMap"). | 		Str("handler", "PollNetMap"). | ||||||
| 		Str("machine", machine). | 		Str("machine", machine). | ||||||
|  | |||||||
| @ -106,3 +106,32 @@ func (i StringList) Value() (driver.Value, error) { | |||||||
| 
 | 
 | ||||||
| 	return string(bytes), err | 	return string(bytes), err | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | type StateUpdateType int | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	StateFullUpdate StateUpdateType = iota | ||||||
|  | 	StatePeerChanged | ||||||
|  | 	StatePeerRemoved | ||||||
|  | 	StateDERPUpdated | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // StateUpdate is an internal message containing information about
 | ||||||
|  | // a state change that has happened to the network.
 | ||||||
|  | type StateUpdate struct { | ||||||
|  | 	// The type of update
 | ||||||
|  | 	Type StateUpdateType | ||||||
|  | 
 | ||||||
|  | 	// Changed must be set when Type is StatePeerChanged and
 | ||||||
|  | 	// contain the Machine IDs of machines that has changed.
 | ||||||
|  | 	Changed []uint64 | ||||||
|  | 
 | ||||||
|  | 	// Removed must be set when Type is StatePeerRemoved and
 | ||||||
|  | 	// contain a list of the nodes that has been removed from
 | ||||||
|  | 	// the network.
 | ||||||
|  | 	Removed []tailcfg.NodeID | ||||||
|  | 
 | ||||||
|  | 	// DERPMap must be set when Type is StateDERPUpdated and
 | ||||||
|  | 	// contain the new DERP Map.
 | ||||||
|  | 	DERPMap tailcfg.DERPMap | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user