1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-01-18 00:06:09 +01:00

WIP: Client updates. Long polling rewritten

This commit is contained in:
Juan Font Alonso 2021-02-23 21:07:52 +01:00
parent ca6904fc95
commit 06fb7d4587
2 changed files with 105 additions and 40 deletions

8
app.go
View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"log" "log"
"os" "os"
"sync"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -30,6 +31,9 @@ type Headscale struct {
dbString string dbString string
publicKey *wgcfg.Key publicKey *wgcfg.Key
privateKey *wgcfg.PrivateKey privateKey *wgcfg.PrivateKey
pollMu sync.Mutex
clientsPolling map[uint64]chan []byte // this is by all means a hackity hack
} }
// NewHeadscale returns the Headscale app // NewHeadscale returns the Headscale app
@ -54,6 +58,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
h.clientsPolling = make(map[uint64]chan []byte)
return &h, nil return &h, nil
} }
@ -64,9 +69,6 @@ func (h *Headscale) Serve() error {
r.GET("/register", h.RegisterWebAPI) r.GET("/register", h.RegisterWebAPI)
r.POST("/machine/:id/map", h.PollNetMapHandler) r.POST("/machine/:id/map", h.PollNetMapHandler)
r.POST("/machine/:id", h.RegistrationHandler) r.POST("/machine/:id", h.RegistrationHandler)
// r.LoadHTMLFiles("./frontend/build/index.html")
// r.Use(static.Serve("/", static.LocalFile("./frontend/build", true)))
err := r.Run(h.cfg.Addr) err := r.Run(h.cfg.Addr)
return err return err
} }

View File

