mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			826 lines
		
	
	
		
			22 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			826 lines
		
	
	
		
			22 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// nolint
 | 
						|
package hscontrol
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
	"net/netip"
 | 
						|
	"os"
 | 
						|
	"slices"
 | 
						|
	"sort"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/puzpuzpuz/xsync/v4"
 | 
						|
	"github.com/rs/zerolog/log"
 | 
						|
	"github.com/samber/lo"
 | 
						|
	"google.golang.org/grpc/codes"
 | 
						|
	"google.golang.org/grpc/status"
 | 
						|
	"google.golang.org/protobuf/types/known/timestamppb"
 | 
						|
	"gorm.io/gorm"
 | 
						|
	"tailscale.com/net/tsaddr"
 | 
						|
	"tailscale.com/tailcfg"
 | 
						|
	"tailscale.com/types/key"
 | 
						|
 | 
						|
	v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
 | 
						|
	"github.com/juanfont/headscale/hscontrol/state"
 | 
						|
	"github.com/juanfont/headscale/hscontrol/types"
 | 
						|
	"github.com/juanfont/headscale/hscontrol/util"
 | 
						|
)
 | 
						|
 | 
						|
type headscaleV1APIServer struct { // v1.HeadscaleServiceServer
 | 
						|
	v1.UnimplementedHeadscaleServiceServer
 | 
						|
	h *Headscale
 | 
						|
}
 | 
						|
 | 
						|
