mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	Merge pull request #21 from juanfont/working-preauth
Support for pre auth keys
This commit is contained in:
		
						commit
						8ca940ad30
					
				
							
								
								
									
										54
									
								
								api.go
									
									
									
									
									
								
							
							
						
						
									
										54
									
								
								api.go
									
									
									
									
									
								
							| @ -33,6 +33,8 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// spew.Dump(c.Params)
 | ||||
| 
 | ||||
| 	c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` | ||||
| 	<html> | ||||
| 	<body> | ||||
| @ -71,6 +73,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { | ||||
| 		c.String(http.StatusInternalServerError, "Very sad!") | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	db, err := h.db() | ||||
| 	if err != nil { | ||||
| 		log.Printf("Cannot open DB: %s", err) | ||||
| @ -93,6 +96,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { | ||||
| 			log.Println("Client is registered and we have the current key. All clear to /map") | ||||
| 			resp.AuthURL = "" | ||||
| 			resp.User = *m.Namespace.toUser() | ||||
| 			resp.MachineAuthorized = true | ||||
| 			respBody, err := encode(resp, &mKey, h.privateKey) | ||||
| 			if err != nil { | ||||
| 				log.Printf("Cannot encode message: %s", err) | ||||
| @ -135,6 +139,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { | ||||
| 	} | ||||
| 
 | ||||
| 	log.Println("We dont know anything about the new key. WTF") | ||||
| 	// spew.Dump(req)
 | ||||
| } | ||||
| 
 | ||||
| // PollNetMapHandler takes care of /machine/:id/map
 | ||||
| @ -359,21 +364,60 @@ func (h *Headscale) getMapKeepAliveResponse(mKey wgcfg.Key, req tailcfg.MapReque | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) handleNewServer(c *gin.Context, db *gorm.DB, idKey wgcfg.Key, req tailcfg.RegisterRequest) { | ||||
| 	mNew := Machine{ | ||||
| 	m := Machine{ | ||||
| 		MachineKey: idKey.HexString(), | ||||
| 		NodeKey:    wgcfg.Key(req.NodeKey).HexString(), | ||||
| 		Expiry:     &req.Expiry, | ||||
| 		Name:       req.Hostinfo.Hostname, | ||||
| 	} | ||||
| 	if err := db.Create(&mNew).Error; err != nil { | ||||
| 	if err := db.Create(&m).Error; err != nil { | ||||
| 		log.Printf("Could not create row: %s", err) | ||||
| 		return | ||||
| 	} | ||||
| 	resp := tailcfg.RegisterResponse{ | ||||
| 		AuthURL: fmt.Sprintf("%s/register?key=%s", | ||||
| 			h.cfg.ServerURL, idKey.HexString()), | ||||
| 
 | ||||
| 	resp := tailcfg.RegisterResponse{} | ||||
| 
 | ||||
| 	if req.Auth.AuthKey != "" { | ||||
| 		pak, err := h.checkKeyValidity(req.Auth.AuthKey) | ||||
| 		if err != nil { | ||||
| 			resp.MachineAuthorized = false | ||||
| 			respBody, err := encode(resp, &idKey, h.privateKey) | ||||
| 			if err != nil { | ||||
| 				log.Printf("Cannot encode message: %s", err) | ||||
| 				c.String(http.StatusInternalServerError, "") | ||||
| 				return | ||||
| 			} | ||||
| 			c.Data(200, "application/json; charset=utf-8", respBody) | ||||
| 			return | ||||
| 		} | ||||
| 		ip, err := h.getAvailableIP() | ||||
| 		if err != nil { | ||||
| 			log.Println(err) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		m.IPAddress = ip.String() | ||||
| 		m.NamespaceID = pak.NamespaceID | ||||
| 		m.AuthKeyID = uint(pak.ID) | ||||
| 		m.RegisterMethod = "authKey" | ||||
| 		m.Registered = true | ||||
| 		db.Save(&m) | ||||
| 
 | ||||
| 		resp.MachineAuthorized = true | ||||
| 		resp.User = *pak.Namespace.toUser() | ||||
| 		respBody, err := encode(resp, &idKey, h.privateKey) | ||||
| 		if err != nil { | ||||
| 			log.Printf("Cannot encode message: %s", err) | ||||
| 			c.String(http.StatusInternalServerError, "Extremely sad!") | ||||
| 			return | ||||
| 		} | ||||
| 		c.Data(200, "application/json; charset=utf-8", respBody) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	resp.AuthURL = fmt.Sprintf("%s/register?key=%s", | ||||
| 		h.cfg.ServerURL, idKey.HexString()) | ||||
| 
 | ||||
| 	respBody, err := encode(resp, &idKey, h.privateKey) | ||||
| 	if err != nil { | ||||
| 		log.Printf("Cannot encode message: %s", err) | ||||
|  | ||||
							
								
								
									
										1
									
								
								cli.go
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								cli.go
									
									
									
									
									
								
							| @ -43,6 +43,7 @@ func (h *Headscale) RegisterMachine(key string, namespace string) error { | ||||
| 	m.IPAddress = ip.String() | ||||
| 	m.NamespaceID = ns.ID | ||||
| 	m.Registered = true | ||||
| 	m.RegisterMethod = "cli" | ||||
| 	db.Save(&m) | ||||
| 	fmt.Println("Machine registered 🎉") | ||||
| 	return nil | ||||
|  | ||||
							
								
								
									
										10
									
								
								machine.go
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								machine.go
									
									
									
									
									
								
							| @ -25,9 +25,13 @@ type Machine struct { | ||||
| 	NamespaceID uint | ||||
| 	Namespace   Namespace | ||||
| 
 | ||||
| 	Registered bool // temp
 | ||||
| 	LastSeen   *time.Time | ||||
| 	Expiry     *time.Time | ||||
| 	Registered     bool // temp
 | ||||
| 	RegisterMethod string | ||||
| 	AuthKeyID      uint | ||||
| 	AuthKey        *PreAuthKey | ||||
| 
 | ||||
| 	LastSeen *time.Time | ||||
| 	Expiry   *time.Time | ||||
| 
 | ||||
| 	HostInfo      postgres.Jsonb | ||||
| 	Endpoints     postgres.Jsonb | ||||
|  | ||||
| @ -7,6 +7,10 @@ import ( | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| const errorAuthKeyNotFound = Error("AuthKey not found") | ||||
| const errorAuthKeyExpired = Error("AuthKey expired") | ||||
| const errorAuthKeyNotReusableAlreadyUsed = Error("AuthKey not reusable already used") | ||||
| 
 | ||||
| // PreAuthKey describes a pre-authorization key usable in a particular namespace
 | ||||
| type PreAuthKey struct { | ||||
| 	ID          uint64 `gorm:"primary_key"` | ||||
| @ -72,6 +76,41 @@ func (h *Headscale) GetPreAuthKeys(namespaceName string) (*[]PreAuthKey, error) | ||||
| 	return &keys, nil | ||||
| } | ||||
| 
 | ||||
| // checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
 | ||||
| // If returns no error and a PreAuthKey, it can be used
 | ||||
| func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { | ||||
| 	db, err := h.db() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer db.Close() | ||||
| 
 | ||||
| 	pak := PreAuthKey{} | ||||
| 	if db.Preload("Namespace").First(&pak, "key = ?", k).RecordNotFound() { | ||||
| 		return nil, errorAuthKeyNotFound | ||||
| 	} | ||||
| 
 | ||||
| 	if pak.Expiration != nil && pak.Expiration.Before(time.Now()) { | ||||
| 		return nil, errorAuthKeyExpired | ||||
| 	} | ||||
| 
 | ||||
| 	if pak.Reusable { // we don't need to check if has been used before
 | ||||
| 		return &pak, nil | ||||
| 	} | ||||
| 
 | ||||
| 	machines := []Machine{} | ||||
| 	if err := db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	if len(machines) != 0 { | ||||
| 		return nil, errorAuthKeyNotReusableAlreadyUsed | ||||
| 	} | ||||
| 
 | ||||
| 	// missing here validation on current usage
 | ||||
| 	return &pak, nil | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) generateKey() (string, error) { | ||||
| 	size := 24 | ||||
| 	bytes := make([]byte, size) | ||||
|  | ||||
| @ -5,6 +5,7 @@ import ( | ||||
| 	"io/ioutil" | ||||
| 	"os" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	_ "github.com/jinzhu/gorm/dialects/sqlite" // sql driver
 | ||||
| 
 | ||||
| @ -48,6 +49,7 @@ func (s *Suite) TearDownSuite(c *check.C) { | ||||
| 
 | ||||
| func (*Suite) TestCreatePreAuthKey(c *check.C) { | ||||
| 	_, err := h.CreatePreAuthKey("bogus", true, nil) | ||||
| 
 | ||||
| 	c.Assert(err, check.NotNil) | ||||
| 
 | ||||
| 	n, err := h.CreateNamespace("test") | ||||
| @ -73,3 +75,106 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { | ||||
| 	// Make sure the Namespace association is populated
 | ||||
| 	c.Assert((*keys)[0].Namespace.Name, check.Equals, n.Name) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestExpiredPreAuthKey(c *check.C) { | ||||
| 	n, err := h.CreateNamespace("test2") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	now := time.Now() | ||||
| 	pak, err := h.CreatePreAuthKey(n.Name, true, &now) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	p, err := h.checkKeyValidity(pak.Key) | ||||
| 	c.Assert(err, check.Equals, errorAuthKeyExpired) | ||||
| 	c.Assert(p, check.IsNil) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { | ||||
| 	p, err := h.checkKeyValidity("potatoKey") | ||||
| 	c.Assert(err, check.Equals, errorAuthKeyNotFound) | ||||
| 	c.Assert(p, check.IsNil) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestValidateKeyOk(c *check.C) { | ||||
| 	n, err := h.CreateNamespace("test3") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := h.CreatePreAuthKey(n.Name, true, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	p, err := h.checkKeyValidity(pak.Key) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(p.ID, check.Equals, pak.ID) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestAlreadyUsedKey(c *check.C) { | ||||
| 	n, err := h.CreateNamespace("test4") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := h.CreatePreAuthKey(n.Name, false, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	db, err := h.db() | ||||
| 	if err != nil { | ||||
| 		c.Fatal(err) | ||||
| 	} | ||||
| 	defer db.Close() | ||||
| 	m := Machine{ | ||||
| 		ID:             0, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Name:           "testest", | ||||
| 		NamespaceID:    n.ID, | ||||
| 		Registered:     true, | ||||
| 		RegisterMethod: "authKey", | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	db.Save(&m) | ||||
| 
 | ||||
| 	p, err := h.checkKeyValidity(pak.Key) | ||||
| 	c.Assert(err, check.Equals, errorAuthKeyNotReusableAlreadyUsed) | ||||
| 	c.Assert(p, check.IsNil) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestReusableBeingUsedKey(c *check.C) { | ||||
| 	n, err := h.CreateNamespace("test5") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := h.CreatePreAuthKey(n.Name, true, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	db, err := h.db() | ||||
| 	if err != nil { | ||||
| 		c.Fatal(err) | ||||
| 	} | ||||
| 	defer db.Close() | ||||
| 	m := Machine{ | ||||
| 		ID:             1, | ||||
| 		MachineKey:     "foo", | ||||
| 		NodeKey:        "bar", | ||||
| 		DiscoKey:       "faa", | ||||
| 		Name:           "testest", | ||||
| 		NamespaceID:    n.ID, | ||||
| 		Registered:     true, | ||||
| 		RegisterMethod: "authKey", | ||||
| 		AuthKeyID:      uint(pak.ID), | ||||
| 	} | ||||
| 	db.Save(&m) | ||||
| 
 | ||||
| 	p, err := h.checkKeyValidity(pak.Key) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(p.ID, check.Equals, pak.ID) | ||||
| } | ||||
| 
 | ||||
| func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { | ||||
| 	n, err := h.CreateNamespace("test6") | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	pak, err := h.CreatePreAuthKey(n.Name, false, nil) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 
 | ||||
| 	p, err := h.checkKeyValidity(pak.Key) | ||||
| 	c.Assert(err, check.IsNil) | ||||
| 	c.Assert(p.ID, check.Equals, pak.ID) | ||||
| } | ||||
|  | ||||
							
								
								
									
										5
									
								
								utils.go
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								utils.go
									
									
									
									
									
								
							| @ -21,6 +21,11 @@ import ( | ||||
| 	"tailscale.com/wgengine/wgcfg" | ||||
| ) | ||||
| 
 | ||||
| // Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
 | ||||
| type Error string | ||||
| 
 | ||||
| func (e Error) Error() string { return string(e) } | ||||
| 
 | ||||
| func decode(msg []byte, v interface{}, pubKey *wgcfg.Key, privKey *wgcfg.PrivateKey) error { | ||||
| 	return decodeMsg(msg, v, pubKey, privKey) | ||||
| } | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user