@ -57,7 +57,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// We do have the updated key! // We do have the updated key!
if m.NodeKey == wgcfg.Key(req.NodeKey).HexString() { if m.NodeKey == wgcfg.Key(req.NodeKey).HexString() {
if m.Registered { if m.Registered {
log.Println("Registered and we have the updated key! Lets move to map") log.Println("Client is registered and we have the current key. All clear to /map")
resp.AuthURL = "" resp.AuthURL = ""
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil { if err != nil {
@ -102,50 +102,73 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
log.Println("We dont know anything about the new key. WTF") log.Println("We dont know anything about the new key. WTF")
} }
// PollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(c *gin.Context) { func (h *Headscale) PollNetMapHandler(c *gin.Context) {
body, _ := io.ReadAll(c.Request.Body) body, _ := io.ReadAll(c.Request.Body)
mKeyStr := c.Param("id") mKeyStr := c.Param("id")
mKey, err := wgcfg.ParseHexKey(mKeyStr) mKey, err := wgcfg.ParseHexKey(mKeyStr)
if err != nil { if err != nil {
log.Printf("Cannot parse client key: %s", err) log.Printf("Cannot parse client key: %s", err)
c.String(http.StatusOK, "Sad!")
return return
} }
req := tailcfg.MapRequest{} req := tailcfg.MapRequest{}
err = decode(body, &req, &mKey, h.privateKey) err = decode(body, &req, &mKey, h.privateKey)
if err != nil { if err != nil {
log.Printf("Cannot decode message: %s", err) log.Printf("Cannot decode message: %s", err)
c.String(http.StatusOK, "Very sad!") return
// return
} }
db, err := h.db() db, err := h.db()
if err != nil { if err != nil {
log.Printf("Cannot open DB: %s", err) log.Printf("Cannot open DB: %s", err)
c.String(http.StatusInternalServerError, ":(")
return return
} }
defer db.Close() defer db.Close()
var m Machine var m Machine
if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() { if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() {
log.Printf("Cannot encode message: %s", err) log.Printf("Cannot find machine: %s", err)
c.String(http.StatusOK, "Extremely sad!")
return return
} }
endpoints, _ := json.Marshal(req.Endpoints)
hostinfo, _ := json.Marshal(req.Hostinfo) hostinfo, _ := json.Marshal(req.Hostinfo)
m.Endpoints = postgres.Jsonb{RawMessage: json.RawMessage(endpoints)} m.Name = req.Hostinfo.Hostname
m.HostInfo = postgres.Jsonb{RawMessage: json.RawMessage(hostinfo)} m.HostInfo = postgres.Jsonb{RawMessage: json.RawMessage(hostinfo)}
m.DiscoKey = wgcfg.Key(req.DiscoKey).HexString() m.DiscoKey = wgcfg.Key(req.DiscoKey).HexString()
now := time.Now().UTC() now := time.Now().UTC()
// From Tailscale client:
//
// ReadOnly is whether the client just wants to fetch the MapResponse,
// without updating their Endpoints. The Endpoints field will be ignored and
// LastSeen will not be updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at start-up
// before their first real endpoint update.
if !req.ReadOnly {
endpoints, _ := json.Marshal(req.Endpoints)
m.Endpoints = postgres.Jsonb{RawMessage: json.RawMessage(endpoints)}
m.LastSeen = &now m.LastSeen = &now
}
db.Save(&m) db.Save(&m)
db.Close() db.Close()
chanStream := make(chan []byte, 1) pollData := make(chan []byte, 1)
go func() { update := make(chan []byte, 1)
defer close(chanStream) cancelKeepAlive := make(chan []byte, 1)
defer close(pollData)
defer close(update)
defer close(cancelKeepAlive)
h.pollMu.Lock()
h.clientsPolling[m.ID] = update
h.pollMu.Unlock()
data, err := h.getMapResponse(mKey, req, m) data, err := h.getMapResponse(mKey, req, m)
if err != nil { if err != nil {
@ -153,34 +176,73 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
return return
} }
//send initial dump log.Printf("[%s] sending initial map", m.Name)
chanStream <- *data pollData <- *data
for {
data, err := h.getMapKeepAliveResponse(mKey, req, m) // We update our peers if the client is not sending ReadOnly in the MapRequest
if err != nil { // so we don't distribute its initial request (it comes with
c.String(http.StatusInternalServerError, ":(") // empty endpoints to peers)
return if !req.ReadOnly {
} peers, _ := h.getPeers(m)
chanStream <- *data h.pollMu.Lock()
// keep the node entertained for _, p := range *peers {
time.Sleep(time.Second * 180) log.Printf("[%s] notifying peer %s (%s)", m.Name, p.Name, p.Addresses[0])
break if pUp, ok := h.clientsPolling[uint64(p.ID)]; ok {
} pUp <- []byte{}
}()
c.Stream(func(w io.Writer) bool {
if msg, ok := <-chanStream; ok {
log.Printf("🦀 Sending data to %s: %d bytes", c.Request.RemoteAddr, len(msg))
w.Write(msg)
return true
} else { } else {
log.Printf("🦄 Closing connection to %s", c.Request.RemoteAddr) log.Printf("[%s] Peer %s does not appear to be polling", m.Name, p.Name)
c.AbortWithStatus(200) }
}
h.pollMu.Unlock()
}
go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m)
c.Stream(func(w io.Writer) bool {
select {
case data := <-pollData:
log.Printf("[%s] Sending data (%d bytes)", m.Name, len(data))
w.Write(data)
return true
case <-update:
log.Printf("[%s] Received a request for update", m.Name)
data, err := h.getMapResponse(mKey, req, m)
if err != nil {
fmt.Printf("[%s] 🤮 Cannot get the poll response: %s", m.Name, err)
}
w.Write(*data)
return true
case <-c.Request.Context().Done():
log.Printf("[%s] 😥 The client has closed the connection", m.Name)
h.pollMu.Lock()
cancelKeepAlive <- []byte{}
delete(h.clientsPolling, m.ID)
h.pollMu.Unlock()
return false return false
} }
}) })
}
func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgcfg.Key, req tailcfg.MapRequest, m Machine) {
for {
select {
case <-cancel:
return
default:
data, err := h.getMapKeepAliveResponse(mKey, req, m)
if err != nil {
log.Printf("Error generating the keep alive msg: %s", err)
return
}
pollData <- *data
time.Sleep(60 * time.Second)
}
}
} }
func (h *Headscale) getMapResponse(mKey wgcfg.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) { func (h *Headscale) getMapResponse(mKey wgcfg.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) {
@ -221,7 +283,7 @@ func (h *Headscale) getMapResponse(mKey wgcfg.Key, req tailcfg.MapRequest, m Mac
return nil, err return nil, err
} }
} }
// spew.Dump(resp)
// declare the incoming size on the first 4 bytes // declare the incoming size on the first 4 bytes
data := make([]byte, 4) data := make([]byte, 4)
binary.LittleEndian.PutUint32(data, uint32(len(respBody))) binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
@ -289,6 +351,7 @@ func (h *Headscale) handleNewServer(c *gin.Context, db *gorm.DB, idKey wgcfg.Key
MachineKey: idKey.HexString(), MachineKey: idKey.HexString(),
NodeKey: wgcfg.Key(req.NodeKey).HexString(), NodeKey: wgcfg.Key(req.NodeKey).HexString(),
Expiry: &req.Expiry, Expiry: &req.Expiry,
Name: req.Hostinfo.Hostname,
} }
if err := db.Create(&mNew).Error; err != nil { if err := db.Create(&mNew).Error; err != nil {
log.Printf("Could not create row: %s", err) log.Printf("Could not create row: %s", err)