func newHeadscaleV1APIServer(h *Headscale) v1.HeadscaleServiceServer {
 | 
						|
	return headscaleV1APIServer{
 | 
						|
		h: h,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) CreateUser(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.CreateUserRequest,
 | 
						|
) (*v1.CreateUserResponse, error) {
 | 
						|
	newUser := types.User{
 | 
						|
		Name:          request.GetName(),
 | 
						|
		DisplayName:   request.GetDisplayName(),
 | 
						|
		Email:         request.GetEmail(),
 | 
						|
		ProfilePicURL: request.GetPictureUrl(),
 | 
						|
	}
 | 
						|
	user, policyChanged, err := api.h.state.CreateUser(newUser)
 | 
						|
	if err != nil {
 | 
						|
		return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Send policy update notifications if needed
 | 
						|
	if policyChanged {
 | 
						|
		ctx := types.NotifyCtx(context.Background(), "grpc-user-created", user.Name)
 | 
						|
		api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
 | 
						|
	}
 | 
						|
 | 
						|
	return &v1.CreateUserResponse{User: user.Proto()}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) RenameUser(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.RenameUserRequest,
 | 
						|
) (*v1.RenameUserResponse, error) {
 | 
						|
	oldUser, err := api.h.state.GetUserByID(types.UserID(request.GetOldId()))
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	_, policyChanged, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	// Send policy update notifications if needed
 | 
						|
	if policyChanged {
 | 
						|
		ctx := types.NotifyCtx(context.Background(), "grpc-user-renamed", request.GetNewName())
 | 
						|
		api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
 | 
						|
	}
 | 
						|
 | 
						|
	newUser, err := api.h.state.GetUserByName(request.GetNewName())
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return &v1.RenameUserResponse{User: newUser.Proto()}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) DeleteUser(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.DeleteUserRequest,
 | 
						|
) (*v1.DeleteUserResponse, error) {
 | 
						|
	user, err := api.h.state.GetUserByID(types.UserID(request.GetId()))
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	err = api.h.state.DeleteUser(types.UserID(user.ID))
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return &v1.DeleteUserResponse{}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) ListUsers(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.ListUsersRequest,
 | 
						|
) (*v1.ListUsersResponse, error) {
 | 
						|
	var err error
 | 
						|
	var users []types.User
 | 
						|
 | 
						|
	switch {
 | 
						|
	case request.GetName() != "":
 | 
						|
		users, err = api.h.state.ListUsersWithFilter(&types.User{Name: request.GetName()})
 | 
						|
	case request.GetEmail() != "":
 | 
						|
		users, err = api.h.state.ListUsersWithFilter(&types.User{Email: request.GetEmail()})
 | 
						|
	case request.GetId() != 0:
 | 
						|
		users, err = api.h.state.ListUsersWithFilter(&types.User{Model: gorm.Model{ID: uint(request.GetId())}})
 | 
						|
	default:
 | 
						|
		users, err = api.h.state.ListAllUsers()
 | 
						|
	}
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	response := make([]*v1.User, len(users))
 | 
						|
	for index, user := range users {
 | 
						|
		response[index] = user.Proto()
 | 
						|
	}
 | 
						|
 | 
						|
	sort.Slice(response, func(i, j int) bool {
 | 
						|
		return response[i].Id < response[j].Id
 | 
						|
	})
 | 
						|
 | 
						|
	return &v1.ListUsersResponse{Users: response}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) CreatePreAuthKey(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.CreatePreAuthKeyRequest,
 | 
						|
) (*v1.CreatePreAuthKeyResponse, error) {
 | 
						|
	var expiration time.Time
 | 
						|
	if request.GetExpiration() != nil {
 | 
						|
		expiration = request.GetExpiration().AsTime()
 | 
						|
	}
 | 
						|
 | 
						|
	for _, tag := range request.AclTags {
 | 
						|
		err := validateTag(tag)
 | 
						|
		if err != nil {
 | 
						|
			return &v1.CreatePreAuthKeyResponse{
 | 
						|
				PreAuthKey: nil,
 | 
						|
			}, status.Error(codes.InvalidArgument, err.Error())
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	user, err := api.h.state.GetUserByID(types.UserID(request.GetUser()))
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	preAuthKey, err := api.h.state.CreatePreAuthKey(
 | 
						|
		types.UserID(user.ID),
 | 
						|
		request.GetReusable(),
 | 
						|
		request.GetEphemeral(),
 | 
						|
		&expiration,
 | 
						|
		request.AclTags,
 | 
						|
	)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.Proto()}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) ExpirePreAuthKey(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.ExpirePreAuthKeyRequest,
 | 
						|
) (*v1.ExpirePreAuthKeyResponse, error) {
 | 
						|
	preAuthKey, err := api.h.state.GetPreAuthKey(request.Key)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if uint64(preAuthKey.User.ID) != request.GetUser() {
 | 
						|
		return nil, fmt.Errorf("preauth key does not belong to user")
 | 
						|
	}
 | 
						|
 | 
						|
	err = api.h.state.ExpirePreAuthKey(preAuthKey)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return &v1.ExpirePreAuthKeyResponse{}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) ListPreAuthKeys(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.ListPreAuthKeysRequest,
 | 
						|
) (*v1.ListPreAuthKeysResponse, error) {
 | 
						|
	user, err := api.h.state.GetUserByID(types.UserID(request.GetUser()))
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	preAuthKeys, err := api.h.state.ListPreAuthKeys(types.UserID(user.ID))
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	response := make([]*v1.PreAuthKey, len(preAuthKeys))
 | 
						|
	for index, key := range preAuthKeys {
 | 
						|
		response[index] = key.Proto()
 | 
						|
	}
 | 
						|
 | 
						|
	sort.Slice(response, func(i, j int) bool {
 | 
						|
		return response[i].Id < response[j].Id
 | 
						|
	})
 | 
						|
 | 
						|
	return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) RegisterNode(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.RegisterNodeRequest,
 | 
						|
) (*v1.RegisterNodeResponse, error) {
 | 
						|
	log.Trace().
 | 
						|
		Str("user", request.GetUser()).
 | 
						|
		Str("registration_id", request.GetKey()).
 | 
						|
		Msg("Registering node")
 | 
						|
 | 
						|
	registrationId, err := types.RegistrationIDFromString(request.GetKey())
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	user, err := api.h.state.GetUserByName(request.GetUser())
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("looking up user: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	node, _, err := api.h.state.HandleNodeFromAuthPath(
 | 
						|
		registrationId,
 | 
						|
		types.UserID(user.ID),
 | 
						|
		nil,
 | 
						|
		util.RegisterMethodCLI,
 | 
						|
	)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	// This is a bit of a back and forth, but we have a bit of a chicken and egg
 | 
						|
	// dependency here.
 | 
						|
	// Because the way the policy manager works, we need to have the node
 | 
						|
	// in the database, then add it to the policy manager and then we can
 | 
						|
	// approve the route. This means we get this dance where the node is
 | 
						|
	// first added to the database, then we add it to the policy manager via
 | 
						|
	// SaveNode (which automatically updates the policy manager) and then we can auto approve the routes.
 | 
						|
	// As that only approves the struct object, we need to save it again and
 | 
						|
	// ensure we send an update.
 | 
						|
	// This works, but might be another good candidate for doing some sort of
 | 
						|
	// eventbus.
 | 
						|
	routesChanged := api.h.state.AutoApproveRoutes(node)
 | 
						|
	_, policyChanged, err := api.h.state.SaveNode(node)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Send policy update notifications if needed (from SaveNode or route changes)
 | 
						|
	if policyChanged {
 | 
						|
		ctx := types.NotifyCtx(context.Background(), "grpc-nodes-change", "all")
 | 
						|
		api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
 | 
						|
	}
 | 
						|
 | 
						|
	if routesChanged {
 | 
						|
		ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname)
 | 
						|
		api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
 | 
						|
	}
 | 
						|
 | 
						|
	return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) GetNode(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.GetNodeRequest,
 | 
						|
) (*v1.GetNodeResponse, error) {
 | 
						|
	node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	resp := node.Proto()
 | 
						|
 | 
						|
	// Populate the online field based on
 | 
						|
	// currently connected nodes.
 | 
						|
	resp.Online = api.h.nodeNotifier.IsConnected(node.ID)
 | 
						|
 | 
						|
	return &v1.GetNodeResponse{Node: resp}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) SetTags(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.SetTagsRequest,
 | 
						|
) (*v1.SetTagsResponse, error) {
 | 
						|
	for _, tag := range request.GetTags() {
 | 
						|
		err := validateTag(tag)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	node, policyChanged, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags())
 | 
						|
	if err != nil {
 | 
						|
		return &v1.SetTagsResponse{
 | 
						|
			Node: nil,
 | 
						|
		}, status.Error(codes.InvalidArgument, err.Error())
 | 
						|
	}
 | 
						|
 | 
						|
	// Send policy update notifications if needed
 | 
						|
	if policyChanged {
 | 
						|
		ctx := types.NotifyCtx(context.Background(), "grpc-node-tags", node.Hostname)
 | 
						|
		api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
 | 
						|
	}
 | 
						|
 | 
						|
	ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname)
 | 
						|
	api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
 | 
						|
 | 
						|
	log.Trace().
 | 
						|
		Str("node", node.Hostname).
 | 
						|
		Strs("tags", request.GetTags()).
 | 
						|
		Msg("Changing tags of node")
 | 
						|
 | 
						|
	return &v1.SetTagsResponse{Node: node.Proto()}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) SetApprovedRoutes(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.SetApprovedRoutesRequest,
 | 
						|
) (*v1.SetApprovedRoutesResponse, error) {
 | 
						|
	var routes []netip.Prefix
 | 
						|
	for _, route := range request.GetRoutes() {
 | 
						|
		prefix, err := netip.ParsePrefix(route)
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("parsing route: %w", err)
 | 
						|
		}
 | 
						|
 | 
						|
		// If the prefix is an exit route, add both. The client expect both
 | 
						|
		// to annotate the node as an exit node.
 | 
						|
		if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() {
 | 
						|
			routes = append(routes, tsaddr.AllIPv4(), tsaddr.AllIPv6())
 | 
						|
		} else {
 | 
						|
			routes = append(routes, prefix)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	tsaddr.SortPrefixes(routes)
 | 
						|
	routes = slices.Compact(routes)
 | 
						|
 | 
						|
	node, policyChanged, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
 | 
						|
	if err != nil {
 | 
						|
		return nil, status.Error(codes.InvalidArgument, err.Error())
 | 
						|
	}
 | 
						|
 | 
						|
	// Send policy update notifications if needed
 | 
						|
	if policyChanged {
 | 
						|
		ctx := types.NotifyCtx(context.Background(), "grpc-routes-approved", node.Hostname)
 | 
						|
		api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
 | 
						|
	}
 | 
						|
 | 
						|
	if api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...) {
 | 
						|
		ctx := types.NotifyCtx(ctx, "poll-primary-change", node.Hostname)
 | 
						|
		api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
 | 
						|
	} else {
 | 
						|
		ctx = types.NotifyCtx(ctx, "cli-approveroutes", node.Hostname)
 | 
						|
		api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
 | 
						|
	}
 | 
						|
 | 
						|
	proto := node.Proto()
 | 
						|
	proto.SubnetRoutes = util.PrefixesToString(api.h.state.GetNodePrimaryRoutes(node.ID))
 | 
						|
 | 
						|
	return &v1.SetApprovedRoutesResponse{Node: proto}, nil
 | 
						|
}
 | 
						|
 | 
						|
func validateTag(tag string) error {
 | 
						|
	if strings.Index(tag, "tag:") != 0 {
 | 
						|
		return errors.New("tag must start with the string 'tag:'")
 | 
						|
	}
 | 
						|
	if strings.ToLower(tag) != tag {
 | 
						|
		return errors.New("tag should be lowercase")
 | 
						|
	}
 | 
						|
	if len(strings.Fields(tag)) > 1 {
 | 
						|
		return errors.New("tag should not contains space")
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) DeleteNode(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.DeleteNodeRequest,
 | 
						|
) (*v1.DeleteNodeResponse, error) {
 | 
						|
	node, err := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	policyChanged, err := api.h.state.DeleteNode(node)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	// Send policy update notifications if needed
 | 
						|
	if policyChanged {
 | 
						|
		ctx := types.NotifyCtx(context.Background(), "grpc-node-deleted", node.Hostname)
 | 
						|
		api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
 | 
						|
	}
 | 
						|
 | 
						|
	ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
 | 
						|
	api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
 | 
						|
 | 
						|
	return &v1.DeleteNodeResponse{}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) ExpireNode(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.ExpireNodeRequest,
 | 
						|
) (*v1.ExpireNodeResponse, error) {
 | 
						|
	now := time.Now()
 | 
						|
 | 
						|
	node, policyChanged, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	// Send policy update notifications if needed
 | 
						|
	if policyChanged {
 | 
						|
		ctx := types.NotifyCtx(context.Background(), "grpc-node-expired", node.Hostname)
 | 
						|
		api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
 | 
						|
	}
 | 
						|
 | 
						|
	ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
 | 
						|
	api.h.nodeNotifier.NotifyByNodeID(
 | 
						|
		ctx,
 | 
						|
		types.UpdateSelf(node.ID),
 | 
						|
		node.ID)
 | 
						|
 | 
						|
	ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
 | 
						|
	api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, now), node.ID)
 | 
						|
 | 
						|
	log.Trace().
 | 
						|
		Str("node", node.Hostname).
 | 
						|
		Time("expiry", *node.Expiry).
 | 
						|
		Msg("node expired")
 | 
						|
 | 
						|
	return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) RenameNode(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.RenameNodeRequest,
 | 
						|
) (*v1.RenameNodeResponse, error) {
 | 
						|
	node, policyChanged, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName())
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	// Send policy update notifications if needed
 | 
						|
	if policyChanged {
 | 
						|
		ctx := types.NotifyCtx(context.Background(), "grpc-node-renamed", node.Hostname)
 | 
						|
		api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
 | 
						|
	}
 | 
						|
 | 
						|
	ctx = types.NotifyCtx(ctx, "cli-renamenode-self", node.Hostname)
 | 
						|
	api.h.nodeNotifier.NotifyByNodeID(ctx, types.UpdateSelf(node.ID), node.ID)
 | 
						|
 | 
						|
	ctx = types.NotifyCtx(ctx, "cli-renamenode-peers", node.Hostname)
 | 
						|
	api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
 | 
						|
 | 
						|
	log.Trace().
 | 
						|
		Str("node", node.Hostname).
 | 
						|
		Str("new_name", request.GetNewName()).
 | 
						|
		Msg("node renamed")
 | 
						|
 | 
						|
	return &v1.RenameNodeResponse{Node: node.Proto()}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) ListNodes(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.ListNodesRequest,
 | 
						|
) (*v1.ListNodesResponse, error) {
 | 
						|
	// TODO(kradalby): it looks like this can be simplified a lot,
 | 
						|
	// the filtering of nodes by user, vs nodes as a whole can
 | 
						|
	// probably be done once.
 | 
						|
	// TODO(kradalby): This should be done in one tx.
 | 
						|
 | 
						|
	isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
 | 
						|
	if request.GetUser() != "" {
 | 
						|
		user, err := api.h.state.GetUserByName(request.GetUser())
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		nodes, err := api.h.state.ListNodesByUser(types.UserID(user.ID))
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		response := nodesToProto(api.h.state, isLikelyConnected, nodes)
 | 
						|
		return &v1.ListNodesResponse{Nodes: response}, nil
 | 
						|
	}
 | 
						|
 | 
						|
	nodes, err := api.h.state.ListNodes()
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	sort.Slice(nodes, func(i, j int) bool {
 | 
						|
		return nodes[i].ID < nodes[j].ID
 | 
						|
	})
 | 
						|
 | 
						|
	response := nodesToProto(api.h.state, isLikelyConnected, nodes)
 | 
						|
	return &v1.ListNodesResponse{Nodes: response}, nil
 | 
						|
}
 | 
						|
 | 
						|
