1
0
mirror of https://github.com/juanfont/headscale.git synced 2026-02-07 20:04:00 +01:00
This commit is contained in:
Kristoffer Dalby 2026-02-20 10:53:37 +00:00 committed by GitHub
commit d37106fe80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
51 changed files with 2520 additions and 650 deletions

View File

@ -253,6 +253,10 @@ jobs:
- TestSSHIsBlockedInACL
- TestSSHUserOnlyIsolation
- TestSSHAutogroupSelf
- TestSSHOneUserToOneCheckModeCLI
- TestSSHOneUserToOneCheckModeOIDC
- TestSSHCheckModeUnapprovedTimeout
- TestSSHCheckModeCheckPeriodCLI
- TestTagsAuthKeyWithTagRequestDifferentTag
- TestTagsAuthKeyWithTagNoAdvertiseFlag
- TestTagsAuthKeyWithTagCannotAddViaCLI

View File

@ -48,7 +48,7 @@ repos:
# golangci-lint for Go code quality
- id: golangci-lint
name: golangci-lint
entry: golangci-lint run --new-from-rev=HEAD~1 --timeout=5m --fix
entry: nix develop --command -- golangci-lint run --new-from-rev=HEAD~1 --timeout=5m --fix
language: system
types: [go]
pass_filenames: false

View File

