1
0
mirror of https://github.com/juanfont/headscale.git synced 2024-12-20 19:09:07 +01:00

Added communication between Serve and CLI using KV table (helps in #52)

This commit is contained in:
Juan Font Alonso 2021-07-25 17:59:48 +02:00
parent b83ecc3e6e
commit 97f7c90092
5 changed files with 118 additions and 2 deletions

17
app.go
View File

@ -139,6 +139,20 @@ func (h *Headscale) expireEphemeralNodesWorker() {
} }
} }
// WatchForKVUpdates checks the KV DB table for requests to perform tailnet upgrades
// This is a way to communitate the CLI with the headscale server
func (h *Headscale) watchForKVUpdates(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
for range ticker.C {
h.watchForKVUpdatesWorker()
}
}
func (h *Headscale) watchForKVUpdatesWorker() {
h.checkForNamespacesPendingUpdates()
// more functions will come here in the future
}
// Serve launches a GIN server with the Headscale API // Serve launches a GIN server with the Headscale API
func (h *Headscale) Serve() error { func (h *Headscale) Serve() error {
r := gin.Default() r := gin.Default()
@ -147,6 +161,9 @@ func (h *Headscale) Serve() error {
r.POST("/machine/:id/map", h.PollNetMapHandler) r.POST("/machine/:id/map", h.PollNetMapHandler)
r.POST("/machine/:id", h.RegistrationHandler) r.POST("/machine/:id", h.RegistrationHandler)
var err error var err error
go h.watchForKVUpdates(5000)
if h.cfg.TLSLetsEncryptHostname != "" { if h.cfg.TLSLetsEncryptHostname != "" {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") { if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Println("WARNING: listening with TLS but ServerURL does not start with https://") log.Println("WARNING: listening with TLS but ServerURL does not start with https://")

2
db.go
View File

@ -79,6 +79,7 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
return db, nil return db, nil
} }
// getValue returns the value for the given key in KV
func (h *Headscale) getValue(key string) (string, error) { func (h *Headscale) getValue(key string) (string, error) {
var row KV var row KV
if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) { if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -87,6 +88,7 @@ func (h *Headscale) getValue(key string) (string, error) {
return row.Value, nil return row.Value, nil
} }
// setValue sets value for the given key in KV
func (h *Headscale) setValue(key string, value string) error { func (h *Headscale) setValue(key string, value string) error {
kv := KV{ kv := KV{
Key: key, Key: key,

View File

@ -200,19 +200,22 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
// DeleteMachine softs deletes a Machine from the database // DeleteMachine softs deletes a Machine from the database
func (h *Headscale) DeleteMachine(m *Machine) error { func (h *Headscale) DeleteMachine(m *Machine) error {
m.Registered = false m.Registered = false
namespaceID := m.NamespaceID
h.db.Save(&m) // we mark it as unregistered, just in case h.db.Save(&m) // we mark it as unregistered, just in case
if err := h.db.Delete(&m).Error; err != nil { if err := h.db.Delete(&m).Error; err != nil {
return err return err
} }
return nil
return h.RequestMapUpdates(namespaceID)
} }
// HardDeleteMachine hard deletes a Machine from the database // HardDeleteMachine hard deletes a Machine from the database
func (h *Headscale) HardDeleteMachine(m *Machine) error { func (h *Headscale) HardDeleteMachine(m *Machine) error {
namespaceID := m.NamespaceID
if err := h.db.Unscoped().Delete(&m).Error; err != nil { if err := h.db.Unscoped().Delete(&m).Error; err != nil {
return err return err
} }
return nil return h.RequestMapUpdates(namespaceID)
} }
// GetHostInfo returns a Hostinfo struct for the machine // GetHostInfo returns a Hostinfo struct for the machine

View File

@ -1,6 +1,8 @@
package headscale package headscale
import ( import (
"encoding/json"
"gopkg.in/check.v1" "gopkg.in/check.v1"
) )
@ -81,6 +83,15 @@ func (s *Suite) TestDeleteMachine(c *check.C) {
h.db.Save(&m) h.db.Save(&m)
err = h.DeleteMachine(&m) err = h.DeleteMachine(&m)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
v, err := h.getValue("namespaces_pending_updates")
c.Assert(err, check.IsNil)
names := []string{}
err = json.Unmarshal([]byte(v), &names)
c.Assert(err, check.IsNil)
c.Assert(names, check.DeepEquals, []string{n.Name})
h.checkForNamespacesPendingUpdates()
v, _ = h.getValue("namespaces_pending_updates")
c.Assert(v, check.Equals, "")
_, err = h.GetMachine(n.Name, "testmachine") _, err = h.GetMachine(n.Name, "testmachine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
} }

View File

@ -1,7 +1,9 @@
package headscale package headscale
import ( import (
"encoding/json"
"errors" "errors"
"fmt"
"log" "log"
"time" "time"
@ -103,6 +105,87 @@ func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error
return nil return nil
} }
// RequestMapUpdates signals the KV worker to update the maps for this namespace
func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
namespace := Namespace{}
if err := h.db.First(&namespace, namespaceID).Error; err != nil {
return err
}
v, err := h.getValue("namespaces_pending_updates")
if err != nil || v == "" {
err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
if err != nil {
return err
}
return nil
}
names := []string{}
err = json.Unmarshal([]byte(v), &names)
if err != nil {
err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
if err != nil {
return err
}
return nil
}
names = append(names, namespace.Name)
data, err := json.Marshal(names)
if err != nil {
log.Printf("Could not marshal namespaces_pending_updates: %s", err)
return err
}
return h.setValue("namespaces_pending_updates", string(data))
}
func (h *Headscale) checkForNamespacesPendingUpdates() {
v, err := h.getValue("namespaces_pending_updates")
if err != nil {
return
}
if v == "" {
return
}
names := []string{}
err = json.Unmarshal([]byte(v), &names)
if err != nil {
return
}
for _, name := range names {
machines, err := h.ListMachinesInNamespace(name)
if err != nil {
continue
}
for _, m := range *machines {
peers, _ := h.getPeers(m)
h.pollMu.Lock()
for _, p := range *peers {
pUp, ok := h.clientsPolling[uint64(p.ID)]
if ok {
log.Printf("[%s] Notifying peer %s (%s)", m.Name, p.Name, p.Addresses[0])
pUp <- []byte{}
} else {
log.Printf("[%s] Peer %s does not appear to be polling", m.Name, p.Name)
}
}
h.pollMu.Unlock()
}
}
newV, err := h.getValue("namespaces_pending_updates")
if err != nil {
return
}
if v == newV { // only clear when no changes, so we notified everybody
err = h.setValue("namespaces_pending_updates", "")
if err != nil {
log.Printf("Could not save to KV: %s", err)
return
}
}
}
func (n *Namespace) toUser() *tailcfg.User { func (n *Namespace) toUser() *tailcfg.User {
u := tailcfg.User{ u := tailcfg.User{
ID: tailcfg.UserID(n.ID), ID: tailcfg.UserID(n.ID),