func nodesToProto(state *state.State, isLikelyConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
 | 
						|
	response := make([]*v1.Node, len(nodes))
 | 
						|
	for index, node := range nodes {
 | 
						|
		resp := node.Proto()
 | 
						|
 | 
						|
		// Populate the online field based on
 | 
						|
		// currently connected nodes.
 | 
						|
		if val, ok := isLikelyConnected.Load(node.ID); ok && val {
 | 
						|
			resp.Online = true
 | 
						|
		}
 | 
						|
 | 
						|
		var tags []string
 | 
						|
		for _, tag := range node.RequestTags() {
 | 
						|
			if state.NodeCanHaveTag(node.View(), tag) {
 | 
						|
				tags = append(tags, tag)
 | 
						|
			}
 | 
						|
		}
 | 
						|
		resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags...))
 | 
						|
		resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID), node.ExitRoutes()...))
 | 
						|
		response[index] = resp
 | 
						|
	}
 | 
						|
 | 
						|
	return response
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) MoveNode(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.MoveNodeRequest,
 | 
						|
) (*v1.MoveNodeResponse, error) {
 | 
						|
	node, policyChanged, err := api.h.state.AssignNodeToUser(types.NodeID(request.GetNodeId()), types.UserID(request.GetUser()))
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	// Send policy update notifications if needed
 | 
						|
	if policyChanged {
 | 
						|
		ctx := types.NotifyCtx(context.Background(), "grpc-node-moved", node.Hostname)
 | 
						|
		api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
 | 
						|
	}
 | 
						|
 | 
						|
	ctx = types.NotifyCtx(ctx, "cli-movenode-self", node.Hostname)
 | 
						|
	api.h.nodeNotifier.NotifyByNodeID(
 | 
						|
		ctx,
 | 
						|
		types.UpdateSelf(node.ID),
 | 
						|
		node.ID)
 | 
						|
	ctx = types.NotifyCtx(ctx, "cli-movenode", node.Hostname)
 | 
						|
	api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
 | 
						|
 | 
						|
	return &v1.MoveNodeResponse{Node: node.Proto()}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) BackfillNodeIPs(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.BackfillNodeIPsRequest,
 | 
						|
) (*v1.BackfillNodeIPsResponse, error) {
 | 
						|
	log.Trace().Msg("Backfill called")
 | 
						|
 | 
						|
	if !request.Confirmed {
 | 
						|
		return nil, errors.New("not confirmed, aborting")
 | 
						|
	}
 | 
						|
 | 
						|
	changes, err := api.h.state.BackfillNodeIPs()
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return &v1.BackfillNodeIPsResponse{Changes: changes}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) CreateApiKey(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.CreateApiKeyRequest,
 | 
						|
) (*v1.CreateApiKeyResponse, error) {
 | 
						|
	var expiration time.Time
 | 
						|
	if request.GetExpiration() != nil {
 | 
						|
		expiration = request.GetExpiration().AsTime()
 | 
						|
	}
 | 
						|
 | 
						|
	apiKey, _, err := api.h.state.CreateAPIKey(&expiration)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return &v1.CreateApiKeyResponse{ApiKey: apiKey}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) ExpireApiKey(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.ExpireApiKeyRequest,
 | 
						|
) (*v1.ExpireApiKeyResponse, error) {
 | 
						|
	var apiKey *types.APIKey
 | 
						|
	var err error
 | 
						|
 | 
						|
	apiKey, err = api.h.state.GetAPIKey(request.Prefix)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	err = api.h.state.ExpireAPIKey(apiKey)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return &v1.ExpireApiKeyResponse{}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) ListApiKeys(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.ListApiKeysRequest,
 | 
						|
) (*v1.ListApiKeysResponse, error) {
 | 
						|
	apiKeys, err := api.h.state.ListAPIKeys()
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	response := make([]*v1.ApiKey, len(apiKeys))
 | 
						|
	for index, key := range apiKeys {
 | 
						|
		response[index] = key.Proto()
 | 
						|
	}
 | 
						|
 | 
						|
	sort.Slice(response, func(i, j int) bool {
 | 
						|
		return response[i].Id < response[j].Id
 | 
						|
	})
 | 
						|
 | 
						|
	return &v1.ListApiKeysResponse{ApiKeys: response}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) DeleteApiKey(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.DeleteApiKeyRequest,
 | 
						|
) (*v1.DeleteApiKeyResponse, error) {
 | 
						|
	var (
 | 
						|
		apiKey *types.APIKey
 | 
						|
		err    error
 | 
						|
	)
 | 
						|
 | 
						|
	apiKey, err = api.h.state.GetAPIKey(request.Prefix)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	if err := api.h.state.DestroyAPIKey(*apiKey); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return &v1.DeleteApiKeyResponse{}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) GetPolicy(
 | 
						|
	_ context.Context,
 | 
						|
	_ *v1.GetPolicyRequest,
 | 
						|
) (*v1.GetPolicyResponse, error) {
 | 
						|
	switch api.h.cfg.Policy.Mode {
 | 
						|
	case types.PolicyModeDB:
 | 
						|
		p, err := api.h.state.GetPolicy()
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("loading ACL from database: %w", err)
 | 
						|
		}
 | 
						|
 | 
						|
		return &v1.GetPolicyResponse{
 | 
						|
			Policy:    p.Data,
 | 
						|
			UpdatedAt: timestamppb.New(p.UpdatedAt),
 | 
						|
		}, nil
 | 
						|
	case types.PolicyModeFile:
 | 
						|
		// Read the file and return the contents as-is.
 | 
						|
		absPath := util.AbsolutePathFromConfigPath(api.h.cfg.Policy.Path)
 | 
						|
		f, err := os.Open(absPath)
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("reading policy from path %q: %w", absPath, err)
 | 
						|
		}
 | 
						|
 | 
						|
		defer f.Close()
 | 
						|
 | 
						|
		b, err := io.ReadAll(f)
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("reading policy from file: %w", err)
 | 
						|
		}
 | 
						|
 | 
						|
		return &v1.GetPolicyResponse{Policy: string(b)}, nil
 | 
						|
	}
 | 
						|
 | 
						|
	return nil, fmt.Errorf("no supported policy mode found in configuration, policy.mode: %q", api.h.cfg.Policy.Mode)
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) SetPolicy(
 | 
						|
	_ context.Context,
 | 
						|
	request *v1.SetPolicyRequest,
 | 
						|
) (*v1.SetPolicyResponse, error) {
 | 
						|
	if api.h.cfg.Policy.Mode != types.PolicyModeDB {
 | 
						|
		return nil, types.ErrPolicyUpdateIsDisabled
 | 
						|
	}
 | 
						|
 | 
						|
	p := request.GetPolicy()
 | 
						|
 | 
						|
	// Validate and reject configuration that would error when applied
 | 
						|
	// when creating a map response. This requires nodes, so there is still
 | 
						|
	// a scenario where they might be allowed if the server has no nodes
 | 
						|
	// yet, but it should help for the general case and for hot reloading
 | 
						|
	// configurations.
 | 
						|
	nodes, err := api.h.state.ListNodes()
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
 | 
						|
	}
 | 
						|
	changed, err := api.h.state.SetPolicy([]byte(p))
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("setting policy: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if len(nodes) > 0 {
 | 
						|
		_, err = api.h.state.SSHPolicy(nodes[0].View())
 | 
						|
		if err != nil {
 | 
						|
			return nil, fmt.Errorf("verifying SSH rules: %w", err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	updated, err := api.h.state.SetPolicyInDB(p)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	// Only send update if the packet filter has changed.
 | 
						|
	if changed {
 | 
						|
		err = api.h.state.AutoApproveNodes()
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
 | 
						|
		api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
 | 
						|
	}
 | 
						|
 | 
						|
	response := &v1.SetPolicyResponse{
 | 
						|
		Policy:    updated.Data,
 | 
						|
		UpdatedAt: timestamppb.New(updated.UpdatedAt),
 | 
						|
	}
 | 
						|
 | 
						|
	return response, nil
 | 
						|
}
 | 
						|
 | 
						|
// The following service calls are for testing and debugging
 | 
						|
func (api headscaleV1APIServer) DebugCreateNode(
 | 
						|
	ctx context.Context,
 | 
						|
	request *v1.DebugCreateNodeRequest,
 | 
						|
) (*v1.DebugCreateNodeResponse, error) {
 | 
						|
	user, err := api.h.state.GetUserByName(request.GetUser())
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	routes, err := util.StringToIPPrefix(request.GetRoutes())
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	log.Trace().
 | 
						|
		Caller().
 | 
						|
		Interface("route-prefix", routes).
 | 
						|
		Interface("route-str", request.GetRoutes()).
 | 
						|
		Msg("")
 | 
						|
 | 
						|
	hostinfo := tailcfg.Hostinfo{
 | 
						|
		RoutableIPs: routes,
 | 
						|
		OS:          "TestOS",
 | 
						|
		Hostname:    "DebugTestNode",
 | 
						|
	}
 | 
						|
 | 
						|
	registrationId, err := types.RegistrationIDFromString(request.GetKey())
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	newNode := types.RegisterNode{
 | 
						|
		Node: types.Node{
 | 
						|
			NodeKey:    key.NewNode().Public(),
 | 
						|
			MachineKey: key.NewMachine().Public(),
 | 
						|
			Hostname:   request.GetName(),
 | 
						|
			User:       *user,
 | 
						|
 | 
						|
			Expiry:   &time.Time{},
 | 
						|
			LastSeen: &time.Time{},
 | 
						|
 | 
						|
			Hostinfo: &hostinfo,
 | 
						|
		},
 | 
						|
		Registered: make(chan *types.Node),
 | 
						|
	}
 | 
						|
 | 
						|
	log.Debug().
 | 
						|
		Str("registration_id", registrationId.String()).
 | 
						|
		Msg("adding debug machine via CLI, appending to registration cache")
 | 
						|
 | 
						|
	api.h.state.SetRegistrationCacheEntry(registrationId, newNode)
 | 
						|
 | 
						|
	return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {}
 |