mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			873 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			873 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // nolint
 | |
| package hscontrol
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"os"
 | |
| 	"sort"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/rs/zerolog/log"
 | |
| 	"google.golang.org/grpc/codes"
 | |
| 	"google.golang.org/grpc/status"
 | |
| 	"google.golang.org/protobuf/types/known/timestamppb"
 | |
| 	"gorm.io/gorm"
 | |
| 	"tailscale.com/tailcfg"
 | |
| 	"tailscale.com/types/key"
 | |
| 
 | |
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
 | |
| 	"github.com/juanfont/headscale/hscontrol/db"
 | |
| 	"github.com/juanfont/headscale/hscontrol/policy"
 | |
| 	"github.com/juanfont/headscale/hscontrol/types"
 | |
| 	"github.com/juanfont/headscale/hscontrol/util"
 | |
| )
 | |
| 
 | |
| type headscaleV1APIServer struct { // v1.HeadscaleServiceServer
 | |
| 	v1.UnimplementedHeadscaleServiceServer
 | |
| 	h *Headscale
 | |
| }
 | |
| 
 | |
| func newHeadscaleV1APIServer(h *Headscale) v1.HeadscaleServiceServer {
 | |
| 	return headscaleV1APIServer{
 | |
| 		h: h,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) GetUser(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.GetUserRequest,
 | |
| ) (*v1.GetUserResponse, error) {
 | |
| 	user, err := api.h.db.GetUserByName(request.GetName())
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &v1.GetUserResponse{User: user.Proto()}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) CreateUser(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.CreateUserRequest,
 | |
| ) (*v1.CreateUserResponse, error) {
 | |
| 	user, err := api.h.db.CreateUser(request.GetName())
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &v1.CreateUserResponse{User: user.Proto()}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) RenameUser(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.RenameUserRequest,
 | |
| ) (*v1.RenameUserResponse, error) {
 | |
| 	oldUser, err := api.h.db.GetUserByName(request.GetOldName())
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	err = api.h.db.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	newUser, err := api.h.db.GetUserByName(request.GetNewName())
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &v1.RenameUserResponse{User: newUser.Proto()}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) DeleteUser(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.DeleteUserRequest,
 | |
| ) (*v1.DeleteUserResponse, error) {
 | |
| 	user, err := api.h.db.GetUserByName(request.GetName())
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	err = api.h.db.DestroyUser(types.UserID(user.ID))
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &v1.DeleteUserResponse{}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) ListUsers(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.ListUsersRequest,
 | |
| ) (*v1.ListUsersResponse, error) {
 | |
| 	users, err := api.h.db.ListUsers()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	response := make([]*v1.User, len(users))
 | |
| 	for index, user := range users {
 | |
| 		response[index] = user.Proto()
 | |
| 	}
 | |
| 
 | |
| 	sort.Slice(response, func(i, j int) bool {
 | |
| 		return response[i].Id < response[j].Id
 | |
| 	})
 | |
| 
 | |
| 	log.Trace().Caller().Interface("users", response).Msg("")
 | |
| 
 | |
| 	return &v1.ListUsersResponse{Users: response}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) CreatePreAuthKey(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.CreatePreAuthKeyRequest,
 | |
| ) (*v1.CreatePreAuthKeyResponse, error) {
 | |
| 	var expiration time.Time
 | |
| 	if request.GetExpiration() != nil {
 | |
| 		expiration = request.GetExpiration().AsTime()
 | |
| 	}
 | |
| 
 | |
| 	for _, tag := range request.AclTags {
 | |
| 		err := validateTag(tag)
 | |
| 		if err != nil {
 | |
| 			return &v1.CreatePreAuthKeyResponse{
 | |
| 				PreAuthKey: nil,
 | |
| 			}, status.Error(codes.InvalidArgument, err.Error())
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	user, err := api.h.db.GetUserByName(request.GetUser())
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	preAuthKey, err := api.h.db.CreatePreAuthKey(
 | |
| 		types.UserID(user.ID),
 | |
| 		request.GetReusable(),
 | |
| 		request.GetEphemeral(),
 | |
| 		&expiration,
 | |
| 		request.AclTags,
 | |
| 	)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.Proto()}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) ExpirePreAuthKey(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.ExpirePreAuthKeyRequest,
 | |
| ) (*v1.ExpirePreAuthKeyResponse, error) {
 | |
| 	err := api.h.db.Write(func(tx *gorm.DB) error {
 | |
| 		preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		return db.ExpirePreAuthKey(tx, preAuthKey)
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &v1.ExpirePreAuthKeyResponse{}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) ListPreAuthKeys(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.ListPreAuthKeysRequest,
 | |
| ) (*v1.ListPreAuthKeysResponse, error) {
 | |
| 	user, err := api.h.db.GetUserByName(request.GetUser())
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	preAuthKeys, err := api.h.db.ListPreAuthKeys(types.UserID(user.ID))
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	response := make([]*v1.PreAuthKey, len(preAuthKeys))
 | |
| 	for index, key := range preAuthKeys {
 | |
| 		response[index] = key.Proto()
 | |
| 	}
 | |
| 
 | |
| 	sort.Slice(response, func(i, j int) bool {
 | |
| 		return response[i].Id < response[j].Id
 | |
| 	})
 | |
| 
 | |
| 	return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) RegisterNode(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.RegisterNodeRequest,
 | |
| ) (*v1.RegisterNodeResponse, error) {
 | |
| 	log.Trace().
 | |
| 		Str("user", request.GetUser()).
 | |
| 		Str("machine_key", request.GetKey()).
 | |
| 		Msg("Registering node")
 | |
| 
 | |
| 	var mkey key.MachinePublic
 | |
| 	err := mkey.UnmarshalText([]byte(request.GetKey()))
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	ipv4, ipv6, err := api.h.ipAlloc.Next()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	user, err := api.h.db.GetUserByName(request.GetUser())
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("looking up user: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	node, err := api.h.db.RegisterNodeFromAuthCallback(
 | |
| 		mkey,
 | |
| 		types.UserID(user.ID),
 | |
| 		nil,
 | |
| 		util.RegisterMethodCLI,
 | |
| 		ipv4, ipv6,
 | |
| 	)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) GetNode(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.GetNodeRequest,
 | |
| ) (*v1.GetNodeResponse, error) {
 | |
| 	node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	resp := node.Proto()
 | |
| 
 | |
| 	// Populate the online field based on
 | |
| 	// currently connected nodes.
 | |
| 	resp.Online = api.h.nodeNotifier.IsConnected(node.ID)
 | |
| 
 | |
| 	return &v1.GetNodeResponse{Node: resp}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) SetTags(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.SetTagsRequest,
 | |
| ) (*v1.SetTagsResponse, error) {
 | |
| 	for _, tag := range request.GetTags() {
 | |
| 		err := validateTag(tag)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
 | |
| 		err := db.SetTags(tx, types.NodeID(request.GetNodeId()), request.GetTags())
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 		return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return &v1.SetTagsResponse{
 | |
| 			Node: nil,
 | |
| 		}, status.Error(codes.InvalidArgument, err.Error())
 | |
| 	}
 | |
| 
 | |
| 	ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname)
 | |
| 	api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{
 | |
| 		Type:        types.StatePeerChanged,
 | |
| 		ChangeNodes: []types.NodeID{node.ID},
 | |
| 		Message:     "called from api.SetTags",
 | |
| 	}, node.ID)
 | |
| 
 | |
| 	log.Trace().
 | |
| 		Str("node", node.Hostname).
 | |
| 		Strs("tags", request.GetTags()).
 | |
| 		Msg("Changing tags of node")
 | |
| 
 | |
| 	return &v1.SetTagsResponse{Node: node.Proto()}, nil
 | |
| }
 | |
| 
 | |
| func validateTag(tag string) error {
 | |
| 	if strings.Index(tag, "tag:") != 0 {
 | |
| 		return errors.New("tag must start with the string 'tag:'")
 | |
| 	}
 | |
| 	if strings.ToLower(tag) != tag {
 | |
| 		return errors.New("tag should be lowercase")
 | |
| 	}
 | |
| 	if len(strings.Fields(tag)) > 1 {
 | |
| 		return errors.New("tag should not contains space")
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) DeleteNode(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.DeleteNodeRequest,
 | |
| ) (*v1.DeleteNodeResponse, error) {
 | |
| 	node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	changedNodes, err := api.h.db.DeleteNode(
 | |
| 		node,
 | |
| 		api.h.nodeNotifier.LikelyConnectedMap(),
 | |
| 	)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
 | |
| 	api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
 | |
| 		Type:    types.StatePeerRemoved,
 | |
| 		Removed: []types.NodeID{node.ID},
 | |
| 	})
 | |
| 
 | |
| 	if changedNodes != nil {
 | |
| 		api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
 | |
| 			Type:        types.StatePeerChanged,
 | |
| 			ChangeNodes: changedNodes,
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	return &v1.DeleteNodeResponse{}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) ExpireNode(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.ExpireNodeRequest,
 | |
| ) (*v1.ExpireNodeResponse, error) {
 | |
| 	now := time.Now()
 | |
| 
 | |
| 	node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
 | |
| 		db.NodeSetExpiry(
 | |
| 			tx,
 | |
| 			types.NodeID(request.GetNodeId()),
 | |
| 			now,
 | |
| 		)
 | |
| 
 | |
| 		return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
 | |
| 	api.h.nodeNotifier.NotifyByNodeID(
 | |
| 		ctx,
 | |
| 		types.StateUpdate{
 | |
| 			Type:        types.StateSelfUpdate,
 | |
| 			ChangeNodes: []types.NodeID{node.ID},
 | |
| 		},
 | |
| 		node.ID)
 | |
| 
 | |
| 	ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
 | |
| 	api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, now), node.ID)
 | |
| 
 | |
| 	log.Trace().
 | |
| 		Str("node", node.Hostname).
 | |
| 		Time("expiry", *node.Expiry).
 | |
| 		Msg("node expired")
 | |
| 
 | |
| 	return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) RenameNode(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.RenameNodeRequest,
 | |
| ) (*v1.RenameNodeResponse, error) {
 | |
| 	node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
 | |
| 		err := db.RenameNode(
 | |
| 			tx,
 | |
| 			types.NodeID(request.GetNodeId()),
 | |
| 			request.GetNewName(),
 | |
| 		)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 		return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
 | |
| 	api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{
 | |
| 		Type:        types.StatePeerChanged,
 | |
| 		ChangeNodes: []types.NodeID{node.ID},
 | |
| 		Message:     "called from api.RenameNode",
 | |
| 	}, node.ID)
 | |