@ -11,6 +11,19 @@ to understand how the packet filter should be generated. We discovered a few dif
overall our implementation was very close.
[#3036](https://github.com/juanfont/headscale/pull/3036)
### SSH check action
SSH rules with `"action": "check"` are now supported. When a client initiates an SSH connection to a node
with a `check` action policy, the user is prompted to authenticate via OIDC or CLI approval before access
is granted.
A new `headscale auth` CLI command group supports the approval flow:
- `headscale auth approve --auth-id <id>` approves a pending authentication request (SSH check or web auth)
- `headscale auth register --auth-id <id> --user <user>` registers a node (replaces deprecated `headscale nodes register`)
[#1850](https://github.com/juanfont/headscale/pull/1850)
### BREAKING
- **ACL Policy**: Wildcard (`*`) in ACL sources and destinations now resolves to Tailscale's CGNAT range (`100.64.0.0/10`) and ULA range (`fd7a:115c:a1e0::/48`) instead of all IPs (`0.0.0.0/0` and `::/0`) [#3036](https://github.com/juanfont/headscale/pull/3036)
@ -26,6 +39,8 @@ overall our implementation was very close.
- **ACL Policy**: The `proto:icmp` protocol name now only includes ICMPv4 (protocol 1), matching Tailscale behavior [#3036](https://github.com/juanfont/headscale/pull/3036)
- Previously, `proto:icmp` included both ICMPv4 and ICMPv6
- Use `proto:ipv6-icmp` or protocol number `58` explicitly for ICMPv6
- **CLI**: `headscale nodes register` is deprecated in favour of `headscale auth register --auth-id <id> --user <user>` [#1850](https://github.com/juanfont/headscale/pull/1850)
- The old command continues to work but will be removed in a future release
### Changes
@ -35,6 +50,11 @@ overall our implementation was very close.
- **ACL Policy**: Merge filter rules with identical SrcIPs and IPProto matching Tailscale behavior - multiple ACL rules with the same source now produce a single FilterRule with combined DstPorts [#3036](https://github.com/juanfont/headscale/pull/3036)
- Remove deprecated `--namespace` flag from `nodes list`, `nodes register`, and `debug create-node` commands (use `--user` instead) [#3093](https://github.com/juanfont/headscale/pull/3093)
- Remove deprecated `namespace`/`ns` command aliases for `users` and `machine`/`machines` aliases for `nodes` [#3093](https://github.com/juanfont/headscale/pull/3093)
- Add SSH `check` action support with OIDC and CLI-based approval flows [#1850](https://github.com/juanfont/headscale/pull/1850)
- Add `headscale auth register` and `headscale auth approve` CLI commands [#1850](https://github.com/juanfont/headscale/pull/1850)
- Deprecate `headscale nodes register --key` in favour of `headscale auth register --auth-id` [#1850](https://github.com/juanfont/headscale/pull/1850)
- Generalise auth templates into reusable `AuthSuccess` and `AuthWeb` components [#1850](https://github.com/juanfont/headscale/pull/1850)
- Unify auth pipeline with `AuthVerdict` type, supporting registration, reauthentication, and SSH checks [#1850](https://github.com/juanfont/headscale/pull/1850)
## 0.28.0 (2026-02-04)

70
cmd/headscale/cli/auth.go Normal file
View File

@ -0,0 +1,70 @@
package cli
import (
"context"
"fmt"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/spf13/cobra"
)
func init() {
rootCmd.AddCommand(authCmd)
authRegisterCmd.Flags().StringP("user", "u", "", "User")
authRegisterCmd.Flags().String("auth-id", "", "Auth ID")
mustMarkRequired(authRegisterCmd, "user", "auth-id")
authCmd.AddCommand(authRegisterCmd)
authApproveCmd.Flags().String("auth-id", "", "Auth ID")
mustMarkRequired(authApproveCmd, "auth-id")
authCmd.AddCommand(authApproveCmd)
}
var authCmd = &cobra.Command{
Use: "auth",
Short: "Manage node authentication and approval",
}
var authRegisterCmd = &cobra.Command{
Use: "register",
Short: "Register a node to your network",
RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error {
user, _ := cmd.Flags().GetString("user")
authID, _ := cmd.Flags().GetString("auth-id")
request := &v1.AuthRegisterRequest{
AuthId: authID,
User: user,
}
response, err := client.AuthRegister(ctx, request)
if err != nil {
return fmt.Errorf("registering node: %w", err)
}
return printOutput(
cmd,
response.GetNode(),
fmt.Sprintf("Node %s registered", response.GetNode().GetGivenName()))
}),
}
var authApproveCmd = &cobra.Command{
Use: "approve",
Short: "Approve a pending authentication request",
RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error {
authID, _ := cmd.Flags().GetString("auth-id")
request := &v1.AuthApproveRequest{
AuthId: authID,
}
response, err := client.AuthApprove(ctx, request)
if err != nil {
return fmt.Errorf("approving auth request: %w", err)
}
return printOutput(cmd, response, "Auth request approved")
}),
}

View File

@ -37,7 +37,7 @@ var createNodeCmd = &cobra.Command{
name, _ := cmd.Flags().GetString("name")
registrationID, _ := cmd.Flags().GetString("key")
_, err := types.RegistrationIDFromString(registrationID)
_, err := types.AuthIDFromString(registrationID)
if err != nil {
return fmt.Errorf("parsing machine key: %w", err)
}

View File

@ -63,8 +63,9 @@ var nodeCmd = &cobra.Command{
}
var registerNodeCmd = &cobra.Command{
Use: "register",
Short: "Registers a node to your network",
Use: "register",
Short: "Registers a node to your network",
Deprecated: "use 'headscale auth register --auth-id <id> --user <user>' instead",
RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error {
user, _ := cmd.Flags().GetString("user")
registrationID, _ := cmd.Flags().GetString("key")

View File

@ -27,7 +27,7 @@
let
pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system};
buildGo = pkgs.buildGo126Module;
vendorHash = "sha256-9BvphYDAxzwooyVokI3l+q1wRuRsWn/qM+NpWUgqJH0=";
vendorHash = "sha256-oUN53ELb3+xn4yA7lEfXyT2c7NxbQC6RtbkGVq6+RLU=";
in
{
headscale = buildGo {
@ -135,11 +135,6 @@
};
};
# The package uses buildGo125Module, not the convention.
# goreleaser = prev.goreleaser.override {
# buildGoModule = buildGo;
# };
gotestsum = prev.gotestsum.override {
buildGoModule = buildGo;
};
@ -152,9 +147,9 @@
buildGoModule = buildGo;
};
# gopls = prev.gopls.override {
# buildGoModule = buildGo;
# };
gopls = prev.gopls.override {
buildGoLatestModule = buildGo;
};
};
}
// flake-utils.lib.eachDefaultSystem

View File

@ -0,0 +1,266 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.11
// protoc (unknown)
// source: headscale/v1/auth.proto
package v1
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type AuthRegisterRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
User string `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"`
AuthId string `protobuf:"bytes,2,opt,name=auth_id,json=authId,proto3" json:"auth_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *AuthRegisterRequest) Reset() {
*x = AuthRegisterRequest{}
mi := &file_headscale_v1_auth_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AuthRegisterRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AuthRegisterRequest) ProtoMessage() {}
func (x *AuthRegisterRequest) ProtoReflect() protoreflect.Message {
mi := &file_headscale_v1_auth_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use AuthRegisterRequest.ProtoReflect.Descriptor instead.
func (*AuthRegisterRequest) Descriptor() ([]byte, []int) {
return file_headscale_v1_auth_proto_rawDescGZIP(), []int{0}
}
func (x *AuthRegisterRequest) GetUser() string {
if x != nil {
return x.User
}
return ""
}
func (x *AuthRegisterRequest) GetAuthId() string {
if x != nil {
return x.AuthId
}
return ""
}
type AuthRegisterResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *AuthRegisterResponse) Reset() {
*x = AuthRegisterResponse{}
mi := &file_headscale_v1_auth_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AuthRegisterResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AuthRegisterResponse) ProtoMessage() {}
func (x *AuthRegisterResponse) ProtoReflect() protoreflect.Message {
mi := &file_headscale_v1_auth_proto_msgTypes[1]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use AuthRegisterResponse.ProtoReflect.Descriptor instead.
func (*AuthRegisterResponse) Descriptor() ([]byte, []int) {
return file_headscale_v1_auth_proto_rawDescGZIP(), []int{1}
}
func (x *AuthRegisterResponse) GetNode() *Node {
if x != nil {
return x.Node
}
return nil
}
type AuthApproveRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
AuthId string `protobuf:"bytes,1,opt,name=auth_id,json=authId,proto3" json:"auth_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *AuthApproveRequest) Reset() {
*x = AuthApproveRequest{}
mi := &file_headscale_v1_auth_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AuthApproveRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AuthApproveRequest) ProtoMessage() {}
func (x *AuthApproveRequest) ProtoReflect() protoreflect.Message {
mi := &file_headscale_v1_auth_proto_msgTypes[2]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use AuthApproveRequest.ProtoReflect.Descriptor instead.
func (*AuthApproveRequest) Descriptor() ([]byte, []int) {
return file_headscale_v1_auth_proto_rawDescGZIP(), []int{2}
}
func (x *AuthApproveRequest) GetAuthId() string {
if x != nil {
return x.AuthId
}
return ""
}
type AuthApproveResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *AuthApproveResponse) Reset() {
*x = AuthApproveResponse{}
mi := &file_headscale_v1_auth_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AuthApproveResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AuthApproveResponse) ProtoMessage() {}
func (x *AuthApproveResponse) ProtoReflect() protoreflect.Message {
mi := &file_headscale_v1_auth_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use AuthApproveResponse.ProtoReflect.Descriptor instead.
func (*AuthApproveResponse) Descriptor() ([]byte, []int) {
return file_headscale_v1_auth_proto_rawDescGZIP(), []int{3}
}
var File_headscale_v1_auth_proto protoreflect.FileDescriptor
const file_headscale_v1_auth_proto_rawDesc = "" +
"\n" +
"\x17headscale/v1/auth.proto\x12\fheadscale.v1\x1a\x17headscale/v1/node.proto\"B\n" +
"\x13AuthRegisterRequest\x12\x12\n" +
"\x04user\x18\x01 \x01(\tR\x04user\x12\x17\n" +
"\aauth_id\x18\x02 \x01(\tR\x06authId\">\n" +
"\x14AuthRegisterResponse\x12&\n" +
"\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\"-\n" +
"\x12AuthApproveRequest\x12\x17\n" +
"\aauth_id\x18\x01 \x01(\tR\x06authId\"\x15\n" +
"\x13AuthApproveResponseB)Z'github.com/juanfont/headscale/gen/go/v1b\x06proto3"
var (
file_headscale_v1_auth_proto_rawDescOnce sync.Once
file_headscale_v1_auth_proto_rawDescData []byte
)
func file_headscale_v1_auth_proto_rawDescGZIP() []byte {
file_headscale_v1_auth_proto_rawDescOnce.Do(func() {
file_headscale_v1_auth_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_headscale_v1_auth_proto_rawDesc), len(file_headscale_v1_auth_proto_rawDesc)))
})
return file_headscale_v1_auth_proto_rawDescData
}
var file_headscale_v1_auth_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
var file_headscale_v1_auth_proto_goTypes = []any{
(*AuthRegisterRequest)(nil), // 0: headscale.v1.AuthRegisterRequest
(*AuthRegisterResponse)(nil), // 1: headscale.v1.AuthRegisterResponse
(*AuthApproveRequest)(nil), // 2: headscale.v1.AuthApproveRequest
(*AuthApproveResponse)(nil), // 3: headscale.v1.AuthApproveResponse
(*Node)(nil), // 4: headscale.v1.Node
}
var file_headscale_v1_auth_proto_depIdxs = []int32{
4, // 0: headscale.v1.AuthRegisterResponse.node:type_name -> headscale.v1.Node
1, // [1:1] is the sub-list for method output_type
1, // [1:1] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
}
func init() { file_headscale_v1_auth_proto_init() }
func file_headscale_v1_auth_proto_init() {
if File_headscale_v1_auth_proto != nil {
return
}
file_headscale_v1_node_proto_init()
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_headscale_v1_auth_proto_rawDesc), len(file_headscale_v1_auth_proto_rawDesc)),
NumEnums: 0,
NumMessages: 4,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_headscale_v1_auth_proto_goTypes,
DependencyIndexes: file_headscale_v1_auth_proto_depIdxs,
MessageInfos: file_headscale_v1_auth_proto_msgTypes,
}.Build()
File_headscale_v1_auth_proto = out.File
file_headscale_v1_auth_proto_goTypes = nil
file_headscale_v1_auth_proto_depIdxs = nil
}

View File

@ -106,10 +106,10 @@ var File_headscale_v1_headscale_proto protoreflect.FileDescriptor
const file_headscale_v1_headscale_proto_rawDesc = "" +
"\n" +
"\x1cheadscale/v1/headscale.proto\x12\fheadscale.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17headscale/v1/user.proto\x1a\x1dheadscale/v1/preauthkey.proto\x1a\x17headscale/v1/node.proto\x1a\x19headscale/v1/apikey.proto\x1a\x19headscale/v1/policy.proto\"\x0f\n" +
"\x1cheadscale/v1/headscale.proto\x12\fheadscale.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17headscale/v1/user.proto\x1a\x1dheadscale/v1/preauthkey.proto\x1a\x17headscale/v1/node.proto\x1a\x19headscale/v1/apikey.proto\x1a\x17headscale/v1/auth.proto\x1a\x19headscale/v1/policy.proto\"\x0f\n" +
"\rHealthRequest\"E\n" +
"\x0eHealthResponse\x123\n" +
"\x15database_connectivity\x18\x01 \x01(\bR\x14databaseConnectivity2\x8c\x17\n" +
"\x15database_connectivity\x18\x01 \x01(\bR\x14databaseConnectivity2\xfa\x18\n" +
"\x10HeadscaleService\x12h\n" +
"\n" +
"CreateUser\x12\x1f.headscale.v1.CreateUserRequest\x1a .headscale.v1.CreateUserResponse\"\x17\x82\xd3\xe4\x93\x02\x11:\x01*\"\f/api/v1/user\x12\x80\x01\n" +
@ -134,7 +134,9 @@ const file_headscale_v1_headscale_proto_rawDesc = "" +
"\n" +
"RenameNode\x12\x1f.headscale.v1.RenameNodeRequest\x1a .headscale.v1.RenameNodeResponse\"0\x82\xd3\xe4\x93\x02*\"(/api/v1/node/{node_id}/rename/{new_name}\x12b\n" +
"\tListNodes\x12\x1e.headscale.v1.ListNodesRequest\x1a\x1f.headscale.v1.ListNodesResponse\"\x14\x82\xd3\xe4\x93\x02\x0e\x12\f/api/v1/node\x12\x80\x01\n" +
"\x0fBackfillNodeIPs\x12$.headscale.v1.BackfillNodeIPsRequest\x1a%.headscale.v1.BackfillNodeIPsResponse\" \x82\xd3\xe4\x93\x02\x1a\"\x18/api/v1/node/backfillips\x12p\n" +
"\x0fBackfillNodeIPs\x12$.headscale.v1.BackfillNodeIPsRequest\x1a%.headscale.v1.BackfillNodeIPsResponse\" \x82\xd3\xe4\x93\x02\x1a\"\x18/api/v1/node/backfillips\x12w\n" +
"\fAuthRegister\x12!.headscale.v1.AuthRegisterRequest\x1a\".headscale.v1.AuthRegisterResponse\" \x82\xd3\xe4\x93\x02\x1a:\x01*\"\x15/api/v1/auth/register\x12s\n" +
"\vAuthApprove\x12 .headscale.v1.AuthApproveRequest\x1a!.headscale.v1.AuthApproveResponse\"\x1f\x82\xd3\xe4\x93\x02\x19:\x01*\"\x14/api/v1/auth/approve\x12p\n" +
"\fCreateApiKey\x12!.headscale.v1.CreateApiKeyRequest\x1a\".headscale.v1.CreateApiKeyResponse\"\x19\x82\xd3\xe4\x93\x02\x13:\x01*\"\x0e/api/v1/apikey\x12w\n" +
"\fExpireApiKey\x12!.headscale.v1.ExpireApiKeyRequest\x1a\".headscale.v1.ExpireApiKeyResponse\" \x82\xd3\xe4\x93\x02\x1a:\x01*\"\x15/api/v1/apikey/expire\x12j\n" +
"\vListApiKeys\x12 .headscale.v1.ListApiKeysRequest\x1a!.headscale.v1.ListApiKeysResponse\"\x16\x82\xd3\xe4\x93\x02\x10\x12\x0e/api/v1/apikey\x12v\n" +
@ -177,36 +179,40 @@ var file_headscale_v1_headscale_proto_goTypes = []any{
(*RenameNodeRequest)(nil), // 17: headscale.v1.RenameNodeRequest
(*ListNodesRequest)(nil), // 18: headscale.v1.ListNodesRequest
(*BackfillNodeIPsRequest)(nil), // 19: headscale.v1.BackfillNodeIPsRequest
(*CreateApiKeyRequest)(nil), // 20: headscale.v1.CreateApiKeyRequest
(*ExpireApiKeyRequest)(nil), // 21: headscale.v1.ExpireApiKeyRequest
(*ListApiKeysRequest)(nil), // 22: headscale.v1.ListApiKeysRequest
(*DeleteApiKeyRequest)(nil), // 23: headscale.v1.DeleteApiKeyRequest
(*GetPolicyRequest)(nil), // 24: headscale.v1.GetPolicyRequest
(*SetPolicyRequest)(nil), // 25: headscale.v1.SetPolicyRequest
(*CreateUserResponse)(nil), // 26: headscale.v1.CreateUserResponse
(*RenameUserResponse)(nil), // 27: headscale.v1.RenameUserResponse
(*DeleteUserResponse)(nil), // 28: headscale.v1.DeleteUserResponse
(*ListUsersResponse)(nil), // 29: headscale.v1.ListUsersResponse
(*CreatePreAuthKeyResponse)(nil), // 30: headscale.v1.CreatePreAuthKeyResponse
(*ExpirePreAuthKeyResponse)(nil), // 31: headscale.v1.ExpirePreAuthKeyResponse
(*DeletePreAuthKeyResponse)(nil), // 32: headscale.v1.DeletePreAuthKeyResponse
(*ListPreAuthKeysResponse)(nil), // 33: headscale.v1.ListPreAuthKeysResponse
(*DebugCreateNodeResponse)(nil), // 34: headscale.v1.DebugCreateNodeResponse
(*GetNodeResponse)(nil), // 35: headscale.v1.GetNodeResponse
(*SetTagsResponse)(nil), // 36: headscale.v1.SetTagsResponse
(*SetApprovedRoutesResponse)(nil), // 37: headscale.v1.SetApprovedRoutesResponse
(*RegisterNodeResponse)(nil), // 38: headscale.v1.RegisterNodeResponse
(*DeleteNodeResponse)(nil), // 39: headscale.v1.DeleteNodeResponse
(*ExpireNodeResponse)(nil), // 40: headscale.v1.ExpireNodeResponse
(*RenameNodeResponse)(nil), // 41: headscale.v1.RenameNodeResponse
(*ListNodesResponse)(nil), // 42: headscale.v1.ListNodesResponse
(*BackfillNodeIPsResponse)(nil), // 43: headscale.v1.BackfillNodeIPsResponse
(*CreateApiKeyResponse)(nil), // 44: headscale.v1.CreateApiKeyResponse
(*ExpireApiKeyResponse)(nil), // 45: headscale.v1.ExpireApiKeyResponse
(*ListApiKeysResponse)(nil), // 46: headscale.v1.ListApiKeysResponse
(*DeleteApiKeyResponse)(nil), // 47: headscale.v1.DeleteApiKeyResponse
(*GetPolicyResponse)(nil), // 48: headscale.v1.GetPolicyResponse
(*SetPolicyResponse)(nil), // 49: headscale.v1.SetPolicyResponse
(*AuthRegisterRequest)(nil), // 20: headscale.v1.AuthRegisterRequest
(*AuthApproveRequest)(nil), // 21: headscale.v1.AuthApproveRequest
(*CreateApiKeyRequest)(nil), // 22: headscale.v1.CreateApiKeyRequest
(*ExpireApiKeyRequest)(nil), // 23: headscale.v1.ExpireApiKeyRequest
(*ListApiKeysRequest)(nil), // 24: headscale.v1.ListApiKeysRequest
(*DeleteApiKeyRequest)(nil), // 25: headscale.v1.DeleteApiKeyRequest
(*GetPolicyRequest)(nil), // 26: headscale.v1.GetPolicyRequest
(*SetPolicyRequest)(nil), // 27: headscale.v1.SetPolicyRequest
(*CreateUserResponse)(nil), // 28: headscale.v1.CreateUserResponse
(*RenameUserResponse)(nil), // 29: headscale.v1.RenameUserResponse
(*DeleteUserResponse)(nil), // 30: headscale.v1.DeleteUserResponse
(*ListUsersResponse)(nil), // 31: headscale.v1.ListUsersResponse
(*CreatePreAuthKeyResponse)(nil), // 32: headscale.v1.CreatePreAuthKeyResponse
(*ExpirePreAuthKeyResponse)(nil), // 33: headscale.v1.ExpirePreAuthKeyResponse
(*DeletePreAuthKeyResponse)(nil), // 34: headscale.v1.DeletePreAuthKeyResponse
(*ListPreAuthKeysResponse)(nil), // 35: headscale.v1.ListPreAuthKeysResponse
(*DebugCreateNodeResponse)(nil), // 36: headscale.v1.DebugCreateNodeResponse
(*GetNodeResponse)(nil), // 37: headscale.v1.GetNodeResponse
(*SetTagsResponse)(nil), // 38: headscale.v1.SetTagsResponse
(*SetApprovedRoutesResponse)(nil), // 39: headscale.v1.SetApprovedRoutesResponse
(*RegisterNodeResponse)(nil), // 40: headscale.v1.RegisterNodeResponse
(*DeleteNodeResponse)(nil), // 41: headscale.v1.DeleteNodeResponse
(*ExpireNodeResponse)(nil), // 42: headscale.v1.ExpireNodeResponse
(*RenameNodeResponse)(nil), // 43: headscale.v1.RenameNodeResponse
(*ListNodesResponse)(nil), // 44: headscale.v1.ListNodesResponse
(*BackfillNodeIPsResponse)(nil), // 45: headscale.v1.BackfillNodeIPsResponse
(*AuthRegisterResponse)(nil), // 46: headscale.v1.AuthRegisterResponse
(*AuthApproveResponse)(nil), // 47: headscale.v1.AuthApproveResponse
(*CreateApiKeyResponse)(nil), // 48: headscale.v1.CreateApiKeyResponse
(*ExpireApiKeyResponse)(nil), // 49: headscale.v1.ExpireApiKeyResponse
(*ListApiKeysResponse)(nil), // 50: headscale.v1.ListApiKeysResponse
(*DeleteApiKeyResponse)(nil), // 51: headscale.v1.DeleteApiKeyResponse
(*GetPolicyResponse)(nil), // 52: headscale.v1.GetPolicyResponse
(*SetPolicyResponse)(nil), // 53: headscale.v1.SetPolicyResponse
}
var file_headscale_v1_headscale_proto_depIdxs = []int32{
2, // 0: headscale.v1.HeadscaleService.CreateUser:input_type -> headscale.v1.CreateUserRequest
@ -227,40 +233,44 @@ var file_headscale_v1_headscale_proto_depIdxs = []int32{
17, // 15: headscale.v1.HeadscaleService.RenameNode:input_type -> headscale.v1.RenameNodeRequest
18, // 16: headscale.v1.HeadscaleService.ListNodes:input_type -> headscale.v1.ListNodesRequest
19, // 17: headscale.v1.HeadscaleService.BackfillNodeIPs:input_type -> headscale.v1.BackfillNodeIPsRequest
20, // 18: headscale.v1.HeadscaleService.CreateApiKey:input_type -> headscale.v1.CreateApiKeyRequest
21, // 19: headscale.v1.HeadscaleService.ExpireApiKey:input_type -> headscale.v1.ExpireApiKeyRequest
22, // 20: headscale.v1.HeadscaleService.ListApiKeys:input_type -> headscale.v1.ListApiKeysRequest
23, // 21: headscale.v1.HeadscaleService.DeleteApiKey:input_type -> headscale.v1.DeleteApiKeyRequest
24, // 22: headscale.v1.HeadscaleService.GetPolicy:input_type -> headscale.v1.GetPolicyRequest
25, // 23: headscale.v1.HeadscaleService.SetPolicy:input_type -> headscale.v1.SetPolicyRequest
0, // 24: headscale.v1.HeadscaleService.Health:input_type -> headscale.v1.HealthRequest
26, // 25: headscale.v1.HeadscaleService.CreateUser:output_type -> headscale.v1.CreateUserResponse
27, // 26: headscale.v1.HeadscaleService.RenameUser:output_type -> headscale.v1.RenameUserResponse
28, // 27: headscale.v1.HeadscaleService.DeleteUser:output_type -> headscale.v1.DeleteUserResponse
29, // 28: headscale.v1.HeadscaleService.ListUsers:output_type -> headscale.v1.ListUsersResponse
30, // 29: headscale.v1.HeadscaleService.CreatePreAuthKey:output_type -> headscale.v1.CreatePreAuthKeyResponse
31, // 30: headscale.v1.HeadscaleService.ExpirePreAuthKey:output_type -> headscale.v1.ExpirePreAuthKeyResponse
32, // 31: headscale.v1.HeadscaleService.DeletePreAuthKey:output_type -> headscale.v1.DeletePreAuthKeyResponse
33, // 32: headscale.v1.HeadscaleService.ListPreAuthKeys:output_type -> headscale.v1.ListPreAuthKeysResponse
34, // 33: headscale.v1.HeadscaleService.DebugCreateNode:output_type -> headscale.v1.DebugCreateNodeResponse
35, // 34: headscale.v1.HeadscaleService.GetNode:output_type -> headscale.v1.GetNodeResponse
36, // 35: headscale.v1.HeadscaleService.SetTags:output_type -> headscale.v1.SetTagsResponse
37, // 36: headscale.v1.HeadscaleService.SetApprovedRoutes:output_type -> headscale.v1.SetApprovedRoutesResponse
38, // 37: headscale.v1.HeadscaleService.RegisterNode:output_type -> headscale.v1.RegisterNodeResponse
39, // 38: headscale.v1.HeadscaleService.DeleteNode:output_type -> headscale.v1.DeleteNodeResponse
40, // 39: headscale.v1.HeadscaleService.ExpireNode:output_type -> headscale.v1.ExpireNodeResponse
41, // 40: headscale.v1.HeadscaleService.RenameNode:output_type -> headscale.v1.RenameNodeResponse
42, // 41: headscale.v1.HeadscaleService.ListNodes:output_type -> headscale.v1.ListNodesResponse
43, // 42: headscale.v1.HeadscaleService.BackfillNodeIPs:output_type -> headscale.v1.BackfillNodeIPsResponse
44, // 43: headscale.v1.HeadscaleService.CreateApiKey:output_type -> headscale.v1.CreateApiKeyResponse
45, // 44: headscale.v1.HeadscaleService.ExpireApiKey:output_type -> headscale.v1.ExpireApiKeyResponse
46, // 45: headscale.v1.HeadscaleService.ListApiKeys:output_type -> headscale.v1.ListApiKeysResponse
47, // 46: headscale.v1.HeadscaleService.DeleteApiKey:output_type -> headscale.v1.DeleteApiKeyResponse
48, // 47: headscale.v1.HeadscaleService.GetPolicy:output_type -> headscale.v1.GetPolicyResponse
49, // 48: headscale.v1.HeadscaleService.SetPolicy:output_type -> headscale.v1.SetPolicyResponse
1, // 49: headscale.v1.HeadscaleService.Health:output_type -> headscale.v1.HealthResponse
25, // [25:50] is the sub-list for method output_type
0, // [0:25] is the sub-list for method input_type
20, // 18: headscale.v1.HeadscaleService.AuthRegister:input_type -> headscale.v1.AuthRegisterRequest
21, // 19: headscale.v1.HeadscaleService.AuthApprove:input_type -> headscale.v1.AuthApproveRequest
22, // 20: headscale.v1.HeadscaleService.CreateApiKey:input_type -> headscale.v1.CreateApiKeyRequest
23, // 21: headscale.v1.HeadscaleService.ExpireApiKey:input_type -> headscale.v1.ExpireApiKeyRequest
24, // 22: headscale.v1.HeadscaleService.ListApiKeys:input_type -> headscale.v1.ListApiKeysRequest
25, // 23: headscale.v1.HeadscaleService.DeleteApiKey:input_type -> headscale.v1.DeleteApiKeyRequest
26, // 24: headscale.v1.HeadscaleService.GetPolicy:input_type -> headscale.v1.GetPolicyRequest
27, // 25: headscale.v1.HeadscaleService.SetPolicy:input_type -> headscale.v1.SetPolicyRequest
0, // 26: headscale.v1.HeadscaleService.Health:input_type -> headscale.v1.HealthRequest
28, // 27: headscale.v1.HeadscaleService.CreateUser:output_type -> headscale.v1.CreateUserResponse
29, // 28: headscale.v1.HeadscaleService.RenameUser:output_type -> headscale.v1.RenameUserResponse
30, // 29: headscale.v1.HeadscaleService.DeleteUser:output_type -> headscale.v1.DeleteUserResponse
31, // 30: headscale.v1.HeadscaleService.ListUsers:output_type -> headscale.v1.ListUsersResponse
32, // 31: headscale.v1.HeadscaleService.CreatePreAuthKey:output_type -> headscale.v1.CreatePreAuthKeyResponse
33, // 32: headscale.v1.HeadscaleService.ExpirePreAuthKey:output_type -> headscale.v1.ExpirePreAuthKeyResponse
34, // 33: headscale.v1.HeadscaleService.DeletePreAuthKey:output_type -> headscale.v1.DeletePreAuthKeyResponse
35, // 34: headscale.v1.HeadscaleService.ListPreAuthKeys:output_type -> headscale.v1.ListPreAuthKeysResponse
36, // 35: headscale.v1.HeadscaleService.DebugCreateNode:output_type -> headscale.v1.DebugCreateNodeResponse
37, // 36: headscale.v1.HeadscaleService.GetNode:output_type -> headscale.v1.GetNodeResponse
38, // 37: headscale.v1.HeadscaleService.SetTags:output_type -> headscale.v1.SetTagsResponse
39, // 38: headscale.v1.HeadscaleService.SetApprovedRoutes:output_type -> headscale.v1.SetApprovedRoutesResponse
40, // 39: headscale.v1.HeadscaleService.RegisterNode:output_type -> headscale.v1.RegisterNodeResponse
41, // 40: headscale.v1.HeadscaleService.DeleteNode:output_type -> headscale.v1.DeleteNodeResponse
42, // 41: headscale.v1.HeadscaleService.ExpireNode:output_type -> headscale.v1.ExpireNodeResponse
43, // 42: headscale.v1.HeadscaleService.RenameNode:output_type -> headscale.v1.RenameNodeResponse
44, // 43: headscale.v1.HeadscaleService.ListNodes:output_type -> headscale.v1.ListNodesResponse
45, // 44: headscale.v1.HeadscaleService.BackfillNodeIPs:output_type -> headscale.v1.BackfillNodeIPsResponse
46, // 45: headscale.v1.HeadscaleService.AuthRegister:output_type -> headscale.v1.AuthRegisterResponse
47, // 46: headscale.v1.HeadscaleService.AuthApprove:output_type -> headscale.v1.AuthApproveResponse
48, // 47: headscale.v1.HeadscaleService.CreateApiKey:output_type -> headscale.v1.CreateApiKeyResponse
49, // 48: headscale.v1.HeadscaleService.ExpireApiKey:output_type -> headscale.v1.ExpireApiKeyResponse
50, // 49: headscale.v1.HeadscaleService.ListApiKeys:output_type -> headscale.v1.ListApiKeysResponse
51, // 50: headscale.v1.HeadscaleService.DeleteApiKey:output_type -> headscale.v1.DeleteApiKeyResponse
52, // 51: headscale.v1.HeadscaleService.GetPolicy:output_type -> headscale.v1.GetPolicyResponse
53, // 52: headscale.v1.HeadscaleService.SetPolicy:output_type -> headscale.v1.SetPolicyResponse
1, // 53: headscale.v1.HeadscaleService.Health:output_type -> headscale.v1.HealthResponse
27, // [27:54] is the sub-list for method output_type
0, // [0:27] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
@ -275,6 +285,7 @@ func file_headscale_v1_headscale_proto_init() {
file_headscale_v1_preauthkey_proto_init()
file_headscale_v1_node_proto_init()
file_headscale_v1_apikey_proto_init()
file_headscale_v1_auth_proto_init()
file_headscale_v1_policy_proto_init()
type x struct{}
out := protoimpl.TypeBuilder{

View File

@ -709,6 +709,60 @@ func local_request_HeadscaleService_BackfillNodeIPs_0(ctx context.Context, marsh
return msg, metadata, err
}
func request_HeadscaleService_AuthRegister_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq AuthRegisterRequest
metadata runtime.ServerMetadata
)
if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
if req.Body != nil {
_, _ = io.Copy(io.Discard, req.Body)
}
msg, err := client.AuthRegister(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD))
return msg, metadata, err
}
func local_request_HeadscaleService_AuthRegister_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq AuthRegisterRequest
metadata runtime.ServerMetadata
)
if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
msg, err := server.AuthRegister(ctx, &protoReq)
return msg, metadata, err
}
func request_HeadscaleService_AuthApprove_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq AuthApproveRequest
metadata runtime.ServerMetadata
)
if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
if req.Body != nil {
_, _ = io.Copy(io.Discard, req.Body)
}
msg, err := client.AuthApprove(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD))
return msg, metadata, err
}
func local_request_HeadscaleService_AuthApprove_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq AuthApproveRequest
metadata runtime.ServerMetadata
)
if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
msg, err := server.AuthApprove(ctx, &protoReq)
return msg, metadata, err
}
func request_HeadscaleService_CreateApiKey_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq CreateApiKeyRequest
@ -1272,6 +1326,46 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser
}
forward_HeadscaleService_BackfillNodeIPs_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthRegister_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
var stream runtime.ServerTransportStream
ctx = grpc.NewContextWithServerTransportStream(ctx, &stream)
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthRegister", runtime.WithHTTPPathPattern("/api/v1/auth/register"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := local_request_HeadscaleService_AuthRegister_0(annotatedContext, inboundMarshaler, server, req, pathParams)
md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer())
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
forward_HeadscaleService_AuthRegister_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthApprove_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
var stream runtime.ServerTransportStream
ctx = grpc.NewContextWithServerTransportStream(ctx, &stream)
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthApprove", runtime.WithHTTPPathPattern("/api/v1/auth/approve"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := local_request_HeadscaleService_AuthApprove_0(annotatedContext, inboundMarshaler, server, req, pathParams)
md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer())
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
forward_HeadscaleService_AuthApprove_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle(http.MethodPost, pattern_HeadscaleService_CreateApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
@ -1758,6 +1852,40 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser
}
forward_HeadscaleService_BackfillNodeIPs_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthRegister_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthRegister", runtime.WithHTTPPathPattern("/api/v1/auth/register"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := request_HeadscaleService_AuthRegister_0(annotatedContext, inboundMarshaler, client, req, pathParams)
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
forward_HeadscaleService_AuthRegister_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthApprove_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthApprove", runtime.WithHTTPPathPattern("/api/v1/auth/approve"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := request_HeadscaleService_AuthApprove_0(annotatedContext, inboundMarshaler, client, req, pathParams)
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
forward_HeadscaleService_AuthApprove_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle(http.MethodPost, pattern_HeadscaleService_CreateApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
@ -1899,6 +2027,8 @@ var (
pattern_HeadscaleService_RenameNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4, 1, 0, 4, 1, 5, 5}, []string{"api", "v1", "node", "node_id", "rename", "new_name"}, ""))
pattern_HeadscaleService_ListNodes_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "node"}, ""))
pattern_HeadscaleService_BackfillNodeIPs_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "node", "backfillips"}, ""))
pattern_HeadscaleService_AuthRegister_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "register"}, ""))
pattern_HeadscaleService_AuthApprove_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "approve"}, ""))
pattern_HeadscaleService_CreateApiKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "apikey"}, ""))
pattern_HeadscaleService_ExpireApiKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "apikey", "expire"}, ""))
pattern_HeadscaleService_ListApiKeys_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "apikey"}, ""))
@ -1927,6 +2057,8 @@ var (
forward_HeadscaleService_RenameNode_0 = runtime.ForwardResponseMessage
forward_HeadscaleService_ListNodes_0 = runtime.ForwardResponseMessage
forward_HeadscaleService_BackfillNodeIPs_0 = runtime.ForwardResponseMessage
forward_HeadscaleService_AuthRegister_0 = runtime.ForwardResponseMessage
forward_HeadscaleService_AuthApprove_0 = runtime.ForwardResponseMessage
forward_HeadscaleService_CreateApiKey_0 = runtime.ForwardResponseMessage
forward_HeadscaleService_ExpireApiKey_0 = runtime.ForwardResponseMessage
forward_HeadscaleService_ListApiKeys_0 = runtime.ForwardResponseMessage

View File

@ -37,6 +37,8 @@ const (
HeadscaleService_RenameNode_FullMethodName = "/headscale.v1.HeadscaleService/RenameNode"
HeadscaleService_ListNodes_FullMethodName = "/headscale.v1.HeadscaleService/ListNodes"
HeadscaleService_BackfillNodeIPs_FullMethodName = "/headscale.v1.HeadscaleService/BackfillNodeIPs"
HeadscaleService_AuthRegister_FullMethodName = "/headscale.v1.HeadscaleService/AuthRegister"
HeadscaleService_AuthApprove_FullMethodName = "/headscale.v1.HeadscaleService/AuthApprove"
HeadscaleService_CreateApiKey_FullMethodName = "/headscale.v1.HeadscaleService/CreateApiKey"
HeadscaleService_ExpireApiKey_FullMethodName = "/headscale.v1.HeadscaleService/ExpireApiKey"
HeadscaleService_ListApiKeys_FullMethodName = "/headscale.v1.HeadscaleService/ListApiKeys"
@ -71,6 +73,9 @@ type HeadscaleServiceClient interface {
RenameNode(ctx context.Context, in *RenameNodeRequest, opts ...grpc.CallOption) (*RenameNodeResponse, error)
ListNodes(ctx context.Context, in *ListNodesRequest, opts ...grpc.CallOption) (*ListNodesResponse, error)
BackfillNodeIPs(ctx context.Context, in *BackfillNodeIPsRequest, opts ...grpc.CallOption) (*BackfillNodeIPsResponse, error)
// --- Auth start ---
AuthRegister(ctx context.Context, in *AuthRegisterRequest, opts ...grpc.CallOption) (*AuthRegisterResponse, error)
AuthApprove(ctx context.Context, in *AuthApproveRequest, opts ...grpc.CallOption) (*AuthApproveResponse, error)
// --- ApiKeys start ---
CreateApiKey(ctx context.Context, in *CreateApiKeyRequest, opts ...grpc.CallOption) (*CreateApiKeyResponse, error)
ExpireApiKey(ctx context.Context, in *ExpireApiKeyRequest, opts ...grpc.CallOption) (*ExpireApiKeyResponse, error)
@ -271,6 +276,26 @@ func (c *headscaleServiceClient) BackfillNodeIPs(ctx context.Context, in *Backfi
return out, nil
}
func (c *headscaleServiceClient) AuthRegister(ctx context.Context, in *AuthRegisterRequest, opts ...grpc.CallOption) (*AuthRegisterResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(AuthRegisterResponse)
err := c.cc.Invoke(ctx, HeadscaleService_AuthRegister_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *headscaleServiceClient) AuthApprove(ctx context.Context, in *AuthApproveRequest, opts ...grpc.CallOption) (*AuthApproveResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(AuthApproveResponse)
err := c.cc.Invoke(ctx, HeadscaleService_AuthApprove_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *headscaleServiceClient) CreateApiKey(ctx context.Context, in *CreateApiKeyRequest, opts ...grpc.CallOption) (*CreateApiKeyResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(CreateApiKeyResponse)
@ -366,6 +391,9 @@ type HeadscaleServiceServer interface {
RenameNode(context.Context, *RenameNodeRequest) (*RenameNodeResponse, error)
ListNodes(context.Context, *ListNodesRequest) (*ListNodesResponse, error)
BackfillNodeIPs(context.Context, *BackfillNodeIPsRequest) (*BackfillNodeIPsResponse, error)
// --- Auth start ---
AuthRegister(context.Context, *AuthRegisterRequest) (*AuthRegisterResponse, error)
AuthApprove(context.Context, *AuthApproveRequest) (*AuthApproveResponse, error)
// --- ApiKeys start ---
CreateApiKey(context.Context, *CreateApiKeyRequest) (*CreateApiKeyResponse, error)
ExpireApiKey(context.Context, *ExpireApiKeyRequest) (*ExpireApiKeyResponse, error)
@ -440,6 +468,12 @@ func (UnimplementedHeadscaleServiceServer) ListNodes(context.Context, *ListNodes
func (UnimplementedHeadscaleServiceServer) BackfillNodeIPs(context.Context, *BackfillNodeIPsRequest) (*BackfillNodeIPsResponse, error) {
return nil, status.Error(codes.Unimplemented, "method BackfillNodeIPs not implemented")
}
func (UnimplementedHeadscaleServiceServer) AuthRegister(context.Context, *AuthRegisterRequest) (*AuthRegisterResponse, error) {
return nil, status.Error(codes.Unimplemented, "method AuthRegister not implemented")
}
func (UnimplementedHeadscaleServiceServer) AuthApprove(context.Context, *AuthApproveRequest) (*AuthApproveResponse, error) {
return nil, status.Error(codes.Unimplemented, "method AuthApprove not implemented")
}
func (UnimplementedHeadscaleServiceServer) CreateApiKey(context.Context, *CreateApiKeyRequest) (*CreateApiKeyResponse, error) {
return nil, status.Error(codes.Unimplemented, "method CreateApiKey not implemented")
}
@ -806,6 +840,42 @@ func _HeadscaleService_BackfillNodeIPs_Handler(srv interface{}, ctx context.Cont
return interceptor(ctx, in, info, handler)
}
func _HeadscaleService_AuthRegister_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AuthRegisterRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(HeadscaleServiceServer).AuthRegister(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: HeadscaleService_AuthRegister_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(HeadscaleServiceServer).AuthRegister(ctx, req.(*AuthRegisterRequest))
}
return interceptor(ctx, in, info, handler)
}
func _HeadscaleService_AuthApprove_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AuthApproveRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(HeadscaleServiceServer).AuthApprove(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: HeadscaleService_AuthApprove_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(HeadscaleServiceServer).AuthApprove(ctx, req.(*AuthApproveRequest))
}
return interceptor(ctx, in, info, handler)
}
func _HeadscaleService_CreateApiKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(CreateApiKeyRequest)
if err := dec(in); err != nil {
@ -1011,6 +1081,14 @@ var HeadscaleService_ServiceDesc = grpc.ServiceDesc{
MethodName: "BackfillNodeIPs",
Handler: _HeadscaleService_BackfillNodeIPs_Handler,
},
{
MethodName: "AuthRegister",
Handler: _HeadscaleService_AuthRegister_Handler,
},
{
MethodName: "AuthApprove",
Handler: _HeadscaleService_AuthApprove_Handler,
},
{
MethodName: "CreateApiKey",
Handler: _HeadscaleService_CreateApiKey_Handler,

View File

@ -0,0 +1,44 @@
{
"swagger": "2.0",
"info": {
"title": "headscale/v1/auth.proto",
"version": "version not set"
},
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"paths": {},
"definitions": {
"protobufAny": {
"type": "object",
"properties": {
"@type": {
"type": "string"
}
},
"additionalProperties": {}
},
"rpcStatus": {
"type": "object",
"properties": {
"code": {
"type": "integer",
"format": "int32"
},
"message": {
"type": "string"
},
"details": {
"type": "array",
"items": {
"type": "object",
"$ref": "#/definitions/protobufAny"
}
}
}
}
}
}

View File

@ -138,6 +138,71 @@
]
}
},
"/api/v1/auth/approve": {
"post": {
"operationId": "HeadscaleService_AuthApprove",
"responses": {
"200": {
"description": "A successful response.",
"schema": {
"$ref": "#/definitions/v1AuthApproveResponse"
}
},
"default": {
"description": "An unexpected error response.",
"schema": {
"$ref": "#/definitions/rpcStatus"
}
}
},
"parameters": [
{
"name": "body",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/v1AuthApproveRequest"
}
}
],
"tags": [
"HeadscaleService"
]
}
},
"/api/v1/auth/register": {
"post": {
"summary": "--- Auth start ---",
"operationId": "HeadscaleService_AuthRegister",
"responses": {
"200": {
"description": "A successful response.",
"schema": {
"$ref": "#/definitions/v1AuthRegisterResponse"
}
},
"default": {
"description": "An unexpected error response.",
"schema": {
"$ref": "#/definitions/rpcStatus"
}
}
},
"parameters": [
{
"name": "body",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/v1AuthRegisterRequest"
}
}
],
"tags": [
"HeadscaleService"
]
}
},
"/api/v1/debug/node": {
"post": {
"summary": "--- Node start ---",
@ -888,6 +953,36 @@
}
}
},
"v1AuthApproveRequest": {
"type": "object",
"properties": {
"authId": {
"type": "string"
}
}
},
"v1AuthApproveResponse": {
"type": "object"
},
"v1AuthRegisterRequest": {
"type": "object",
"properties": {
"user": {
"type": "string"
},
"authId": {
"type": "string"
}
}
},
"v1AuthRegisterResponse": {
"type": "object",
"properties": {
"node": {
"$ref": "#/definitions/v1Node"
}
}
},
"v1BackfillNodeIPsResponse": {
"type": "object",
"properties": {

2
go.mod
View File

@ -14,6 +14,8 @@ require (
github.com/docker/docker v28.5.2+incompatible
github.com/fsnotify/fsnotify v1.9.0
github.com/glebarez/sqlite v1.11.0
github.com/go-chi/chi/v5 v5.2.5
github.com/go-chi/metrics v0.1.1
github.com/go-gormigrate/gormigrate/v2 v2.1.5
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e
github.com/gofrs/uuid/v5 v5.4.0

4
go.sum
View File

@ -181,6 +181,10 @@ github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec
github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc=
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
github.com/go-chi/metrics v0.1.1 h1:CXhbnkAVVjb0k73EBRQ6Z2YdWFnbXZgNtg1Mboguibk=
github.com/go-chi/metrics v0.1.1/go.mod h1:mcGTM1pPalP7WCtb+akNYFO/lwNwBBLCuedepqjoPn4=
github.com/go-gormigrate/gormigrate/v2 v2.1.5 h1:1OyorA5LtdQw12cyJDEHuTrEV3GiXiIhS4/QTTa/SM8=
github.com/go-gormigrate/gormigrate/v2 v2.1.5/go.mod h1:mj9ekk/7CPF3VjopaFvWKN2v7fN3D9d3eEOAXRhi/+M=
github.com/go-jose/go-jose/v3 v3.0.4 h1:Wp5HA7bLQcKnf6YYao/4kpRpVMp/yf6+pJKV8WFSaNY=

View File

@ -20,7 +20,9 @@ import (
"github.com/cenkalti/backoff/v5"
"github.com/davecgh/go-spew/spew"
"github.com/gorilla/mux"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/metrics"
grpcRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -457,50 +459,58 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error {
return os.Remove(h.cfg.UnixSocket)
}
func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
router := mux.NewRouter()
router.Use(prometheusMiddleware)
func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux {
r := chi.NewRouter()
r.Use(metrics.Collector(metrics.CollectorOpts{
Host: false,
Proto: true,
Skip: func(r *http.Request) bool {
return r.Method != http.MethodOptions
},
}))
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler).
Methods(http.MethodPost, http.MethodGet)
r.Post(ts2021UpgradePath, h.NoiseUpgradeHandler)
router.HandleFunc("/robots.txt", h.RobotsHandler).Methods(http.MethodGet)
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/version", h.VersionHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).
Methods(http.MethodGet)
r.Get("/robots.txt", h.RobotsHandler)
r.Get("/health", h.HealthHandler)
r.Get("/version", h.VersionHandler)
r.Get("/key", h.KeyHandler)
r.Get("/register/{auth_id}", h.authProvider.RegisterHandler)
r.Get("/auth/{auth_id}", h.authProvider.AuthHandler)
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)
r.Get("/oidc/callback", provider.OIDCCallbackHandler)
}
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
Methods(http.MethodGet)
router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet)
r.Get("/apple", h.AppleConfigMessage)
r.Get("/apple/{platform}", h.ApplePlatformConfig)
r.Get("/windows", h.WindowsConfigMessage)
// TODO(kristoffer): move swagger into a package
router.HandleFunc("/swagger", headscale.SwaggerUI).Methods(http.MethodGet)
router.HandleFunc("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1).
Methods(http.MethodGet)
r.Get("/swagger", headscale.SwaggerUI)
r.Get("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1)
router.HandleFunc("/verify", h.VerifyHandler).Methods(http.MethodPost)
r.Post("/verify", h.VerifyHandler)
if h.cfg.DERP.ServerEnabled {
router.HandleFunc("/derp", h.DERPServer.DERPHandler)
router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler)
router.HandleFunc("/derp/latency-check", derpServer.DERPProbeHandler)
router.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.state.DERPMap()))
r.HandleFunc("/derp", h.DERPServer.DERPHandler)
r.HandleFunc("/derp/probe", derpServer.DERPProbeHandler)
r.HandleFunc("/derp/latency-check", derpServer.DERPProbeHandler)
r.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.state.DERPMap()))
}
apiRouter := router.PathPrefix("/api").Subrouter()
apiRouter.Use(h.httpAuthenticationMiddleware)
apiRouter.PathPrefix("/v1/").HandlerFunc(grpcMux.ServeHTTP)
router.HandleFunc("/favicon.ico", FaviconHandler)
router.PathPrefix("/").HandlerFunc(BlankHandler)
r.Route("/api", func(r chi.Router) {
r.Use(h.httpAuthenticationMiddleware)
r.HandleFunc("/v1/*", grpcMux.ServeHTTP)
})
r.Get("/favicon.ico", FaviconHandler)
r.Get("/", BlankHandler)
return router
return r
}
// Serve launches the HTTP and gRPC server service Headscale and the API.

