mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			216 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			216 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package db
 | 
						|
 | 
						|
import (
 | 
						|
	"crypto/rand"
 | 
						|
	"encoding/hex"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/juanfont/headscale/hscontrol/types"
 | 
						|
	"gorm.io/gorm"
 | 
						|
	"tailscale.com/types/ptr"
 | 
						|
	"tailscale.com/util/set"
 | 
						|
)
 | 
						|
 | 
						|
var (
 | 
						|
	ErrPreAuthKeyNotFound          = errors.New("AuthKey not found")
 | 
						|
	ErrPreAuthKeyExpired           = errors.New("AuthKey expired")
 | 
						|
	ErrSingleUseAuthKeyHasBeenUsed = errors.New("AuthKey has already been used")
 | 
						|
	ErrUserMismatch                = errors.New("user mismatch")
 | 
						|
	ErrPreAuthKeyACLTagInvalid     = errors.New("AuthKey tag is invalid")
 | 
						|
)
 | 
						|
 | 
						|
func (hsdb *HSDatabase) CreatePreAuthKey(
 | 
						|
	// TODO(kradalby): Should be ID, not name
 | 
						|
	userName string,
 | 
						|
	reusable bool,
 | 
						|
	preApproved bool,
 | 
						|
	ephemeral bool,
 | 
						|
	expiration *time.Time,
 | 
						|
	aclTags []string,
 | 
						|
) (*types.PreAuthKey, error) {
 | 
						|
	return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
 | 
						|
		return CreatePreAuthKey(tx, userName, reusable, preApproved, ephemeral, expiration, aclTags)
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
 | 
						|
func CreatePreAuthKey(
 | 
						|
	tx *gorm.DB,
 | 
						|
	// TODO(kradalby): Should be ID, not name
 | 
						|
	userName string,
 | 
						|
	reusable bool,
 | 
						|
	preApproved bool,
 | 
						|
	ephemeral bool,
 | 
						|
	expiration *time.Time,
 | 
						|
	aclTags []string,
 | 
						|
) (*types.PreAuthKey, error) {
 | 
						|
	user, err := GetUserByUsername(tx, userName)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	// Remove duplicates
 | 
						|
	aclTags = set.SetOf(aclTags).Slice()
 | 
						|
 | 
						|
	// TODO(kradalby): factor out and create a reusable tag validation,
 | 
						|
	// check if there is one in Tailscale's lib.
 | 
						|
	for _, tag := range aclTags {
 | 
						|
		if !strings.HasPrefix(tag, "tag:") {
 | 
						|
			return nil, fmt.Errorf(
 | 
						|
				"%w: '%s' did not begin with 'tag:'",
 | 
						|
				ErrPreAuthKeyACLTagInvalid,
 | 
						|
				tag,
 | 
						|
			)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	now := time.Now().UTC()
 | 
						|
	kstr, err := generateKey()
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	key := types.PreAuthKey{
 | 
						|
		Key:         kstr,
 | 
						|
		UserID:      user.ID,
 | 
						|
		User:        *user,
 | 
						|
		Reusable:    reusable,
 | 
						|
		PreApproved: preApproved,
 | 
						|
		Ephemeral:   ephemeral,
 | 
						|
		CreatedAt:   &now,
 | 
						|
		Expiration:  expiration,
 | 
						|
		Tags:        aclTags,
 | 
						|
	}
 | 
						|
 | 
						|
	if err := tx.Save(&key).Error; err != nil {
 | 
						|
		return nil, fmt.Errorf("failed to create key in the database: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return &key, nil
 | 
						|
}
 | 
						|
 | 
						|
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
 | 
						|
	return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
 | 
						|
		return ListPreAuthKeys(rx, userName)
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
 | 
						|
func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
 | 
						|
	user, err := GetUserByUsername(tx, userName)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	keys := []types.PreAuthKey{}
 | 
						|
	if err := tx.Preload("User").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return keys, nil
 | 
						|
}
 | 
						|
 | 
						|
// GetPreAuthKey returns a PreAuthKey for a given key.
 | 
						|
func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) {
 | 
						|
	pak, err := ValidatePreAuthKey(tx, key)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if pak.User.Name != user {
 | 
						|
		return nil, ErrUserMismatch
 | 
						|
	}
 | 
						|
 | 
						|
	return pak, nil
 | 
						|
}
 | 
						|
 | 
						|
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
 | 
						|
// does not exist.
 | 
						|
func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error {
 | 
						|
	return tx.Transaction(func(db *gorm.DB) error {
 | 
						|
		if result := db.Unscoped().Delete(pak); result.Error != nil {
 | 
						|
			return result.Error
 | 
						|
		}
 | 
						|
 | 
						|
		return nil
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
 | 
						|
	return hsdb.Write(func(tx *gorm.DB) error {
 | 
						|
		return ExpirePreAuthKey(tx, k)
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
 | 
						|
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
 | 
						|
	if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// UsePreAuthKey marks a PreAuthKey as used.
 | 
						|
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
 | 
						|
	k.Used = true
 | 
						|
	if err := tx.Save(k).Error; err != nil {
 | 
						|
		return fmt.Errorf("failed to update key used status in the database: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
 | 
						|
	return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
 | 
						|
		return ValidatePreAuthKey(rx, k)
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
 | 
						|
// If returns no error and a PreAuthKey, it can be used.
 | 
						|
func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) {
 | 
						|
	pak := types.PreAuthKey{}
 | 
						|
	if result := tx.Preload("User").First(&pak, "key = ?", k); errors.Is(
 | 
						|
		result.Error,
 | 
						|
		gorm.ErrRecordNotFound,
 | 
						|
	) {
 | 
						|
		return nil, ErrPreAuthKeyNotFound
 | 
						|
	}
 | 
						|
 | 
						|
	if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
 | 
						|
		return nil, ErrPreAuthKeyExpired
 | 
						|
	}
 | 
						|
 | 
						|
	if pak.Reusable { // we don't need to check if has been used before
 | 
						|
		return &pak, nil
 | 
						|
	}
 | 
						|
 | 
						|
	nodes := types.Nodes{}
 | 
						|
	if err := tx.
 | 
						|
		Preload("AuthKey").
 | 
						|
		Where(&types.Node{AuthKeyID: ptr.To(pak.ID)}).
 | 
						|
		Find(&nodes).Error; err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if len(nodes) != 0 || pak.Used {
 | 
						|
		return nil, ErrSingleUseAuthKeyHasBeenUsed
 | 
						|
	}
 | 
						|
 | 
						|
	return &pak, nil
 | 
						|
}
 | 
						|
 | 
						|
func generateKey() (string, error) {
 | 
						|
	size := 24
 | 
						|
	bytes := make([]byte, size)
 | 
						|
	if _, err := rand.Read(bytes); err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	return hex.EncodeToString(bytes), nil
 | 
						|
}
 |