1
0
mirror of https://github.com/juanfont/headscale.git synced 2024-12-20 19:09:07 +01:00
juanfont.headscale/handlers.go
2021-02-21 21:34:28 +01:00

310 lines
7.8 KiB
Go

package headscale
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/dialects/postgres"
"github.com/klauspost/compress/zstd"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/wgengine/wgcfg"
)
func (h *Headscale) KeyHandler(c *gin.Context) {
c.Data(200, "text/plain; charset=utf-8", []byte(h.publicKey.HexString()))
}
func (h *Headscale) RegistrationHandler(c *gin.Context) {
body, _ := ioutil.ReadAll(c.Request.Body)
mKeyStr := c.Param("id")
mKey, err := wgcfg.ParseHexKey(mKeyStr)
if err != nil {
log.Printf("Cannot parse machine key: %s", err)
c.String(http.StatusInternalServerError, "Sad!")
return
}
req := tailcfg.RegisterRequest{}
err = decode(body, &req, &mKey, h.privateKey)
if err != nil {
log.Printf("Cannot decode message: %s", err)
c.String(http.StatusInternalServerError, "Very sad!")
return
}
db, err := h.db()
if err != nil {
log.Printf("Cannot open DB: %s", err)
c.String(http.StatusInternalServerError, ":(")
return
}
defer db.Close()
var m Machine
resp := tailcfg.RegisterResponse{}
if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() {
log.Println("New Machine!")
h.handleNewServer(c, db, mKey, req)
return
}
// We do have the updated key!
if m.NodeKey == wgcfg.Key(req.NodeKey).HexString() {
if m.Registered {
log.Println("Registered and we have the updated key! Lets move to map")
resp.AuthURL = ""
respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil {
log.Printf("Cannot encode message: %s", err)
c.String(http.StatusInternalServerError, "Extremely sad!")
return
}
c.Data(200, "application/json; charset=utf-8", respBody)
return
}
log.Println("Hey! Not registered. Not asking for key rotation. Send a passive-agressive authurl to register")
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
h.cfg.ServerURL, mKey.HexString())
respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil {
log.Printf("Cannot encode message: %s", err)
c.String(http.StatusInternalServerError, "Extremely sad!")
return
}
c.Data(200, "application/json; charset=utf-8", respBody)
return
}
// We dont have the updated key in the DB. Lets try with the old one.
if m.NodeKey == wgcfg.Key(req.OldNodeKey).HexString() {
log.Println("Key rotation!")
m.NodeKey = wgcfg.Key(req.NodeKey).HexString()
db.Save(&m)
resp.AuthURL = ""
respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil {
log.Printf("Cannot encode message: %s", err)
c.String(http.StatusInternalServerError, "Extremely sad!")
return
}
c.Data(200, "application/json; charset=utf-8", respBody)
return
}
log.Println("We dont know anything about the new key. WTF")
}
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
body, _ := ioutil.ReadAll(c.Request.Body)
mKeyStr := c.Param("id")
mKey, err := wgcfg.ParseHexKey(mKeyStr)
if err != nil {
log.Printf("Cannot parse client key: %s", err)
c.String(http.StatusOK, "Sad!")
return
}
req := tailcfg.MapRequest{}
err = decode(body, &req, &mKey, h.privateKey)
if err != nil {
log.Printf("Cannot decode message: %s", err)
c.String(http.StatusOK, "Very sad!")
// return
}
db, err := h.db()
if err != nil {
log.Printf("Cannot open DB: %s", err)
c.String(http.StatusInternalServerError, ":(")
return
}
var m Machine
if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() {
log.Printf("Cannot encode message: %s", err)
c.String(http.StatusOK, "Extremely sad!")
return
}
endpoints, _ := json.Marshal(req.Endpoints)
hostinfo, _ := json.Marshal(req.Hostinfo)
m.Endpoints = postgres.Jsonb{RawMessage: json.RawMessage(endpoints)}
m.HostInfo = postgres.Jsonb{RawMessage: json.RawMessage(hostinfo)}
m.DiscoKey = wgcfg.Key(req.DiscoKey).HexString()
now := time.Now().UTC()
m.LastSeen = &now
db.Save(&m)
db.Close()
chanStream := make(chan []byte, 1)
go func() {
defer close(chanStream)
data, err := h.getMapResponse(mKey, req, m)
if err != nil {
c.String(http.StatusInternalServerError, ":(")
return
}
//send initial dump
chanStream <- *data
for {
data, err := h.getMapKeepAliveResponse(mKey, req, m)
if err != nil {
c.String(http.StatusInternalServerError, ":(")
return
}
chanStream <- *data
// keep the node entertained
time.Sleep(time.Second * 180)
break
}
}()
c.Stream(func(w io.Writer) bool {
if msg, ok := <-chanStream; ok {
log.Printf("🦀 Sending data to %s: %d bytes", c.Request.RemoteAddr, len(msg))
w.Write(msg)
return true
} else {
log.Printf("🦄 Closing connection to %s", c.Request.RemoteAddr)
c.AbortWithStatus(200)
return false
}
})
}
func (h *Headscale) getMapResponse(mKey wgcfg.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) {
node, err := m.toNode()
if err != nil {
log.Printf("Cannot convert to node: %s", err)
return nil, err
}
peers, err := h.getPeers(m)
if err != nil {
log.Printf("Cannot fetch peers: %s", err)
return nil, err
}
resp := tailcfg.MapResponse{
KeepAlive: false,
Node: node,
Peers: *peers,
DNS: []netaddr.IP{},
SearchPaths: []string{},
Domain: "foobar@example.com",
PacketFilter: tailcfg.FilterAllowAll,
DERPMap: h.cfg.DerpMap,
UserProfiles: []tailcfg.UserProfile{},
Roles: []tailcfg.Role{}}
var respBody []byte
if req.Compress == "zstd" {
src, _ := json.Marshal(resp)
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
if err != nil {
return nil, err
}
} else {
respBody, err = encode(resp, &mKey, h.privateKey)
if err != nil {
return nil, err
}
}
// declare the incoming size on the first 4 bytes
data := make([]byte, 4)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return &data, nil
}
func (h *Headscale) getMapKeepAliveResponse(mKey wgcfg.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) {
resp := tailcfg.MapResponse{
KeepAlive: true,
}
var respBody []byte
var err error
if req.Compress == "zstd" {
src, _ := json.Marshal(resp)
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
if err != nil {
return nil, err
}
} else {
respBody, err = encode(resp, &mKey, h.privateKey)
if err != nil {
return nil, err
}
}
data := make([]byte, 4)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return &data, nil
}
// RegisterWebAPI shows a simple message in the browser to point to the CLI
func (h *Headscale) RegisterWebAPI(c *gin.Context) {
mKeyStr := c.Query("key")
if mKeyStr == "" {
c.String(http.StatusBadRequest, "Wrong params")
return
}
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html>
<body>
<h1>headscale</h1>
<p>
Run the command below in the headscale server to add this machine to your network:
</p>
<p>
<code>
<b>headscale register %s</b>
</code>
</p>
</body>
</html>
`, mKeyStr)))
return
}
func (h *Headscale) handleNewServer(c *gin.Context, db *gorm.DB, idKey wgcfg.Key, req tailcfg.RegisterRequest) {
mNew := Machine{
MachineKey: idKey.HexString(),
NodeKey: wgcfg.Key(req.NodeKey).HexString(),
Expiry: &req.Expiry,
}
if err := db.Create(&mNew).Error; err != nil {
log.Printf("Could not create row: %s", err)
return
}
resp := tailcfg.RegisterResponse{
AuthURL: fmt.Sprintf("%s/register?key=%s",
h.cfg.ServerURL, idKey.HexString()),
}
respBody, err := encode(resp, &idKey, h.privateKey)
if err != nil {
log.Printf("Cannot encode message: %s", err)
c.String(http.StatusInternalServerError, "Extremely sad!")
return
}
c.Data(200, "application/json; charset=utf-8", respBody)
}