View File

@ -20,7 +20,9 @@ import (
type AuthProvider interface {
RegisterHandler(w http.ResponseWriter, r *http.Request)
AuthURL(regID types.RegistrationID) string
AuthHandler(w http.ResponseWriter, r *http.Request)
RegisterURL(authID types.AuthID) string
AuthURL(authID types.AuthID) string
}
func (h *Headscale) handleRegister(
@ -261,22 +263,24 @@ func (h *Headscale) waitForFollowup(
return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err)
}
followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
followupReg, err := types.AuthIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
if err != nil {
return nil, NewHTTPError(http.StatusUnauthorized, "invalid registration ID", err)
}
if reg, ok := h.state.GetRegistrationCacheEntry(followupReg); ok {
if reg, ok := h.state.GetAuthCacheEntry(followupReg); ok {
select {
case <-ctx.Done():
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
case node := <-reg.Registered:
if node == nil {
// registration is expired in the cache, instruct the client to try a new registration
return h.reqToNewRegisterResponse(req, machineKey)
}
case verdict := <-reg.WaitForAuth():
if verdict.Accept() {
if !verdict.Node.Valid() {
// registration is expired in the cache, instruct the client to try a new registration
return h.reqToNewRegisterResponse(req, machineKey)
}
return nodeToRegisterResponse(node.View()), nil
return nodeToRegisterResponse(verdict.Node), nil
}
}
}
@ -291,14 +295,14 @@ func (h *Headscale) reqToNewRegisterResponse(
req tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
newRegID, err := types.NewRegistrationID()
newAuthID, err := types.NewAuthID()
if err != nil {
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
}
// Ensure we have a valid hostname
hostname := util.EnsureHostname(
req.Hostinfo,
req.Hostinfo.View(),
machineKey.String(),
req.NodeKey.String(),
)
@ -307,25 +311,25 @@ func (h *Headscale) reqToNewRegisterResponse(
hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{})
hostinfo.Hostname = hostname
nodeToRegister := types.NewRegisterNode(
types.Node{
Hostname: hostname,
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: hostinfo,
LastSeen: new(time.Now()),
},
)
if !req.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &req.Expiry
nodeToRegister := types.Node{
Hostname: hostname,
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: hostinfo,
LastSeen: new(time.Now()),
}
log.Info().Msgf("new followup node registration using key: %s", newRegID)
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
if !req.Expiry.IsZero() {
nodeToRegister.Expiry = &req.Expiry
}
authRegReq := types.NewRegisterAuthRequest(nodeToRegister)
log.Info().Msgf("new followup node registration using key: %s", newAuthID)
h.state.SetAuthCacheEntry(newAuthID, authRegReq)
return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.AuthURL(newRegID),
AuthURL: h.authProvider.RegisterURL(newAuthID),
}, nil
}
@ -376,13 +380,6 @@ func (h *Headscale) handleRegisterWithAuthKey(
// Send both changes. Empty changes are ignored by Change().
h.Change(changed, routesChange)
// TODO(kradalby): I think this is covered above, but we need to validate that.
// // If policy changed due to node registration, send a separate policy change
// if policyChanged {
// policyChange := change.PolicyChange()
// h.Change(policyChange)
// }
resp := &tailcfg.RegisterResponse{
MachineAuthorized: true,
NodeKeyExpired: node.IsExpired(),
@ -404,14 +401,14 @@ func (h *Headscale) handleRegisterInteractive(
req tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
registrationId, err := types.NewRegistrationID()
authID, err := types.NewAuthID()
if err != nil {
return nil, fmt.Errorf("generating registration ID: %w", err)
}
// Ensure we have a valid hostname
hostname := util.EnsureHostname(
req.Hostinfo,
req.Hostinfo.View(),
machineKey.String(),
req.NodeKey.String(),
)
@ -434,28 +431,28 @@ func (h *Headscale) handleRegisterInteractive(
hostinfo.Hostname = hostname
nodeToRegister := types.NewRegisterNode(
types.Node{
Hostname: hostname,
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: hostinfo,
LastSeen: new(time.Now()),
},
)
if !req.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &req.Expiry
nodeToRegister := types.Node{
Hostname: hostname,
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: hostinfo,
LastSeen: new(time.Now()),
}
h.state.SetRegistrationCacheEntry(
registrationId,
nodeToRegister,
if !req.Expiry.IsZero() {
nodeToRegister.Expiry = &req.Expiry
}
authRegReq := types.NewRegisterAuthRequest(nodeToRegister)
h.state.SetAuthCacheEntry(
authID,
authRegReq,
)
log.Info().Msgf("starting node registration using key: %s", registrationId)
log.Info().Msgf("starting node registration using key: %s", authID)
return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.AuthURL(registrationId),
AuthURL: h.authProvider.RegisterURL(authID),
}, nil
}

View File

@ -651,8 +651,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
// Step 1: Create user-owned node WITH expiry set
clientExpiry := time.Now().Add(24 * time.Hour)
registrationID1 := types.MustRegistrationID()
regEntry1 := types.NewRegisterNode(types.Node{
registrationID1 := types.MustAuthID()
regEntry1 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey1.Public(),
Hostname: "personal-to-tagged",
@ -662,7 +662,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
},
Expiry: &clientExpiry,
})
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1)
app.state.SetAuthCacheEntry(registrationID1, regEntry1)
node, _, err := app.state.HandleNodeFromAuthPath(
registrationID1, types.UserID(user.ID), nil, "webauth",
@ -673,8 +673,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
// Step 2: Re-auth with tags (Personal → Tagged conversion)
nodeKey2 := key.NewNode()
registrationID2 := types.MustRegistrationID()
regEntry2 := types.NewRegisterNode(types.Node{
registrationID2 := types.MustAuthID()
regEntry2 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey2.Public(),
Hostname: "personal-to-tagged",
@ -684,7 +684,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
},
Expiry: &clientExpiry, // Client still sends expiry
})
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2)
app.state.SetAuthCacheEntry(registrationID2, regEntry2)
nodeAfter, _, err := app.state.HandleNodeFromAuthPath(
registrationID2, types.UserID(user.ID), nil, "webauth",
@ -723,8 +723,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
nodeKey1 := key.NewNode()
// Step 1: Create tagged node (expiry should be nil)
registrationID1 := types.MustRegistrationID()
regEntry1 := types.NewRegisterNode(types.Node{
registrationID1 := types.MustAuthID()
regEntry1 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey1.Public(),
Hostname: "tagged-to-personal",
@ -733,7 +733,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
RequestTags: []string{"tag:server"}, // Tagged node
},
})
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1)
app.state.SetAuthCacheEntry(registrationID1, regEntry1)
node, _, err := app.state.HandleNodeFromAuthPath(
registrationID1, types.UserID(user.ID), nil, "webauth",
@ -745,8 +745,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
// Step 2: Re-auth with empty tags (Tagged → Personal conversion)
nodeKey2 := key.NewNode()
clientExpiry := time.Now().Add(48 * time.Hour)
registrationID2 := types.MustRegistrationID()
regEntry2 := types.NewRegisterNode(types.Node{
registrationID2 := types.MustAuthID()
regEntry2 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey2.Public(),
Hostname: "tagged-to-personal",
@ -756,7 +756,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
},
Expiry: &clientExpiry, // Client requests expiry
})
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2)
app.state.SetAuthCacheEntry(registrationID2, regEntry2)
nodeAfter, _, err := app.state.HandleNodeFromAuthPath(
registrationID2, types.UserID(user.ID), nil, "webauth",

View File

@ -676,28 +676,23 @@ func TestAuthenticationFlows(t *testing.T) {
{
name: "followup_registration_success",
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper
regID, err := types.NewRegistrationID()
regID, err := types.NewAuthID()
if err != nil {
return "", err
}
registered := make(chan *types.Node, 1)
nodeToRegister := types.RegisterNode{
Node: types.Node{
Hostname: "followup-success-node",
},
Registered: registered,
}
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
Hostname: "followup-success-node",
})
app.state.SetAuthCacheEntry(regID, nodeToRegister)
// Simulate successful registration - send to buffered channel
// The channel is buffered (size 1), so this can complete immediately
// and handleRegister will receive the value when it starts waiting
// Simulate successful registration
// handleRegister will receive the value when it starts waiting
go func() {
user := app.state.CreateUserForTest("followup-user")
node := app.state.CreateNodeForTest(user, "followup-success-node")
registered <- node
nodeToRegister.FinishAuth(types.AuthVerdict{Node: node.View()})
}()
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
@ -723,20 +718,16 @@ func TestAuthenticationFlows(t *testing.T) {
{
name: "followup_registration_timeout",
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper
regID, err := types.NewRegistrationID()
regID, err := types.NewAuthID()
if err != nil {
return "", err
}
registered := make(chan *types.Node, 1)
nodeToRegister := types.RegisterNode{
Node: types.Node{
Hostname: "followup-timeout-node",
},
Registered: registered,
}
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
// Don't send anything on channel - will timeout
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
Hostname: "followup-timeout-node",
})
app.state.SetAuthCacheEntry(regID, nodeToRegister)
// Don't call FinishRegistration - will timeout
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
},
@ -1345,24 +1336,19 @@ func TestAuthenticationFlows(t *testing.T) {
{
name: "followup_registration_node_nil_response",
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper
regID, err := types.NewRegistrationID()
regID, err := types.NewAuthID()
if err != nil {
return "", err
}
registered := make(chan *types.Node, 1)
nodeToRegister := types.RegisterNode{
Node: types.Node{
Hostname: "nil-response-node",
},
Registered: registered,
}
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
Hostname: "nil-response-node",
})
app.state.SetAuthCacheEntry(regID, nodeToRegister)
// Simulate registration that returns nil (cache expired during auth)
// The channel is buffered (size 1), so this can complete immediately
// Simulate registration that returns empty NodeView (cache expired during auth)
go func() {
registered <- nil // Nil indicates cache expiry
nodeToRegister.FinishAuth(types.AuthVerdict{Node: types.NodeView{}}) // Empty view indicates cache expiry
}()
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
@ -1815,7 +1801,7 @@ func TestAuthenticationFlows(t *testing.T) {
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper
// Generate a registration ID that doesn't exist in cache
// This simulates an expired/missing cache entry
regID, err := types.NewRegistrationID()
regID, err := types.NewAuthID()
if err != nil {
return "", err
}
@ -1847,11 +1833,11 @@ func TestAuthenticationFlows(t *testing.T) {
// Extract and validate the new registration ID exists in cache
newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/")
newRegID, err := types.RegistrationIDFromString(newRegIDStr)
newRegID, err := types.AuthIDFromString(newRegIDStr)
assert.NoError(t, err, "should be able to parse new registration ID") //nolint:testifylint // inside closure
// Verify new registration entry exists in cache
_, found := app.state.GetRegistrationCacheEntry(newRegID)
_, found := app.state.GetAuthCacheEntry(newRegID)
assert.True(t, found, "new registration should exist in cache")
},
},
@ -2300,7 +2286,7 @@ func TestAuthenticationFlows(t *testing.T) {
require.NoError(t, err)
// Verify cache entry exists
cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID)
cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
assert.True(t, found, "registration cache entry should exist initially")
assert.NotNil(t, cacheEntry)
@ -2315,7 +2301,7 @@ func TestAuthenticationFlows(t *testing.T) {
assert.Error(t, err, "should fail with invalid user ID") //nolint:testifylint // inside closure, uses assert pattern
// Cache entry should still exist after auth error (for retry scenarios)
_, stillFound := app.state.GetRegistrationCacheEntry(registrationID)
_, stillFound := app.state.GetAuthCacheEntry(registrationID)
assert.True(t, stillFound, "registration cache entry should still exist after auth error for potential retry")
},
},
@ -2375,8 +2361,8 @@ func TestAuthenticationFlows(t *testing.T) {
assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs")
// Both cache entries should exist simultaneously
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
_, found1 := app.state.GetAuthCacheEntry(regID1)
_, found2 := app.state.GetAuthCacheEntry(regID2)
assert.True(t, found1, "first registration cache entry should exist")
assert.True(t, found2, "second registration cache entry should exist")
@ -2427,8 +2413,8 @@ func TestAuthenticationFlows(t *testing.T) {
require.NoError(t, err)
// Verify both exist
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
_, found1 := app.state.GetAuthCacheEntry(regID1)
_, found2 := app.state.GetAuthCacheEntry(regID2)
assert.True(t, found1, "first cache entry should exist")
assert.True(t, found2, "second cache entry should exist")
@ -2490,7 +2476,7 @@ func TestAuthenticationFlows(t *testing.T) {
}
// First registration should still be in cache (not completed)
_, stillFound := app.state.GetRegistrationCacheEntry(regID1)
_, stillFound := app.state.GetAuthCacheEntry(regID1)
assert.True(t, stillFound, "first registration should still be pending")
},
},
@ -2601,7 +2587,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
var (
initialResp *tailcfg.RegisterResponse
authURL string
registrationID types.RegistrationID
registrationID types.AuthID
finalResp *tailcfg.RegisterResponse
err error
)
@ -2629,10 +2615,10 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
if step.expectCacheEntry {
// Verify registration cache entry was created
cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID)
cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
require.True(t, found, "registration cache entry should exist")
require.NotNil(t, cacheEntry, "cache entry should not be nil")
require.Equal(t, req.NodeKey, cacheEntry.Node.NodeKey, "cache entry should have correct node key")
require.Equal(t, req.NodeKey, cacheEntry.Node().NodeKey(), "cache entry should have correct node key")
}
case stepTypeAuthCompletion:
@ -2692,7 +2678,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
// Check cache cleanup expectation for this step
if step.expectCacheEntry == false && registrationID != "" {
// Verify cache entry was cleaned up
_, found := app.state.GetRegistrationCacheEntry(registrationID)
_, found := app.state.GetAuthCacheEntry(registrationID)
require.False(t, found, "registration cache entry should be cleaned up after step: %s", step.stepType)
}
}
@ -2714,7 +2700,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
}
// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL.
func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) {
func extractRegistrationIDFromAuthURL(authURL string) (types.AuthID, error) {
// AuthURL format: "http://localhost/register/abc123"
const registerPrefix = "/register/"
@ -2725,7 +2711,7 @@ func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, err
idStr := authURL[idx+len(registerPrefix):]
return types.RegistrationIDFromString(idStr)
return types.AuthIDFromString(idStr)
}
// validateCompleteRegistrationResponse performs comprehensive validation of a registration response.
@ -2962,7 +2948,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
// Scenario:
// 1. Node registers with user1 via pre-auth key
// 2. Node logs out (expires)
// 3. Admin runs: headscale nodes register --user user2 --key <key>
// 3. Admin runs: headscale auth register --auth-id <id> --user user2
//
// Expected behavior:
// - User1's original node should STILL EXIST (expired)
@ -3041,7 +3027,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
require.NotEmpty(t, regID, "Should have valid registration ID")
// Step 4: Admin completes authentication via CLI
// This simulates: headscale nodes register --user user2 --key <key>
// This simulates: headscale auth register --auth-id <id> --user user2
node, _, err := app.state.HandleNodeFromAuthPath(
regID,
types.UserID(user2.ID), // Register to user2, not user1!
@ -3583,8 +3569,8 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
nodeKey := key.NewNode()
// Simulate a registration cache entry (as would be created during web auth)
registrationID := types.MustRegistrationID()
regEntry := types.NewRegisterNode(types.Node{
registrationID := types.MustAuthID()
regEntry := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "webauth-tags-node",
@ -3593,7 +3579,7 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
RequestTags: []string{"tag:unauthorized"}, // This tag is not in policy
},
})
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
app.state.SetAuthCacheEntry(registrationID, regEntry)
// Complete the web auth - should fail because tag is unauthorized
_, _, err := app.state.HandleNodeFromAuthPath(
@ -3646,8 +3632,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
nodeKey1 := key.NewNode()
// Step 1: Initial registration with tags
registrationID1 := types.MustRegistrationID()
regEntry1 := types.NewRegisterNode(types.Node{
registrationID1 := types.MustAuthID()
regEntry1 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey1.Public(),
Hostname: "reauth-untag-node",
@ -3656,7 +3642,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
RequestTags: []string{"tag:valid-owned", "tag:second"},
},
})
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1)
app.state.SetAuthCacheEntry(registrationID1, regEntry1)
// Complete initial registration with tags
node, _, err := app.state.HandleNodeFromAuthPath(
@ -3673,8 +3659,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
// Step 2: Reauth with EMPTY tags to untag
nodeKey2 := key.NewNode() // New node key for reauth
registrationID2 := types.MustRegistrationID()
regEntry2 := types.NewRegisterNode(types.Node{
registrationID2 := types.MustAuthID()
regEntry2 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(), // Same machine key
NodeKey: nodeKey2.Public(), // Different node key (rotation)
Hostname: "reauth-untag-node",
@ -3683,7 +3669,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
RequestTags: []string{}, // EMPTY - should untag
},
})
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2)
app.state.SetAuthCacheEntry(registrationID2, regEntry2)
// Complete reauth with empty tags
nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath(
@ -3759,8 +3745,8 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) {
// Step 2: Reauth via web auth with EMPTY tags to transition to user-owned
nodeKey2 := key.NewNode() // New node key for reauth
registrationID := types.MustRegistrationID()
regEntry := types.NewRegisterNode(types.Node{
registrationID := types.MustAuthID()
regEntry := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(), // Same machine key
NodeKey: nodeKey2.Public(), // Different node key (rotation)
Hostname: "authkey-tagged-node",
@ -3769,7 +3755,7 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) {
RequestTags: []string{}, // EMPTY - should untag
},
})
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
app.state.SetAuthCacheEntry(registrationID, regEntry)
// Complete reauth with empty tags
nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath(
@ -3956,10 +3942,10 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) {
require.NotNil(t, alice, "Alice user should be created")
// Step 4: Re-register the node to alice via HandleNodeFromAuthPath
// This is what happens when running: headscale nodes register --user alice --key ...
// This is what happens when running: headscale auth register --auth-id <id> --user alice
nodeKey2 := key.NewNode()
registrationID := types.MustRegistrationID()
regEntry := types.NewRegisterNode(types.Node{
registrationID := types.MustAuthID()
regEntry := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(), // Same machine key as the tagged node
NodeKey: nodeKey2.Public(),
Hostname: "tagged-orphan-node",
@ -3968,7 +3954,7 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) {
RequestTags: []string{}, // Empty - transition to user-owned
},
})
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
app.state.SetAuthCacheEntry(registrationID, regEntry)
// This should NOT panic - before the fix, this would panic with:
// panic: runtime error: invalid memory address or nil pointer dereference