| 
 | |
| 	log.Trace().
 | |
| 		Str("node", node.Hostname).
 | |
| 		Str("new_name", request.GetNewName()).
 | |
| 		Msg("node renamed")
 | |
| 
 | |
| 	return &v1.RenameNodeResponse{Node: node.Proto()}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) ListNodes(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.ListNodesRequest,
 | |
| ) (*v1.ListNodesResponse, error) {
 | |
| 	// TODO(kradalby): it looks like this can be simplified a lot,
 | |
| 	// the filtering of nodes by user, vs nodes as a whole can
 | |
| 	// probably be done once.
 | |
| 	// TODO(kradalby): This should be done in one tx.
 | |
| 
 | |
| 	isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
 | |
| 	if request.GetUser() != "" {
 | |
| 		user, err := api.h.db.GetUserByName(request.GetUser())
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 		nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
 | |
| 			return db.ListNodesByUser(rx, types.UserID(user.ID))
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 		response := make([]*v1.Node, len(nodes))
 | |
| 		for index, node := range nodes {
 | |
| 			resp := node.Proto()
 | |
| 
 | |
| 			// Populate the online field based on
 | |
| 			// currently connected nodes.
 | |
| 			if val, ok := isLikelyConnected.Load(node.ID); ok && val {
 | |
| 				resp.Online = true
 | |
| 			}
 | |
| 
 | |
| 			response[index] = resp
 | |
| 		}
 | |
| 
 | |
| 		return &v1.ListNodesResponse{Nodes: response}, nil
 | |
| 	}
 | |
| 
 | |
| 	nodes, err := api.h.db.ListNodes()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	sort.Slice(nodes, func(i, j int) bool {
 | |
| 		return nodes[i].ID < nodes[j].ID
 | |
| 	})
 | |
| 
 | |
| 	response := make([]*v1.Node, len(nodes))
 | |
| 	for index, node := range nodes {
 | |
| 		resp := node.Proto()
 | |
| 
 | |
| 		// Populate the online field based on
 | |
| 		// currently connected nodes.
 | |
| 		if val, ok := isLikelyConnected.Load(node.ID); ok && val {
 | |
| 			resp.Online = true
 | |
| 		}
 | |
| 
 | |
| 		validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
 | |
| 			node,
 | |
| 		)
 | |
| 		resp.InvalidTags = invalidTags
 | |
| 		resp.ValidTags = validTags
 | |
| 		response[index] = resp
 | |
| 	}
 | |
