mirror of
https://github.com/juanfont/headscale.git
synced 2025-06-15 01:15:23 +02:00
temp
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
0152597c50
commit
47d0f2d4c9
@ -163,6 +163,14 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
return nil, fmt.Errorf("loading ACL policy: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): There is an circular dependency here, maybe we should
|
||||
// look at some sort of dependency injection?
|
||||
// https://github.com/uber-go/dig
|
||||
// or
|
||||
// https://github.com/uber-go/fx
|
||||
// Maybe overkill?
|
||||
app.db.SetPolicyManager(app.polMan)
|
||||
|
||||
var authProvider AuthProvider
|
||||
authProvider = NewAuthProviderWeb(cfg.ServerURL)
|
||||
if cfg.OIDC.Issuer != "" {
|
||||
|
@ -153,27 +153,6 @@ func (h *Headscale) waitForFollowup(
|
||||
return nil, NewHTTPError(http.StatusNotFound, "followup registration not found", nil)
|
||||
}
|
||||
|
||||
// canUsePreAuthKey checks if a pre auth key can be used.
|
||||
func canUsePreAuthKey(pak *types.PreAuthKey) error {
|
||||
if pak == nil {
|
||||
return NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil)
|
||||
}
|
||||
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
|
||||
return NewHTTPError(http.StatusUnauthorized, "authkey expired", nil)
|
||||
}
|
||||
|
||||
// we don't need to check if has been used before
|
||||
if pak.Reusable {
|
||||
return nil
|
||||
}
|
||||
|
||||
if pak.Used {
|
||||
return NewHTTPError(http.StatusUnauthorized, "authkey already used", nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Headscale) handleRegisterWithAuthKey(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
@ -183,32 +162,28 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = canUsePreAuthKey(pak)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil)
|
||||
}
|
||||
|
||||
nodeToRegister := types.Node{
|
||||
Hostname: regReq.Hostinfo.Hostname,
|
||||
UserID: ptr.To(pak.User.ID),
|
||||
User: ptr.To(pak.User),
|
||||
MachineKey: machineKey,
|
||||
NodeKey: regReq.NodeKey,
|
||||
Hostinfo: regReq.Hostinfo,
|
||||
LastSeen: ptr.To(time.Now()),
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
|
||||
// TODO(kradalby): This should not be set on the node,
|
||||
// they should be looked up through the key, which is
|
||||
// attached to the node.
|
||||
Tags: pak.Proto().GetAclTags(),
|
||||
AuthKey: pak,
|
||||
AuthKeyID: &pak.ID,
|
||||
}
|
||||
|
||||
if pak.IsTagged() {
|
||||
nodeToRegister.Tags = pak.Tags
|
||||
} else {
|
||||
nodeToRegister.UserID = pak.UserID
|
||||
nodeToRegister.User = pak.User
|
||||
}
|
||||
|
||||
if !regReq.Expiry.IsZero() {
|
||||
nodeToRegister.Expiry = ®Req.Expiry
|
||||
}
|
||||
@ -257,7 +232,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
// This works, but might be another good candidate for doing some sort of
|
||||
// eventbus.
|
||||
routesChanged := policy.AutoApproveRoutes(h.polMan, node)
|
||||
if err := h.db.DB.Save(node).Error; err != nil {
|
||||
if err := h.db.SaveNode(node); err != nil {
|
||||
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||
}
|
||||
|
||||
|
@ -65,7 +65,7 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
|
||||
|
||||
// GetAPIKey returns a ApiKey for a given key.
|
||||
func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
|
||||
key := types.APIKey{}
|
||||
var key types.APIKey
|
||||
if result := hsdb.DB.First(&key, "prefix = ?", prefix); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
@ -75,7 +75,7 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
|
||||
|
||||
// GetAPIKeyByID returns a ApiKey for a given id.
|
||||
func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
|
||||
key := types.APIKey{}
|
||||
var key types.APIKey
|
||||
if result := hsdb.DB.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/go-gormigrate/gormigrate/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
@ -44,6 +45,7 @@ type HSDatabase struct {
|
||||
DB *gorm.DB
|
||||
cfg *types.DatabaseConfig
|
||||
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
|
||||
polMan policy.PolicyManager
|
||||
|
||||
baseDomain string
|
||||
}
|
||||
@ -766,6 +768,10 @@ AND auth_key_id NOT IN (
|
||||
return &db, err
|
||||
}
|
||||
|
||||
func (db *HSDatabase) SetPolicyManager(pol policy.PolicyManager) {
|
||||
db.polMan = pol
|
||||
}
|
||||
|
||||
func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
|
||||
// TODO(kradalby): Integrate this with zerolog
|
||||
var dbLogger logger.Interface
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
@ -97,6 +98,23 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
|
||||
})
|
||||
}
|
||||
|
||||
// SaveNode saves a node to the database.
|
||||
// It performs checks to validate if the conforms to certain restrictions:
|
||||
// - A node must be either tagged or owned by a user, not both.
|
||||
func (hsdb *HSDatabase) SaveNode(node *types.Node) error {
|
||||
if node.IsTagged() && node.UserID != nil {
|
||||
return fmt.Errorf("node %q is tagged and has a user ID, has to be either tagged or owned by user", node.Hostname)
|
||||
}
|
||||
|
||||
if !node.IsTagged() && node.UserID == nil {
|
||||
return fmt.Errorf("node %q is not tagged and has no user ID, has to be either tagged or owned by user", node.Hostname)
|
||||
}
|
||||
|
||||
slices.Sort(node.Tags)
|
||||
node.Tags = slices.Compact(node.Tags)
|
||||
return hsdb.DB.Save(node).Error
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getNode(uid types.UserID, name string) (*types.Node, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return getNode(rx, uid, name)
|
||||
@ -196,18 +214,26 @@ func (hsdb *HSDatabase) SetTags(
|
||||
|
||||
// SetTags takes a NodeID and update the forced tags.
|
||||
// It will overwrite any tags with the new list.
|
||||
// If the node has a UserID, it will be unset as a node
|
||||
// can only have a UserID or tags, not both.
|
||||
func SetTags(
|
||||
tx *gorm.DB,
|
||||
nodeID types.NodeID,
|
||||
tags []string,
|
||||
) error {
|
||||
// If no tags are provided, return an error.
|
||||
// Tailscale does not support removing all tags from a node.
|
||||
// A node needs to have either a User owner, or be tagged, and
|
||||
// it is not supported to remove all tags and "return it to a user".
|
||||
if len(tags) == 0 {
|
||||
// if no tags are provided, we remove all forced tags
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("tags", "[]").Error; err != nil {
|
||||
return fmt.Errorf("removing tags: %w", err)
|
||||
}
|
||||
return types.ErrCannotRemoveAllTags
|
||||
}
|
||||
|
||||
return nil
|
||||
// If the node has a UserID, we need to remove it.
|
||||
// This is because a node can only have a UserID or tags, not both.
|
||||
// We need to set the UserID to nil.
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("user_id", nil).Error; err != nil {
|
||||
return fmt.Errorf("removing user from tagged node: %w", err)
|
||||
}
|
||||
|
||||
slices.Sort(tags)
|
||||
@ -224,7 +250,8 @@ func SetTags(
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTags takes a Node struct pointer and update the forced tags.
|
||||
// SetApprovedRoutes takes a NodeID and a list of routes and updates the
|
||||
// approved routes for the node.
|
||||
func SetApprovedRoutes(
|
||||
tx *gorm.DB,
|
||||
nodeID types.NodeID,
|
||||
@ -339,6 +366,30 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
|
||||
})
|
||||
}
|
||||
|
||||
func checkTags(polMan policy.PolicyManager, node *types.Node, reqTags []string) ([]string, error) {
|
||||
if len(reqTags) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var tags []string
|
||||
var invalidTags []string
|
||||
for _, tag := range reqTags {
|
||||
if polMan.NodeCanHaveTag(node, tag) {
|
||||
tags = append(tags, tag)
|
||||
} else {
|
||||
invalidTags = append(invalidTags, tag)
|
||||
}
|
||||
}
|
||||
|
||||
if len(invalidTags) > 0 {
|
||||
return nil, fmt.Errorf(`requested tags %v are invalid or not defined in policy`, invalidTags)
|
||||
}
|
||||
|
||||
slices.Sort(tags)
|
||||
tags = slices.Compact(tags)
|
||||
return tags, nil
|
||||
}
|
||||
|
||||
// HandleNodeFromAuthPath is called from the OIDC or CLI auth path
|
||||
// with a registrationID to register or reauthenticate a node.
|
||||
// If the node found in the registration cache is not already registered,
|
||||
@ -352,8 +403,9 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
registrationMethod string,
|
||||
ipv4 *netip.Addr,
|
||||
ipv6 *netip.Addr,
|
||||
) (*types.Node, bool, error) {
|
||||
) (*types.Node, types.ChangeSet, error) {
|
||||
var newNode bool
|
||||
cs := types.ChangeSet{}
|
||||
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if reg, ok := hsdb.regCache.Get(registrationID); ok {
|
||||
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
|
||||
@ -381,8 +433,17 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
return nil, ErrDifferentRegisteredUser
|
||||
}
|
||||
|
||||
reg.Node.UserID = &user.ID
|
||||
reg.Node.User = user
|
||||
if reqTags := reg.Node.RequestTags(); len(reqTags) > 0 {
|
||||
tags, err := checkTags(hsdb.polMan, ®.Node, reqTags)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reg.Node.Tags = tags
|
||||
} else {
|
||||
reg.Node.UserID = &user.ID
|
||||
reg.Node.User = user
|
||||
}
|
||||
|
||||
reg.Node.RegisterMethod = registrationMethod
|
||||
|
||||
if nodeExpiry != nil {
|
||||
@ -406,14 +467,27 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
}
|
||||
close(reg.Registered)
|
||||
|
||||
newNode = true
|
||||
cs.New = true
|
||||
return node, err
|
||||
} else {
|
||||
if reqTags := reg.Node.RequestTags(); len(reqTags) > 0 {
|
||||
tags, err := checkTags(hsdb.polMan, ®.Node, reqTags)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = SetTags(tx, node.ID, tags)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cs.Tags = true
|
||||
}
|
||||
|
||||
// If the node is already registered, this is a refresh.
|
||||
err := NodeSetExpiry(tx, node.ID, *nodeExpiry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cs.Expiry = true
|
||||
return node, nil
|
||||
}
|
||||
}
|
||||
@ -421,7 +495,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
return nil, ErrNodeNotFoundRegistrationCache
|
||||
})
|
||||
|
||||
return node, newNode, err
|
||||
return node, cs, err
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||
|
@ -10,12 +10,10 @@ import (
|
||||
"os"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
@ -281,7 +279,7 @@ func (api headscaleV1APIServer) RegisterNode(
|
||||
// This works, but might be another good candidate for doing some sort of
|
||||
// eventbus.
|
||||
routesChanged := policy.AutoApproveRoutes(api.h.polMan, node)
|
||||
if err := api.h.db.DB.Save(node).Error; err != nil {
|
||||
if err := api.h.db.SaveNode(node); err != nil {
|
||||
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||
}
|
||||
|
||||
@ -315,15 +313,27 @@ func (api headscaleV1APIServer) SetTags(
|
||||
ctx context.Context,
|
||||
request *v1.SetTagsRequest,
|
||||
) (*v1.SetTagsResponse, error) {
|
||||
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tags []string
|
||||
var invalidTags []string
|
||||
for _, tag := range request.GetTags() {
|
||||
err := validateTag(tag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if api.h.polMan.NodeCanHaveTag(node, tag) {
|
||||
tags = append(tags, tag)
|
||||
} else {
|
||||
invalidTags = append(invalidTags, tag)
|
||||
}
|
||||
}
|
||||
|
||||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
err := db.SetTags(tx, types.NodeID(request.GetNodeId()), request.GetTags())
|
||||
if len(invalidTags) > 0 {
|
||||
return nil, fmt.Errorf(`requested tags %v are invalid or not defined in policy`, invalidTags)
|
||||
}
|
||||
|
||||
node, err = db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
err := db.SetTags(tx, types.NodeID(request.GetNodeId()), tags)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -395,19 +405,6 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
||||
return &v1.SetApprovedRoutesResponse{Node: proto}, nil
|
||||
}
|
||||
|
||||
func validateTag(tag string) error {
|
||||
if strings.Index(tag, "tag:") != 0 {
|
||||
return errors.New("tag must start with the string 'tag:'")
|
||||
}
|
||||
if strings.ToLower(tag) != tag {
|
||||
return errors.New("tag should be lowercase")
|
||||
}
|
||||
if len(strings.Fields(tag)) > 1 {
|
||||
return errors.New("tag should not contains space")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) DeleteNode(
|
||||
ctx context.Context,
|
||||
request *v1.DeleteNodeRequest,
|
||||
@ -546,13 +543,8 @@ func nodesToProto(polMan policy.PolicyManager, isLikelyConnected *xsync.MapOf[ty
|
||||
resp.Online = true
|
||||
}
|
||||
|
||||
var tags []string
|
||||
for _, tag := range node.RequestTags() {
|
||||
if polMan.NodeCanHaveTag(node, tag) {
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
}
|
||||
resp.ValidTags = lo.Uniq(append(tags, node.Tags...))
|
||||
// TODO(kradalby): Rename ValidTags, there is only Tags
|
||||
resp.ValidTags = node.Tags
|
||||
resp.SubnetRoutes = util.PrefixesToString(append(pr.PrimaryRoutes(node.ID), node.ExitRoutes()...))
|
||||
response[index] = resp
|
||||
}
|
||||
|
@ -509,7 +509,7 @@ func (a *AuthProviderOIDC) handleRegistration(
|
||||
return false, err
|
||||
}
|
||||
|
||||
node, newNode, err := a.db.HandleNodeFromAuthPath(
|
||||
node, cs, err := a.db.HandleNodeFromAuthPath(
|
||||
registrationID,
|
||||
types.UserID(user.ID),
|
||||
&expiry,
|
||||
@ -540,7 +540,7 @@ func (a *AuthProviderOIDC) handleRegistration(
|
||||
// This works, but might be another good candidate for doing some sort of
|
||||
// eventbus.
|
||||
routesChanged := policy.AutoApproveRoutes(a.polMan, node)
|
||||
if err := a.db.DB.Save(node).Error; err != nil {
|
||||
if err := a.db.SaveNode(node); err != nil {
|
||||
return false, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||
}
|
||||
|
||||
@ -556,7 +556,7 @@ func (a *AuthProviderOIDC) handleRegistration(
|
||||
a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
}
|
||||
|
||||
return newNode, nil
|
||||
return cs.New, nil
|
||||
}
|
||||
|
||||
// TODO(kradalby):
|
||||
|
@ -495,7 +495,7 @@ func (m *mapSession) handleEndpointUpdate() {
|
||||
// the hostname change.
|
||||
m.node.ApplyHostnameFromHostInfo(m.req.Hostinfo)
|
||||
|
||||
if err := m.h.db.DB.Save(m.node).Error; err != nil {
|
||||
if err := m.h.db.SaveNode(m.node); err != nil {
|
||||
m.errf(err, "Failed to persist/update node in the database")
|
||||
http.Error(m.w, "", http.StatusInternalServerError)
|
||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||
|
@ -196,3 +196,14 @@ type RegisterNode struct {
|
||||
Node Node
|
||||
Registered chan *Node
|
||||
}
|
||||
|
||||
// TODO(kradalby): Not sure if this is a good idea,
|
||||
// but ran into this problem in HandleNodeFromAuthPath
|
||||
// describing what has changed in the node...
|
||||
// ChangeSet described changes that has happend to a node
|
||||
type ChangeSet struct {
|
||||
NodeID NodeID
|
||||
New bool
|
||||
Tags bool
|
||||
Expiry bool
|
||||
}
|
||||
|
@ -183,11 +183,11 @@ func (node *Node) IsUserOwned() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
var ErrCannotRemoveAllTags = errors.New("cannot remove all tags from node")
|
||||
|
||||
// IsTagged reports if a device is tagged
|
||||
// and therefore should not be treated as a
|
||||
// user owned device.
|
||||
// Currently, this function only handles tags set
|
||||
// via CLI ("forced tags" and preauthkeys)
|
||||
func (node *Node) IsTagged() bool {
|
||||
if node.Tags == nil {
|
||||
return false
|
||||
|
Loading…
Reference in New Issue
Block a user