mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	Merge pull request #21 from juanfont/working-preauth
Support for pre auth keys
This commit is contained in:
		
						commit
						8ca940ad30
					
				
							
								
								
									
										54
									
								
								api.go
									
									
									
									
									
								
							
							
						
						
									
										54
									
								
								api.go
									
									
									
									
									
								
							@ -33,6 +33,8 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// spew.Dump(c.Params)
 | 
			
		||||
 | 
			
		||||
	c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
 | 
			
		||||
	<html>
 | 
			
		||||
	<body>
 | 
			
		||||
@ -71,6 +73,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
 | 
			
		||||
		c.String(http.StatusInternalServerError, "Very sad!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db, err := h.db()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Printf("Cannot open DB: %s", err)
 | 
			
		||||
@ -93,6 +96,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
 | 
			
		||||
			log.Println("Client is registered and we have the current key. All clear to /map")
 | 
			
		||||
			resp.AuthURL = ""
 | 
			
		||||
			resp.User = *m.Namespace.toUser()
 | 
			
		||||
			resp.MachineAuthorized = true
 | 
			
		||||
			respBody, err := encode(resp, &mKey, h.privateKey)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				log.Printf("Cannot encode message: %s", err)
 | 
			
		||||
@ -135,6 +139,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Println("We dont know anything about the new key. WTF")
 | 
			
		||||
	// spew.Dump(req)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PollNetMapHandler takes care of /machine/:id/map
 | 
			
		||||
@ -359,20 +364,59 @@ func (h *Headscale) getMapKeepAliveResponse(mKey wgcfg.Key, req tailcfg.MapReque
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Headscale) handleNewServer(c *gin.Context, db *gorm.DB, idKey wgcfg.Key, req tailcfg.RegisterRequest) {
 | 
			
		||||
	mNew := Machine{
 | 
			
		||||
	m := Machine{
 | 
			
		||||
		MachineKey: idKey.HexString(),
 | 
			
		||||
		NodeKey:    wgcfg.Key(req.NodeKey).HexString(),
 | 
			
		||||
		Expiry:     &req.Expiry,
 | 
			
		||||
		Name:       req.Hostinfo.Hostname,
 | 
			
		||||
	}
 | 
			
		||||
	if err := db.Create(&mNew).Error; err != nil {
 | 
			
		||||
	if err := db.Create(&m).Error; err != nil {
 | 
			
		||||
		log.Printf("Could not create row: %s", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp := tailcfg.RegisterResponse{
 | 
			
		||||
		AuthURL: fmt.Sprintf("%s/register?key=%s",
 | 
			
		||||
			h.cfg.ServerURL, idKey.HexString()),
 | 
			
		||||
 | 
			
		||||
	resp := tailcfg.RegisterResponse{}
 | 
			
		||||
 | 
			
		||||
	if req.Auth.AuthKey != "" {
 | 
			
		||||
		pak, err := h.checkKeyValidity(req.Auth.AuthKey)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.MachineAuthorized = false
 | 
			
		||||
			respBody, err := encode(resp, &idKey, h.privateKey)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				log.Printf("Cannot encode message: %s", err)
 | 
			
		||||
				c.String(http.StatusInternalServerError, "")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			c.Data(200, "application/json; charset=utf-8", respBody)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		ip, err := h.getAvailableIP()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Println(err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		m.IPAddress = ip.String()
 | 
			
		||||
		m.NamespaceID = pak.NamespaceID
 | 
			
		||||
		m.AuthKeyID = uint(pak.ID)
 | 
			
		||||
		m.RegisterMethod = "authKey"
 | 
			
		||||
		m.Registered = true
 | 
			
		||||
		db.Save(&m)
 | 
			
		||||
 | 
			
		||||
		resp.MachineAuthorized = true
 | 
			
		||||
		resp.User = *pak.Namespace.toUser()
 | 
			
		||||
		respBody, err := encode(resp, &idKey, h.privateKey)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Printf("Cannot encode message: %s", err)
 | 
			
		||||
			c.String(http.StatusInternalServerError, "Extremely sad!")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		c.Data(200, "application/json; charset=utf-8", respBody)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
 | 
			
		||||
		h.cfg.ServerURL, idKey.HexString())
 | 
			
		||||
 | 
			
		||||
	respBody, err := encode(resp, &idKey, h.privateKey)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								cli.go
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								cli.go
									
									
									
									
									
								
							@ -43,6 +43,7 @@ func (h *Headscale) RegisterMachine(key string, namespace string) error {
 | 
			
		||||
	m.IPAddress = ip.String()
 | 
			
		||||
	m.NamespaceID = ns.ID
 | 
			
		||||
	m.Registered = true
 | 
			
		||||
	m.RegisterMethod = "cli"
 | 
			
		||||
	db.Save(&m)
 | 
			
		||||
	fmt.Println("Machine registered 🎉")
 | 
			
		||||
	return nil
 | 
			
		||||
 | 
			
		||||
@ -26,6 +26,10 @@ type Machine struct {
 | 
			
		||||
	Namespace   Namespace
 | 
			
		||||
 | 
			
		||||
	Registered     bool // temp
 | 
			
		||||
	RegisterMethod string
 | 
			
		||||
	AuthKeyID      uint
 | 
			
		||||
	AuthKey        *PreAuthKey
 | 
			
		||||
 | 
			
		||||
	LastSeen *time.Time
 | 
			
		||||
	Expiry   *time.Time
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -7,6 +7,10 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const errorAuthKeyNotFound = Error("AuthKey not found")
 | 
			
		||||
const errorAuthKeyExpired = Error("AuthKey expired")
 | 
			
		||||
const errorAuthKeyNotReusableAlreadyUsed = Error("AuthKey not reusable already used")
 | 
			
		||||
 | 
			
		||||
// PreAuthKey describes a pre-authorization key usable in a particular namespace
 | 
			
		||||
type PreAuthKey struct {
 | 
			
		||||
	ID          uint64 `gorm:"primary_key"`
 | 
			
		||||
@ -72,6 +76,41 @@ func (h *Headscale) GetPreAuthKeys(namespaceName string) (*[]PreAuthKey, error)
 | 
			
		||||
	return &keys, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
 | 
			
		||||
// If returns no error and a PreAuthKey, it can be used
 | 
			
		||||
func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
 | 
			
		||||
	db, err := h.db()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
 | 
			
		||||
	pak := PreAuthKey{}
 | 
			
		||||
	if db.Preload("Namespace").First(&pak, "key = ?", k).RecordNotFound() {
 | 
			
		||||
		return nil, errorAuthKeyNotFound
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
 | 
			
		||||
		return nil, errorAuthKeyExpired
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if pak.Reusable { // we don't need to check if has been used before
 | 
			
		||||
		return &pak, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	machines := []Machine{}
 | 
			
		||||
	if err := db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(machines) != 0 {
 | 
			
		||||
		return nil, errorAuthKeyNotReusableAlreadyUsed
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// missing here validation on current usage
 | 
			
		||||
	return &pak, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Headscale) generateKey() (string, error) {
 | 
			
		||||
	size := 24
 | 
			
		||||
	bytes := make([]byte, size)
 | 
			
		||||
 | 
			
		||||
@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"os"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	_ "github.com/jinzhu/gorm/dialects/sqlite" // sql driver
 | 
			
		||||
 | 
			
		||||
@ -48,6 +49,7 @@ func (s *Suite) TearDownSuite(c *check.C) {
 | 
			
		||||
 | 
			
		||||
func (*Suite) TestCreatePreAuthKey(c *check.C) {
 | 
			
		||||
	_, err := h.CreatePreAuthKey("bogus", true, nil)
 | 
			
		||||
 | 
			
		||||
	c.Assert(err, check.NotNil)
 | 
			
		||||
 | 
			
		||||
	n, err := h.CreateNamespace("test")
 | 
			
		||||
@ -73,3 +75,106 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) {
 | 
			
		||||
	// Make sure the Namespace association is populated
 | 
			
		||||
	c.Assert((*keys)[0].Namespace.Name, check.Equals, n.Name)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
 | 
			
		||||
	n, err := h.CreateNamespace("test2")
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	pak, err := h.CreatePreAuthKey(n.Name, true, &now)
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	p, err := h.checkKeyValidity(pak.Key)
 | 
			
		||||
	c.Assert(err, check.Equals, errorAuthKeyExpired)
 | 
			
		||||
	c.Assert(p, check.IsNil)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
 | 
			
		||||
	p, err := h.checkKeyValidity("potatoKey")
 | 
			
		||||
	c.Assert(err, check.Equals, errorAuthKeyNotFound)
 | 
			
		||||
	c.Assert(p, check.IsNil)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (*Suite) TestValidateKeyOk(c *check.C) {
 | 
			
		||||
	n, err := h.CreateNamespace("test3")
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	pak, err := h.CreatePreAuthKey(n.Name, true, nil)
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	p, err := h.checkKeyValidity(pak.Key)
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
	c.Assert(p.ID, check.Equals, pak.ID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (*Suite) TestAlreadyUsedKey(c *check.C) {
 | 
			
		||||
	n, err := h.CreateNamespace("test4")
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	pak, err := h.CreatePreAuthKey(n.Name, false, nil)
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	db, err := h.db()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
	m := Machine{
 | 
			
		||||
		ID:             0,
 | 
			
		||||
		MachineKey:     "foo",
 | 
			
		||||
		NodeKey:        "bar",
 | 
			
		||||
		DiscoKey:       "faa",
 | 
			
		||||
		Name:           "testest",
 | 
			
		||||
		NamespaceID:    n.ID,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: "authKey",
 | 
			
		||||
		AuthKeyID:      uint(pak.ID),
 | 
			
		||||
	}
 | 
			
		||||
	db.Save(&m)
 | 
			
		||||
 | 
			
		||||
	p, err := h.checkKeyValidity(pak.Key)
 | 
			
		||||
	c.Assert(err, check.Equals, errorAuthKeyNotReusableAlreadyUsed)
 | 
			
		||||
	c.Assert(p, check.IsNil)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (*Suite) TestReusableBeingUsedKey(c *check.C) {
 | 
			
		||||
	n, err := h.CreateNamespace("test5")
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	pak, err := h.CreatePreAuthKey(n.Name, true, nil)
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	db, err := h.db()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	defer db.Close()
 | 
			
		||||
	m := Machine{
 | 
			
		||||
		ID:             1,
 | 
			
		||||
		MachineKey:     "foo",
 | 
			
		||||
		NodeKey:        "bar",
 | 
			
		||||
		DiscoKey:       "faa",
 | 
			
		||||
		Name:           "testest",
 | 
			
		||||
		NamespaceID:    n.ID,
 | 
			
		||||
		Registered:     true,
 | 
			
		||||
		RegisterMethod: "authKey",
 | 
			
		||||
		AuthKeyID:      uint(pak.ID),
 | 
			
		||||
	}
 | 
			
		||||
	db.Save(&m)
 | 
			
		||||
 | 
			
		||||
	p, err := h.checkKeyValidity(pak.Key)
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
	c.Assert(p.ID, check.Equals, pak.ID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
 | 
			
		||||
	n, err := h.CreateNamespace("test6")
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	pak, err := h.CreatePreAuthKey(n.Name, false, nil)
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
 | 
			
		||||
	p, err := h.checkKeyValidity(pak.Key)
 | 
			
		||||
	c.Assert(err, check.IsNil)
 | 
			
		||||
	c.Assert(p.ID, check.Equals, pak.ID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										5
									
								
								utils.go
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								utils.go
									
									
									
									
									
								
							@ -21,6 +21,11 @@ import (
 | 
			
		||||
	"tailscale.com/wgengine/wgcfg"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
 | 
			
		||||
type Error string
 | 
			
		||||
 | 
			
		||||
func (e Error) Error() string { return string(e) }
 | 
			
		||||
 | 
			
		||||
func decode(msg []byte, v interface{}, pubKey *wgcfg.Key, privKey *wgcfg.PrivateKey) error {
 | 
			
		||||
	return decodeMsg(msg, v, pubKey, privKey)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user