View File

@ -47,7 +47,7 @@ const (
type HSDatabase struct {
DB *gorm.DB
cfg *types.Config
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
regCache *zcache.Cache[types.AuthID, types.AuthRequest]
}
// NewHeadscaleDatabase creates a new database connection and runs migrations.
@ -56,7 +56,7 @@ type HSDatabase struct {
//nolint:gocyclo // complex database initialization with many migrations
func NewHeadscaleDatabase(
cfg *types.Config,
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
regCache *zcache.Cache[types.AuthID, types.AuthRequest],
) (*HSDatabase, error) {
dbConn, err := openDB(cfg.Database)
if err != nil {

View File

@ -162,8 +162,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
}
}
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
}
func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {

View File

@ -247,7 +247,7 @@ func (api headscaleV1APIServer) RegisterNode(
Str(zf.RegistrationKey, registrationKey).
Msg("registering node")
registrationId, err := types.RegistrationIDFromString(request.GetKey())
registrationId, err := types.AuthIDFromString(request.GetKey())
if err != nil {
return nil, err
}
@ -780,33 +780,32 @@ func (api headscaleV1APIServer) DebugCreateNode(
Hostname: request.GetName(),
}
registrationId, err := types.RegistrationIDFromString(request.GetKey())
registrationId, err := types.AuthIDFromString(request.GetKey())
if err != nil {
return nil, err
}
newNode := types.NewRegisterNode(
types.Node{
NodeKey: key.NewNode().Public(),
MachineKey: key.NewMachine().Public(),
Hostname: request.GetName(),
User: user,
newNode := types.Node{
NodeKey: key.NewNode().Public(),
MachineKey: key.NewMachine().Public(),
Hostname: request.GetName(),
User: user,
Expiry: &time.Time{},
LastSeen: &time.Time{},
Expiry: &time.Time{},
LastSeen: &time.Time{},
Hostinfo: &hostinfo,
},
)
Hostinfo: &hostinfo,
}
log.Debug().
Caller().
Str("registration_id", registrationId.String()).
Msg("adding debug machine via CLI, appending to registration cache")
api.h.state.SetRegistrationCacheEntry(registrationId, newNode)
authRegReq := types.NewRegisterAuthRequest(newNode)
api.h.state.SetAuthCacheEntry(registrationId, authRegReq)
return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil
return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil
}
func (api headscaleV1APIServer) Health(
@ -829,4 +828,38 @@ func (api headscaleV1APIServer) Health(
return response, healthErr
}
func (api headscaleV1APIServer) AuthRegister(
ctx context.Context,
request *v1.AuthRegisterRequest,
) (*v1.AuthRegisterResponse, error) {
resp, err := api.RegisterNode(ctx, &v1.RegisterNodeRequest{
Key: request.GetAuthId(),
User: request.GetUser(),
})
if err != nil {
return nil, err
}
return &v1.AuthRegisterResponse{Node: resp.GetNode()}, nil
}
func (api headscaleV1APIServer) AuthApprove(
ctx context.Context,
request *v1.AuthApproveRequest,
) (*v1.AuthApproveResponse, error) {
authID, err := types.AuthIDFromString(request.GetAuthId())
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid auth_id: %v", err)
}
authReq, ok := api.h.state.GetAuthCacheEntry(authID)
if !ok {
return nil, status.Errorf(codes.NotFound, "no pending auth session for auth_id %s", authID)
}
authReq.FinishAuth(types.AuthVerdict{})
return &v1.AuthApproveResponse{}, nil
}
func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {}

View File

@ -11,7 +11,6 @@ import (
"strings"
"time"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/assets"
"github.com/juanfont/headscale/hscontrol/templates"
"github.com/juanfont/headscale/hscontrol/types"
@ -245,11 +244,58 @@ func NewAuthProviderWeb(serverURL string) *AuthProviderWeb {
}
}
func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string {
func (a *AuthProviderWeb) RegisterURL(authID types.AuthID) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
registrationId.String())
authID.String())
}
func (a *AuthProviderWeb) AuthURL(authID types.AuthID) string {
return fmt.Sprintf(
"%s/auth/%s",
strings.TrimSuffix(a.serverURL, "/"),
authID.String())
}
func (a *AuthProviderWeb) AuthHandler(
writer http.ResponseWriter,
req *http.Request,
) {
authID, err := authIDFromRequest(req)
if err != nil {
httpError(writer, err)
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write([]byte(templates.AuthWeb(
"Authentication check",
"Run the command below in the headscale server to approve this authentication request:",
"headscale auth approve --auth-id "+authID.String(),
).Render()))
if err != nil {
log.Error().Err(err).Msg("failed to write auth response")
}
}
func authIDFromRequest(req *http.Request) (types.AuthID, error) {
raw, err := urlParam[string](req, "auth_id")
if err != nil {
return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err))
}
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
registrationId, err := types.AuthIDFromString(raw)
if err != nil {
return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err))
}
return registrationId, nil
}
// RegisterHandler shows a simple message in the browser to point to the CLI
@ -261,22 +307,20 @@ func (a *AuthProviderWeb) RegisterHandler(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
registrationIdStr := vars["registration_id"]
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
registrationId, err := authIDFromRequest(req)
if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
httpError(writer, err)
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write([]byte(templates.RegisterWeb(registrationId).Render()))
_, err = writer.Write([]byte(templates.AuthWeb(
"Node registration",
"Run the command below in the headscale server to add this node to your network:",
fmt.Sprintf("headscale auth register --auth-id %s --user USERNAME", registrationId.String()),
).Render()))
if err != nil {
log.Error().Err(err).Msg("failed to write register response")
}

View File

@ -95,8 +95,8 @@ var allBatcherFunctions = []batcherTestCase{
}
// emptyCache creates an empty registration cache for testing.
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
}
// Test configuration constants.

View File