| 
 | |
| 	return &v1.ListNodesResponse{Nodes: response}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) MoveNode(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.MoveNodeRequest,
 | |
| ) (*v1.MoveNodeResponse, error) {
 | |
| 	// TODO(kradalby): This should be done in one tx.
 | |
| 	node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	user, err := api.h.db.GetUserByName(request.GetUser())
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	err = api.h.db.AssignNodeToUser(node, types.UserID(user.ID))
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &v1.MoveNodeResponse{Node: node.Proto()}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) BackfillNodeIPs(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.BackfillNodeIPsRequest,
 | |
| ) (*v1.BackfillNodeIPsResponse, error) {
 | |
| 	log.Trace().Msg("Backfill called")
 | |
| 
 | |
| 	if !request.Confirmed {
 | |
| 		return nil, errors.New("not confirmed, aborting")
 | |
| 	}
 | |
| 
 | |
| 	changes, err := api.h.db.BackfillNodeIPs(api.h.ipAlloc)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &v1.BackfillNodeIPsResponse{Changes: changes}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) GetRoutes(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.GetRoutesRequest,
 | |
| ) (*v1.GetRoutesResponse, error) {
 | |
| 	routes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Routes, error) {
 | |
| 		return db.GetRoutes(rx)
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &v1.GetRoutesResponse{
 | |
| 		Routes: types.Routes(routes).Proto(),
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) EnableRoute(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.EnableRouteRequest,
 | |
| ) (*v1.EnableRouteResponse, error) {
 | |
| 	update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
 | |
| 		return db.EnableRoute(tx, request.GetRouteId())
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	if update != nil {
 | |
| 		ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown")
 | |
| 		api.h.nodeNotifier.NotifyAll(
 | |
| 			ctx, *update)
 | |
| 	}
 | |
| 
 | |
| 	return &v1.EnableRouteResponse{}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) DisableRoute(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.DisableRouteRequest,
 | |
| ) (*v1.DisableRouteResponse, error) {
 | |
| 	update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
 | |
| 		return db.DisableRoute(tx, request.GetRouteId(), api.h.nodeNotifier.LikelyConnectedMap())
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	if update != nil {
 | |
| 		ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
 | |
| 		api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
 | |
| 			Type:        types.StatePeerChanged,
 | |
| 			ChangeNodes: update,
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	return &v1.DisableRouteResponse{}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) GetNodeRoutes(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.GetNodeRoutesRequest,
 | |
| ) (*v1.GetNodeRoutesResponse, error) {
 | |
| 	node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	routes, err := api.h.db.GetNodeRoutes(node)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &v1.GetNodeRoutesResponse{
 | |
| 		Routes: types.Routes(routes).Proto(),
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) DeleteRoute(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.DeleteRouteRequest,
 | |
| ) (*v1.DeleteRouteResponse, error) {
 | |
| 	isConnected := api.h.nodeNotifier.LikelyConnectedMap()
 | |
| 	update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
 | |
| 		return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	if update != nil {
 | |
| 		ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
 | |
| 		api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
 | |
| 			Type:        types.StatePeerChanged,
 | |
| 			ChangeNodes: update,
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	return &v1.DeleteRouteResponse{}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) CreateApiKey(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.CreateApiKeyRequest,
 | |
| ) (*v1.CreateApiKeyResponse, error) {
 | |
| 	var expiration time.Time
 | |
| 	if request.GetExpiration() != nil {
 | |
| 		expiration = request.GetExpiration().AsTime()
 | |
| 	}
 | |
| 
 | |
| 	apiKey, _, err := api.h.db.CreateAPIKey(
 | |
| 		&expiration,
 | |
| 	)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &v1.CreateApiKeyResponse{ApiKey: apiKey}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) ExpireApiKey(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.ExpireApiKeyRequest,
 | |
| ) (*v1.ExpireApiKeyResponse, error) {
 | |
| 	var apiKey *types.APIKey
 | |
| 	var err error
 | |
| 
 | |
| 	apiKey, err = api.h.db.GetAPIKey(request.Prefix)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	err = api.h.db.ExpireAPIKey(apiKey)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &v1.ExpireApiKeyResponse{}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) ListApiKeys(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.ListApiKeysRequest,
 | |
| ) (*v1.ListApiKeysResponse, error) {
 | |
| 	apiKeys, err := api.h.db.ListAPIKeys()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	response := make([]*v1.ApiKey, len(apiKeys))
 | |
| 	for index, key := range apiKeys {
 | |
| 		response[index] = key.Proto()
 | |
| 	}
 | |
| 
 | |
| 	sort.Slice(response, func(i, j int) bool {
 | |
| 		return response[i].Id < response[j].Id
 | |
| 	})
 | |
| 
 | |
| 	return &v1.ListApiKeysResponse{ApiKeys: response}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) DeleteApiKey(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.DeleteApiKeyRequest,
 | |
| ) (*v1.DeleteApiKeyResponse, error) {
 | |
| 	var (
 | |
| 		apiKey *types.APIKey
 | |
| 		err    error
 | |
| 	)
 | |
| 
 | |
| 	apiKey, err = api.h.db.GetAPIKey(request.Prefix)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	if err := api.h.db.DestroyAPIKey(*apiKey); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &v1.DeleteApiKeyResponse{}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) GetPolicy(
 | |
| 	_ context.Context,
 | |
| 	_ *v1.GetPolicyRequest,
 | |
| ) (*v1.GetPolicyResponse, error) {
 | |
| 	switch api.h.cfg.Policy.Mode {
 | |
| 	case types.PolicyModeDB:
 | |
| 		p, err := api.h.db.GetPolicy()
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("loading ACL from database: %w", err)
 | |
| 		}
 | |
| 
 | |
| 		return &v1.GetPolicyResponse{
 | |
| 			Policy:    p.Data,
 | |
| 			UpdatedAt: timestamppb.New(p.UpdatedAt),
 | |
| 		}, nil
 | |
| 	case types.PolicyModeFile:
 | |
| 		// Read the file and return the contents as-is.
 | |
| 		absPath := util.AbsolutePathFromConfigPath(api.h.cfg.Policy.Path)
 | |
| 		f, err := os.Open(absPath)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("reading policy from path %q: %w", absPath, err)
 | |
| 		}
 | |
| 
 | |
| 		defer f.Close()
 | |
| 
 | |
| 		b, err := io.ReadAll(f)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("reading policy from file: %w", err)
 | |
| 		}
 | |
| 
 | |
| 		return &v1.GetPolicyResponse{Policy: string(b)}, nil
 | |
| 	}
 | |
| 
 | |
| 	return nil, fmt.Errorf("no supported policy mode found in configuration, policy.mode: %q", api.h.cfg.Policy.Mode)
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) SetPolicy(
 | |
| 	_ context.Context,
 | |
| 	request *v1.SetPolicyRequest,
 | |
| ) (*v1.SetPolicyResponse, error) {
 | |
| 	if api.h.cfg.Policy.Mode != types.PolicyModeDB {
 | |
| 		return nil, types.ErrPolicyUpdateIsDisabled
 | |
| 	}
 | |
| 
 | |
| 	p := request.GetPolicy()
 | |
| 
 | |
| 	pol, err := policy.LoadACLPolicyFromBytes([]byte(p))
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("loading ACL policy file: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	// Validate and reject configuration that would error when applied
 | |
| 	// when creating a map response. This requires nodes, so there is still
 | |
| 	// a scenario where they might be allowed if the server has no nodes
 | |
| 	// yet, but it should help for the general case and for hot reloading
 | |
| 	// configurations.
 | |
| 	nodes, err := api.h.db.ListNodes()
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
 | |
| 	}
 | |
| 	users, err := api.h.db.ListUsers()
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("loading users from database to validate policy: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	_, err = pol.CompileFilterRules(users, nodes)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("verifying policy rules: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	if len(nodes) > 0 {
 | |
| 		_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("verifying SSH rules: %w", err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	updated, err := api.h.db.SetPolicy(p)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	api.h.ACLPolicy = pol
 | |
| 
 | |
| 	ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
 | |
| 	api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
 | |
| 		Type: types.StateFullUpdate,
 | |
| 	})
 | |
| 
 | |
| 	response := &v1.SetPolicyResponse{
 | |
| 		Policy:    updated.Data,
 | |
| 		UpdatedAt: timestamppb.New(updated.UpdatedAt),
 | |
| 	}
 | |
| 
 | |
| 	return response, nil
 | |
| }
 | |
| 
 | |
| // The following service calls are for testing and debugging
 | |
| func (api headscaleV1APIServer) DebugCreateNode(
 | |
| 	ctx context.Context,
 | |
| 	request *v1.DebugCreateNodeRequest,
 | |
| ) (*v1.DebugCreateNodeResponse, error) {
 | |
| 	user, err := api.h.db.GetUserByName(request.GetUser())
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	routes, err := util.StringToIPPrefix(request.GetRoutes())
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	log.Trace().
 | |
| 		Caller().
 | |
| 		Interface("route-prefix", routes).
 | |
| 		Interface("route-str", request.GetRoutes()).
 | |
| 		Msg("")
 | |
| 
 | |
| 	hostinfo := tailcfg.Hostinfo{
 | |
| 		RoutableIPs: routes,
 | |
| 		OS:          "TestOS",
 | |
| 		Hostname:    "DebugTestNode",
 | |
| 	}
 | |
| 
 | |
| 	var mkey key.MachinePublic
 | |
| 	err = mkey.UnmarshalText([]byte(request.GetKey()))
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	nodeKey := key.NewNode()
 | |
| 
 | |
| 	newNode := types.Node{
 | |
| 		MachineKey: mkey,
 | |
| 		NodeKey:    nodeKey.Public(),
 | |
| 		Hostname:   request.GetName(),
 | |
| 		User:       *user,
 | |
| 
 | |
| 		Expiry:   &time.Time{},
 | |
| 		LastSeen: &time.Time{},
 | |
| 
 | |
| 		Hostinfo: &hostinfo,
 | |
| 	}
 | |
| 
 | |
| 	log.Debug().
 | |
| 		Str("machine_key", mkey.ShortString()).
 | |
| 		Msg("adding debug machine via CLI, appending to registration cache")
 | |
| 
 | |
| 	api.h.registrationCache.Set(
 | |
| 		mkey.String(),
 | |
| 		newNode,
 | |
| 	)
 | |
| 
 | |
| 	return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil
 | |
| }
 | |
| 
 | |
| func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {}
 |