mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	Merge pull request #63 from juanfont/use-kv-for-updates
Added communication between Serve and CLI using KV table
This commit is contained in:
		
						commit
						6091373b53
					
				
							
								
								
									
										17
									
								
								app.go
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								app.go
									
									
									
									
									
								
							@ -141,6 +141,20 @@ func (h *Headscale) expireEphemeralNodesWorker() {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// WatchForKVUpdates checks the KV DB table for requests to perform tailnet upgrades
 | 
				
			||||||
 | 
					// This is a way to communitate the CLI with the headscale server
 | 
				
			||||||
 | 
					func (h *Headscale) watchForKVUpdates(milliSeconds int64) {
 | 
				
			||||||
 | 
						ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
 | 
				
			||||||
 | 
						for range ticker.C {
 | 
				
			||||||
 | 
							h.watchForKVUpdatesWorker()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *Headscale) watchForKVUpdatesWorker() {
 | 
				
			||||||
 | 
						h.checkForNamespacesPendingUpdates()
 | 
				
			||||||
 | 
						// more functions will come here in the future
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Serve launches a GIN server with the Headscale API
 | 
					// Serve launches a GIN server with the Headscale API
 | 
				
			||||||
func (h *Headscale) Serve() error {
 | 
					func (h *Headscale) Serve() error {
 | 
				
			||||||
	r := gin.Default()
 | 
						r := gin.Default()
 | 
				
			||||||
@ -149,6 +163,9 @@ func (h *Headscale) Serve() error {
 | 
				
			|||||||
	r.POST("/machine/:id/map", h.PollNetMapHandler)
 | 
						r.POST("/machine/:id/map", h.PollNetMapHandler)
 | 
				
			||||||
	r.POST("/machine/:id", h.RegistrationHandler)
 | 
						r.POST("/machine/:id", h.RegistrationHandler)
 | 
				
			||||||
	var err error
 | 
						var err error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go h.watchForKVUpdates(5000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if h.cfg.TLSLetsEncryptHostname != "" {
 | 
						if h.cfg.TLSLetsEncryptHostname != "" {
 | 
				
			||||||
		if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
 | 
							if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
 | 
				
			||||||
			log.Println("WARNING: listening with TLS but ServerURL does not start with https://")
 | 
								log.Println("WARNING: listening with TLS but ServerURL does not start with https://")
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								db.go
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								db.go
									
									
									
									
									
								
							@ -79,6 +79,7 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
 | 
				
			|||||||
	return db, nil
 | 
						return db, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// getValue returns the value for the given key in KV
 | 
				
			||||||
func (h *Headscale) getValue(key string) (string, error) {
 | 
					func (h *Headscale) getValue(key string) (string, error) {
 | 
				
			||||||
	var row KV
 | 
						var row KV
 | 
				
			||||||
	if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) {
 | 
						if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) {
 | 
				
			||||||
@ -87,6 +88,7 @@ func (h *Headscale) getValue(key string) (string, error) {
 | 
				
			|||||||
	return row.Value, nil
 | 
						return row.Value, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// setValue sets value for the given key in KV
 | 
				
			||||||
func (h *Headscale) setValue(key string, value string) error {
 | 
					func (h *Headscale) setValue(key string, value string) error {
 | 
				
			||||||
	kv := KV{
 | 
						kv := KV{
 | 
				
			||||||
		Key:   key,
 | 
							Key:   key,
 | 
				
			||||||
 | 
				
			|||||||
@ -200,19 +200,22 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
 | 
				
			|||||||
// DeleteMachine softs deletes a Machine from the database
 | 
					// DeleteMachine softs deletes a Machine from the database
 | 
				
			||||||
func (h *Headscale) DeleteMachine(m *Machine) error {
 | 
					func (h *Headscale) DeleteMachine(m *Machine) error {
 | 
				
			||||||
	m.Registered = false
 | 
						m.Registered = false
 | 
				
			||||||
 | 
						namespaceID := m.NamespaceID
 | 
				
			||||||
	h.db.Save(&m) // we mark it as unregistered, just in case
 | 
						h.db.Save(&m) // we mark it as unregistered, just in case
 | 
				
			||||||
	if err := h.db.Delete(&m).Error; err != nil {
 | 
						if err := h.db.Delete(&m).Error; err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil
 | 
					
 | 
				
			||||||
 | 
						return h.RequestMapUpdates(namespaceID)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// HardDeleteMachine hard deletes a Machine from the database
 | 
					// HardDeleteMachine hard deletes a Machine from the database
 | 
				
			||||||
func (h *Headscale) HardDeleteMachine(m *Machine) error {
 | 
					func (h *Headscale) HardDeleteMachine(m *Machine) error {
 | 
				
			||||||
 | 
						namespaceID := m.NamespaceID
 | 
				
			||||||
	if err := h.db.Unscoped().Delete(&m).Error; err != nil {
 | 
						if err := h.db.Unscoped().Delete(&m).Error; err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil
 | 
						return h.RequestMapUpdates(namespaceID)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// GetHostInfo returns a Hostinfo struct for the machine
 | 
					// GetHostInfo returns a Hostinfo struct for the machine
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,8 @@
 | 
				
			|||||||
package headscale
 | 
					package headscale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gopkg.in/check.v1"
 | 
						"gopkg.in/check.v1"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -81,6 +83,15 @@ func (s *Suite) TestDeleteMachine(c *check.C) {
 | 
				
			|||||||
	h.db.Save(&m)
 | 
						h.db.Save(&m)
 | 
				
			||||||
	err = h.DeleteMachine(&m)
 | 
						err = h.DeleteMachine(&m)
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						v, err := h.getValue("namespaces_pending_updates")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						names := []string{}
 | 
				
			||||||
 | 
						err = json.Unmarshal([]byte(v), &names)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						c.Assert(names, check.DeepEquals, []string{n.Name})
 | 
				
			||||||
 | 
						h.checkForNamespacesPendingUpdates()
 | 
				
			||||||
 | 
						v, _ = h.getValue("namespaces_pending_updates")
 | 
				
			||||||
 | 
						c.Assert(v, check.Equals, "")
 | 
				
			||||||
	_, err = h.GetMachine(n.Name, "testmachine")
 | 
						_, err = h.GetMachine(n.Name, "testmachine")
 | 
				
			||||||
	c.Assert(err, check.NotNil)
 | 
						c.Assert(err, check.NotNil)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,7 +1,9 @@
 | 
				
			|||||||
package headscale
 | 
					package headscale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
	"log"
 | 
						"log"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -103,6 +105,88 @@ func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RequestMapUpdates signals the KV worker to update the maps for this namespace
 | 
				
			||||||
 | 
					func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
 | 
				
			||||||
 | 
						namespace := Namespace{}
 | 
				
			||||||
 | 
						if err := h.db.First(&namespace, namespaceID).Error; err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						v, err := h.getValue("namespaces_pending_updates")
 | 
				
			||||||
 | 
						if err != nil || v == "" {
 | 
				
			||||||
 | 
							err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						names := []string{}
 | 
				
			||||||
 | 
						err = json.Unmarshal([]byte(v), &names)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						names = append(names, namespace.Name)
 | 
				
			||||||
 | 
						data, err := json.Marshal(names)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Printf("Could not marshal namespaces_pending_updates: %s", err)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return h.setValue("namespaces_pending_updates", string(data))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *Headscale) checkForNamespacesPendingUpdates() {
 | 
				
			||||||
 | 
						v, err := h.getValue("namespaces_pending_updates")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if v == "" {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						names := []string{}
 | 
				
			||||||
 | 
						err = json.Unmarshal([]byte(v), &names)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, name := range names {
 | 
				
			||||||
 | 
							log.Printf("Sending updates to nodes in namespace %s", name)
 | 
				
			||||||
 | 
							machines, err := h.ListMachinesInNamespace(name)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							for _, m := range *machines {
 | 
				
			||||||
 | 
								peers, _ := h.getPeers(m)
 | 
				
			||||||
 | 
								h.pollMu.Lock()
 | 
				
			||||||
 | 
								for _, p := range *peers {
 | 
				
			||||||
 | 
									pUp, ok := h.clientsPolling[uint64(p.ID)]
 | 
				
			||||||
 | 
									if ok {
 | 
				
			||||||
 | 
										log.Printf("[%s] Notifying peer %s (%s)", m.Name, p.Name, p.Addresses[0])
 | 
				
			||||||
 | 
										pUp <- []byte{}
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										log.Printf("[%s] Peer %s does not appear to be polling", m.Name, p.Name)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								h.pollMu.Unlock()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						newV, err := h.getValue("namespaces_pending_updates")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if v == newV { // only clear when no changes, so we notified everybody
 | 
				
			||||||
 | 
							err = h.setValue("namespaces_pending_updates", "")
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Printf("Could not save to KV: %s", err)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (n *Namespace) toUser() *tailcfg.User {
 | 
					func (n *Namespace) toUser() *tailcfg.User {
 | 
				
			||||||
	u := tailcfg.User{
 | 
						u := tailcfg.User{
 | 
				
			||||||
		ID:            tailcfg.UserID(n.ID),
 | 
							ID:            tailcfg.UserID(n.ID),
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user