@ -7,8 +7,11 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"github.com/gorilla/mux"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/metrics"
"github.com/juanfont/headscale/hscontrol/capver"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
@ -22,6 +25,12 @@ import (
// ErrUnsupportedClientVersion is returned when a client connects with an unsupported protocol version.
var ErrUnsupportedClientVersion = errors.New("unsupported client version")
// ErrMissingURLParameter is returned when a required URL parameter is not provided.
var ErrMissingURLParameter = errors.New("missing URL parameter")
// ErrUnsupportedURLParameterType is returned when a URL parameter has an unsupported type.
var ErrUnsupportedURLParameterType = errors.New("unsupported URL parameter type")
const (
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
ts2021UpgradePath = "/ts2021"
@ -69,7 +78,7 @@ func (h *Headscale) NoiseUpgradeHandler(
return
}
noiseServer := noiseServer{
ns := noiseServer{
headscale: h,
challenge: key.NewChallenge(),
}
@ -79,42 +88,89 @@ func (h *Headscale) NoiseUpgradeHandler(
writer,
req,
*h.noisePrivateKey,
noiseServer.earlyNoise,
ns.earlyNoise,
)
if err != nil {
httpError(writer, fmt.Errorf("upgrading noise connection: %w", err))
return
}
noiseServer.conn = noiseConn
noiseServer.machineKey = noiseServer.conn.Peer()
noiseServer.protocolVersion = noiseServer.conn.ProtocolVersion()
ns.conn = noiseConn
ns.machineKey = ns.conn.Peer()
ns.protocolVersion = ns.conn.ProtocolVersion()
// This router is served only over the Noise connection, and exposes only the new API.
//
// The HTTP2 server that exposes this router is created for
// a single hijacked connection from /ts2021, using netutil.NewOneConnListener
router := mux.NewRouter()
router.Use(prometheusMiddleware)
router.HandleFunc("/machine/register", noiseServer.NoiseRegistrationHandler).
Methods(http.MethodPost)
r := chi.NewRouter()
r.Use(metrics.Collector(metrics.CollectorOpts{
Host: false,
Proto: true,
Skip: func(r *http.Request) bool {
return r.Method != http.MethodOptions
},
}))
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
// Endpoints outside of the register endpoint must use getAndValidateNode to
// get the node to ensure that the MachineKey matches the Node setting up the
// connection.
router.HandleFunc("/machine/map", noiseServer.NoisePollNetMapHandler)
r.Handle("/metrics", metrics.Handler())
noiseServer.httpBaseConfig = &http.Server{
Handler: router,
r.Route("/machine", func(r chi.Router) {
r.Post("/register", ns.RegistrationHandler)
r.Post("/map", ns.PollNetMapHandler)
// SSH Check mode endpoint, consulted to validate if a given SSH connection should be accepted or rejected.
r.Get("/ssh/action/from/{src_node_id}/to/{dst_node_id}", ns.SSHActionHandler)
// Not implemented yet
//
// /whoami is a debug endpoint to validate that the client can communicate over the connection,
// not clear if there is a specific response, it looks like it is just logged.
// https://github.com/tailscale/tailscale/blob/dfba01ca9bd8c4df02c3c32f400d9aeb897c5fc7/cmd/tailscale/cli/debug.go#L1138
r.Get("/whoami", ns.NotImplementedHandler)
// client sends a [tailcfg.SetDNSRequest] to this endpoints and expect
// the server to create or update this DNS record "somewhere".
// It is typically a TXT record for an ACME challenge.
r.Post("/set-dns", ns.NotImplementedHandler)
// A patch of [tailcfg.SetDeviceAttributesRequest] to update device attributes.
// We currently do not support device attributes.
r.Patch("/set-device-attr", ns.NotImplementedHandler)
// A [tailcfg.AuditLogRequest] to send audit log entries to the server.
// The server is expected to store them "somewhere".
// We currently do not support device attributes.
r.Post("/audit-log", ns.NotImplementedHandler)
// handles requests to get an OIDC ID token. Receives a [tailcfg.TokenRequest].
r.Post("/id-token", ns.NotImplementedHandler)
// Asks the server if a feature is available and receive information about how to enable it.
// Gets a [tailcfg.QueryFeatureRequest] and returns a [tailcfg.QueryFeatureResponse].
r.Post("/feature/query", ns.NotImplementedHandler)
r.Post("/update-health", ns.NotImplementedHandler)
r.Route("/webclient", func(r chi.Router) {})
r.Post("/c2n", ns.NotImplementedHandler)
})
ns.httpBaseConfig = &http.Server{
Handler: r,
ReadHeaderTimeout: types.HTTPTimeout,
}
noiseServer.http2Server = &http2.Server{}
ns.http2Server = &http2.Server{}
noiseServer.http2Server.ServeConn(
ns.http2Server.ServeConn(
noiseConn,
&http2.ServeConnOpts{
BaseConfig: noiseServer.httpBaseConfig,
BaseConfig: ns.httpBaseConfig,
},
)
}
@ -189,7 +245,143 @@ func rejectUnsupported(
return false
}
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol
func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *http.Request) {
d, _ := io.ReadAll(req.Body)
log.Trace().Caller().Str("path", req.URL.String()).Bytes("body", d).Msgf("not implemented handler hit")
http.Error(writer, "Not implemented yet", http.StatusNotImplemented)
}
func urlParam[T any](req *http.Request, key string) (T, error) {
var zero T
param := chi.URLParam(req, key)
if param == "" {
return zero, fmt.Errorf("%w: %s", ErrMissingURLParameter, key)
}
var value T
switch any(value).(type) {
case string:
v, ok := any(param).(T)
if !ok {
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
}
value = v
default:
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
}
return value, nil
}
// SSHActionHandler handles the /ssh-action endpoint, it returns a [tailcfg.SSHActionHandler]
// to the client with the verdict of an SSH access request.
func (ns *noiseServer) SSHActionHandler(writer http.ResponseWriter, req *http.Request) {
srcNodeID, _ := urlParam[types.NodeID](req, "src_node_id")
dstNodeID, _ := urlParam[types.NodeID](req, "dst_node_id")
sshUser := req.URL.Query().Get("ssh_user")
localUser := req.URL.Query().Get("local_user")
// Set if this is a follow up request.
authIDStr := req.URL.Query().Get("auth_id")
log.Trace().Caller().
Str("path", req.URL.String()).
Uint64("src_node_id", srcNodeID.Uint64()).
Uint64("dst_node_id", dstNodeID.Uint64()).
Str("ssh_user", sshUser).
Str("local_user", localUser).
Str("auth_id", authIDStr).
Msg("got SSH action request")
var action tailcfg.SSHAction
action.AllowAgentForwarding = true
action.AllowLocalPortForwarding = true
action.AllowRemotePortForwarding = true
if authIDStr == "" {
holdURL, err := url.Parse(ns.headscale.cfg.ServerURL + "/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER")
if err != nil {
log.Error().Caller().Err(err).Msg("failed to parse SSH action URL")
http.Error(writer, "Internal error", http.StatusInternalServerError)
return
}
authID, err := types.NewAuthID()
if err != nil {
log.Error().Caller().Err(err).Msg("failed to generate auth ID for SSH action")
http.Error(writer, "Internal error", http.StatusInternalServerError)
return
}
ns.headscale.state.SetAuthCacheEntry(authID, types.NewAuthRequest())
authURL := ns.headscale.authProvider.AuthURL(authID)
q := holdURL.Query()
q.Set("auth_id", authID.String())
holdURL.RawQuery = q.Encode()
action.HoldAndDelegate = holdURL.String()
// TODO(kradalby): here we can also send a very tiny mapresponse
// "popping" the url and opening it for the user.
action.Message = fmt.Sprintf(`# Headscale SSH requires an additional check.
# To authenticate, visit: %s
# Authentication checked with Headscale SSH.
`, authURL)
} else {
authID, err := types.AuthIDFromString(authIDStr)
if err != nil {
log.Error().Caller().Err(err).Str("auth_id", authIDStr).Msg("invalid auth_id in SSH action request")
http.Error(writer, "Invalid auth_id", http.StatusBadRequest)
return
}
log.Trace().Caller().Str("auth_id", authID.String()).Msg("SSH action follow-up request with auth_id")
auth, ok := ns.headscale.state.GetAuthCacheEntry(authID)
if !ok {
log.Error().Caller().Str("auth_id", authID.String()).Msg("no auth session found for auth_id in SSH action request")
http.Error(writer, "Invalid auth_id", http.StatusBadRequest)
return
}
verdict := <-auth.WaitForAuth()
if verdict.Accept() {
action.Reject = false
action.Accept = true
} else {
action.Reject = true
action.Accept = false
log.Trace().Caller().Str("auth_id", authID.String()).Err(verdict.Err).Msg("SSH action authentication rejected")
}
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
err := json.NewEncoder(writer).Encode(action)
if err != nil {
log.Error().Caller().Err(err).Msg("failed to encode SSH action response")
return
}
// Ensure response is flushed to client
if flusher, ok := writer.(http.Flusher); ok {
flusher.Flush()
}
}
// PollNetMapHandler takes care of /machine/:id/map using the Noise protocol
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
@ -198,7 +390,7 @@ func rejectUnsupported(
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (ns *noiseServer) NoisePollNetMapHandler(
func (ns *noiseServer) PollNetMapHandler(
writer http.ResponseWriter,
req *http.Request,
) {
@ -237,8 +429,8 @@ func regErr(err error) *tailcfg.RegisterResponse {
return &tailcfg.RegisterResponse{Error: err.Error()}
}
// NoiseRegistrationHandler handles the actual registration process of a node.
func (ns *noiseServer) NoiseRegistrationHandler(
// RegistrationHandler handles the actual registration process of a node.
func (ns *noiseServer) RegistrationHandler(
writer http.ResponseWriter,
req *http.Request,
) {

View File

@ -12,7 +12,6 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/templates"
"github.com/juanfont/headscale/hscontrol/types"
@ -26,8 +25,8 @@ import (
const (
randomByteSize = 16
defaultOAuthOptionsCount = 3
registerCacheExpiration = time.Minute * 15
registerCacheCleanup = time.Minute * 20
authCacheExpiration = time.Minute * 15
authCacheCleanup = time.Minute * 20
)
var (
@ -44,17 +43,21 @@ var (
errOIDCUnverifiedEmail = errors.New("authenticated principal has an unverified email")
)
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
type RegistrationInfo struct {
RegistrationID types.RegistrationID
Verifier *string
// AuthInfo contains both auth ID and verifier information for OIDC validation.
type AuthInfo struct {
AuthID types.AuthID
Verifier *string
Registration bool
}
type AuthProviderOIDC struct {
h *Headscale
serverURL string
cfg *types.OIDCConfig
registrationCache *zcache.Cache[string, RegistrationInfo]
h *Headscale
serverURL string
cfg *types.OIDCConfig
// authCache holds auth information between
// the auth and the callback steps.
authCache *zcache.Cache[string, AuthInfo]
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
@ -81,45 +84,63 @@ func NewAuthProviderOIDC(
Scopes: cfg.Scope,
}
registrationCache := zcache.New[string, RegistrationInfo](
registerCacheExpiration,
registerCacheCleanup,
authCache := zcache.New[string, AuthInfo](
authCacheExpiration,
authCacheCleanup,
)
return &AuthProviderOIDC{
h: h,
serverURL: serverURL,
cfg: cfg,
registrationCache: registrationCache,
h: h,
serverURL: serverURL,
cfg: cfg,
authCache: authCache,
oidcProvider: oidcProvider,
oauth2Config: oauth2Config,
}, nil
}
func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string {
func (a *AuthProviderOIDC) AuthURL(authID types.AuthID) string {
return fmt.Sprintf(
"%s/auth/%s",
strings.TrimSuffix(a.serverURL, "/"),
authID.String())
}
func (a *AuthProviderOIDC) AuthHandler(
writer http.ResponseWriter,
req *http.Request,
) {
a.authHandler(writer, req, false)
}
func (a *AuthProviderOIDC) RegisterURL(authID types.AuthID) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
registrationID.String())
authID.String())
}
// RegisterHandler registers the OIDC callback handler with the given router.
// It puts NodeKey in cache so the callback can retrieve it using the oidc state param.
// Listens in /register/:registration_id.
// Listens in /register/:auth_id.
func (a *AuthProviderOIDC) RegisterHandler(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
registrationIdStr := vars["registration_id"]
a.authHandler(writer, req, true)
}
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
// authHandler takes an incoming request that needs to be authenticated and
// validates and prepares it for the OIDC flow.
func (a *AuthProviderOIDC) authHandler(
writer http.ResponseWriter,
req *http.Request,
registration bool,
) {
authID, err := authIDFromRequest(req)
if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
httpError(writer, err)
return
}
@ -137,9 +158,9 @@ func (a *AuthProviderOIDC) RegisterHandler(
return
}
// Initialize registration info with machine key
registrationInfo := RegistrationInfo{
RegistrationID: registrationId,
registrationInfo := AuthInfo{
AuthID: authID,
Registration: registration,
}
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
@ -167,7 +188,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
extras = append(extras, oidc.Nonce(nonce))
// Cache the registration info
a.registrationCache.Set(state, registrationInfo)
a.authCache.Set(state, registrationInfo)
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL)
@ -302,16 +323,20 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// If the node exists, then the node should be reauthenticated,
// if the node does not exist, and the machine key exists, then
// this is a new node that should be registered.
registrationId := a.getRegistrationIDFromState(state)
authInfo := a.getAuthInfoFromState(state)
if authInfo == nil {
log.Debug().Caller().Str("state", state).Msg("state not found in cache, login session may have expired")
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
// Register the node if it does not exist.
if registrationId != nil {
verb := "Reauthenticated"
return
}
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
// If this is a registration flow, then we need to register the node.
if authInfo.Registration {
newNode, err := a.handleRegistration(user, authInfo.AuthID, nodeExpiry)
if err != nil {
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed")
log.Debug().Caller().Str("registration_id", authInfo.AuthID.String()).Msg("registration session expired before authorization completed")
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
return
@ -322,12 +347,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return
}
if newNode {
verb = "Authenticated"
}
// TODO(kradalby): replace with go-elem
content := renderOIDCCallbackTemplate(user, verb)
content := renderRegistrationSuccessTemplate(user, newNode)
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
@ -339,9 +359,28 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return
}
// Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node.
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
// If this is not a registration callback, then its a regular authentication callback
// and we need to send a response and confirm that the access was allowed.
authReq, ok := a.h.state.GetAuthCacheEntry(authInfo.AuthID)
if !ok {
log.Debug().Caller().Str("auth_id", authInfo.AuthID.String()).Msg("auth session expired before authorization completed")
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
return
}
// Send a finish auth verdict with no errors to let the CLI know that the authentication was successful.
authReq.FinishAuth(types.AuthVerdict{})
content := renderAuthSuccessTemplate(user)
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(content.Bytes()); err != nil { //nolint:noinlineerr
util.LogErr(err, "Failed to write HTTP response")
}
}
func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time {
@ -374,7 +413,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
var exchangeOpts []oauth2.AuthCodeOption
if a.cfg.PKCE.Enabled {
regInfo, ok := a.registrationCache.Get(state)
regInfo, ok := a.authCache.Get(state)
if !ok {
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
}
@ -507,14 +546,14 @@ func doOIDCAuthorization(
return nil
}
// getRegistrationIDFromState retrieves the registration ID from the state.
func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID {
regInfo, ok := a.registrationCache.Get(state)
// getAuthInfoFromState retrieves the registration ID from the state.
func (a *AuthProviderOIDC) getAuthInfoFromState(state string) *AuthInfo {
authInfo, ok := a.authCache.Get(state)
if !ok {
return nil
}
return &regInfo.RegistrationID
return &authInfo
}
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
@ -562,7 +601,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
func (a *AuthProviderOIDC) handleRegistration(
user *types.User,
registrationID types.RegistrationID,
registrationID types.AuthID,
expiry time.Time,
) (bool, error) {
node, nodeChange, err := a.h.state.HandleNodeFromAuthPath(
@ -597,12 +636,38 @@ func (a *AuthProviderOIDC) handleRegistration(
return !nodeChange.IsEmpty(), nil
}
func renderOIDCCallbackTemplate(
func renderRegistrationSuccessTemplate(
user *types.User,
verb string,
newNode bool,
) *bytes.Buffer {
html := templates.OIDCCallback(user.Display(), verb).Render()
return bytes.NewBufferString(html)
result := templates.AuthSuccessResult{
Title: "Headscale - Node Reauthenticated",
Heading: "Node reauthenticated",
Verb: "Reauthenticated",
User: user.Display(),
Message: "You can now close this window.",
}
if newNode {
result.Title = "Headscale - Node Registered"
result.Heading = "Node registered"
result.Verb = "Registered"
}
return bytes.NewBufferString(templates.AuthSuccess(result).Render())
}
func renderAuthSuccessTemplate(
user *types.User,
) *bytes.Buffer {
result := templates.AuthSuccessResult{
Title: "Headscale - SSH Session Authorized",
Heading: "SSH session authorized",
Verb: "Authorized",
User: user.Display(),
Message: "You may return to your terminal.",
}
return bytes.NewBufferString(templates.AuthSuccess(result).Render())
}
// getCookieName generates a unique cookie name based on a cookie value.

View File

@ -7,35 +7,54 @@ import (
"github.com/stretchr/testify/assert"
)
func TestOIDCCallbackTemplate(t *testing.T) {
func TestAuthSuccessTemplate(t *testing.T) {
tests := []struct {
name string
userName string
verb string
name string
result templates.AuthSuccessResult
}{
{
name: "logged_in_user",
userName: "test@example.com",
verb: "Logged in",
name: "node_registered",
result: templates.AuthSuccessResult{
Title: "Headscale - Node Registered",
Heading: "Node registered",
Verb: "Registered",
User: "newuser@example.com",
Message: "You can now close this window.",
},
},
{
name: "registered_user",
userName: "newuser@example.com",
verb: "Registered",
name: "node_reauthenticated",
result: templates.AuthSuccessResult{
Title: "Headscale - Node Reauthenticated",
Heading: "Node reauthenticated",
Verb: "Reauthenticated",
User: "test@example.com",
Message: "You can now close this window.",
},
},
{
name: "ssh_session_authorized",
result: templates.AuthSuccessResult{
Title: "Headscale - SSH Session Authorized",
Heading: "SSH session authorized",
Verb: "Authorized",
User: "test@example.com",
Message: "You may return to your terminal.",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Render using the elem-go template
html := templates.OIDCCallback(tt.userName, tt.verb).Render()
html := templates.AuthSuccess(tt.result).Render()
// Verify the HTML contains expected elements
// Verify the HTML contains expected structural elements
assert.Contains(t, html, "<!DOCTYPE html>")
assert.Contains(t, html, "<title>Headscale Authentication Succeeded</title>")
assert.Contains(t, html, tt.verb)
assert.Contains(t, html, tt.userName)
assert.Contains(t, html, "You can now close this window")
assert.Contains(t, html, "<title>"+tt.result.Title+"</title>")
assert.Contains(t, html, tt.result.Heading)
assert.Contains(t, html, tt.result.Verb+" as ")
assert.Contains(t, html, tt.result.User)
assert.Contains(t, html, tt.result.Message)
// Verify Material for MkDocs design system CSS is present
assert.Contains(t, html, "Material for MkDocs")

View File

@ -19,7 +19,7 @@ type PolicyManager interface {
MatchersForNode(node types.NodeView) ([]matcher.Match, error)
// BuildPeerMap constructs peer relationship maps for the given nodes
BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView
SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error)
SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error)
SetPolicy(pol []byte) (bool, error)
SetUsers(users []types.User) (bool, error)
SetNodes(nodes views.Slice[types.NodeView]) (bool, error)

View File

@ -1188,8 +1188,9 @@ func TestSSHPolicyRules(t *testing.T) {
"root": "",
},
Action: &tailcfg.SSHAction{
Accept: true,
Accept: false,
SessionDuration: 24 * time.Hour,
HoldAndDelegate: "unused-url/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER",
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
@ -1476,7 +1477,7 @@ func TestSSHPolicyRules(t *testing.T) {
require.NoError(t, err)
got, err := pm.SSHPolicy(tt.targetNode.View())
got, err := pm.SSHPolicy("unused-url", tt.targetNode.View())
require.NoError(t, err)
if diff := cmp.Diff(tt.wantSSH, got); diff != "" {

View File

@ -319,11 +319,27 @@ func (pol *Policy) compileACLWithAutogroupSelf(
return rules, nil
}
func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
var sshAccept = tailcfg.SSHAction{
Reject: false,
Accept: true,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
}
func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction {
return tailcfg.SSHAction{
Reject: !accept,
Accept: accept,
SessionDuration: duration,
Reject: false,
Accept: false,
SessionDuration: duration,
// Replaced in the client:
// * $SRC_NODE_IP (URL escaped)
// * $SRC_NODE_ID (Node.ID as int64 string)
// * $DST_NODE_IP (URL escaped)
// * $DST_NODE_ID (Node.ID as int64 string)
// * $SSH_USER (URL escaped, ssh user requested)
// * $LOCAL_USER (URL escaped, local user mapped)
HoldAndDelegate: baseURL + "/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER",
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
@ -332,6 +348,7 @@ func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
//nolint:gocyclo // complex SSH policy compilation logic
func (pol *Policy) compileSSHPolicy(
baseURL string,
users types.Users,
node types.NodeView,
nodes views.Slice[types.NodeView],
@ -377,9 +394,9 @@ func (pol *Policy) compileSSHPolicy(
switch rule.Action {
case SSHActionAccept:
action = sshAction(true, 0)
action = sshAccept
case SSHActionCheck:
action = sshAction(true, time.Duration(rule.CheckPeriod))
action = sshCheck(baseURL, time.Duration(rule.CheckPeriod))
default:
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
}

View File

@ -615,7 +615,7 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
require.NoError(t, err)
// Compile SSH policy
sshPolicy, err := tt.policy.compileSSHPolicy(users, tt.targetNode.View(), nodes.ViewSlice())
sshPolicy, err := tt.policy.compileSSHPolicy("unused-server-url", users, tt.targetNode.View(), nodes.ViewSlice())
require.NoError(t, err)
if tt.wantEmpty {
@ -691,7 +691,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
err := policy.validate()
require.NoError(t, err)
sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, nodeTaggedServer.View(), nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@ -704,8 +704,11 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
}
assert.Equal(t, expectedUsers, rule.SSHUsers)
// Verify check action with session duration
assert.True(t, rule.Action.Accept)
// Verify check action: Accept is false, HoldAndDelegate is set
assert.False(t, rule.Action.Accept)
assert.False(t, rule.Action.Reject)
assert.NotEmpty(t, rule.Action.HoldAndDelegate)
assert.Contains(t, rule.Action.HoldAndDelegate, "/machine/ssh/action/")
assert.Equal(t, 24*time.Hour, rule.Action.SessionDuration)
}
@ -756,7 +759,7 @@ func TestSSHIntegrationReproduction(t *testing.T) {
require.NoError(t, err)
// Test SSH policy compilation for node2 (owned by user2, who is in the group)
sshPolicy, err := policy.compileSSHPolicy(users, node2.View(), nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node2.View(), nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@ -806,7 +809,7 @@ func TestSSHJSONSerialization(t *testing.T) {
err := policy.validate()
require.NoError(t, err)
sshPolicy, err := policy.compileSSHPolicy(users, node.View(), nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node.View(), nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
@ -1413,7 +1416,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
// Test for user1's first node
node1 := nodes[0].View()
sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@ -1432,7 +1435,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
// Test for user2's first node
node3 := nodes[2].View()
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy2)
require.Len(t, sshPolicy2.Rules, 1)
@ -1451,7 +1454,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
// Test for tagged node (should have no SSH rules)
node5 := nodes[4].View()
sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
sshPolicy3, err := policy.compileSSHPolicy("unused-server-url", users, node5, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy3 != nil {
@ -1491,7 +1494,7 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
// For user1's node: should allow SSH from user1's devices
node1 := nodes[0].View()
sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@ -1508,7 +1511,7 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
// For user2's node: should have no rules (user1's devices can't match user2's self)
node3 := nodes[2].View()
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy2 != nil {
@ -1551,7 +1554,7 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
// For user1's node: should allow SSH from user1's devices only (not user2's)
node1 := nodes[0].View()
sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@ -1568,7 +1571,7 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
// For user3's node: should have no rules (not in group:admins)
node5 := nodes[4].View()
sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node5, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy2 != nil {
@ -1610,7 +1613,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
// For untagged node: should only get principals from other untagged nodes
node1 := nodes[0].View()
sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@ -1628,7 +1631,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
// For tagged node: should get no SSH rules
node3 := nodes[2].View()
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy2 != nil {
@ -1671,7 +1674,7 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
// Test 1: Compile for user1's device (should only match autogroup:self destination)
node1 := nodes[0].View()
sshPolicy1, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
sshPolicy1, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy1)
require.Len(t, sshPolicy1.Rules, 1, "user1's device should have 1 SSH rule (autogroup:self)")
@ -1690,7 +1693,7 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
// Test 2: Compile for router (should only match tag:router destination)
routerNode := nodes[3].View() // user2-router
sshPolicyRouter, err := policy.compileSSHPolicy(users, routerNode, nodes.ViewSlice())
sshPolicyRouter, err := policy.compileSSHPolicy("unused-server-url", users, routerNode, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicyRouter)
require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)")

View File

@ -222,7 +222,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
return true, nil
}
func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
func (pm *PolicyManager) SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
@ -230,7 +230,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err
return sshPol, nil
}
sshPol, err := pm.pol.compileSSHPolicy(pm.users, node, pm.nodes)
sshPol, err := pm.pol.compileSSHPolicy(baseURL, pm.users, node, pm.nodes)
if err != nil {
return nil, fmt.Errorf("compiling SSH policy: %w", err)
}

View File

@ -64,6 +64,9 @@ var ErrNodeNotInNodeStore = errors.New("node no longer exists in NodeStore")
// ErrNodeNameNotUnique is returned when a node name is not unique.
var ErrNodeNameNotUnique = errors.New("node name is not unique")
// ErrRegistrationExpired is returned when a registration has expired.
var ErrRegistrationExpired = errors.New("registration expired")
// State manages Headscale's core state, coordinating between database, policy management,
// IP allocation, and DERP routing. All methods are thread-safe.
type State struct {
@ -82,8 +85,10 @@ type State struct {
derpMap atomic.Pointer[tailcfg.DERPMap]
// polMan handles policy evaluation and management
polMan policy.PolicyManager
// registrationCache caches node registration data to reduce database load
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
// authCache caches any pending authentication requests, from either auth type (Web and OIDC).
authCache *zcache.Cache[types.AuthID, types.AuthRequest]
// primaryRoutes tracks primary route assignments for nodes
primaryRoutes *routes.PrimaryRoutes
}
@ -101,20 +106,20 @@ func NewState(cfg *types.Config) (*State, error) {
cacheCleanup = cfg.Tuning.RegisterCacheCleanup
}
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
authCache := zcache.New[types.AuthID, types.AuthRequest](
cacheExpiration,
cacheCleanup,
)
registrationCache.OnEvicted(
func(id types.RegistrationID, rn types.RegisterNode) {
rn.SendAndClose(nil)
authCache.OnEvicted(
func(id types.AuthID, rn types.AuthRequest) {
rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired})
},
)
db, err := hsdb.NewHeadscaleDatabase(
cfg,
registrationCache,
authCache,
)
if err != nil {
return nil, fmt.Errorf("initializing database: %w", err)
@ -178,12 +183,12 @@ func NewState(cfg *types.Config) (*State, error) {
return &State{
cfg: cfg,
db: db,
ipAlloc: ipAlloc,
polMan: polMan,
registrationCache: registrationCache,
primaryRoutes: routes.New(),
nodeStore: nodeStore,
db: db,
ipAlloc: ipAlloc,
polMan: polMan,
authCache: authCache,
primaryRoutes: routes.New(),
nodeStore: nodeStore,
}, nil
}
@ -851,7 +856,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha
// SSHPolicy returns the SSH access policy for a node.
func (s *State) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
return s.polMan.SSHPolicy(node)
return s.polMan.SSHPolicy(s.cfg.ServerURL, node)
}
// Filter returns the current network filter rules and matches.
@ -1042,9 +1047,9 @@ func (s *State) DeletePreAuthKey(id uint64) error {
return s.db.DeletePreAuthKey(id)
}
// GetRegistrationCacheEntry retrieves a node registration from cache.
func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.RegisterNode, bool) {
entry, found := s.registrationCache.Get(id)
// GetAuthCacheEntry retrieves a node registration from cache.
func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) {
entry, found := s.authCache.Get(id)
if !found {
return nil, false
}
@ -1052,26 +1057,24 @@ func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.Regis
return &entry, true
}
// SetRegistrationCacheEntry stores a node registration in cache.
func (s *State) SetRegistrationCacheEntry(id types.RegistrationID, entry types.RegisterNode) {
s.registrationCache.Set(id, entry)
// SetAuthCacheEntry stores a node registration in cache.
func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) {
s.authCache.Set(id, entry)
}
// logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname.
func logHostinfoValidation(machineKey, nodeKey, username, hostname string, hostinfo *tailcfg.Hostinfo) {
if hostinfo == nil {
func logHostinfoValidation(nv types.NodeView, username, hostname string) {
if !nv.Hostinfo().Valid() {
log.Warn().
Caller().
Str(zf.MachineKey, machineKey).
Str(zf.NodeKey, nodeKey).
EmbedObject(nv).
Str(zf.UserName, username).
Str(zf.GeneratedHostname, hostname).
Msg("Registration had nil hostinfo, generated default hostname")
} else if hostinfo.Hostname == "" {
} else if nv.Hostinfo().Hostname() == "" {
log.Warn().
Caller().
Str(zf.MachineKey, machineKey).
Str(zf.NodeKey, nodeKey).
EmbedObject(nv).
Str(zf.UserName, username).
Str(zf.GeneratedHostname, hostname).
Msg("Registration had empty hostname, generated default")
@ -1113,7 +1116,7 @@ type authNodeUpdateParams struct {
// Node to update; must be valid and in NodeStore.
ExistingNode types.NodeView
// Client data: keys, hostinfo, endpoints.
RegEntry *types.RegisterNode
RegEntry *types.AuthRequest
// Pre-validated hostinfo; NetInfo preserved from ExistingNode.
ValidHostinfo *tailcfg.Hostinfo
// Hostname from hostinfo, or generated from keys if client omits it.
@ -1132,6 +1135,7 @@ type authNodeUpdateParams struct {
// an existing node. It updates the node in NodeStore, processes RequestTags, and
// persists changes to the database.
func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) {
regNv := params.RegEntry.Node()
// Log the operation type
if params.IsConvertFromTag {
log.Info().
@ -1140,16 +1144,16 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
Msg("Converting tagged node to user-owned node")
} else {
log.Info().
EmbedObject(params.ExistingNode).
Interface("hostinfo", params.RegEntry.Node.Hostinfo).
Object("existing", params.ExistingNode).
Object("incoming", regNv).
Msg("Updating existing node registration via reauth")
}
// Process RequestTags during reauth (#2979)
// Due to json:",omitempty", we treat empty/nil as "clear tags"
var requestTags []string
if params.RegEntry.Node.Hostinfo != nil {
requestTags = params.RegEntry.Node.Hostinfo.RequestTags
if regNv.Hostinfo().Valid() {
requestTags = regNv.Hostinfo().RequestTags().AsSlice()
}
oldTags := params.ExistingNode.Tags().AsSlice()
@ -1167,8 +1171,8 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
// Update existing node in NodeStore - validation passed, safe to mutate
updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) {
node.NodeKey = params.RegEntry.Node.NodeKey
node.DiscoKey = params.RegEntry.Node.DiscoKey
node.NodeKey = regNv.NodeKey()
node.DiscoKey = regNv.DiscoKey()
node.Hostname = params.Hostname
// Preserve NetInfo from existing node when re-registering
@ -1179,7 +1183,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
params.ValidHostinfo,
)
node.Endpoints = params.RegEntry.Node.Endpoints
node.Endpoints = regNv.Endpoints().AsSlice()
node.IsOnline = new(false)
node.LastSeen = new(time.Now())
@ -1188,7 +1192,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
if params.IsConvertFromTag {
node.RegisterMethod = params.RegisterMethod
} else {
node.RegisterMethod = params.RegEntry.Node.RegisterMethod
node.RegisterMethod = regNv.RegisterMethod()
}
// Track tagged status BEFORE processing tags
@ -1208,7 +1212,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
if params.Expiry != nil {
node.Expiry = params.Expiry
} else {
node.Expiry = params.RegEntry.Node.Expiry
node.Expiry = regNv.Expiry().Clone()
}
case !wasTagged && isTagged:
// Personal → Tagged: clear expiry (tagged nodes don't expire)
@ -1218,14 +1222,14 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
if params.Expiry != nil {
node.Expiry = params.Expiry
} else {
node.Expiry = params.RegEntry.Node.Expiry
node.Expiry = regNv.Expiry().Clone()
}
case !isTagged:
// Personal → Personal: update expiry from client
if params.Expiry != nil {
node.Expiry = params.Expiry
} else {
node.Expiry = params.RegEntry.Node.Expiry
node.Expiry = regNv.Expiry().Clone()
}
}
// Tagged → Tagged: keep existing expiry (nil) - no action needed
@ -1511,13 +1515,13 @@ func (s *State) processReauthTags(
// HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC).
func (s *State) HandleNodeFromAuthPath(
registrationID types.RegistrationID,
authID types.AuthID,
userID types.UserID,
expiry *time.Time,
registrationMethod string,
) (types.NodeView, change.Change, error) {
// Get the registration entry from cache
regEntry, ok := s.GetRegistrationCacheEntry(registrationID)
regEntry, ok := s.GetAuthCacheEntry(authID)
if !ok {
return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache
}
@ -1530,25 +1534,27 @@ func (s *State) HandleNodeFromAuthPath(
// Ensure we have a valid hostname from the registration cache entry
hostname := util.EnsureHostname(
regEntry.Node.Hostinfo,
regEntry.Node.MachineKey.String(),
regEntry.Node.NodeKey.String(),
regEntry.Node().Hostinfo(),
regEntry.Node().MachineKey().String(),
regEntry.Node().NodeKey().String(),
)
// Ensure we have valid hostinfo
validHostinfo := cmp.Or(regEntry.Node.Hostinfo, &tailcfg.Hostinfo{})
validHostinfo.Hostname = hostname
hostinfo := &tailcfg.Hostinfo{}
if regEntry.Node().Hostinfo().Valid() {
hostinfo = regEntry.Node().Hostinfo().AsStruct()
}
hostinfo.Hostname = hostname
logHostinfoValidation(
regEntry.Node.MachineKey.ShortString(),
regEntry.Node.NodeKey.String(),
regEntry.Node(),
user.Name,
hostname,
regEntry.Node.Hostinfo,
)
// Lookup existing nodes
machineKey := regEntry.Node.MachineKey
machineKey := regEntry.Node().MachineKey()
existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID))
existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey)
@ -1562,7 +1568,7 @@ func (s *State) HandleNodeFromAuthPath(
// Create logger with common fields for all auth operations
logger := log.With().
Str(zf.RegistrationID, registrationID.String()).
Str(zf.RegistrationID, authID.String()).
Str(zf.UserName, user.Name).
Str(zf.MachineKey, machineKey.ShortString()).
Str(zf.Method, registrationMethod).
@ -1571,7 +1577,7 @@ func (s *State) HandleNodeFromAuthPath(
// Common params for update operations
updateParams := authNodeUpdateParams{
RegEntry: regEntry,
ValidHostinfo: validHostinfo,
ValidHostinfo: hostinfo,
Hostname: hostname,
User: user,
Expiry: expiry,
@ -1605,7 +1611,7 @@ func (s *State) HandleNodeFromAuthPath(
Msg("Creating new node for different user (same machine key exists for another user)")
finalNode, err = s.createNewNodeFromAuth(
logger, user, regEntry, hostname, validHostinfo,
logger, user, regEntry, hostname, hostinfo,
expiry, registrationMethod, existingNodeAnyUser,
)
if err != nil {
@ -1613,7 +1619,7 @@ func (s *State) HandleNodeFromAuthPath(
}
} else {
finalNode, err = s.createNewNodeFromAuth(
logger, user, regEntry, hostname, validHostinfo,
logger, user, regEntry, hostname, hostinfo,
expiry, registrationMethod, types.NodeView{},
)
if err != nil {
@ -1622,10 +1628,10 @@ func (s *State) HandleNodeFromAuthPath(
}
// Signal to waiting clients
regEntry.SendAndClose(finalNode.AsStruct())
regEntry.FinishAuth(types.AuthVerdict{Node: finalNode})
// Delete from registration cache
s.registrationCache.Delete(registrationID)
s.authCache.Delete(authID)
// Update policy managers
usersChange, err := s.updatePolicyManagerUsers()
@ -1654,7 +1660,7 @@ func (s *State) HandleNodeFromAuthPath(
func (s *State) createNewNodeFromAuth(
logger zerolog.Logger,
user *types.User,
regEntry *types.RegisterNode,
regEntry *types.AuthRequest,
hostname string,
validHostinfo *tailcfg.Hostinfo,
expiry *time.Time,
@ -1667,13 +1673,13 @@ func (s *State) createNewNodeFromAuth(
return s.createAndSaveNewNode(newNodeParams{
User: *user,
MachineKey: regEntry.Node.MachineKey,
NodeKey: regEntry.Node.NodeKey,
DiscoKey: regEntry.Node.DiscoKey,
MachineKey: regEntry.Node().MachineKey(),
NodeKey: regEntry.Node().NodeKey(),
DiscoKey: regEntry.Node().DiscoKey(),
Hostname: hostname,
Hostinfo: validHostinfo,
Endpoints: regEntry.Node.Endpoints,
Expiry: cmp.Or(expiry, regEntry.Node.Expiry),
Endpoints: regEntry.Node().Endpoints().AsSlice(),
Expiry: cmp.Or(expiry, regEntry.Node().Expiry().Clone()),
RegisterMethod: registrationMethod,
ExistingNodeForNetinfo: existingNodeForNetinfo,
})
@ -1759,7 +1765,7 @@ func (s *State) HandleNodeFromPreAuthKey(
// Ensure we have a valid hostname - handle nil/empty cases
hostname := util.EnsureHostname(
regReq.Hostinfo,
regReq.Hostinfo.View(),
machineKey.String(),
regReq.NodeKey.String(),
)
@ -1768,14 +1774,6 @@ func (s *State) HandleNodeFromPreAuthKey(
validHostinfo := cmp.Or(regReq.Hostinfo, &tailcfg.Hostinfo{})
validHostinfo.Hostname = hostname
logHostinfoValidation(
machineKey.ShortString(),
regReq.NodeKey.ShortString(),
pakUsername(),
hostname,
regReq.Hostinfo,
)
log.Debug().
Caller().
Str(zf.NodeName, hostname).

View File

@ -0,0 +1,62 @@
package templates
import (
"github.com/chasefleming/elem-go"
)
// AuthSuccessResult contains the text content for an authentication success page.
// Each field controls a distinct piece of user-facing text so that every auth
// flow (node registration, reauthentication, SSH check, …) can clearly
// communicate what just happened.
type AuthSuccessResult struct {
// Title is the browser tab / page title,
// e.g. "Headscale - Node Registered".
Title string
// Heading is the bold green text inside the success box,
// e.g. "Node registered".
Heading string
// Verb is the action prefix in the body text before "as <user>",
// e.g. "Registered", "Reauthenticated", "Authorized".
Verb string
// User is the display name shown in bold in the body text,
// e.g. "user@example.com".
User string
// Message is the follow-up instruction shown after the user name,
// e.g. "You can now close this window."
Message string
}
// AuthSuccess renders an authentication / authorisation success page.
// The caller controls every user-visible string via [AuthSuccessResult] so the
// page clearly describes what succeeded (registration, reauth, SSH check, …).
func AuthSuccess(result AuthSuccessResult) *elem.Element {
box := successBox(
result.Heading,
elem.Text(result.Verb+" as "),
elem.Strong(nil, elem.Text(result.User)),
elem.Text(". "+result.Message),
)
return HtmlStructure(
elem.Title(nil, elem.Text(result.Title)),
mdTypesetBody(
headscaleLogo(),
box,
H2(elem.Text("Getting started")),
P(elem.Text("Check out the documentation to learn more about headscale and Tailscale:")),
Ul(
elem.Li(nil,
externalLink("https://headscale.net/stable/", "Headscale documentation"),
),
elem.Li(nil,
externalLink("https://tailscale.com/kb/", "Tailscale knowledge base"),
),
),
pageFooter(),
),
)
}

View File

@ -0,0 +1,21 @@
package templates
import (
"github.com/chasefleming/elem-go"
)
// AuthWeb renders a page that instructs an administrator to run a CLI command
// to complete an authentication or registration flow.
// It is used by both the registration and auth-approve web handlers.
func AuthWeb(title, description, command string) *elem.Element {
return HtmlStructure(
elem.Title(nil, elem.Text(title+" - Headscale")),
mdTypesetBody(
headscaleLogo(),
H1(elem.Text(title)),
P(elem.Text(description)),
Pre(PreCode(command)),
pageFooter(),
),
)
}

View File

@ -365,6 +365,47 @@ func orDivider() *elem.Element {
)
}
// successBox creates a green success feedback box with a checkmark icon.
// The heading is displayed as bold green text, and children are rendered below it.
// Pairs with warningBox for consistent feedback styling.
//
//nolint:unused // Used in auth_success.go template.
func successBox(heading string, children ...elem.Node) *elem.Element {
return elem.Div(attrs.Props{
attrs.Style: styles.Props{
styles.Display: "flex",
styles.AlignItems: "center",
styles.Gap: spaceM,
styles.Padding: spaceL,
styles.BackgroundColor: colorSuccessLight,
styles.Border: "1px solid " + colorSuccess,
styles.BorderRadius: "0.5rem",
styles.MarginBottom: spaceXL,
}.ToInline(),
},
checkboxIcon(),
elem.Div(nil,
append([]elem.Node{
elem.Strong(attrs.Props{
attrs.Style: styles.Props{
styles.Display: "block",
styles.Color: colorSuccess,
styles.FontSize: fontSizeH3,
styles.MarginBottom: spaceXS,
}.ToInline(),
}, elem.Text(heading)),
}, children...)...,
),
)
}
// checkboxIcon returns the success checkbox SVG icon as raw HTML.
func checkboxIcon() elem.Node {
return elem.Raw(`<svg id="checkbox" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" width="48" height="48" viewBox="0 0 512 512">
<path d="M256 32C132.3 32 32 132.3 32 256s100.3 224 224 224 224-100.3 224-224S379.7 32 256 32zm114.9 149.1L231.8 359.6c-1.1 1.1-2.9 3.5-5.1 3.5-2.3 0-3.8-1.6-5.1-2.9-1.3-1.3-78.9-75.9-78.9-75.9l-1.5-1.5c-.6-.9-1.1-2-1.1-3.2 0-1.2.5-2.3 1.1-3.2.4-.4.7-.7 1.1-1.2 7.7-8.1 23.3-24.5 24.3-25.5 1.3-1.3 2.4-3 4.8-3 2.5 0 4.1 2.1 5.3 3.3 1.2 1.2 45 43.3 45 43.3l111.3-143c1-.8 2.2-1.4 3.5-1.4 1.3 0 2.5.5 3.5 1.3l30.6 24.1c.8 1 1.3 2.2 1.3 3.5.1 1.3-.4 2.4-1 3.3z"></path>
</svg>`)
}
// warningBox creates a warning message box with icon and content.
//
//nolint:unused // Used in apple.go template.

View File

@ -1,69 +0,0 @@
package templates
import (
"github.com/chasefleming/elem-go"
"github.com/chasefleming/elem-go/attrs"
"github.com/chasefleming/elem-go/styles"
)
// checkboxIcon returns the success checkbox SVG icon as raw HTML.
func checkboxIcon() elem.Node {
return elem.Raw(`<svg id="checkbox" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" width="48" height="48" viewBox="0 0 512 512">
<path d="M256 32C132.3 32 32 132.3 32 256s100.3 224 224 224 224-100.3 224-224S379.7 32 256 32zm114.9 149.1L231.8 359.6c-1.1 1.1-2.9 3.5-5.1 3.5-2.3 0-3.8-1.6-5.1-2.9-1.3-1.3-78.9-75.9-78.9-75.9l-1.5-1.5c-.6-.9-1.1-2-1.1-3.2 0-1.2.5-2.3 1.1-3.2.4-.4.7-.7 1.1-1.2 7.7-8.1 23.3-24.5 24.3-25.5 1.3-1.3 2.4-3 4.8-3 2.5 0 4.1 2.1 5.3 3.3 1.2 1.2 45 43.3 45 43.3l111.3-143c1-.8 2.2-1.4 3.5-1.4 1.3 0 2.5.5 3.5 1.3l30.6 24.1c.8 1 1.3 2.2 1.3 3.5.1 1.3-.4 2.4-1 3.3z"></path>
</svg>`)
}
// OIDCCallback renders the OIDC authentication success callback page.
func OIDCCallback(user, verb string) *elem.Element {
// Success message box
successBox := elem.Div(attrs.Props{
attrs.Style: styles.Props{
styles.Display: "flex",
styles.AlignItems: "center",
styles.Gap: spaceM,
styles.Padding: spaceL,
styles.BackgroundColor: colorSuccessLight,
styles.Border: "1px solid " + colorSuccess,
styles.BorderRadius: "0.5rem",
styles.MarginBottom: spaceXL,
}.ToInline(),
},
checkboxIcon(),
elem.Div(nil,
elem.Strong(attrs.Props{
attrs.Style: styles.Props{
styles.Display: "block",
styles.Color: colorSuccess,
styles.FontSize: fontSizeH3,
styles.MarginBottom: spaceXS,
}.ToInline(),
}, elem.Text("Signed in successfully")),
elem.P(attrs.Props{
attrs.Style: styles.Props{
styles.Margin: "0",
styles.Color: colorTextPrimary,
styles.FontSize: fontSizeBase,
}.ToInline(),
}, elem.Text(verb), elem.Text(" as "), elem.Strong(nil, elem.Text(user)), elem.Text(". You can now close this window.")),
),
)
return HtmlStructure(
elem.Title(nil, elem.Text("Headscale Authentication Succeeded")),
mdTypesetBody(
headscaleLogo(),
successBox,
H2(elem.Text("Getting started")),
P(elem.Text("Check out the documentation to learn more about headscale and Tailscale:")),
Ul(
elem.Li(nil,
externalLink("https://headscale.net/stable/", "Headscale documentation"),
),
elem.Li(nil,
externalLink("https://tailscale.com/kb/", "Tailscale knowledge base"),
),
),
pageFooter(),
),
)
}

View File

@ -1,21 +0,0 @@
package templates
import (
"fmt"
"github.com/chasefleming/elem-go"
"github.com/juanfont/headscale/hscontrol/types"
)
func RegisterWeb(registrationID types.RegistrationID) *elem.Element {
return HtmlStructure(
elem.Title(nil, elem.Text("Registration - Headscale")),
mdTypesetBody(
headscaleLogo(),
H1(elem.Text("Machine registration")),
P(elem.Text("Run the command below in the headscale server to add this machine to your network:")),
Pre(PreCode(fmt.Sprintf("headscale nodes register --key %s --user USERNAME", registrationID.String()))),
pageFooter(),
),
)
}

View File

@ -5,7 +5,6 @@ import (
"testing"
"github.com/juanfont/headscale/hscontrol/templates"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
)
@ -16,12 +15,30 @@ func TestTemplateHTMLConsistency(t *testing.T) {
html string
}{
{
name: "OIDC Callback",
html: templates.OIDCCallback("test@example.com", "Logged in").Render(),
name: "Auth Success",
html: templates.AuthSuccess(templates.AuthSuccessResult{
Title: "Headscale - Node Registered",
Heading: "Node registered",
Verb: "Registered",
User: "test@example.com",
Message: "You can now close this window.",
}).Render(),
},
{
name: "Register Web",
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
name: "Auth Web Register",
html: templates.AuthWeb(
"Machine registration",
"Run the command below in the headscale server to add this machine to your network:",
"headscale auth register --auth-id test-key-123 --user USERNAME",
).Render(),
},
{
name: "Auth Web Approve",
html: templates.AuthWeb(
"Authentication check",
"Run the command below in the headscale server to approve this authentication request:",
"headscale auth approve --auth-id test-key-123",
).Render(),
},
{
name: "Windows Config",
@ -72,12 +89,30 @@ func TestTemplateModernHTMLFeatures(t *testing.T) {
html string
}{
{
name: "OIDC Callback",
html: templates.OIDCCallback("test@example.com", "Logged in").Render(),
name: "Auth Success",
html: templates.AuthSuccess(templates.AuthSuccessResult{
Title: "Headscale - Node Registered",
Heading: "Node registered",
Verb: "Registered",
User: "test@example.com",
Message: "You can now close this window.",
}).Render(),
},
{
name: "Register Web",
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
name: "Auth Web Register",
html: templates.AuthWeb(
"Machine registration",
"Run the command below in the headscale server to add this machine to your network:",
"headscale auth register --auth-id test-key-123 --user USERNAME",
).Render(),
},
{
name: "Auth Web Approve",
html: templates.AuthWeb(
"Authentication check",
"Run the command below in the headscale server to approve this authentication request:",
"headscale auth approve --auth-id test-key-123",
).Render(),
},
{
name: "Windows Config",
@ -116,16 +151,35 @@ func TestTemplateExternalLinkSecurity(t *testing.T) {
externalURLs []string // URLs that should have security attributes
}{
{
name: "OIDC Callback",
html: templates.OIDCCallback("test@example.com", "Logged in").Render(),
name: "Auth Success",
html: templates.AuthSuccess(templates.AuthSuccessResult{
Title: "Headscale - Node Registered",
Heading: "Node registered",
Verb: "Registered",
User: "test@example.com",
Message: "You can now close this window.",
}).Render(),
externalURLs: []string{
"https://headscale.net/stable/",
"https://tailscale.com/kb/",
},
},
{
name: "Register Web",
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
name: "Auth Web Register",
html: templates.AuthWeb(
"Machine registration",
"Run the command below in the headscale server to add this machine to your network:",
"headscale auth register --auth-id test-key-123 --user USERNAME",
).Render(),
externalURLs: []string{}, // No external links
},
{
name: "Auth Web Approve",
html: templates.AuthWeb(
"Authentication check",
"Run the command below in the headscale server to approve this authentication request:",
"headscale auth approve --auth-id test-key-123",
).Render(),
externalURLs: []string{}, // No external links
},
{
@ -185,12 +239,30 @@ func TestTemplateAccessibilityAttributes(t *testing.T) {
html string
}{
{
name: "OIDC Callback",
html: templates.OIDCCallback("test@example.com", "Logged in").Render(),
name: "Auth Success",
html: templates.AuthSuccess(templates.AuthSuccessResult{
Title: "Headscale - Node Registered",
Heading: "Node registered",
Verb: "Registered",
User: "test@example.com",
Message: "You can now close this window.",
}).Render(),
},
{
name: "Register Web",
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
name: "Auth Web Register",
html: templates.AuthWeb(
"Machine registration",
"Run the command below in the headscale server to add this machine to your network:",
"headscale auth register --auth-id test-key-123 --user USERNAME",
).Render(),
},
{
name: "Auth Web Approve",
html: templates.AuthWeb(
"Authentication check",
"Run the command below in the headscale server to approve this authentication request:",
"headscale auth approve --auth-id test-key-123",
).Render(),
},
{
name: "Windows Config",

View File

@ -22,8 +22,8 @@ const (
// Common errors.
var (
ErrCannotParsePrefix = errors.New("cannot parse prefix")
ErrInvalidRegistrationIDLength = errors.New("registration ID has invalid length")
ErrCannotParsePrefix = errors.New("cannot parse prefix")
ErrInvalidAuthIDLength = errors.New("registration ID has invalid length")
)
type StateUpdateType int
@ -159,21 +159,21 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
}
}
const RegistrationIDLength = 24
const AuthIDLength = 24
type RegistrationID string
type AuthID string
func NewRegistrationID() (RegistrationID, error) {
rid, err := util.GenerateRandomStringURLSafe(RegistrationIDLength)
func NewAuthID() (AuthID, error) {
rid, err := util.GenerateRandomStringURLSafe(AuthIDLength)
if err != nil {
return "", err
}
return RegistrationID(rid), nil
return AuthID(rid), nil
}
func MustRegistrationID() RegistrationID {
rid, err := NewRegistrationID()
func MustAuthID() AuthID {
rid, err := NewAuthID()
if err != nil {
panic(err)
}
@ -181,43 +181,96 @@ func MustRegistrationID() RegistrationID {
return rid
}
func RegistrationIDFromString(str string) (RegistrationID, error) {
if len(str) != RegistrationIDLength {
return "", fmt.Errorf("%w: expected %d, got %d", ErrInvalidRegistrationIDLength, RegistrationIDLength, len(str))
func AuthIDFromString(str string) (AuthID, error) {
r := AuthID(str)
err := r.Validate()
if err != nil {
return "", err
}
return RegistrationID(str), nil
return r, nil
}
func (r RegistrationID) String() string {
func (r AuthID) String() string {
return string(r)
}
type RegisterNode struct {
Node Node
Registered chan *Node
closed *atomic.Bool
func (r AuthID) Validate() error {
if len(r) != AuthIDLength {
return fmt.Errorf("%w: expected %d, got %d", ErrInvalidAuthIDLength, AuthIDLength, len(r))
}
return nil
}
func NewRegisterNode(node Node) RegisterNode {
return RegisterNode{
Node: node,
Registered: make(chan *Node),
closed: &atomic.Bool{},
// AuthRequest represent a pending authentication request from a user or a node.
// If it is a registration request, the node field will be populate with the node that is trying to register.
// When the authentication process is finished, the node that has been authenticated will be sent through the Finished channel.
// The closed field is used to ensure that the Finished channel is only closed once, and that no more nodes are sent through it after it has been closed.
type AuthRequest struct {
node *Node
finished chan AuthVerdict
closed *atomic.Bool
}
func NewAuthRequest() AuthRequest {
return AuthRequest{
finished: make(chan AuthVerdict),
closed: &atomic.Bool{},
}
}
func (rn *RegisterNode) SendAndClose(node *Node) {
func NewRegisterAuthRequest(node Node) AuthRequest {
return AuthRequest{
node: &node,
finished: make(chan AuthVerdict),
closed: &atomic.Bool{},
}
}
// Node returns the node that is trying to register.
// It will panic if the AuthRequest is not a registration request.
// Can _only_ be used in the registration path.
func (rn *AuthRequest) Node() NodeView {
if rn.node == nil {
panic("Node can only be used in registration requests")
}
return rn.node.View()
}
func (rn *AuthRequest) FinishAuth(verdict AuthVerdict) {
if rn.closed.Swap(true) {
return
}
select {
case rn.Registered <- node:
case rn.finished <- verdict:
default:
}
close(rn.Registered)
close(rn.finished)
}
func (rn *AuthRequest) WaitForAuth() <-chan AuthVerdict {
return rn.finished
}
type AuthVerdict struct {
// Err is the error that occurred during the authentication process, if any.
// If Err is nil, the authentication process has succeeded.
// If Err is not nil, the authentication process has failed and the node should not be authenticated.
Err error
// Node is the node that has been authenticated.
// Node is only valid if the auth request was a registration request
// and the authentication process has succeeded.
Node NodeView
}
func (v AuthVerdict) Accept() bool {
return v.Err == nil
}
// DefaultBatcherWorkers returns the default number of batcher workers.

View File

@ -295,8 +295,8 @@ func IsCI() bool {
// 3. If normalisation fails → generate invalid-<random> replacement
//
// Returns the guaranteed-valid hostname to use.
func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) string {
if hostinfo == nil || hostinfo.Hostname == "" {
func EnsureHostname(hostinfo tailcfg.HostinfoView, machineKey, nodeKey string) string {
if !hostinfo.Valid() || hostinfo.Hostname() == "" {
key := cmp.Or(machineKey, nodeKey)
if key == "" {
return "unknown-node"
@ -310,7 +310,7 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri
return "node-" + keyPrefix
}
lowercased := strings.ToLower(hostinfo.Hostname)
lowercased := strings.ToLower(hostinfo.Hostname())
err := ValidateHostname(lowercased)
if err == nil {

View File

@ -1070,7 +1070,7 @@ func TestEnsureHostname(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
got := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey)
// For invalid hostnames, we just check the prefix since the random part varies
if strings.HasPrefix(tt.want, "invalid-") {
if !strings.HasPrefix(got, "invalid-") {
@ -1255,7 +1255,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
gotHostname := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey)
// For invalid hostnames, we just check the prefix since the random part varies
if strings.HasPrefix(tt.wantHostname, "invalid-") {
if !strings.HasPrefix(gotHostname, "invalid-") {
@ -1284,7 +1284,7 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) {
hostinfo := &tailcfg.Hostinfo{Hostname: hostname}
result := EnsureHostname(hostinfo, "mkey", "nkey")
result := EnsureHostname(hostinfo.View(), "mkey", "nkey")
if len(result) > 63 {
t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result))
}
@ -1300,8 +1300,8 @@ func TestEnsureHostname_Idempotent(t *testing.T) {
OS: "linux",
}
hostname1 := EnsureHostname(originalHostinfo, "mkey", "nkey")
hostname2 := EnsureHostname(originalHostinfo, "mkey", "nkey")
hostname1 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey")
hostname2 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey")
if hostname1 != hostname2 {
t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2)

View File

@ -312,7 +312,7 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) {
}
// Register all clients as user1 (this is where cross-user registration happens)
// This simulates: headscale nodes register --user user1 --key <key>
// This simulates: headscale auth register --auth-id <id> --user user1
_ = scenario.runHeadscaleRegister("user1", body)
}

View File

@ -1065,11 +1065,11 @@ func TestNodeCommand(t *testing.T) {
require.NoError(t, err)
regIDs := []string{
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
}
nodes := make([]*v1.Node, len(regIDs))
@ -1100,11 +1100,11 @@ func TestNodeCommand(t *testing.T) {
headscale,
[]string{
"headscale",
"nodes",
"auth",
"register",
"--user",
"node-user",
"register",
"--key",
"--auth-id",
regID,
"--output",
"json",
@ -1153,8 +1153,8 @@ func TestNodeCommand(t *testing.T) {
assert.Equal(t, "node-5", listAll[4].GetName())
otherUserRegIDs := []string{
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
}
otherUserMachines := make([]*v1.Node, len(otherUserRegIDs))
@ -1185,11 +1185,11 @@ func TestNodeCommand(t *testing.T) {
headscale,
[]string{
"headscale",
"nodes",
"auth",
"register",
"--user",
"other-user",
"register",
"--key",
"--auth-id",
regID,
"--output",
"json",
@ -1326,11 +1326,11 @@ func TestNodeExpireCommand(t *testing.T) {
require.NoError(t, err)
regIDs := []string{
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
}
nodes := make([]*v1.Node, len(regIDs))
@ -1359,11 +1359,11 @@ func TestNodeExpireCommand(t *testing.T) {
headscale,
[]string{
"headscale",
"nodes",
"auth",
"register",
"--user",
"node-expire-user",
"register",
"--key",
"--auth-id",
regID,
"--output",
"json",
@ -1461,11 +1461,11 @@ func TestNodeRenameCommand(t *testing.T) {
require.NoError(t, err)
regIDs := []string{
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
types.MustAuthID().String(),
}
nodes := make([]*v1.Node, len(regIDs))
@ -1496,11 +1496,11 @@ func TestNodeRenameCommand(t *testing.T) {
headscale,
[]string{
"headscale",
"nodes",
"auth",
"register",
"--user",
"node-rename-command",
"register",
"--key",
"--auth-id",
regID,
"--output",
"json",

View File

@ -16,6 +16,7 @@ import (
type ControlServer interface {
Shutdown() (string, string, error)
SaveLog(path string) (string, string, error)
ReadLog() (string, string, error)
SaveProfile(path string) error
Execute(command []string) (string, error)
WriteFile(path string, content []byte) error

View File

@ -699,6 +699,18 @@ func (t *HeadscaleInContainer) WriteLogs(stdout, stderr io.Writer) error {
return dockertestutil.WriteLog(t.pool, t.container, stdout, stderr)
}
// ReadLog returns the current stdout and stderr logs from the headscale container.
func (t *HeadscaleInContainer) ReadLog() (string, string, error) {
var stdout, stderr bytes.Buffer
err := dockertestutil.WriteLog(t.pool, t.container, &stdout, &stderr)
if err != nil {
return "", "", fmt.Errorf("reading container logs: %w", err)
}
return stdout.String(), stderr.String(), nil
}
// SaveLog saves the current stdout log of the container to a path
// on the host system.
func (t *HeadscaleInContainer) SaveLog(path string) (string, string, error) {

View File

@ -141,6 +141,12 @@ type ScenarioSpec struct {
// Versions is specific list of versions to use for the test.
Versions []string
// OIDCSkipUserCreation, if true, skips creating users via headscale CLI
// during environment setup. Useful for OIDC tests where the SSH policy
// references users by name, since OIDC login creates users automatically
// and pre-creating them via CLI causes duplicate user records.
OIDCSkipUserCreation bool
// OIDCUsers, if populated, will start a Mock OIDC server and populate
// the user login stack with the given users.
// If the NodesPerUser is set, it should align with this list to ensure
@ -866,9 +872,18 @@ func (s *Scenario) createHeadscaleEnvWithTags(
}
for _, user := range s.spec.Users {
u, err := s.CreateUser(user)
if err != nil {
return err
var u *v1.User
if s.spec.OIDCSkipUserCreation {
// Only register locally — OIDC login will create the headscale user.
s.mu.Lock()
s.users[user] = &User{Clients: make(map[string]TailscaleClient)}
s.mu.Unlock()
} else {
u, err = s.CreateUser(user)
if err != nil {
return err
}
}
var userOpts []tsic.Option
@ -1169,7 +1184,7 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error {
return errParseAuthPage
}
keySep := strings.Split(codeSep[0], "key ")
keySep := strings.Split(codeSep[0], "--auth-id ")
if len(keySep) != 2 {
return errParseAuthPage
}
@ -1180,7 +1195,7 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error {
if headscale, err := s.Headscale(); err == nil { //nolint:noinlineerr
_, err = headscale.Execute(
[]string{"headscale", "nodes", "register", "--user", userStr, "--key", key},
[]string{"headscale", "auth", "register", "--user", userStr, "--auth-id", key},
)
if err != nil {
log.Printf("registering node: %s", err)

View File

@ -3,13 +3,17 @@ package integration
import (
"fmt"
"log"
"net/url"
"strings"
"testing"
"time"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/oauth2-proxy/mockoidc"
"github.com/prometheus/common/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
@ -579,3 +583,558 @@ func TestSSHAutogroupSelf(t *testing.T) {
}
}
}
type sshCheckResult struct {
stdout string
stderr string
err error
}
// doSSHCheck runs SSH in a goroutine with a longer timeout, returning a channel
// for the result. The SSH command will block while waiting for auth approval in
// check mode.
func doSSHCheck(
t *testing.T,
client TailscaleClient,
peer TailscaleClient,
) chan sshCheckResult {
t.Helper()
peerFQDN, _ := peer.FQDN()
command := []string{
"/usr/bin/ssh", "-o StrictHostKeyChecking=no", "-o ConnectTimeout=30",
fmt.Sprintf("%s@%s", "ssh-it-user", peerFQDN),
"'hostname'",
}
log.Printf(
"[SSH check] Running from %s to %s",
client.Hostname(),
peer.Hostname(),
)
ch := make(chan sshCheckResult, 1)
go func() {
stdout, stderr, err := client.Execute(
command,
dockertestutil.ExecuteCommandTimeout(60*time.Second),
)
ch <- sshCheckResult{stdout, stderr, err}
}()
return ch
}
// findSSHCheckAuthID polls headscale container logs for the SSH action auth-id.
// The SSH action handler logs "SSH action follow-up" with the auth_id on the
// follow-up request (where auth_id is non-empty).
func findSSHCheckAuthID(t *testing.T, headscale ControlServer) string {
t.Helper()
var authID string
assert.EventuallyWithT(t, func(c *assert.CollectT) {
_, stderr, err := headscale.ReadLog()
assert.NoError(c, err)
for line := range strings.SplitSeq(stderr, "\n") {
if !strings.Contains(line, "SSH action follow-up") {
continue
}
if idx := strings.Index(line, "auth_id="); idx != -1 {
start := idx + len("auth_id=")
end := strings.IndexByte(line[start:], ' ')
if end == -1 {
end = len(line[start:])
}
authID = line[start : start+end]
}
}
assert.NotEmpty(c, authID, "auth-id not found in headscale logs")
}, 10*time.Second, 500*time.Millisecond, "waiting for SSH check auth-id in headscale logs")
return authID
}
// sshCheckPolicy returns a policy with SSH "check" mode for group:integration-test
// targeting autogroup:member and autogroup:tagged destinations.
func sshCheckPolicy() *policyv2.Policy {
return &policyv2.Policy{
Groups: policyv2.Groups{
policyv2.Group("group:integration-test"): []policyv2.Username{
policyv2.Username("user1@"),
},
},
ACLs: []policyv2.ACL{
{
Action: "accept",
Protocol: "tcp",
Sources: []policyv2.Alias{wildcard()},
Destinations: []policyv2.AliasWithPorts{
aliasWithPorts(wildcard(), tailcfg.PortRangeAny),
},
},
},
SSHs: []policyv2.SSH{
{
Action: "check",
Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")},
Destinations: policyv2.SSHDstAliases{
new(policyv2.AutoGroupMember),
new(policyv2.AutoGroupTagged),
},
Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")},
},
},
}
}
// sshCheckPolicyWithPeriod returns a policy with SSH "check" mode and a
// specified checkPeriod for session duration.
func sshCheckPolicyWithPeriod(period time.Duration) *policyv2.Policy {
return &policyv2.Policy{
Groups: policyv2.Groups{
policyv2.Group("group:integration-test"): []policyv2.Username{
policyv2.Username("user1@"),
},
},
ACLs: []policyv2.ACL{
{
Action: "accept",
Protocol: "tcp",
Sources: []policyv2.Alias{wildcard()},
Destinations: []policyv2.AliasWithPorts{
aliasWithPorts(wildcard(), tailcfg.PortRangeAny),
},
},
},
SSHs: []policyv2.SSH{
{
Action: "check",
Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")},
Destinations: policyv2.SSHDstAliases{
new(policyv2.AutoGroupMember),
new(policyv2.AutoGroupTagged),
},
Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")},
CheckPeriod: model.Duration(period),
},
},
}
}
// findNewSSHCheckAuthID polls headscale logs for an SSH check auth-id
// that differs from excludeID. Used to verify re-authentication after
// session expiry.
func findNewSSHCheckAuthID(
t *testing.T,
headscale ControlServer,
excludeID string,
) string {
t.Helper()
var authID string
assert.EventuallyWithT(t, func(c *assert.CollectT) {
_, stderr, err := headscale.ReadLog()
assert.NoError(c, err)
for line := range strings.SplitSeq(stderr, "\n") {
if !strings.Contains(line, "SSH action follow-up") {
continue
}
if idx := strings.Index(line, "auth_id="); idx != -1 {
start := idx + len("auth_id=")
end := strings.IndexByte(line[start:], ' ')
if end == -1 {
end = len(line[start:])
}
id := line[start : start+end]
if id != excludeID {
authID = id
}
}
}
assert.NotEmpty(c, authID, "new auth-id not found in headscale logs")
}, 10*time.Second, 500*time.Millisecond, "waiting for new SSH check auth-id")
return authID
}
func TestSSHOneUserToOneCheckModeCLI(t *testing.T) {
IntegrationSkip(t)
scenario := sshScenario(t, sshCheckPolicy(), 1)
// defer scenario.ShutdownAssertNoPanics(t)
allClients, err := scenario.ListTailscaleClients()
requireNoErrListClients(t, err)
user1Clients, err := scenario.ListTailscaleClients("user1")
requireNoErrListClients(t, err)
user2Clients, err := scenario.ListTailscaleClients("user2")
requireNoErrListClients(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
requireNoErrListFQDN(t, err)
// user1 can SSH (via check) to all peers
for _, client := range user1Clients {
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
}
// Start SSH — will block waiting for check auth
sshResult := doSSHCheck(t, client, peer)
// Find the auth-id from headscale logs
authID := findSSHCheckAuthID(t, headscale)
// Approve via CLI
_, err := headscale.Execute(
[]string{
"headscale", "auth", "approve",
"--auth-id", authID,
},
)
require.NoError(t, err)
// Wait for SSH to complete
select {
case result := <-sshResult:
require.NoError(t, result.err)
require.Contains(
t,
peer.ContainerID(),
strings.ReplaceAll(result.stdout, "\n", ""),
)
case <-time.After(30 * time.Second):
t.Fatal("SSH did not complete after auth approval")
}
}
}
// user2 cannot SSH — not in the check policy group
for _, client := range user2Clients {
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
}
assertSSHPermissionDenied(t, client, peer)
}
}
}
func TestSSHOneUserToOneCheckModeOIDC(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1", "user2"},
OIDCSkipUserCreation: true,
OIDCUsers: []mockoidc.MockUser{
// First 2: consumed during node registration
oidcMockUser("user1", true),
oidcMockUser("user2", true),
// Extra: consumed during SSH check auth flows.
// Each SSH check pops one user from the queue.
oidcMockUser("user1", true),
},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
// defer scenario.ShutdownAssertNoPanics(t)
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(),
"HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(),
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
}
err = scenario.CreateHeadscaleEnvWithLoginURL(
[]tsic.Option{
tsic.WithSSH(),
tsic.WithNetfilter("off"),
tsic.WithPackages("openssh"),
tsic.WithExtraCommands("adduser ssh-it-user"),
tsic.WithDockerWorkdir("/"),
},
hsic.WithACLPolicy(sshCheckPolicy()),
hsic.WithTestName("sshcheckoidc"),
hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(),
hsic.WithFileInContainer(
"/tmp/hs_client_oidc_secret",
[]byte(scenario.mockOIDC.ClientSecret()),
),
)
require.NoError(t, err)
allClients, err := scenario.ListTailscaleClients()
requireNoErrListClients(t, err)
user1Clients, err := scenario.ListTailscaleClients("user1")
requireNoErrListClients(t, err)
user2Clients, err := scenario.ListTailscaleClients("user2")
requireNoErrListClients(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
requireNoErrListFQDN(t, err)
// user1 can SSH (via check) to all peers
for _, client := range user1Clients {
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
}
// Start SSH — will block waiting for check auth
sshResult := doSSHCheck(t, client, peer)
// Find the auth-id from headscale logs
authID := findSSHCheckAuthID(t, headscale)
// Build auth URL and visit it to trigger OIDC flow.
// The mock OIDC server auto-authenticates from the user queue.
authURL := headscale.GetEndpoint() + "/auth/" + authID
parsedURL, err := url.Parse(authURL)
require.NoError(t, err)
_, err = doLoginURL("ssh-check-oidc", parsedURL)
require.NoError(t, err)
// Wait for SSH to complete
select {
case result := <-sshResult:
require.NoError(t, result.err)
require.Contains(
t,
peer.ContainerID(),
strings.ReplaceAll(result.stdout, "\n", ""),
)
case <-time.After(30 * time.Second):
t.Fatal("SSH did not complete after OIDC auth")
}
}
}
// user2 cannot SSH — not in the check policy group
for _, client := range user2Clients {
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
}
assertSSHPermissionDenied(t, client, peer)
}
}
}
// TestSSHCheckModeUnapprovedTimeout verifies that SSH in check mode is rejected
// when nobody approves the auth request and the registration cache entry expires.
func TestSSHCheckModeUnapprovedTimeout(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{
tsic.WithSSH(),
tsic.WithNetfilter("off"),
tsic.WithPackages("openssh"),
tsic.WithExtraCommands("adduser ssh-it-user"),
tsic.WithDockerWorkdir("/"),
},
hsic.WithACLPolicy(sshCheckPolicy()),
hsic.WithTestName("sshchecktimeout"),
hsic.WithConfigEnv(map[string]string{
"HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "15s",
"HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "5s",
}),
)
require.NoError(t, err)
allClients, err := scenario.ListTailscaleClients()
requireNoErrListClients(t, err)
user1Clients, err := scenario.ListTailscaleClients("user1")
requireNoErrListClients(t, err)
user2Clients, err := scenario.ListTailscaleClients("user2")
requireNoErrListClients(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
requireNoErrListFQDN(t, err)
// user1 attempts SSH — enters check flow, but nobody approves
for _, client := range user1Clients {
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
}
sshResult := doSSHCheck(t, client, peer)
// Confirm the check flow was entered
_ = findSSHCheckAuthID(t, headscale)
// Do NOT approve — wait for cache expiry and SSH rejection
select {
case result := <-sshResult:
require.Error(t, result.err, "SSH should be rejected when unapproved")
assert.Empty(t, result.stdout, "no command output expected on rejection")
case <-time.After(60 * time.Second):
t.Fatal("SSH did not complete after cache expiry timeout")
}
}
}
// user2 still gets immediate Permission Denied
for _, client := range user2Clients {
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
}
assertSSHPermissionDenied(t, client, peer)
}
}
}
// TestSSHCheckModeCheckPeriodCLI verifies that after approval with a short
// checkPeriod, the session expires and the next SSH connection requires
// re-authentication via a new check flow.
func TestSSHCheckModeCheckPeriodCLI(t *testing.T) {
IntegrationSkip(t)
// 1 minute is the documented minimum checkPeriod
scenario := sshScenario(t, sshCheckPolicyWithPeriod(time.Minute), 1)
defer scenario.ShutdownAssertNoPanics(t)
allClients, err := scenario.ListTailscaleClients()
requireNoErrListClients(t, err)
user1Clients, err := scenario.ListTailscaleClients("user1")
requireNoErrListClients(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
_, err = scenario.ListTailscaleClientsFQDNs()
requireNoErrListFQDN(t, err)
// === Phase 1: First SSH check — approve, verify success ===
for _, client := range user1Clients {
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
}
sshResult := doSSHCheck(t, client, peer)
firstAuthID := findSSHCheckAuthID(t, headscale)
_, err := headscale.Execute(
[]string{
"headscale", "auth", "approve",
"--auth-id", firstAuthID,
},
)
require.NoError(t, err)
select {
case result := <-sshResult:
require.NoError(t, result.err, "first SSH should succeed after approval")
require.Contains(
t,
peer.ContainerID(),
strings.ReplaceAll(result.stdout, "\n", ""),
)
case <-time.After(30 * time.Second):
t.Fatal("first SSH did not complete after auth approval")
}
// === Phase 2: Wait for checkPeriod to expire ===
//nolint:forbidigo // Intentional sleep: waiting for the check period session
// to expire. This is a time-based expiry, not a pollable condition — the
// Tailscale client caches the approval for SessionDuration and only
// re-triggers the check flow after it elapses.
time.Sleep(70 * time.Second)
// === Phase 3: Second SSH — must re-authenticate ===
sshResult2 := doSSHCheck(t, client, peer)
secondAuthID := findNewSSHCheckAuthID(t, headscale, firstAuthID)
require.NotEqual(
t,
firstAuthID,
secondAuthID,
"second SSH should trigger a new auth flow after checkPeriod expiry",
)
_, err = headscale.Execute(
[]string{
"headscale", "auth", "approve",
"--auth-id", secondAuthID,
},
)
require.NoError(t, err)
select {
case result := <-sshResult2:
require.NoError(t, result.err, "second SSH should succeed after re-approval")
require.Contains(
t,
peer.ContainerID(),
strings.ReplaceAll(result.stdout, "\n", ""),
)
case <-time.After(30 * time.Second):
t.Fatal("second SSH did not complete after re-auth approval")
}
}
}
}

View File

@ -3122,7 +3122,7 @@ func TestTagsAuthKeyWithoutUserRejectsAdvertisedTags(t *testing.T) {
// TestTagsAuthKeyConvertToUserViaCLIRegister reproduces the panic from
// issue #3038: register a node with a tags-only preauthkey (no user), then
// convert it to a user-owned node via "headscale nodes register --user <user> --key ...".
// convert it to a user-owned node via "headscale auth register --auth-id <id> --user <user>".
// The crash happens in the mapper's generateUserProfiles when node.User is nil
// after the tag→user conversion in processReauthTags.
//

View File

@ -0,0 +1,20 @@
syntax = "proto3";
package headscale.v1;
option go_package = "github.com/juanfont/headscale/gen/go/v1";
import "headscale/v1/node.proto";
message AuthRegisterRequest {
string user = 1;
string auth_id = 2;
}
message AuthRegisterResponse {
Node node = 1;
}
message AuthApproveRequest {
string auth_id = 1;
}
message AuthApproveResponse {}

View File

@ -8,6 +8,7 @@ import "headscale/v1/user.proto";
import "headscale/v1/preauthkey.proto";
import "headscale/v1/node.proto";
import "headscale/v1/apikey.proto";
import "headscale/v1/auth.proto";
import "headscale/v1/policy.proto";
service HeadscaleService {
@ -139,6 +140,22 @@ service HeadscaleService {
// --- Node end ---
// --- Auth start ---
rpc AuthRegister(AuthRegisterRequest) returns (AuthRegisterResponse) {
option (google.api.http) = {
post : "/api/v1/auth/register"
body : "*"
};
}
rpc AuthApprove(AuthApproveRequest) returns (AuthApproveResponse) {
option (google.api.http) = {
post : "/api/v1/auth/approve"
body : "*"
};
}
// --- Auth end ---
// --- ApiKeys start ---
rpc CreateApiKey(CreateApiKeyRequest) returns (CreateApiKeyResponse) {
option (google.api.http) = {