mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	move MapResponse peer logic into function and reuse
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									387aa03adb
								
							
						
					
					
						commit
						432e975a7f
					
				@ -92,6 +92,8 @@ type Headscale struct {
 | 
			
		||||
 | 
			
		||||
	shutdownChan       chan struct{}
 | 
			
		||||
	pollNetMapStreamWG sync.WaitGroup
 | 
			
		||||
 | 
			
		||||
	pollStreamOpenMu sync.Mutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
 | 
			
		||||
 | 
			
		||||
@ -340,6 +340,8 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		machine := &route.Machine
 | 
			
		||||
 | 
			
		||||
		if !route.IsPrimary {
 | 
			
		||||
			_, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix))
 | 
			
		||||
			if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) {
 | 
			
		||||
@ -355,7 +357,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
 | 
			
		||||
					return err
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				changedMachines = append(changedMachines, &route.Machine)
 | 
			
		||||
				changedMachines = append(changedMachines, machine)
 | 
			
		||||
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
@ -429,7 +431,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			changedMachines = append(changedMachines, &route.Machine)
 | 
			
		||||
			changedMachines = append(changedMachines, machine)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -38,6 +38,16 @@ const (
 | 
			
		||||
 | 
			
		||||
var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH")
 | 
			
		||||
 | 
			
		||||
// TODO: Optimise
 | 
			
		||||
// As this work continues, the idea is that there will be one Mapper instance
 | 
			
		||||
// per node, attached to the open stream between the control and client.
 | 
			
		||||
// This means that this can hold a state per machine and we can use that to
 | 
			
		||||
// improve the mapresponses sent.
 | 
			
		||||
// We could:
 | 
			
		||||
// - Keep information about the previous mapresponse so we can send a diff
 | 
			
		||||
// - Store hashes
 | 
			
		||||
// - Create a "minifier" that removes info not needed for the node
 | 
			
		||||
 | 
			
		||||
type Mapper struct {
 | 
			
		||||
	privateKey2019 *key.MachinePrivate
 | 
			
		||||
	isNoise        bool
 | 
			
		||||
@ -102,105 +112,6 @@ func (m *Mapper) String() string {
 | 
			
		||||
	return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: Optimise
 | 
			
		||||
// As this work continues, the idea is that there will be one Mapper instance
 | 
			
		||||
// per node, attached to the open stream between the control and client.
 | 
			
		||||
// This means that this can hold a state per machine and we can use that to
 | 
			
		||||
// improve the mapresponses sent.
 | 
			
		||||
// We could:
 | 
			
		||||
// - Keep information about the previous mapresponse so we can send a diff
 | 
			
		||||
// - Store hashes
 | 
			
		||||
// - Create a "minifier" that removes info not needed for the node
 | 
			
		||||
 | 
			
		||||
// fullMapResponse is the internal function for generating a MapResponse
 | 
			
		||||
// for a machine.
 | 
			
		||||
func fullMapResponse(
 | 
			
		||||
	pol *policy.ACLPolicy,
 | 
			
		||||
	machine *types.Machine,
 | 
			
		||||
	peers types.Machines,
 | 
			
		||||
 | 
			
		||||
	baseDomain string,
 | 
			
		||||
	dnsCfg *tailcfg.DNSConfig,
 | 
			
		||||
	derpMap *tailcfg.DERPMap,
 | 
			
		||||
	logtail bool,
 | 
			
		||||
	randomClientPort bool,
 | 
			
		||||
) (*tailcfg.MapResponse, error) {
 | 
			
		||||
	tailnode, err := tailNode(machine, pol, dnsCfg, baseDomain)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
 | 
			
		||||
	resp := tailcfg.MapResponse{
 | 
			
		||||
		Node: tailnode,
 | 
			
		||||
 | 
			
		||||
		DERPMap: derpMap,
 | 
			
		||||
 | 
			
		||||
		Domain: baseDomain,
 | 
			
		||||
 | 
			
		||||
		// Do not instruct clients to collect services we do not
 | 
			
		||||
		// support or do anything with them
 | 
			
		||||
		CollectServices: "false",
 | 
			
		||||
 | 
			
		||||
		ControlTime:  &now,
 | 
			
		||||
		KeepAlive:    false,
 | 
			
		||||
		OnlineChange: db.OnlineMachineMap(peers),
 | 
			
		||||
 | 
			
		||||
		Debug: &tailcfg.Debug{
 | 
			
		||||
			DisableLogTail:      !logtail,
 | 
			
		||||
			RandomizeClientPort: randomClientPort,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if peers != nil || len(peers) > 0 {
 | 
			
		||||
		rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
 | 
			
		||||
			pol,
 | 
			
		||||
			machine,
 | 
			
		||||
			peers,
 | 
			
		||||
		)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Filter out peers that have expired.
 | 
			
		||||
		peers = filterExpiredAndNotReady(peers)
 | 
			
		||||
 | 
			
		||||
		// If there are filter rules present, see if there are any machines that cannot
 | 
			
		||||
		// access eachother at all and remove them from the peers.
 | 
			
		||||
		if len(rules) > 0 {
 | 
			
		||||
			peers = policy.FilterMachinesByACL(machine, peers, rules)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		profiles := generateUserProfiles(machine, peers, baseDomain)
 | 
			
		||||
 | 
			
		||||
		dnsConfig := generateDNSConfig(
 | 
			
		||||
			dnsCfg,
 | 
			
		||||
			baseDomain,
 | 
			
		||||
			machine,
 | 
			
		||||
			peers,
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Peers is always returned sorted by Node.ID.
 | 
			
		||||
		sort.SliceStable(tailPeers, func(x, y int) bool {
 | 
			
		||||
			return tailPeers[x].ID < tailPeers[y].ID
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		resp.Peers = tailPeers
 | 
			
		||||
		resp.DNSConfig = dnsConfig
 | 
			
		||||
		resp.PacketFilter = policy.ReduceFilterRules(machine, rules)
 | 
			
		||||
		resp.UserProfiles = profiles
 | 
			
		||||
		resp.SSHPolicy = sshPolicy
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &resp, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func generateUserProfiles(
 | 
			
		||||
	machine *types.Machine,
 | 
			
		||||
	peers types.Machines,
 | 
			
		||||
@ -294,6 +205,38 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine *types.Machine) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// fullMapResponse creates a complete MapResponse for a node.
 | 
			
		||||
// It is a separate function to make testing easier.
 | 
			
		||||
func (m *Mapper) fullMapResponse(
 | 
			
		||||
	machine *types.Machine,
 | 
			
		||||
	pol *policy.ACLPolicy,
 | 
			
		||||
) (*tailcfg.MapResponse, error) {
 | 
			
		||||
	peers := machineMapToList(m.peers)
 | 
			
		||||
 | 
			
		||||
	resp, err := m.baseWithConfigMapResponse(machine, pol)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO(kradalby): Move this into appendPeerChanges?
 | 
			
		||||
	resp.OnlineChange = db.OnlineMachineMap(peers)
 | 
			
		||||
 | 
			
		||||
	err = appendPeerChanges(
 | 
			
		||||
		resp,
 | 
			
		||||
		pol,
 | 
			
		||||
		machine,
 | 
			
		||||
		peers,
 | 
			
		||||
		peers,
 | 
			
		||||
		m.baseDomain,
 | 
			
		||||
		m.dnsCfg,
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return resp, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FullMapResponse returns a MapResponse for the given machine.
 | 
			
		||||
func (m *Mapper) FullMapResponse(
 | 
			
		||||
	mapRequest tailcfg.MapRequest,
 | 
			
		||||
@ -303,25 +246,16 @@ func (m *Mapper) FullMapResponse(
 | 
			
		||||
	m.mu.Lock()
 | 
			
		||||
	defer m.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	mapResponse, err := fullMapResponse(
 | 
			
		||||
		pol,
 | 
			
		||||
		machine,
 | 
			
		||||
		machineMapToList(m.peers),
 | 
			
		||||
		m.baseDomain,
 | 
			
		||||
		m.dnsCfg,
 | 
			
		||||
		m.derpMap,
 | 
			
		||||
		m.logtail,
 | 
			
		||||
		m.randomClientPort,
 | 
			
		||||
	)
 | 
			
		||||
	resp, err := m.fullMapResponse(machine, pol)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if m.isNoise {
 | 
			
		||||
		return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress)
 | 
			
		||||
		return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress)
 | 
			
		||||
	return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// LiteMapResponse returns a MapResponse for the given machine.
 | 
			
		||||
@ -332,32 +266,23 @@ func (m *Mapper) LiteMapResponse(
 | 
			
		||||
	machine *types.Machine,
 | 
			
		||||
	pol *policy.ACLPolicy,
 | 
			
		||||
) ([]byte, error) {
 | 
			
		||||
	mapResponse, err := fullMapResponse(
 | 
			
		||||
		pol,
 | 
			
		||||
		machine,
 | 
			
		||||
		nil,
 | 
			
		||||
		m.baseDomain,
 | 
			
		||||
		m.dnsCfg,
 | 
			
		||||
		m.derpMap,
 | 
			
		||||
		m.logtail,
 | 
			
		||||
		m.randomClientPort,
 | 
			
		||||
	)
 | 
			
		||||
	resp, err := m.baseWithConfigMapResponse(machine, pol)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if m.isNoise {
 | 
			
		||||
		return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress)
 | 
			
		||||
		return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return m.marshalMapResponse(mapRequest, mapResponse, machine, mapRequest.Compress)
 | 
			
		||||
	return m.marshalMapResponse(mapRequest, resp, machine, mapRequest.Compress)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *Mapper) KeepAliveResponse(
 | 
			
		||||
	mapRequest tailcfg.MapRequest,
 | 
			
		||||
	machine *types.Machine,
 | 
			
		||||
) ([]byte, error) {
 | 
			
		||||
	resp := m.baseMapResponse(machine)
 | 
			
		||||
	resp := m.baseMapResponse()
 | 
			
		||||
	resp.KeepAlive = true
 | 
			
		||||
 | 
			
		||||
	return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
 | 
			
		||||
@ -368,7 +293,7 @@ func (m *Mapper) DERPMapResponse(
 | 
			
		||||
	machine *types.Machine,
 | 
			
		||||
	derpMap tailcfg.DERPMap,
 | 
			
		||||
) ([]byte, error) {
 | 
			
		||||
	resp := m.baseMapResponse(machine)
 | 
			
		||||
	resp := m.baseMapResponse()
 | 
			
		||||
	resp.DERPMap = &derpMap
 | 
			
		||||
 | 
			
		||||
	return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
 | 
			
		||||
@ -383,7 +308,6 @@ func (m *Mapper) PeerChangedResponse(
 | 
			
		||||
	m.mu.Lock()
 | 
			
		||||
	defer m.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	lastSeen := make(map[tailcfg.NodeID]bool)
 | 
			
		||||
 | 
			
		||||
	// Update our internal map.
 | 
			
		||||
@ -394,37 +318,21 @@ func (m *Mapper) PeerChangedResponse(
 | 
			
		||||
		lastSeen[tailcfg.NodeID(machine.ID)] = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
 | 
			
		||||
	resp := m.baseMapResponse()
 | 
			
		||||
 | 
			
		||||
	err := appendPeerChanges(
 | 
			
		||||
		&resp,
 | 
			
		||||
		pol,
 | 
			
		||||
		machine,
 | 
			
		||||
		machineMapToList(m.peers),
 | 
			
		||||
		changed,
 | 
			
		||||
		m.baseDomain,
 | 
			
		||||
		m.dnsCfg,
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	changed = filterExpiredAndNotReady(changed)
 | 
			
		||||
 | 
			
		||||
	// If there are filter rules present, see if there are any machines that cannot
 | 
			
		||||
	// access eachother at all and remove them from the changed.
 | 
			
		||||
	if len(rules) > 0 {
 | 
			
		||||
		changed = policy.FilterMachinesByACL(machine, changed, rules)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tailPeers, err := tailNodes(changed, pol, m.dnsCfg, m.baseDomain)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Peers is always returned sorted by Node.ID.
 | 
			
		||||
	sort.SliceStable(tailPeers, func(x, y int) bool {
 | 
			
		||||
		return tailPeers[x].ID < tailPeers[y].ID
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	resp := m.baseMapResponse(machine)
 | 
			
		||||
	resp.PeersChanged = tailPeers
 | 
			
		||||
	resp.PacketFilter = policy.ReduceFilterRules(machine, rules)
 | 
			
		||||
	resp.SSHPolicy = sshPolicy
 | 
			
		||||
	// resp.PeerSeenChange = lastSeen
 | 
			
		||||
 | 
			
		||||
	return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
 | 
			
		||||
@ -443,7 +351,7 @@ func (m *Mapper) PeerRemovedResponse(
 | 
			
		||||
		delete(m.peers, uint64(id))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp := m.baseMapResponse(machine)
 | 
			
		||||
	resp := m.baseMapResponse()
 | 
			
		||||
	resp.PeersRemoved = removed
 | 
			
		||||
 | 
			
		||||
	return m.marshalMapResponse(mapRequest, &resp, machine, mapRequest.Compress)
 | 
			
		||||
@ -497,7 +405,7 @@ func (m *Mapper) marshalMapResponse(
 | 
			
		||||
			panic(err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		now := time.Now().Unix()
 | 
			
		||||
		now := time.Now().UnixNano()
 | 
			
		||||
 | 
			
		||||
		mapResponsePath := path.Join(
 | 
			
		||||
			mPath,
 | 
			
		||||
@ -583,7 +491,9 @@ var zstdEncoderPool = &sync.Pool{
 | 
			
		||||
	},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *Mapper) baseMapResponse(_ *types.Machine) tailcfg.MapResponse {
 | 
			
		||||
// baseMapResponse returns a tailcfg.MapResponse with
 | 
			
		||||
// KeepAlive false and ControlTime set to now.
 | 
			
		||||
func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
 | 
			
		||||
	resp := tailcfg.MapResponse{
 | 
			
		||||
@ -591,14 +501,43 @@ func (m *Mapper) baseMapResponse(_ *types.Machine) tailcfg.MapResponse {
 | 
			
		||||
		ControlTime: &now,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// online, err := m.db.ListOnlineMachines(machine)
 | 
			
		||||
	// if err == nil {
 | 
			
		||||
	// 	resp.OnlineChange = online
 | 
			
		||||
	// }
 | 
			
		||||
 | 
			
		||||
	return resp
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// baseWithConfigMapResponse returns a tailcfg.MapResponse struct
 | 
			
		||||
// with the basic configuration from headscale set.
 | 
			
		||||
// It is used in for bigger updates, such as full and lite, not
 | 
			
		||||
// incremental.
 | 
			
		||||
func (m *Mapper) baseWithConfigMapResponse(
 | 
			
		||||
	machine *types.Machine,
 | 
			
		||||
	pol *policy.ACLPolicy,
 | 
			
		||||
) (*tailcfg.MapResponse, error) {
 | 
			
		||||
	resp := m.baseMapResponse()
 | 
			
		||||
 | 
			
		||||
	tailnode, err := tailNode(machine, pol, m.dnsCfg, m.baseDomain)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	resp.Node = tailnode
 | 
			
		||||
 | 
			
		||||
	resp.DERPMap = m.derpMap
 | 
			
		||||
 | 
			
		||||
	resp.Domain = m.baseDomain
 | 
			
		||||
 | 
			
		||||
	// Do not instruct clients to collect services we do not
 | 
			
		||||
	// support or do anything with them
 | 
			
		||||
	resp.CollectServices = "false"
 | 
			
		||||
 | 
			
		||||
	resp.KeepAlive = false
 | 
			
		||||
 | 
			
		||||
	resp.Debug = &tailcfg.Debug{
 | 
			
		||||
		DisableLogTail:      !m.logtail,
 | 
			
		||||
		RandomizeClientPort: m.randomClientPort,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &resp, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func machineMapToList(machines map[uint64]*types.Machine) types.Machines {
 | 
			
		||||
	ret := make(types.Machines, 0)
 | 
			
		||||
 | 
			
		||||
@ -617,3 +556,67 @@ func filterExpiredAndNotReady(peers types.Machines) types.Machines {
 | 
			
		||||
		return !item.IsExpired() || len(item.Endpoints) > 0
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// appendPeerChanges mutates a tailcfg.MapResponse with all the
 | 
			
		||||
// necessary changes when peers have changed.
 | 
			
		||||
func appendPeerChanges(
 | 
			
		||||
	resp *tailcfg.MapResponse,
 | 
			
		||||
 | 
			
		||||
	pol *policy.ACLPolicy,
 | 
			
		||||
	machine *types.Machine,
 | 
			
		||||
	peers types.Machines,
 | 
			
		||||
	changed types.Machines,
 | 
			
		||||
	baseDomain string,
 | 
			
		||||
	dnsCfg *tailcfg.DNSConfig,
 | 
			
		||||
) error {
 | 
			
		||||
	fullChange := len(peers) == len(changed)
 | 
			
		||||
 | 
			
		||||
	rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
 | 
			
		||||
		pol,
 | 
			
		||||
		machine,
 | 
			
		||||
		peers,
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Filter out peers that have expired.
 | 
			
		||||
	changed = filterExpiredAndNotReady(changed)
 | 
			
		||||
 | 
			
		||||
	// If there are filter rules present, see if there are any machines that cannot
 | 
			
		||||
	// access eachother at all and remove them from the peers.
 | 
			
		||||
	if len(rules) > 0 {
 | 
			
		||||
		changed = policy.FilterMachinesByACL(machine, changed, rules)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	profiles := generateUserProfiles(machine, changed, baseDomain)
 | 
			
		||||
 | 
			
		||||
	dnsConfig := generateDNSConfig(
 | 
			
		||||
		dnsCfg,
 | 
			
		||||
		baseDomain,
 | 
			
		||||
		machine,
 | 
			
		||||
		peers,
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	tailPeers, err := tailNodes(changed, pol, dnsCfg, baseDomain)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Peers is always returned sorted by Node.ID.
 | 
			
		||||
	sort.SliceStable(tailPeers, func(x, y int) bool {
 | 
			
		||||
		return tailPeers[x].ID < tailPeers[y].ID
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if fullChange {
 | 
			
		||||
		resp.Peers = tailPeers
 | 
			
		||||
	} else {
 | 
			
		||||
		resp.PeersChanged = tailPeers
 | 
			
		||||
	}
 | 
			
		||||
	resp.DNSConfig = dnsConfig
 | 
			
		||||
	resp.PacketFilter = policy.ReduceFilterRules(machine, rules)
 | 
			
		||||
	resp.UserProfiles = profiles
 | 
			
		||||
	resp.SSHPolicy = sshPolicy
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -441,9 +441,11 @@ func Test_fullMapResponse(t *testing.T) {
 | 
			
		||||
						},
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
				UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
 | 
			
		||||
				SSHPolicy:    &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
 | 
			
		||||
				ControlTime:  &time.Time{},
 | 
			
		||||
				UserProfiles: []tailcfg.UserProfile{
 | 
			
		||||
					{LoginName: "mini", DisplayName: "mini"},
 | 
			
		||||
				},
 | 
			
		||||
				SSHPolicy:   &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
 | 
			
		||||
				ControlTime: &time.Time{},
 | 
			
		||||
				Debug: &tailcfg.Debug{
 | 
			
		||||
					DisableLogTail: true,
 | 
			
		||||
				},
 | 
			
		||||
@ -454,17 +456,23 @@ func Test_fullMapResponse(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			got, err := fullMapResponse(
 | 
			
		||||
				tt.pol,
 | 
			
		||||
			mappy := NewMapper(
 | 
			
		||||
				tt.machine,
 | 
			
		||||
				tt.peers,
 | 
			
		||||
				nil,
 | 
			
		||||
				false,
 | 
			
		||||
				tt.derpMap,
 | 
			
		||||
				tt.baseDomain,
 | 
			
		||||
				tt.dnsConfig,
 | 
			
		||||
				tt.derpMap,
 | 
			
		||||
				tt.logtail,
 | 
			
		||||
				tt.randomClientPort,
 | 
			
		||||
			)
 | 
			
		||||
 | 
			
		||||
			got, err := mappy.fullMapResponse(
 | 
			
		||||
				tt.machine,
 | 
			
		||||
				tt.pol,
 | 
			
		||||
			)
 | 
			
		||||
 | 
			
		||||
			if (err != nil) != tt.wantErr {
 | 
			
		||||
				t.Errorf("fullMapResponse() error = %v, wantErr %v", err, tt.wantErr)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,6 +55,8 @@ func logPollFunc(
 | 
			
		||||
 | 
			
		||||
// handlePoll is the common code for the legacy and Noise protocols to
 | 
			
		||||
// managed the poll loop.
 | 
			
		||||
//
 | 
			
		||||
//nolint:gocyclo
 | 
			
		||||
func (h *Headscale) handlePoll(
 | 
			
		||||
	writer http.ResponseWriter,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
@ -67,6 +69,7 @@ func (h *Headscale) handlePoll(
 | 
			
		||||
	// following updates missing
 | 
			
		||||
	var updateChan chan types.StateUpdate
 | 
			
		||||
	if mapRequest.Stream {
 | 
			
		||||
		h.pollStreamOpenMu.Lock()
 | 
			
		||||
		h.pollNetMapStreamWG.Add(1)
 | 
			
		||||
		defer h.pollNetMapStreamWG.Done()
 | 
			
		||||
 | 
			
		||||
@ -251,6 +254,8 @@ func (h *Headscale) handlePoll(
 | 
			
		||||
	ctx, cancel := context.WithCancel(ctx)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	h.pollStreamOpenMu.Unlock()
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		logInfo("Waiting for update on stream channel")
 | 
			
		||||
		select {
 | 
			
		||||
 | 
			
		||||
@ -407,9 +407,8 @@ func TestResolveMagicDNS(t *testing.T) {
 | 
			
		||||
	defer scenario.Shutdown()
 | 
			
		||||
 | 
			
		||||
	spec := map[string]int{
 | 
			
		||||
		// Omit 1.16.2 (-1) because it does not have the FQDN field
 | 
			
		||||
		"magicdns1": len(MustTestVersions) - 1,
 | 
			
		||||
		"magicdns2": len(MustTestVersions) - 1,
 | 
			
		||||
		"magicdns1": len(MustTestVersions),
 | 
			
		||||
		"magicdns2": len(MustTestVersions),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns"))
 | 
			
		||||
 | 
			
		||||
@ -20,10 +20,11 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	tsicHashLength    = 6
 | 
			
		||||
	defaultPingCount  = 10
 | 
			
		||||
	dockerContextPath = "../."
 | 
			
		||||
	headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt"
 | 
			
		||||
	tsicHashLength     = 6
 | 
			
		||||
	defaultPingTimeout = 300 * time.Millisecond
 | 
			
		||||
	defaultPingCount   = 10
 | 
			
		||||
	dockerContextPath  = "../."
 | 
			
		||||
	headscaleCertPath  = "/usr/local/share/ca-certificates/headscale.crt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
@ -591,7 +592,7 @@ func WithPingUntilDirect(direct bool) PingOption {
 | 
			
		||||
// TODO(kradalby): Make multiping, go routine magic.
 | 
			
		||||
func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) error {
 | 
			
		||||
	args := pingArgs{
 | 
			
		||||
		timeout: 300 * time.Millisecond,
 | 
			
		||||
		timeout: defaultPingTimeout,
 | 
			
		||||
		count:   defaultPingCount,
 | 
			
		||||
		direct:  true,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user