diff --git a/api.go b/api.go index 72cb92f6..90d9be2b 100644 --- a/api.go +++ b/api.go @@ -242,11 +242,20 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma DisplayName: m.Namespace.Name, } + nodePeers, err := peers.toNodes(true) + if err != nil { + log.Error(). + Str("func", "getMapResponse"). + Err(err). + Msg("Failed to convert peers to Tailscale nodes") + return nil, err + } + resp := tailcfg.MapResponse{ KeepAlive: false, Node: node, - Peers: *peers, - //TODO(kradalby): As per tailscale docs, if DNSConfig is nil, + Peers: nodePeers, + // TODO(kradalby): As per tailscale docs, if DNSConfig is nil, // it means its not updated, maybe we can have some logic // to check and only pass updates when its updates. // This is probably more relevant if we try to implement diff --git a/machine.go b/machine.go index f49c5a89..c87aba81 100644 --- a/machine.go +++ b/machine.go @@ -6,6 +6,7 @@ import ( "fmt" "sort" "strconv" + "strings" "time" "github.com/rs/zerolog/log" @@ -45,11 +46,329 @@ type Machine struct { DeletedAt *time.Time } +type ( + Machines []Machine + MachinesP []*Machine +) + // For the time being this method is rather naive func (m Machine) isAlreadyRegistered() bool { return m.Registered } +func (h *Headscale) getDirectPeers(m *Machine) (MachinesP, error) { + log.Trace(). + Str("func", "getDirectPeers"). + Str("machine", m.Name). + Msg("Finding peers") + + machines := []Machine{} + if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered", + m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil { + log.Error().Err(err).Msg("Error accessing db") + return nil, err + } + + peers := make(MachinesP, 0) + for _, peer := range machines { + peers = append(peers, &peer) + } + + sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) + + log.Trace(). + Str("func", "getDirectPeers"). + Str("machine", m.Name). + Msgf("Found peers: %s", peers.String()) + return peers, nil +} + +func (h *Headscale) getShared(m *Machine) (MachinesP, error) { + log.Trace(). + Str("func", "getShared"). + Str("machine", m.Name). + Msg("Finding shared peers") + + // We fetch here machines that are shared to the `Namespace` of the machine we are getting peers for + sharedMachines := []SharedMachine{} + if err := h.db.Preload("Namespace").Preload("Machine").Where("namespace_id = ?", + m.NamespaceID).Find(&sharedMachines).Error; err != nil { + return nil, err + } + + peers := make(MachinesP, 0) + for _, sharedMachine := range sharedMachines { + peers = append(peers, &sharedMachine.Machine) + } + + sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) + + log.Trace(). + Str("func", "getShared"). + Str("machine", m.Name). + Msgf("Found shared peers: %s", peers.String()) + return peers, nil +} + +func (h *Headscale) getPeers(m *Machine) (MachinesP, error) { + direct, err := h.getDirectPeers(m) + if err != nil { + log.Error(). + Str("func", "getPeers"). + Err(err). + Msg("Cannot fetch peers") + return nil, err + } + + shared, err := h.getShared(m) + if err != nil { + log.Error(). + Str("func", "getDirectPeers"). + Err(err). + Msg("Cannot fetch peers") + return nil, err + } + + return append(direct, shared...), nil +} + +// GetMachine finds a Machine by name and namespace and returns the Machine struct +func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) { + machines, err := h.ListMachinesInNamespace(namespace) + if err != nil { + return nil, err + } + + for _, m := range *machines { + if m.Name == name { + return &m, nil + } + } + return nil, fmt.Errorf("machine not found") +} + +// GetMachineByID finds a Machine by ID and returns the Machine struct +func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { + m := Machine{} + if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil { + return nil, result.Error + } + return &m, nil +} + +// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct +func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) { + m := Machine{} + if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil { + return nil, result.Error + } + return &m, nil +} + +// UpdateMachine takes a Machine struct pointer (typically already loaded from database +// and updates it with the latest data from the database. +func (h *Headscale) UpdateMachine(m *Machine) error { + if result := h.db.Find(m).First(&m); result.Error != nil { + return result.Error + } + return nil +} + +// DeleteMachine softs deletes a Machine from the database +func (h *Headscale) DeleteMachine(m *Machine) error { + m.Registered = false + namespaceID := m.NamespaceID + h.db.Save(&m) // we mark it as unregistered, just in case + if err := h.db.Delete(&m).Error; err != nil { + return err + } + + return h.RequestMapUpdates(namespaceID) +} + +// HardDeleteMachine hard deletes a Machine from the database +func (h *Headscale) HardDeleteMachine(m *Machine) error { + namespaceID := m.NamespaceID + if err := h.db.Unscoped().Delete(&m).Error; err != nil { + return err + } + return h.RequestMapUpdates(namespaceID) +} + +// GetHostInfo returns a Hostinfo struct for the machine +func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) { + hostinfo := tailcfg.Hostinfo{} + if len(m.HostInfo) != 0 { + hi, err := m.HostInfo.MarshalJSON() + if err != nil { + return nil, err + } + err = json.Unmarshal(hi, &hostinfo) + if err != nil { + return nil, err + } + } + return &hostinfo, nil +} + +func (h *Headscale) notifyChangesToPeers(m *Machine) { + peers, err := h.getPeers(m) + if err != nil { + log.Error(). + Str("func", "notifyChangesToPeers"). + Str("machine", m.Name). + Msgf("Error getting peers: %s", err) + return + } + for _, peer := range peers { + log.Info(). + Str("func", "notifyChangesToPeers"). + Str("machine", m.Name). + Str("peer", peer.Name). + Str("address", peer.IPAddress). + Msgf("Notifying peer %s (%s)", peer.Name, peer.IPAddress) + err := h.sendRequestOnUpdateChannel(peer) + if err != nil { + log.Info(). + Str("func", "notifyChangesToPeers"). + Str("machine", m.Name). + Str("peer", peer.Name). + Msgf("Peer %s does not appear to be polling", peer.Name) + } + log.Trace(). + Str("func", "notifyChangesToPeers"). + Str("machine", m.Name). + Str("peer", peer.Name). + Str("address", peer.IPAddress). + Msgf("Notified peer %s (%s)", peer.Name, peer.IPAddress) + } +} + +func (h *Headscale) getOrOpenUpdateChannel(m *Machine) <-chan struct{} { + var updateChan chan struct{} + if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { + if unwrapped, ok := storedChan.(chan struct{}); ok { + updateChan = unwrapped + } else { + log.Error(). + Str("handler", "openUpdateChannel"). + Str("machine", m.Name). + Msg("Failed to convert update channel to struct{}") + } + } else { + log.Debug(). + Str("handler", "openUpdateChannel"). + Str("machine", m.Name). + Msg("Update channel not found, creating") + + updateChan = make(chan struct{}) + h.clientsUpdateChannels.Store(m.ID, updateChan) + } + return updateChan +} + +func (h *Headscale) closeUpdateChannel(m *Machine) { + h.clientsUpdateChannelMutex.Lock() + defer h.clientsUpdateChannelMutex.Unlock() + + if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { + if unwrapped, ok := storedChan.(chan struct{}); ok { + close(unwrapped) + } + } + h.clientsUpdateChannels.Delete(m.ID) +} + +func (h *Headscale) sendRequestOnUpdateChannel(m *Machine) error { + h.clientsUpdateChannelMutex.Lock() + defer h.clientsUpdateChannelMutex.Unlock() + + pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID)) + if ok { + log.Info(). + Str("func", "requestUpdate"). + Str("machine", m.Name). + Msgf("Notifying peer %s", m.Name) + + if update, ok := pUp.(chan struct{}); ok { + log.Trace(). + Str("func", "requestUpdate"). + Str("machine", m.Name). + Msgf("Update channel is %#v", update) + + update <- struct{}{} + + log.Trace(). + Str("func", "requestUpdate"). + Str("machine", m.Name). + Msgf("Notified machine %s", m.Name) + } + } else { + log.Info(). + Str("func", "requestUpdate"). + Str("machine", m.Name). + Msgf("Machine %s does not appear to be polling", m.Name) + return errors.New("machine does not seem to be polling") + } + return nil +} + +func (h *Headscale) isOutdated(m *Machine) bool { + err := h.UpdateMachine(m) + if err != nil { + return true + } + + lastChange := h.getLastStateChange(m.Namespace.Name) + log.Trace(). + Str("func", "keepAlive"). + Str("machine", m.Name). + Time("last_successful_update", *m.LastSuccessfulUpdate). + Time("last_state_change", lastChange). + Msgf("Checking if %s is missing updates", m.Name) + return m.LastSuccessfulUpdate.Before(lastChange) +} + +func (m Machine) String() string { + return m.Name +} + +func (ms Machines) String() string { + temp := make([]string, len(ms)) + + for index, machine := range ms { + temp[index] = machine.Name + } + + return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) +} + +// TODO(kradalby): Remove when we have generics... +func (ms MachinesP) String() string { + temp := make([]string, len(ms)) + + for index, machine := range ms { + temp[index] = machine.Name + } + + return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) +} + +func (ms MachinesP) toNodes(includeRoutes bool) ([]*tailcfg.Node, error) { + nodes := make([]*tailcfg.Node, len(ms)) + + for index, machine := range ms { + node, err := machine.toNode(includeRoutes) + if err != nil { + return nil, err + } + + nodes[index] = node + } + + return nodes, nil +} + // toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes // as per the expected behaviour in the official SaaS func (m Machine) toNode(includeRoutes bool) (*tailcfg.Node, error) { @@ -171,244 +490,3 @@ func (m Machine) toNode(includeRoutes bool) (*tailcfg.Node, error) { } return &n, nil } - -func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) { - log.Trace(). - Str("func", "getPeers"). - Str("machine", m.Name). - Msg("Finding peers") - - machines := []Machine{} - if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered", - m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil { - log.Error().Err(err).Msg("Error accessing db") - return nil, err - } - - // We fetch here machines that are shared to the `Namespace` of the machine we are getting peers for - sharedMachines := []SharedMachine{} - if err := h.db.Preload("Namespace").Preload("Machine").Where("namespace_id = ?", - m.NamespaceID).Find(&sharedMachines).Error; err != nil { - return nil, err - } - - peers := []*tailcfg.Node{} - for _, mn := range machines { - peer, err := mn.toNode(true) - if err != nil { - return nil, err - } - peers = append(peers, peer) - } - for _, sharedMachine := range sharedMachines { - peer, err := sharedMachine.Machine.toNode(false) // shared nodes do not expose their routes - if err != nil { - return nil, err - } - peers = append(peers, peer) - } - sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) - - log.Trace(). - Str("func", "getPeers"). - Str("machine", m.Name). - Msgf("Found peers: %s", tailNodesToString(peers)) - return &peers, nil -} - -// GetMachine finds a Machine by name and namespace and returns the Machine struct -func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) { - machines, err := h.ListMachinesInNamespace(namespace) - if err != nil { - return nil, err - } - - for _, m := range *machines { - if m.Name == name { - return &m, nil - } - } - return nil, fmt.Errorf("machine not found") -} - -// GetMachineByID finds a Machine by ID and returns the Machine struct -func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { - m := Machine{} - if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil { - return nil, result.Error - } - return &m, nil -} - -// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct -func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) { - m := Machine{} - if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil { - return nil, result.Error - } - return &m, nil -} - -// UpdateMachine takes a Machine struct pointer (typically already loaded from database -// and updates it with the latest data from the database. -func (h *Headscale) UpdateMachine(m *Machine) error { - if result := h.db.Find(m).First(&m); result.Error != nil { - return result.Error - } - return nil -} - -// DeleteMachine softs deletes a Machine from the database -func (h *Headscale) DeleteMachine(m *Machine) error { - m.Registered = false - namespaceID := m.NamespaceID - h.db.Save(&m) // we mark it as unregistered, just in case - if err := h.db.Delete(&m).Error; err != nil { - return err - } - - return h.RequestMapUpdates(namespaceID) -} - -// HardDeleteMachine hard deletes a Machine from the database -func (h *Headscale) HardDeleteMachine(m *Machine) error { - namespaceID := m.NamespaceID - if err := h.db.Unscoped().Delete(&m).Error; err != nil { - return err - } - return h.RequestMapUpdates(namespaceID) -} - -// GetHostInfo returns a Hostinfo struct for the machine -func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) { - hostinfo := tailcfg.Hostinfo{} - if len(m.HostInfo) != 0 { - hi, err := m.HostInfo.MarshalJSON() - if err != nil { - return nil, err - } - err = json.Unmarshal(hi, &hostinfo) - if err != nil { - return nil, err - } - } - return &hostinfo, nil -} - -func (h *Headscale) notifyChangesToPeers(m *Machine) { - peers, err := h.getPeers(m) - if err != nil { - log.Error(). - Str("func", "notifyChangesToPeers"). - Str("machine", m.Name). - Msgf("Error getting peers: %s", err) - return - } - for _, p := range *peers { - log.Info(). - Str("func", "notifyChangesToPeers"). - Str("machine", m.Name). - Str("peer", p.Name). - Str("address", p.Addresses[0].String()). - Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0]) - err := h.sendRequestOnUpdateChannel(p) - if err != nil { - log.Info(). - Str("func", "notifyChangesToPeers"). - Str("machine", m.Name). - Str("peer", p.Name). - Msgf("Peer %s does not appear to be polling", p.Name) - } - log.Trace(). - Str("func", "notifyChangesToPeers"). - Str("machine", m.Name). - Str("peer", p.Name). - Str("address", p.Addresses[0].String()). - Msgf("Notified peer %s (%s)", p.Name, p.Addresses[0]) - } -} - -func (h *Headscale) getOrOpenUpdateChannel(m *Machine) <-chan struct{} { - var updateChan chan struct{} - if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { - if unwrapped, ok := storedChan.(chan struct{}); ok { - updateChan = unwrapped - } else { - log.Error(). - Str("handler", "openUpdateChannel"). - Str("machine", m.Name). - Msg("Failed to convert update channel to struct{}") - } - } else { - log.Debug(). - Str("handler", "openUpdateChannel"). - Str("machine", m.Name). - Msg("Update channel not found, creating") - - updateChan = make(chan struct{}) - h.clientsUpdateChannels.Store(m.ID, updateChan) - } - return updateChan -} - -func (h *Headscale) closeUpdateChannel(m *Machine) { - h.clientsUpdateChannelMutex.Lock() - defer h.clientsUpdateChannelMutex.Unlock() - - if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { - if unwrapped, ok := storedChan.(chan struct{}); ok { - close(unwrapped) - } - } - h.clientsUpdateChannels.Delete(m.ID) -} - -func (h *Headscale) sendRequestOnUpdateChannel(m *Machine) error { - h.clientsUpdateChannelMutex.Lock() - defer h.clientsUpdateChannelMutex.Unlock() - - pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID)) - if ok { - log.Info(). - Str("func", "requestUpdate"). - Str("machine", m.Name). - Msgf("Notifying peer %s", m.Name) - - if update, ok := pUp.(chan struct{}); ok { - log.Trace(). - Str("func", "requestUpdate"). - Str("machine", m.Name). - Msgf("Update channel is %#v", update) - - update <- struct{}{} - - log.Trace(). - Str("func", "requestUpdate"). - Str("machine", m.Name). - Msgf("Notified machine %s", m.Name) - } - } else { - log.Info(). - Str("func", "requestUpdate"). - Str("machine", m.Name). - Msgf("Machine %s does not appear to be polling", m.Name) - return errors.New("machine does not seem to be polling") - } - return nil -} - -func (h *Headscale) isOutdated(m *Machine) bool { - err := h.UpdateMachine(m) - if err != nil { - return true - } - - lastChange := h.getLastStateChange(m.Namespace.Name) - log.Trace(). - Str("func", "keepAlive"). - Str("machine", m.Name). - Time("last_successful_update", *m.LastSuccessfulUpdate). - Time("last_state_change", lastChange). - Msgf("Checking if %s is missing updates", m.Name) - return m.LastSuccessfulUpdate.Before(lastChange) -}