diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 1dfd10ee..f836734d 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -253,6 +253,10 @@ jobs: - TestSSHIsBlockedInACL - TestSSHUserOnlyIsolation - TestSSHAutogroupSelf + - TestSSHOneUserToOneCheckModeCLI + - TestSSHOneUserToOneCheckModeOIDC + - TestSSHCheckModeUnapprovedTimeout + - TestSSHCheckModeCheckPeriodCLI - TestTagsAuthKeyWithTagRequestDifferentTag - TestTagsAuthKeyWithTagNoAdvertiseFlag - TestTagsAuthKeyWithTagCannotAddViaCLI diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b05f2566..f0242a4e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,7 +48,7 @@ repos: # golangci-lint for Go code quality - id: golangci-lint name: golangci-lint - entry: golangci-lint run --new-from-rev=HEAD~1 --timeout=5m --fix + entry: nix develop --command -- golangci-lint run --new-from-rev=HEAD~1 --timeout=5m --fix language: system types: [go] pass_filenames: false diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e01d43e..203e7292 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,19 @@ to understand how the packet filter should be generated. We discovered a few dif overall our implementation was very close. [#3036](https://github.com/juanfont/headscale/pull/3036) +### SSH check action + +SSH rules with `"action": "check"` are now supported. When a client initiates an SSH connection to a node +with a `check` action policy, the user is prompted to authenticate via OIDC or CLI approval before access +is granted. + +A new `headscale auth` CLI command group supports the approval flow: + +- `headscale auth approve --auth-id ` approves a pending authentication request (SSH check or web auth) +- `headscale auth register --auth-id --user ` registers a node (replaces deprecated `headscale nodes register`) + +[#1850](https://github.com/juanfont/headscale/pull/1850) + ### BREAKING - **ACL Policy**: Wildcard (`*`) in ACL sources and destinations now resolves to Tailscale's CGNAT range (`100.64.0.0/10`) and ULA range (`fd7a:115c:a1e0::/48`) instead of all IPs (`0.0.0.0/0` and `::/0`) [#3036](https://github.com/juanfont/headscale/pull/3036) @@ -26,6 +39,8 @@ overall our implementation was very close. - **ACL Policy**: The `proto:icmp` protocol name now only includes ICMPv4 (protocol 1), matching Tailscale behavior [#3036](https://github.com/juanfont/headscale/pull/3036) - Previously, `proto:icmp` included both ICMPv4 and ICMPv6 - Use `proto:ipv6-icmp` or protocol number `58` explicitly for ICMPv6 +- **CLI**: `headscale nodes register` is deprecated in favour of `headscale auth register --auth-id --user ` [#1850](https://github.com/juanfont/headscale/pull/1850) + - The old command continues to work but will be removed in a future release ### Changes @@ -35,6 +50,11 @@ overall our implementation was very close. - **ACL Policy**: Merge filter rules with identical SrcIPs and IPProto matching Tailscale behavior - multiple ACL rules with the same source now produce a single FilterRule with combined DstPorts [#3036](https://github.com/juanfont/headscale/pull/3036) - Remove deprecated `--namespace` flag from `nodes list`, `nodes register`, and `debug create-node` commands (use `--user` instead) [#3093](https://github.com/juanfont/headscale/pull/3093) - Remove deprecated `namespace`/`ns` command aliases for `users` and `machine`/`machines` aliases for `nodes` [#3093](https://github.com/juanfont/headscale/pull/3093) +- Add SSH `check` action support with OIDC and CLI-based approval flows [#1850](https://github.com/juanfont/headscale/pull/1850) +- Add `headscale auth register` and `headscale auth approve` CLI commands [#1850](https://github.com/juanfont/headscale/pull/1850) +- Deprecate `headscale nodes register --key` in favour of `headscale auth register --auth-id` [#1850](https://github.com/juanfont/headscale/pull/1850) +- Generalise auth templates into reusable `AuthSuccess` and `AuthWeb` components [#1850](https://github.com/juanfont/headscale/pull/1850) +- Unify auth pipeline with `AuthVerdict` type, supporting registration, reauthentication, and SSH checks [#1850](https://github.com/juanfont/headscale/pull/1850) ## 0.28.0 (2026-02-04) diff --git a/cmd/headscale/cli/auth.go b/cmd/headscale/cli/auth.go new file mode 100644 index 00000000..cc854805 --- /dev/null +++ b/cmd/headscale/cli/auth.go @@ -0,0 +1,70 @@ +package cli + +import ( + "context" + "fmt" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/spf13/cobra" +) + +func init() { + rootCmd.AddCommand(authCmd) + + authRegisterCmd.Flags().StringP("user", "u", "", "User") + authRegisterCmd.Flags().String("auth-id", "", "Auth ID") + mustMarkRequired(authRegisterCmd, "user", "auth-id") + authCmd.AddCommand(authRegisterCmd) + + authApproveCmd.Flags().String("auth-id", "", "Auth ID") + mustMarkRequired(authApproveCmd, "auth-id") + authCmd.AddCommand(authApproveCmd) +} + +var authCmd = &cobra.Command{ + Use: "auth", + Short: "Manage node authentication and approval", +} + +var authRegisterCmd = &cobra.Command{ + Use: "register", + Short: "Register a node to your network", + RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error { + user, _ := cmd.Flags().GetString("user") + authID, _ := cmd.Flags().GetString("auth-id") + + request := &v1.AuthRegisterRequest{ + AuthId: authID, + User: user, + } + + response, err := client.AuthRegister(ctx, request) + if err != nil { + return fmt.Errorf("registering node: %w", err) + } + + return printOutput( + cmd, + response.GetNode(), + fmt.Sprintf("Node %s registered", response.GetNode().GetGivenName())) + }), +} + +var authApproveCmd = &cobra.Command{ + Use: "approve", + Short: "Approve a pending authentication request", + RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error { + authID, _ := cmd.Flags().GetString("auth-id") + + request := &v1.AuthApproveRequest{ + AuthId: authID, + } + + response, err := client.AuthApprove(ctx, request) + if err != nil { + return fmt.Errorf("approving auth request: %w", err) + } + + return printOutput(cmd, response, "Auth request approved") + }), +} diff --git a/cmd/headscale/cli/debug.go b/cmd/headscale/cli/debug.go index fac317fc..9e4a67fd 100644 --- a/cmd/headscale/cli/debug.go +++ b/cmd/headscale/cli/debug.go @@ -37,7 +37,7 @@ var createNodeCmd = &cobra.Command{ name, _ := cmd.Flags().GetString("name") registrationID, _ := cmd.Flags().GetString("key") - _, err := types.RegistrationIDFromString(registrationID) + _, err := types.AuthIDFromString(registrationID) if err != nil { return fmt.Errorf("parsing machine key: %w", err) } diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index dbc7e8bf..fa71034f 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -63,8 +63,9 @@ var nodeCmd = &cobra.Command{ } var registerNodeCmd = &cobra.Command{ - Use: "register", - Short: "Registers a node to your network", + Use: "register", + Short: "Registers a node to your network", + Deprecated: "use 'headscale auth register --auth-id --user ' instead", RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error { user, _ := cmd.Flags().GetString("user") registrationID, _ := cmd.Flags().GetString("key") diff --git a/flake.nix b/flake.nix index 210b888e..ae02d0ff 100644 --- a/flake.nix +++ b/flake.nix @@ -27,7 +27,7 @@ let pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system}; buildGo = pkgs.buildGo126Module; - vendorHash = "sha256-9BvphYDAxzwooyVokI3l+q1wRuRsWn/qM+NpWUgqJH0="; + vendorHash = "sha256-oUN53ELb3+xn4yA7lEfXyT2c7NxbQC6RtbkGVq6+RLU="; in { headscale = buildGo { @@ -135,11 +135,6 @@ }; }; - # The package uses buildGo125Module, not the convention. - # goreleaser = prev.goreleaser.override { - # buildGoModule = buildGo; - # }; - gotestsum = prev.gotestsum.override { buildGoModule = buildGo; }; @@ -152,9 +147,9 @@ buildGoModule = buildGo; }; - # gopls = prev.gopls.override { - # buildGoModule = buildGo; - # }; + gopls = prev.gopls.override { + buildGoLatestModule = buildGo; + }; }; } // flake-utils.lib.eachDefaultSystem diff --git a/gen/go/headscale/v1/auth.pb.go b/gen/go/headscale/v1/auth.pb.go new file mode 100644 index 00000000..c4017b10 --- /dev/null +++ b/gen/go/headscale/v1/auth.pb.go @@ -0,0 +1,266 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc (unknown) +// source: headscale/v1/auth.proto + +package v1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type AuthRegisterRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + User string `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` + AuthId string `protobuf:"bytes,2,opt,name=auth_id,json=authId,proto3" json:"auth_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AuthRegisterRequest) Reset() { + *x = AuthRegisterRequest{} + mi := &file_headscale_v1_auth_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AuthRegisterRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AuthRegisterRequest) ProtoMessage() {} + +func (x *AuthRegisterRequest) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_auth_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AuthRegisterRequest.ProtoReflect.Descriptor instead. +func (*AuthRegisterRequest) Descriptor() ([]byte, []int) { + return file_headscale_v1_auth_proto_rawDescGZIP(), []int{0} +} + +func (x *AuthRegisterRequest) GetUser() string { + if x != nil { + return x.User + } + return "" +} + +func (x *AuthRegisterRequest) GetAuthId() string { + if x != nil { + return x.AuthId + } + return "" +} + +type AuthRegisterResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AuthRegisterResponse) Reset() { + *x = AuthRegisterResponse{} + mi := &file_headscale_v1_auth_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AuthRegisterResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AuthRegisterResponse) ProtoMessage() {} + +func (x *AuthRegisterResponse) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_auth_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AuthRegisterResponse.ProtoReflect.Descriptor instead. +func (*AuthRegisterResponse) Descriptor() ([]byte, []int) { + return file_headscale_v1_auth_proto_rawDescGZIP(), []int{1} +} + +func (x *AuthRegisterResponse) GetNode() *Node { + if x != nil { + return x.Node + } + return nil +} + +type AuthApproveRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + AuthId string `protobuf:"bytes,1,opt,name=auth_id,json=authId,proto3" json:"auth_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AuthApproveRequest) Reset() { + *x = AuthApproveRequest{} + mi := &file_headscale_v1_auth_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AuthApproveRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AuthApproveRequest) ProtoMessage() {} + +func (x *AuthApproveRequest) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_auth_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AuthApproveRequest.ProtoReflect.Descriptor instead. +func (*AuthApproveRequest) Descriptor() ([]byte, []int) { + return file_headscale_v1_auth_proto_rawDescGZIP(), []int{2} +} + +func (x *AuthApproveRequest) GetAuthId() string { + if x != nil { + return x.AuthId + } + return "" +} + +type AuthApproveResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AuthApproveResponse) Reset() { + *x = AuthApproveResponse{} + mi := &file_headscale_v1_auth_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AuthApproveResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AuthApproveResponse) ProtoMessage() {} + +func (x *AuthApproveResponse) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_auth_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AuthApproveResponse.ProtoReflect.Descriptor instead. +func (*AuthApproveResponse) Descriptor() ([]byte, []int) { + return file_headscale_v1_auth_proto_rawDescGZIP(), []int{3} +} + +var File_headscale_v1_auth_proto protoreflect.FileDescriptor + +const file_headscale_v1_auth_proto_rawDesc = "" + + "\n" + + "\x17headscale/v1/auth.proto\x12\fheadscale.v1\x1a\x17headscale/v1/node.proto\"B\n" + + "\x13AuthRegisterRequest\x12\x12\n" + + "\x04user\x18\x01 \x01(\tR\x04user\x12\x17\n" + + "\aauth_id\x18\x02 \x01(\tR\x06authId\">\n" + + "\x14AuthRegisterResponse\x12&\n" + + "\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\"-\n" + + "\x12AuthApproveRequest\x12\x17\n" + + "\aauth_id\x18\x01 \x01(\tR\x06authId\"\x15\n" + + "\x13AuthApproveResponseB)Z'github.com/juanfont/headscale/gen/go/v1b\x06proto3" + +var ( + file_headscale_v1_auth_proto_rawDescOnce sync.Once + file_headscale_v1_auth_proto_rawDescData []byte +) + +func file_headscale_v1_auth_proto_rawDescGZIP() []byte { + file_headscale_v1_auth_proto_rawDescOnce.Do(func() { + file_headscale_v1_auth_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_headscale_v1_auth_proto_rawDesc), len(file_headscale_v1_auth_proto_rawDesc))) + }) + return file_headscale_v1_auth_proto_rawDescData +} + +var file_headscale_v1_auth_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_headscale_v1_auth_proto_goTypes = []any{ + (*AuthRegisterRequest)(nil), // 0: headscale.v1.AuthRegisterRequest + (*AuthRegisterResponse)(nil), // 1: headscale.v1.AuthRegisterResponse + (*AuthApproveRequest)(nil), // 2: headscale.v1.AuthApproveRequest + (*AuthApproveResponse)(nil), // 3: headscale.v1.AuthApproveResponse + (*Node)(nil), // 4: headscale.v1.Node +} +var file_headscale_v1_auth_proto_depIdxs = []int32{ + 4, // 0: headscale.v1.AuthRegisterResponse.node:type_name -> headscale.v1.Node + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_headscale_v1_auth_proto_init() } +func file_headscale_v1_auth_proto_init() { + if File_headscale_v1_auth_proto != nil { + return + } + file_headscale_v1_node_proto_init() + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_headscale_v1_auth_proto_rawDesc), len(file_headscale_v1_auth_proto_rawDesc)), + NumEnums: 0, + NumMessages: 4, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_headscale_v1_auth_proto_goTypes, + DependencyIndexes: file_headscale_v1_auth_proto_depIdxs, + MessageInfos: file_headscale_v1_auth_proto_msgTypes, + }.Build() + File_headscale_v1_auth_proto = out.File + file_headscale_v1_auth_proto_goTypes = nil + file_headscale_v1_auth_proto_depIdxs = nil +} diff --git a/gen/go/headscale/v1/headscale.pb.go b/gen/go/headscale/v1/headscale.pb.go index 3d16778c..f52ca7e0 100644 --- a/gen/go/headscale/v1/headscale.pb.go +++ b/gen/go/headscale/v1/headscale.pb.go @@ -106,10 +106,10 @@ var File_headscale_v1_headscale_proto protoreflect.FileDescriptor const file_headscale_v1_headscale_proto_rawDesc = "" + "\n" + - "\x1cheadscale/v1/headscale.proto\x12\fheadscale.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17headscale/v1/user.proto\x1a\x1dheadscale/v1/preauthkey.proto\x1a\x17headscale/v1/node.proto\x1a\x19headscale/v1/apikey.proto\x1a\x19headscale/v1/policy.proto\"\x0f\n" + + "\x1cheadscale/v1/headscale.proto\x12\fheadscale.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17headscale/v1/user.proto\x1a\x1dheadscale/v1/preauthkey.proto\x1a\x17headscale/v1/node.proto\x1a\x19headscale/v1/apikey.proto\x1a\x17headscale/v1/auth.proto\x1a\x19headscale/v1/policy.proto\"\x0f\n" + "\rHealthRequest\"E\n" + "\x0eHealthResponse\x123\n" + - "\x15database_connectivity\x18\x01 \x01(\bR\x14databaseConnectivity2\x8c\x17\n" + + "\x15database_connectivity\x18\x01 \x01(\bR\x14databaseConnectivity2\xfa\x18\n" + "\x10HeadscaleService\x12h\n" + "\n" + "CreateUser\x12\x1f.headscale.v1.CreateUserRequest\x1a .headscale.v1.CreateUserResponse\"\x17\x82\xd3\xe4\x93\x02\x11:\x01*\"\f/api/v1/user\x12\x80\x01\n" + @@ -134,7 +134,9 @@ const file_headscale_v1_headscale_proto_rawDesc = "" + "\n" + "RenameNode\x12\x1f.headscale.v1.RenameNodeRequest\x1a .headscale.v1.RenameNodeResponse\"0\x82\xd3\xe4\x93\x02*\"(/api/v1/node/{node_id}/rename/{new_name}\x12b\n" + "\tListNodes\x12\x1e.headscale.v1.ListNodesRequest\x1a\x1f.headscale.v1.ListNodesResponse\"\x14\x82\xd3\xe4\x93\x02\x0e\x12\f/api/v1/node\x12\x80\x01\n" + - "\x0fBackfillNodeIPs\x12$.headscale.v1.BackfillNodeIPsRequest\x1a%.headscale.v1.BackfillNodeIPsResponse\" \x82\xd3\xe4\x93\x02\x1a\"\x18/api/v1/node/backfillips\x12p\n" + + "\x0fBackfillNodeIPs\x12$.headscale.v1.BackfillNodeIPsRequest\x1a%.headscale.v1.BackfillNodeIPsResponse\" \x82\xd3\xe4\x93\x02\x1a\"\x18/api/v1/node/backfillips\x12w\n" + + "\fAuthRegister\x12!.headscale.v1.AuthRegisterRequest\x1a\".headscale.v1.AuthRegisterResponse\" \x82\xd3\xe4\x93\x02\x1a:\x01*\"\x15/api/v1/auth/register\x12s\n" + + "\vAuthApprove\x12 .headscale.v1.AuthApproveRequest\x1a!.headscale.v1.AuthApproveResponse\"\x1f\x82\xd3\xe4\x93\x02\x19:\x01*\"\x14/api/v1/auth/approve\x12p\n" + "\fCreateApiKey\x12!.headscale.v1.CreateApiKeyRequest\x1a\".headscale.v1.CreateApiKeyResponse\"\x19\x82\xd3\xe4\x93\x02\x13:\x01*\"\x0e/api/v1/apikey\x12w\n" + "\fExpireApiKey\x12!.headscale.v1.ExpireApiKeyRequest\x1a\".headscale.v1.ExpireApiKeyResponse\" \x82\xd3\xe4\x93\x02\x1a:\x01*\"\x15/api/v1/apikey/expire\x12j\n" + "\vListApiKeys\x12 .headscale.v1.ListApiKeysRequest\x1a!.headscale.v1.ListApiKeysResponse\"\x16\x82\xd3\xe4\x93\x02\x10\x12\x0e/api/v1/apikey\x12v\n" + @@ -177,36 +179,40 @@ var file_headscale_v1_headscale_proto_goTypes = []any{ (*RenameNodeRequest)(nil), // 17: headscale.v1.RenameNodeRequest (*ListNodesRequest)(nil), // 18: headscale.v1.ListNodesRequest (*BackfillNodeIPsRequest)(nil), // 19: headscale.v1.BackfillNodeIPsRequest - (*CreateApiKeyRequest)(nil), // 20: headscale.v1.CreateApiKeyRequest - (*ExpireApiKeyRequest)(nil), // 21: headscale.v1.ExpireApiKeyRequest - (*ListApiKeysRequest)(nil), // 22: headscale.v1.ListApiKeysRequest - (*DeleteApiKeyRequest)(nil), // 23: headscale.v1.DeleteApiKeyRequest - (*GetPolicyRequest)(nil), // 24: headscale.v1.GetPolicyRequest - (*SetPolicyRequest)(nil), // 25: headscale.v1.SetPolicyRequest - (*CreateUserResponse)(nil), // 26: headscale.v1.CreateUserResponse - (*RenameUserResponse)(nil), // 27: headscale.v1.RenameUserResponse - (*DeleteUserResponse)(nil), // 28: headscale.v1.DeleteUserResponse - (*ListUsersResponse)(nil), // 29: headscale.v1.ListUsersResponse - (*CreatePreAuthKeyResponse)(nil), // 30: headscale.v1.CreatePreAuthKeyResponse - (*ExpirePreAuthKeyResponse)(nil), // 31: headscale.v1.ExpirePreAuthKeyResponse - (*DeletePreAuthKeyResponse)(nil), // 32: headscale.v1.DeletePreAuthKeyResponse - (*ListPreAuthKeysResponse)(nil), // 33: headscale.v1.ListPreAuthKeysResponse - (*DebugCreateNodeResponse)(nil), // 34: headscale.v1.DebugCreateNodeResponse - (*GetNodeResponse)(nil), // 35: headscale.v1.GetNodeResponse - (*SetTagsResponse)(nil), // 36: headscale.v1.SetTagsResponse - (*SetApprovedRoutesResponse)(nil), // 37: headscale.v1.SetApprovedRoutesResponse - (*RegisterNodeResponse)(nil), // 38: headscale.v1.RegisterNodeResponse - (*DeleteNodeResponse)(nil), // 39: headscale.v1.DeleteNodeResponse - (*ExpireNodeResponse)(nil), // 40: headscale.v1.ExpireNodeResponse - (*RenameNodeResponse)(nil), // 41: headscale.v1.RenameNodeResponse - (*ListNodesResponse)(nil), // 42: headscale.v1.ListNodesResponse - (*BackfillNodeIPsResponse)(nil), // 43: headscale.v1.BackfillNodeIPsResponse - (*CreateApiKeyResponse)(nil), // 44: headscale.v1.CreateApiKeyResponse - (*ExpireApiKeyResponse)(nil), // 45: headscale.v1.ExpireApiKeyResponse - (*ListApiKeysResponse)(nil), // 46: headscale.v1.ListApiKeysResponse - (*DeleteApiKeyResponse)(nil), // 47: headscale.v1.DeleteApiKeyResponse - (*GetPolicyResponse)(nil), // 48: headscale.v1.GetPolicyResponse - (*SetPolicyResponse)(nil), // 49: headscale.v1.SetPolicyResponse + (*AuthRegisterRequest)(nil), // 20: headscale.v1.AuthRegisterRequest + (*AuthApproveRequest)(nil), // 21: headscale.v1.AuthApproveRequest + (*CreateApiKeyRequest)(nil), // 22: headscale.v1.CreateApiKeyRequest + (*ExpireApiKeyRequest)(nil), // 23: headscale.v1.ExpireApiKeyRequest + (*ListApiKeysRequest)(nil), // 24: headscale.v1.ListApiKeysRequest + (*DeleteApiKeyRequest)(nil), // 25: headscale.v1.DeleteApiKeyRequest + (*GetPolicyRequest)(nil), // 26: headscale.v1.GetPolicyRequest + (*SetPolicyRequest)(nil), // 27: headscale.v1.SetPolicyRequest + (*CreateUserResponse)(nil), // 28: headscale.v1.CreateUserResponse + (*RenameUserResponse)(nil), // 29: headscale.v1.RenameUserResponse + (*DeleteUserResponse)(nil), // 30: headscale.v1.DeleteUserResponse + (*ListUsersResponse)(nil), // 31: headscale.v1.ListUsersResponse + (*CreatePreAuthKeyResponse)(nil), // 32: headscale.v1.CreatePreAuthKeyResponse + (*ExpirePreAuthKeyResponse)(nil), // 33: headscale.v1.ExpirePreAuthKeyResponse + (*DeletePreAuthKeyResponse)(nil), // 34: headscale.v1.DeletePreAuthKeyResponse + (*ListPreAuthKeysResponse)(nil), // 35: headscale.v1.ListPreAuthKeysResponse + (*DebugCreateNodeResponse)(nil), // 36: headscale.v1.DebugCreateNodeResponse + (*GetNodeResponse)(nil), // 37: headscale.v1.GetNodeResponse + (*SetTagsResponse)(nil), // 38: headscale.v1.SetTagsResponse + (*SetApprovedRoutesResponse)(nil), // 39: headscale.v1.SetApprovedRoutesResponse + (*RegisterNodeResponse)(nil), // 40: headscale.v1.RegisterNodeResponse + (*DeleteNodeResponse)(nil), // 41: headscale.v1.DeleteNodeResponse + (*ExpireNodeResponse)(nil), // 42: headscale.v1.ExpireNodeResponse + (*RenameNodeResponse)(nil), // 43: headscale.v1.RenameNodeResponse + (*ListNodesResponse)(nil), // 44: headscale.v1.ListNodesResponse + (*BackfillNodeIPsResponse)(nil), // 45: headscale.v1.BackfillNodeIPsResponse + (*AuthRegisterResponse)(nil), // 46: headscale.v1.AuthRegisterResponse + (*AuthApproveResponse)(nil), // 47: headscale.v1.AuthApproveResponse + (*CreateApiKeyResponse)(nil), // 48: headscale.v1.CreateApiKeyResponse + (*ExpireApiKeyResponse)(nil), // 49: headscale.v1.ExpireApiKeyResponse + (*ListApiKeysResponse)(nil), // 50: headscale.v1.ListApiKeysResponse + (*DeleteApiKeyResponse)(nil), // 51: headscale.v1.DeleteApiKeyResponse + (*GetPolicyResponse)(nil), // 52: headscale.v1.GetPolicyResponse + (*SetPolicyResponse)(nil), // 53: headscale.v1.SetPolicyResponse } var file_headscale_v1_headscale_proto_depIdxs = []int32{ 2, // 0: headscale.v1.HeadscaleService.CreateUser:input_type -> headscale.v1.CreateUserRequest @@ -227,40 +233,44 @@ var file_headscale_v1_headscale_proto_depIdxs = []int32{ 17, // 15: headscale.v1.HeadscaleService.RenameNode:input_type -> headscale.v1.RenameNodeRequest 18, // 16: headscale.v1.HeadscaleService.ListNodes:input_type -> headscale.v1.ListNodesRequest 19, // 17: headscale.v1.HeadscaleService.BackfillNodeIPs:input_type -> headscale.v1.BackfillNodeIPsRequest - 20, // 18: headscale.v1.HeadscaleService.CreateApiKey:input_type -> headscale.v1.CreateApiKeyRequest - 21, // 19: headscale.v1.HeadscaleService.ExpireApiKey:input_type -> headscale.v1.ExpireApiKeyRequest - 22, // 20: headscale.v1.HeadscaleService.ListApiKeys:input_type -> headscale.v1.ListApiKeysRequest - 23, // 21: headscale.v1.HeadscaleService.DeleteApiKey:input_type -> headscale.v1.DeleteApiKeyRequest - 24, // 22: headscale.v1.HeadscaleService.GetPolicy:input_type -> headscale.v1.GetPolicyRequest - 25, // 23: headscale.v1.HeadscaleService.SetPolicy:input_type -> headscale.v1.SetPolicyRequest - 0, // 24: headscale.v1.HeadscaleService.Health:input_type -> headscale.v1.HealthRequest - 26, // 25: headscale.v1.HeadscaleService.CreateUser:output_type -> headscale.v1.CreateUserResponse - 27, // 26: headscale.v1.HeadscaleService.RenameUser:output_type -> headscale.v1.RenameUserResponse - 28, // 27: headscale.v1.HeadscaleService.DeleteUser:output_type -> headscale.v1.DeleteUserResponse - 29, // 28: headscale.v1.HeadscaleService.ListUsers:output_type -> headscale.v1.ListUsersResponse - 30, // 29: headscale.v1.HeadscaleService.CreatePreAuthKey:output_type -> headscale.v1.CreatePreAuthKeyResponse - 31, // 30: headscale.v1.HeadscaleService.ExpirePreAuthKey:output_type -> headscale.v1.ExpirePreAuthKeyResponse - 32, // 31: headscale.v1.HeadscaleService.DeletePreAuthKey:output_type -> headscale.v1.DeletePreAuthKeyResponse - 33, // 32: headscale.v1.HeadscaleService.ListPreAuthKeys:output_type -> headscale.v1.ListPreAuthKeysResponse - 34, // 33: headscale.v1.HeadscaleService.DebugCreateNode:output_type -> headscale.v1.DebugCreateNodeResponse - 35, // 34: headscale.v1.HeadscaleService.GetNode:output_type -> headscale.v1.GetNodeResponse - 36, // 35: headscale.v1.HeadscaleService.SetTags:output_type -> headscale.v1.SetTagsResponse - 37, // 36: headscale.v1.HeadscaleService.SetApprovedRoutes:output_type -> headscale.v1.SetApprovedRoutesResponse - 38, // 37: headscale.v1.HeadscaleService.RegisterNode:output_type -> headscale.v1.RegisterNodeResponse - 39, // 38: headscale.v1.HeadscaleService.DeleteNode:output_type -> headscale.v1.DeleteNodeResponse - 40, // 39: headscale.v1.HeadscaleService.ExpireNode:output_type -> headscale.v1.ExpireNodeResponse - 41, // 40: headscale.v1.HeadscaleService.RenameNode:output_type -> headscale.v1.RenameNodeResponse - 42, // 41: headscale.v1.HeadscaleService.ListNodes:output_type -> headscale.v1.ListNodesResponse - 43, // 42: headscale.v1.HeadscaleService.BackfillNodeIPs:output_type -> headscale.v1.BackfillNodeIPsResponse - 44, // 43: headscale.v1.HeadscaleService.CreateApiKey:output_type -> headscale.v1.CreateApiKeyResponse - 45, // 44: headscale.v1.HeadscaleService.ExpireApiKey:output_type -> headscale.v1.ExpireApiKeyResponse - 46, // 45: headscale.v1.HeadscaleService.ListApiKeys:output_type -> headscale.v1.ListApiKeysResponse - 47, // 46: headscale.v1.HeadscaleService.DeleteApiKey:output_type -> headscale.v1.DeleteApiKeyResponse - 48, // 47: headscale.v1.HeadscaleService.GetPolicy:output_type -> headscale.v1.GetPolicyResponse - 49, // 48: headscale.v1.HeadscaleService.SetPolicy:output_type -> headscale.v1.SetPolicyResponse - 1, // 49: headscale.v1.HeadscaleService.Health:output_type -> headscale.v1.HealthResponse - 25, // [25:50] is the sub-list for method output_type - 0, // [0:25] is the sub-list for method input_type + 20, // 18: headscale.v1.HeadscaleService.AuthRegister:input_type -> headscale.v1.AuthRegisterRequest + 21, // 19: headscale.v1.HeadscaleService.AuthApprove:input_type -> headscale.v1.AuthApproveRequest + 22, // 20: headscale.v1.HeadscaleService.CreateApiKey:input_type -> headscale.v1.CreateApiKeyRequest + 23, // 21: headscale.v1.HeadscaleService.ExpireApiKey:input_type -> headscale.v1.ExpireApiKeyRequest + 24, // 22: headscale.v1.HeadscaleService.ListApiKeys:input_type -> headscale.v1.ListApiKeysRequest + 25, // 23: headscale.v1.HeadscaleService.DeleteApiKey:input_type -> headscale.v1.DeleteApiKeyRequest + 26, // 24: headscale.v1.HeadscaleService.GetPolicy:input_type -> headscale.v1.GetPolicyRequest + 27, // 25: headscale.v1.HeadscaleService.SetPolicy:input_type -> headscale.v1.SetPolicyRequest + 0, // 26: headscale.v1.HeadscaleService.Health:input_type -> headscale.v1.HealthRequest + 28, // 27: headscale.v1.HeadscaleService.CreateUser:output_type -> headscale.v1.CreateUserResponse + 29, // 28: headscale.v1.HeadscaleService.RenameUser:output_type -> headscale.v1.RenameUserResponse + 30, // 29: headscale.v1.HeadscaleService.DeleteUser:output_type -> headscale.v1.DeleteUserResponse + 31, // 30: headscale.v1.HeadscaleService.ListUsers:output_type -> headscale.v1.ListUsersResponse + 32, // 31: headscale.v1.HeadscaleService.CreatePreAuthKey:output_type -> headscale.v1.CreatePreAuthKeyResponse + 33, // 32: headscale.v1.HeadscaleService.ExpirePreAuthKey:output_type -> headscale.v1.ExpirePreAuthKeyResponse + 34, // 33: headscale.v1.HeadscaleService.DeletePreAuthKey:output_type -> headscale.v1.DeletePreAuthKeyResponse + 35, // 34: headscale.v1.HeadscaleService.ListPreAuthKeys:output_type -> headscale.v1.ListPreAuthKeysResponse + 36, // 35: headscale.v1.HeadscaleService.DebugCreateNode:output_type -> headscale.v1.DebugCreateNodeResponse + 37, // 36: headscale.v1.HeadscaleService.GetNode:output_type -> headscale.v1.GetNodeResponse + 38, // 37: headscale.v1.HeadscaleService.SetTags:output_type -> headscale.v1.SetTagsResponse + 39, // 38: headscale.v1.HeadscaleService.SetApprovedRoutes:output_type -> headscale.v1.SetApprovedRoutesResponse + 40, // 39: headscale.v1.HeadscaleService.RegisterNode:output_type -> headscale.v1.RegisterNodeResponse + 41, // 40: headscale.v1.HeadscaleService.DeleteNode:output_type -> headscale.v1.DeleteNodeResponse + 42, // 41: headscale.v1.HeadscaleService.ExpireNode:output_type -> headscale.v1.ExpireNodeResponse + 43, // 42: headscale.v1.HeadscaleService.RenameNode:output_type -> headscale.v1.RenameNodeResponse + 44, // 43: headscale.v1.HeadscaleService.ListNodes:output_type -> headscale.v1.ListNodesResponse + 45, // 44: headscale.v1.HeadscaleService.BackfillNodeIPs:output_type -> headscale.v1.BackfillNodeIPsResponse + 46, // 45: headscale.v1.HeadscaleService.AuthRegister:output_type -> headscale.v1.AuthRegisterResponse + 47, // 46: headscale.v1.HeadscaleService.AuthApprove:output_type -> headscale.v1.AuthApproveResponse + 48, // 47: headscale.v1.HeadscaleService.CreateApiKey:output_type -> headscale.v1.CreateApiKeyResponse + 49, // 48: headscale.v1.HeadscaleService.ExpireApiKey:output_type -> headscale.v1.ExpireApiKeyResponse + 50, // 49: headscale.v1.HeadscaleService.ListApiKeys:output_type -> headscale.v1.ListApiKeysResponse + 51, // 50: headscale.v1.HeadscaleService.DeleteApiKey:output_type -> headscale.v1.DeleteApiKeyResponse + 52, // 51: headscale.v1.HeadscaleService.GetPolicy:output_type -> headscale.v1.GetPolicyResponse + 53, // 52: headscale.v1.HeadscaleService.SetPolicy:output_type -> headscale.v1.SetPolicyResponse + 1, // 53: headscale.v1.HeadscaleService.Health:output_type -> headscale.v1.HealthResponse + 27, // [27:54] is the sub-list for method output_type + 0, // [0:27] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name @@ -275,6 +285,7 @@ func file_headscale_v1_headscale_proto_init() { file_headscale_v1_preauthkey_proto_init() file_headscale_v1_node_proto_init() file_headscale_v1_apikey_proto_init() + file_headscale_v1_auth_proto_init() file_headscale_v1_policy_proto_init() type x struct{} out := protoimpl.TypeBuilder{ diff --git a/gen/go/headscale/v1/headscale.pb.gw.go b/gen/go/headscale/v1/headscale.pb.gw.go index ab851614..1f769ed9 100644 --- a/gen/go/headscale/v1/headscale.pb.gw.go +++ b/gen/go/headscale/v1/headscale.pb.gw.go @@ -709,6 +709,60 @@ func local_request_HeadscaleService_BackfillNodeIPs_0(ctx context.Context, marsh return msg, metadata, err } +func request_HeadscaleService_AuthRegister_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq AuthRegisterRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + msg, err := client.AuthRegister(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_HeadscaleService_AuthRegister_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq AuthRegisterRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + msg, err := server.AuthRegister(ctx, &protoReq) + return msg, metadata, err +} + +func request_HeadscaleService_AuthApprove_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq AuthApproveRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + msg, err := client.AuthApprove(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_HeadscaleService_AuthApprove_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq AuthApproveRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + msg, err := server.AuthApprove(ctx, &protoReq) + return msg, metadata, err +} + func request_HeadscaleService_CreateApiKey_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { var ( protoReq CreateApiKeyRequest @@ -1272,6 +1326,46 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser } forward_HeadscaleService_BackfillNodeIPs_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) + mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthRegister_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthRegister", runtime.WithHTTPPathPattern("/api/v1/auth/register")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_HeadscaleService_AuthRegister_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_AuthRegister_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthApprove_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthApprove", runtime.WithHTTPPathPattern("/api/v1/auth/approve")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_HeadscaleService_AuthApprove_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_AuthApprove_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) mux.Handle(http.MethodPost, pattern_HeadscaleService_CreateApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -1758,6 +1852,40 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser } forward_HeadscaleService_BackfillNodeIPs_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) + mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthRegister_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthRegister", runtime.WithHTTPPathPattern("/api/v1/auth/register")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_HeadscaleService_AuthRegister_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_AuthRegister_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodPost, pattern_HeadscaleService_AuthApprove_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/AuthApprove", runtime.WithHTTPPathPattern("/api/v1/auth/approve")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_HeadscaleService_AuthApprove_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_AuthApprove_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) mux.Handle(http.MethodPost, pattern_HeadscaleService_CreateApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -1899,6 +2027,8 @@ var ( pattern_HeadscaleService_RenameNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4, 1, 0, 4, 1, 5, 5}, []string{"api", "v1", "node", "node_id", "rename", "new_name"}, "")) pattern_HeadscaleService_ListNodes_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "node"}, "")) pattern_HeadscaleService_BackfillNodeIPs_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "node", "backfillips"}, "")) + pattern_HeadscaleService_AuthRegister_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "register"}, "")) + pattern_HeadscaleService_AuthApprove_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "approve"}, "")) pattern_HeadscaleService_CreateApiKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "apikey"}, "")) pattern_HeadscaleService_ExpireApiKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "apikey", "expire"}, "")) pattern_HeadscaleService_ListApiKeys_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "apikey"}, "")) @@ -1927,6 +2057,8 @@ var ( forward_HeadscaleService_RenameNode_0 = runtime.ForwardResponseMessage forward_HeadscaleService_ListNodes_0 = runtime.ForwardResponseMessage forward_HeadscaleService_BackfillNodeIPs_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_AuthRegister_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_AuthApprove_0 = runtime.ForwardResponseMessage forward_HeadscaleService_CreateApiKey_0 = runtime.ForwardResponseMessage forward_HeadscaleService_ExpireApiKey_0 = runtime.ForwardResponseMessage forward_HeadscaleService_ListApiKeys_0 = runtime.ForwardResponseMessage diff --git a/gen/go/headscale/v1/headscale_grpc.pb.go b/gen/go/headscale/v1/headscale_grpc.pb.go index a3963935..e763d9af 100644 --- a/gen/go/headscale/v1/headscale_grpc.pb.go +++ b/gen/go/headscale/v1/headscale_grpc.pb.go @@ -37,6 +37,8 @@ const ( HeadscaleService_RenameNode_FullMethodName = "/headscale.v1.HeadscaleService/RenameNode" HeadscaleService_ListNodes_FullMethodName = "/headscale.v1.HeadscaleService/ListNodes" HeadscaleService_BackfillNodeIPs_FullMethodName = "/headscale.v1.HeadscaleService/BackfillNodeIPs" + HeadscaleService_AuthRegister_FullMethodName = "/headscale.v1.HeadscaleService/AuthRegister" + HeadscaleService_AuthApprove_FullMethodName = "/headscale.v1.HeadscaleService/AuthApprove" HeadscaleService_CreateApiKey_FullMethodName = "/headscale.v1.HeadscaleService/CreateApiKey" HeadscaleService_ExpireApiKey_FullMethodName = "/headscale.v1.HeadscaleService/ExpireApiKey" HeadscaleService_ListApiKeys_FullMethodName = "/headscale.v1.HeadscaleService/ListApiKeys" @@ -71,6 +73,9 @@ type HeadscaleServiceClient interface { RenameNode(ctx context.Context, in *RenameNodeRequest, opts ...grpc.CallOption) (*RenameNodeResponse, error) ListNodes(ctx context.Context, in *ListNodesRequest, opts ...grpc.CallOption) (*ListNodesResponse, error) BackfillNodeIPs(ctx context.Context, in *BackfillNodeIPsRequest, opts ...grpc.CallOption) (*BackfillNodeIPsResponse, error) + // --- Auth start --- + AuthRegister(ctx context.Context, in *AuthRegisterRequest, opts ...grpc.CallOption) (*AuthRegisterResponse, error) + AuthApprove(ctx context.Context, in *AuthApproveRequest, opts ...grpc.CallOption) (*AuthApproveResponse, error) // --- ApiKeys start --- CreateApiKey(ctx context.Context, in *CreateApiKeyRequest, opts ...grpc.CallOption) (*CreateApiKeyResponse, error) ExpireApiKey(ctx context.Context, in *ExpireApiKeyRequest, opts ...grpc.CallOption) (*ExpireApiKeyResponse, error) @@ -271,6 +276,26 @@ func (c *headscaleServiceClient) BackfillNodeIPs(ctx context.Context, in *Backfi return out, nil } +func (c *headscaleServiceClient) AuthRegister(ctx context.Context, in *AuthRegisterRequest, opts ...grpc.CallOption) (*AuthRegisterResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(AuthRegisterResponse) + err := c.cc.Invoke(ctx, HeadscaleService_AuthRegister_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *headscaleServiceClient) AuthApprove(ctx context.Context, in *AuthApproveRequest, opts ...grpc.CallOption) (*AuthApproveResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(AuthApproveResponse) + err := c.cc.Invoke(ctx, HeadscaleService_AuthApprove_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *headscaleServiceClient) CreateApiKey(ctx context.Context, in *CreateApiKeyRequest, opts ...grpc.CallOption) (*CreateApiKeyResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(CreateApiKeyResponse) @@ -366,6 +391,9 @@ type HeadscaleServiceServer interface { RenameNode(context.Context, *RenameNodeRequest) (*RenameNodeResponse, error) ListNodes(context.Context, *ListNodesRequest) (*ListNodesResponse, error) BackfillNodeIPs(context.Context, *BackfillNodeIPsRequest) (*BackfillNodeIPsResponse, error) + // --- Auth start --- + AuthRegister(context.Context, *AuthRegisterRequest) (*AuthRegisterResponse, error) + AuthApprove(context.Context, *AuthApproveRequest) (*AuthApproveResponse, error) // --- ApiKeys start --- CreateApiKey(context.Context, *CreateApiKeyRequest) (*CreateApiKeyResponse, error) ExpireApiKey(context.Context, *ExpireApiKeyRequest) (*ExpireApiKeyResponse, error) @@ -440,6 +468,12 @@ func (UnimplementedHeadscaleServiceServer) ListNodes(context.Context, *ListNodes func (UnimplementedHeadscaleServiceServer) BackfillNodeIPs(context.Context, *BackfillNodeIPsRequest) (*BackfillNodeIPsResponse, error) { return nil, status.Error(codes.Unimplemented, "method BackfillNodeIPs not implemented") } +func (UnimplementedHeadscaleServiceServer) AuthRegister(context.Context, *AuthRegisterRequest) (*AuthRegisterResponse, error) { + return nil, status.Error(codes.Unimplemented, "method AuthRegister not implemented") +} +func (UnimplementedHeadscaleServiceServer) AuthApprove(context.Context, *AuthApproveRequest) (*AuthApproveResponse, error) { + return nil, status.Error(codes.Unimplemented, "method AuthApprove not implemented") +} func (UnimplementedHeadscaleServiceServer) CreateApiKey(context.Context, *CreateApiKeyRequest) (*CreateApiKeyResponse, error) { return nil, status.Error(codes.Unimplemented, "method CreateApiKey not implemented") } @@ -806,6 +840,42 @@ func _HeadscaleService_BackfillNodeIPs_Handler(srv interface{}, ctx context.Cont return interceptor(ctx, in, info, handler) } +func _HeadscaleService_AuthRegister_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(AuthRegisterRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(HeadscaleServiceServer).AuthRegister(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: HeadscaleService_AuthRegister_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(HeadscaleServiceServer).AuthRegister(ctx, req.(*AuthRegisterRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _HeadscaleService_AuthApprove_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(AuthApproveRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(HeadscaleServiceServer).AuthApprove(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: HeadscaleService_AuthApprove_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(HeadscaleServiceServer).AuthApprove(ctx, req.(*AuthApproveRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _HeadscaleService_CreateApiKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(CreateApiKeyRequest) if err := dec(in); err != nil { @@ -1011,6 +1081,14 @@ var HeadscaleService_ServiceDesc = grpc.ServiceDesc{ MethodName: "BackfillNodeIPs", Handler: _HeadscaleService_BackfillNodeIPs_Handler, }, + { + MethodName: "AuthRegister", + Handler: _HeadscaleService_AuthRegister_Handler, + }, + { + MethodName: "AuthApprove", + Handler: _HeadscaleService_AuthApprove_Handler, + }, { MethodName: "CreateApiKey", Handler: _HeadscaleService_CreateApiKey_Handler, diff --git a/gen/openapiv2/headscale/v1/auth.swagger.json b/gen/openapiv2/headscale/v1/auth.swagger.json new file mode 100644 index 00000000..2e99e1a7 --- /dev/null +++ b/gen/openapiv2/headscale/v1/auth.swagger.json @@ -0,0 +1,44 @@ +{ + "swagger": "2.0", + "info": { + "title": "headscale/v1/auth.proto", + "version": "version not set" + }, + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "paths": {}, + "definitions": { + "protobufAny": { + "type": "object", + "properties": { + "@type": { + "type": "string" + } + }, + "additionalProperties": {} + }, + "rpcStatus": { + "type": "object", + "properties": { + "code": { + "type": "integer", + "format": "int32" + }, + "message": { + "type": "string" + }, + "details": { + "type": "array", + "items": { + "type": "object", + "$ref": "#/definitions/protobufAny" + } + } + } + } + } +} diff --git a/gen/openapiv2/headscale/v1/headscale.swagger.json b/gen/openapiv2/headscale/v1/headscale.swagger.json index 1db1db94..533bd73d 100644 --- a/gen/openapiv2/headscale/v1/headscale.swagger.json +++ b/gen/openapiv2/headscale/v1/headscale.swagger.json @@ -138,6 +138,71 @@ ] } }, + "/api/v1/auth/approve": { + "post": { + "operationId": "HeadscaleService_AuthApprove", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/v1AuthApproveResponse" + } + }, + "default": { + "description": "An unexpected error response.", + "schema": { + "$ref": "#/definitions/rpcStatus" + } + } + }, + "parameters": [ + { + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/v1AuthApproveRequest" + } + } + ], + "tags": [ + "HeadscaleService" + ] + } + }, + "/api/v1/auth/register": { + "post": { + "summary": "--- Auth start ---", + "operationId": "HeadscaleService_AuthRegister", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/v1AuthRegisterResponse" + } + }, + "default": { + "description": "An unexpected error response.", + "schema": { + "$ref": "#/definitions/rpcStatus" + } + } + }, + "parameters": [ + { + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/v1AuthRegisterRequest" + } + } + ], + "tags": [ + "HeadscaleService" + ] + } + }, "/api/v1/debug/node": { "post": { "summary": "--- Node start ---", @@ -888,6 +953,36 @@ } } }, + "v1AuthApproveRequest": { + "type": "object", + "properties": { + "authId": { + "type": "string" + } + } + }, + "v1AuthApproveResponse": { + "type": "object" + }, + "v1AuthRegisterRequest": { + "type": "object", + "properties": { + "user": { + "type": "string" + }, + "authId": { + "type": "string" + } + } + }, + "v1AuthRegisterResponse": { + "type": "object", + "properties": { + "node": { + "$ref": "#/definitions/v1Node" + } + } + }, "v1BackfillNodeIPsResponse": { "type": "object", "properties": { diff --git a/go.mod b/go.mod index c99d4ddd..3adc7e48 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,8 @@ require ( github.com/docker/docker v28.5.2+incompatible github.com/fsnotify/fsnotify v1.9.0 github.com/glebarez/sqlite v1.11.0 + github.com/go-chi/chi/v5 v5.2.5 + github.com/go-chi/metrics v0.1.1 github.com/go-gormigrate/gormigrate/v2 v2.1.5 github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e github.com/gofrs/uuid/v5 v5.4.0 diff --git a/go.sum b/go.sum index e9c39e36..4c5f48ac 100644 --- a/go.sum +++ b/go.sum @@ -181,6 +181,10 @@ github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc= github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= +github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= +github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= +github.com/go-chi/metrics v0.1.1 h1:CXhbnkAVVjb0k73EBRQ6Z2YdWFnbXZgNtg1Mboguibk= +github.com/go-chi/metrics v0.1.1/go.mod h1:mcGTM1pPalP7WCtb+akNYFO/lwNwBBLCuedepqjoPn4= github.com/go-gormigrate/gormigrate/v2 v2.1.5 h1:1OyorA5LtdQw12cyJDEHuTrEV3GiXiIhS4/QTTa/SM8= github.com/go-gormigrate/gormigrate/v2 v2.1.5/go.mod h1:mj9ekk/7CPF3VjopaFvWKN2v7fN3D9d3eEOAXRhi/+M= github.com/go-jose/go-jose/v3 v3.0.4 h1:Wp5HA7bLQcKnf6YYao/4kpRpVMp/yf6+pJKV8WFSaNY= diff --git a/hscontrol/app.go b/hscontrol/app.go index abd29a45..87b37510 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -20,7 +20,9 @@ import ( "github.com/cenkalti/backoff/v5" "github.com/davecgh/go-spew/spew" - "github.com/gorilla/mux" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/metrics" grpcRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/juanfont/headscale" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -457,50 +459,58 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error { return os.Remove(h.cfg.UnixSocket) } -func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { - router := mux.NewRouter() - router.Use(prometheusMiddleware) +func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux { + r := chi.NewRouter() + r.Use(metrics.Collector(metrics.CollectorOpts{ + Host: false, + Proto: true, + Skip: func(r *http.Request) bool { + return r.Method != http.MethodOptions + }, + })) + r.Use(middleware.RequestID) + r.Use(middleware.RealIP) + r.Use(middleware.Logger) + r.Use(middleware.Recoverer) - router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler). - Methods(http.MethodPost, http.MethodGet) + r.Post(ts2021UpgradePath, h.NoiseUpgradeHandler) - router.HandleFunc("/robots.txt", h.RobotsHandler).Methods(http.MethodGet) - router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet) - router.HandleFunc("/version", h.VersionHandler).Methods(http.MethodGet) - router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet) - router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler). - Methods(http.MethodGet) + r.Get("/robots.txt", h.RobotsHandler) + r.Get("/health", h.HealthHandler) + r.Get("/version", h.VersionHandler) + r.Get("/key", h.KeyHandler) + r.Get("/register/{auth_id}", h.authProvider.RegisterHandler) + r.Get("/auth/{auth_id}", h.authProvider.AuthHandler) if provider, ok := h.authProvider.(*AuthProviderOIDC); ok { - router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet) + r.Get("/oidc/callback", provider.OIDCCallbackHandler) } - router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet) - router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig). - Methods(http.MethodGet) - router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet) + r.Get("/apple", h.AppleConfigMessage) + r.Get("/apple/{platform}", h.ApplePlatformConfig) + r.Get("/windows", h.WindowsConfigMessage) // TODO(kristoffer): move swagger into a package - router.HandleFunc("/swagger", headscale.SwaggerUI).Methods(http.MethodGet) - router.HandleFunc("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1). - Methods(http.MethodGet) + r.Get("/swagger", headscale.SwaggerUI) + r.Get("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1) - router.HandleFunc("/verify", h.VerifyHandler).Methods(http.MethodPost) + r.Post("/verify", h.VerifyHandler) if h.cfg.DERP.ServerEnabled { - router.HandleFunc("/derp", h.DERPServer.DERPHandler) - router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler) - router.HandleFunc("/derp/latency-check", derpServer.DERPProbeHandler) - router.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.state.DERPMap())) + r.HandleFunc("/derp", h.DERPServer.DERPHandler) + r.HandleFunc("/derp/probe", derpServer.DERPProbeHandler) + r.HandleFunc("/derp/latency-check", derpServer.DERPProbeHandler) + r.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.state.DERPMap())) } - apiRouter := router.PathPrefix("/api").Subrouter() - apiRouter.Use(h.httpAuthenticationMiddleware) - apiRouter.PathPrefix("/v1/").HandlerFunc(grpcMux.ServeHTTP) - router.HandleFunc("/favicon.ico", FaviconHandler) - router.PathPrefix("/").HandlerFunc(BlankHandler) + r.Route("/api", func(r chi.Router) { + r.Use(h.httpAuthenticationMiddleware) + r.HandleFunc("/v1/*", grpcMux.ServeHTTP) + }) + r.Get("/favicon.ico", FaviconHandler) + r.Get("/", BlankHandler) - return router + return r } // Serve launches the HTTP and gRPC server service Headscale and the API. diff --git a/hscontrol/auth.go b/hscontrol/auth.go index fdc63461..ee301242 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -20,7 +20,9 @@ import ( type AuthProvider interface { RegisterHandler(w http.ResponseWriter, r *http.Request) - AuthURL(regID types.RegistrationID) string + AuthHandler(w http.ResponseWriter, r *http.Request) + RegisterURL(authID types.AuthID) string + AuthURL(authID types.AuthID) string } func (h *Headscale) handleRegister( @@ -261,22 +263,24 @@ func (h *Headscale) waitForFollowup( return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err) } - followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", "")) + followupReg, err := types.AuthIDFromString(strings.ReplaceAll(fu.Path, "/register/", "")) if err != nil { return nil, NewHTTPError(http.StatusUnauthorized, "invalid registration ID", err) } - if reg, ok := h.state.GetRegistrationCacheEntry(followupReg); ok { + if reg, ok := h.state.GetAuthCacheEntry(followupReg); ok { select { case <-ctx.Done(): return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err) - case node := <-reg.Registered: - if node == nil { - // registration is expired in the cache, instruct the client to try a new registration - return h.reqToNewRegisterResponse(req, machineKey) - } + case verdict := <-reg.WaitForAuth(): + if verdict.Accept() { + if !verdict.Node.Valid() { + // registration is expired in the cache, instruct the client to try a new registration + return h.reqToNewRegisterResponse(req, machineKey) + } - return nodeToRegisterResponse(node.View()), nil + return nodeToRegisterResponse(verdict.Node), nil + } } } @@ -291,14 +295,14 @@ func (h *Headscale) reqToNewRegisterResponse( req tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - newRegID, err := types.NewRegistrationID() + newAuthID, err := types.NewAuthID() if err != nil { return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err) } // Ensure we have a valid hostname hostname := util.EnsureHostname( - req.Hostinfo, + req.Hostinfo.View(), machineKey.String(), req.NodeKey.String(), ) @@ -307,25 +311,25 @@ func (h *Headscale) reqToNewRegisterResponse( hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{}) hostinfo.Hostname = hostname - nodeToRegister := types.NewRegisterNode( - types.Node{ - Hostname: hostname, - MachineKey: machineKey, - NodeKey: req.NodeKey, - Hostinfo: hostinfo, - LastSeen: new(time.Now()), - }, - ) - - if !req.Expiry.IsZero() { - nodeToRegister.Node.Expiry = &req.Expiry + nodeToRegister := types.Node{ + Hostname: hostname, + MachineKey: machineKey, + NodeKey: req.NodeKey, + Hostinfo: hostinfo, + LastSeen: new(time.Now()), } - log.Info().Msgf("new followup node registration using key: %s", newRegID) - h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister) + if !req.Expiry.IsZero() { + nodeToRegister.Expiry = &req.Expiry + } + + authRegReq := types.NewRegisterAuthRequest(nodeToRegister) + + log.Info().Msgf("new followup node registration using key: %s", newAuthID) + h.state.SetAuthCacheEntry(newAuthID, authRegReq) return &tailcfg.RegisterResponse{ - AuthURL: h.authProvider.AuthURL(newRegID), + AuthURL: h.authProvider.RegisterURL(newAuthID), }, nil } @@ -376,13 +380,6 @@ func (h *Headscale) handleRegisterWithAuthKey( // Send both changes. Empty changes are ignored by Change(). h.Change(changed, routesChange) - // TODO(kradalby): I think this is covered above, but we need to validate that. - // // If policy changed due to node registration, send a separate policy change - // if policyChanged { - // policyChange := change.PolicyChange() - // h.Change(policyChange) - // } - resp := &tailcfg.RegisterResponse{ MachineAuthorized: true, NodeKeyExpired: node.IsExpired(), @@ -404,14 +401,14 @@ func (h *Headscale) handleRegisterInteractive( req tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - registrationId, err := types.NewRegistrationID() + authID, err := types.NewAuthID() if err != nil { return nil, fmt.Errorf("generating registration ID: %w", err) } // Ensure we have a valid hostname hostname := util.EnsureHostname( - req.Hostinfo, + req.Hostinfo.View(), machineKey.String(), req.NodeKey.String(), ) @@ -434,28 +431,28 @@ func (h *Headscale) handleRegisterInteractive( hostinfo.Hostname = hostname - nodeToRegister := types.NewRegisterNode( - types.Node{ - Hostname: hostname, - MachineKey: machineKey, - NodeKey: req.NodeKey, - Hostinfo: hostinfo, - LastSeen: new(time.Now()), - }, - ) - - if !req.Expiry.IsZero() { - nodeToRegister.Node.Expiry = &req.Expiry + nodeToRegister := types.Node{ + Hostname: hostname, + MachineKey: machineKey, + NodeKey: req.NodeKey, + Hostinfo: hostinfo, + LastSeen: new(time.Now()), } - h.state.SetRegistrationCacheEntry( - registrationId, - nodeToRegister, + if !req.Expiry.IsZero() { + nodeToRegister.Expiry = &req.Expiry + } + + authRegReq := types.NewRegisterAuthRequest(nodeToRegister) + + h.state.SetAuthCacheEntry( + authID, + authRegReq, ) - log.Info().Msgf("starting node registration using key: %s", registrationId) + log.Info().Msgf("starting node registration using key: %s", authID) return &tailcfg.RegisterResponse{ - AuthURL: h.authProvider.AuthURL(registrationId), + AuthURL: h.authProvider.RegisterURL(authID), }, nil } diff --git a/hscontrol/auth_tags_test.go b/hscontrol/auth_tags_test.go index e7b74b75..7016af31 100644 --- a/hscontrol/auth_tags_test.go +++ b/hscontrol/auth_tags_test.go @@ -651,8 +651,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { // Step 1: Create user-owned node WITH expiry set clientExpiry := time.Now().Add(24 * time.Hour) - registrationID1 := types.MustRegistrationID() - regEntry1 := types.NewRegisterNode(types.Node{ + registrationID1 := types.MustAuthID() + regEntry1 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey1.Public(), Hostname: "personal-to-tagged", @@ -662,7 +662,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { }, Expiry: &clientExpiry, }) - app.state.SetRegistrationCacheEntry(registrationID1, regEntry1) + app.state.SetAuthCacheEntry(registrationID1, regEntry1) node, _, err := app.state.HandleNodeFromAuthPath( registrationID1, types.UserID(user.ID), nil, "webauth", @@ -673,8 +673,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { // Step 2: Re-auth with tags (Personal → Tagged conversion) nodeKey2 := key.NewNode() - registrationID2 := types.MustRegistrationID() - regEntry2 := types.NewRegisterNode(types.Node{ + registrationID2 := types.MustAuthID() + regEntry2 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey2.Public(), Hostname: "personal-to-tagged", @@ -684,7 +684,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { }, Expiry: &clientExpiry, // Client still sends expiry }) - app.state.SetRegistrationCacheEntry(registrationID2, regEntry2) + app.state.SetAuthCacheEntry(registrationID2, regEntry2) nodeAfter, _, err := app.state.HandleNodeFromAuthPath( registrationID2, types.UserID(user.ID), nil, "webauth", @@ -723,8 +723,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { nodeKey1 := key.NewNode() // Step 1: Create tagged node (expiry should be nil) - registrationID1 := types.MustRegistrationID() - regEntry1 := types.NewRegisterNode(types.Node{ + registrationID1 := types.MustAuthID() + regEntry1 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey1.Public(), Hostname: "tagged-to-personal", @@ -733,7 +733,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { RequestTags: []string{"tag:server"}, // Tagged node }, }) - app.state.SetRegistrationCacheEntry(registrationID1, regEntry1) + app.state.SetAuthCacheEntry(registrationID1, regEntry1) node, _, err := app.state.HandleNodeFromAuthPath( registrationID1, types.UserID(user.ID), nil, "webauth", @@ -745,8 +745,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { // Step 2: Re-auth with empty tags (Tagged → Personal conversion) nodeKey2 := key.NewNode() clientExpiry := time.Now().Add(48 * time.Hour) - registrationID2 := types.MustRegistrationID() - regEntry2 := types.NewRegisterNode(types.Node{ + registrationID2 := types.MustAuthID() + regEntry2 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey2.Public(), Hostname: "tagged-to-personal", @@ -756,7 +756,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { }, Expiry: &clientExpiry, // Client requests expiry }) - app.state.SetRegistrationCacheEntry(registrationID2, regEntry2) + app.state.SetAuthCacheEntry(registrationID2, regEntry2) nodeAfter, _, err := app.state.HandleNodeFromAuthPath( registrationID2, types.UserID(user.ID), nil, "webauth", diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index d28ed565..4c70cda4 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -676,28 +676,23 @@ func TestAuthenticationFlows(t *testing.T) { { name: "followup_registration_success", setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper - regID, err := types.NewRegistrationID() + regID, err := types.NewAuthID() if err != nil { return "", err } - registered := make(chan *types.Node, 1) - nodeToRegister := types.RegisterNode{ - Node: types.Node{ - Hostname: "followup-success-node", - }, - Registered: registered, - } - app.state.SetRegistrationCacheEntry(regID, nodeToRegister) + nodeToRegister := types.NewRegisterAuthRequest(types.Node{ + Hostname: "followup-success-node", + }) + app.state.SetAuthCacheEntry(regID, nodeToRegister) - // Simulate successful registration - send to buffered channel - // The channel is buffered (size 1), so this can complete immediately - // and handleRegister will receive the value when it starts waiting + // Simulate successful registration + // handleRegister will receive the value when it starts waiting go func() { user := app.state.CreateUserForTest("followup-user") node := app.state.CreateNodeForTest(user, "followup-success-node") - registered <- node + nodeToRegister.FinishAuth(types.AuthVerdict{Node: node.View()}) }() return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil @@ -723,20 +718,16 @@ func TestAuthenticationFlows(t *testing.T) { { name: "followup_registration_timeout", setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper - regID, err := types.NewRegistrationID() + regID, err := types.NewAuthID() if err != nil { return "", err } - registered := make(chan *types.Node, 1) - nodeToRegister := types.RegisterNode{ - Node: types.Node{ - Hostname: "followup-timeout-node", - }, - Registered: registered, - } - app.state.SetRegistrationCacheEntry(regID, nodeToRegister) - // Don't send anything on channel - will timeout + nodeToRegister := types.NewRegisterAuthRequest(types.Node{ + Hostname: "followup-timeout-node", + }) + app.state.SetAuthCacheEntry(regID, nodeToRegister) + // Don't call FinishRegistration - will timeout return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil }, @@ -1345,24 +1336,19 @@ func TestAuthenticationFlows(t *testing.T) { { name: "followup_registration_node_nil_response", setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper - regID, err := types.NewRegistrationID() + regID, err := types.NewAuthID() if err != nil { return "", err } - registered := make(chan *types.Node, 1) - nodeToRegister := types.RegisterNode{ - Node: types.Node{ - Hostname: "nil-response-node", - }, - Registered: registered, - } - app.state.SetRegistrationCacheEntry(regID, nodeToRegister) + nodeToRegister := types.NewRegisterAuthRequest(types.Node{ + Hostname: "nil-response-node", + }) + app.state.SetAuthCacheEntry(regID, nodeToRegister) - // Simulate registration that returns nil (cache expired during auth) - // The channel is buffered (size 1), so this can complete immediately + // Simulate registration that returns empty NodeView (cache expired during auth) go func() { - registered <- nil // Nil indicates cache expiry + nodeToRegister.FinishAuth(types.AuthVerdict{Node: types.NodeView{}}) // Empty view indicates cache expiry }() return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil @@ -1815,7 +1801,7 @@ func TestAuthenticationFlows(t *testing.T) { setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper // Generate a registration ID that doesn't exist in cache // This simulates an expired/missing cache entry - regID, err := types.NewRegistrationID() + regID, err := types.NewAuthID() if err != nil { return "", err } @@ -1847,11 +1833,11 @@ func TestAuthenticationFlows(t *testing.T) { // Extract and validate the new registration ID exists in cache newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/") - newRegID, err := types.RegistrationIDFromString(newRegIDStr) + newRegID, err := types.AuthIDFromString(newRegIDStr) assert.NoError(t, err, "should be able to parse new registration ID") //nolint:testifylint // inside closure // Verify new registration entry exists in cache - _, found := app.state.GetRegistrationCacheEntry(newRegID) + _, found := app.state.GetAuthCacheEntry(newRegID) assert.True(t, found, "new registration should exist in cache") }, }, @@ -2300,7 +2286,7 @@ func TestAuthenticationFlows(t *testing.T) { require.NoError(t, err) // Verify cache entry exists - cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID) + cacheEntry, found := app.state.GetAuthCacheEntry(registrationID) assert.True(t, found, "registration cache entry should exist initially") assert.NotNil(t, cacheEntry) @@ -2315,7 +2301,7 @@ func TestAuthenticationFlows(t *testing.T) { assert.Error(t, err, "should fail with invalid user ID") //nolint:testifylint // inside closure, uses assert pattern // Cache entry should still exist after auth error (for retry scenarios) - _, stillFound := app.state.GetRegistrationCacheEntry(registrationID) + _, stillFound := app.state.GetAuthCacheEntry(registrationID) assert.True(t, stillFound, "registration cache entry should still exist after auth error for potential retry") }, }, @@ -2375,8 +2361,8 @@ func TestAuthenticationFlows(t *testing.T) { assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs") // Both cache entries should exist simultaneously - _, found1 := app.state.GetRegistrationCacheEntry(regID1) - _, found2 := app.state.GetRegistrationCacheEntry(regID2) + _, found1 := app.state.GetAuthCacheEntry(regID1) + _, found2 := app.state.GetAuthCacheEntry(regID2) assert.True(t, found1, "first registration cache entry should exist") assert.True(t, found2, "second registration cache entry should exist") @@ -2427,8 +2413,8 @@ func TestAuthenticationFlows(t *testing.T) { require.NoError(t, err) // Verify both exist - _, found1 := app.state.GetRegistrationCacheEntry(regID1) - _, found2 := app.state.GetRegistrationCacheEntry(regID2) + _, found1 := app.state.GetAuthCacheEntry(regID1) + _, found2 := app.state.GetAuthCacheEntry(regID2) assert.True(t, found1, "first cache entry should exist") assert.True(t, found2, "second cache entry should exist") @@ -2490,7 +2476,7 @@ func TestAuthenticationFlows(t *testing.T) { } // First registration should still be in cache (not completed) - _, stillFound := app.state.GetRegistrationCacheEntry(regID1) + _, stillFound := app.state.GetAuthCacheEntry(regID1) assert.True(t, stillFound, "first registration should still be pending") }, }, @@ -2601,7 +2587,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { var ( initialResp *tailcfg.RegisterResponse authURL string - registrationID types.RegistrationID + registrationID types.AuthID finalResp *tailcfg.RegisterResponse err error ) @@ -2629,10 +2615,10 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { if step.expectCacheEntry { // Verify registration cache entry was created - cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID) + cacheEntry, found := app.state.GetAuthCacheEntry(registrationID) require.True(t, found, "registration cache entry should exist") require.NotNil(t, cacheEntry, "cache entry should not be nil") - require.Equal(t, req.NodeKey, cacheEntry.Node.NodeKey, "cache entry should have correct node key") + require.Equal(t, req.NodeKey, cacheEntry.Node().NodeKey(), "cache entry should have correct node key") } case stepTypeAuthCompletion: @@ -2692,7 +2678,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { // Check cache cleanup expectation for this step if step.expectCacheEntry == false && registrationID != "" { // Verify cache entry was cleaned up - _, found := app.state.GetRegistrationCacheEntry(registrationID) + _, found := app.state.GetAuthCacheEntry(registrationID) require.False(t, found, "registration cache entry should be cleaned up after step: %s", step.stepType) } } @@ -2714,7 +2700,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { } // extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL. -func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) { +func extractRegistrationIDFromAuthURL(authURL string) (types.AuthID, error) { // AuthURL format: "http://localhost/register/abc123" const registerPrefix = "/register/" @@ -2725,7 +2711,7 @@ func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, err idStr := authURL[idx+len(registerPrefix):] - return types.RegistrationIDFromString(idStr) + return types.AuthIDFromString(idStr) } // validateCompleteRegistrationResponse performs comprehensive validation of a registration response. @@ -2962,7 +2948,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Scenario: // 1. Node registers with user1 via pre-auth key // 2. Node logs out (expires) -// 3. Admin runs: headscale nodes register --user user2 --key +// 3. Admin runs: headscale auth register --auth-id --user user2 // // Expected behavior: // - User1's original node should STILL EXIST (expired) @@ -3041,7 +3027,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { require.NotEmpty(t, regID, "Should have valid registration ID") // Step 4: Admin completes authentication via CLI - // This simulates: headscale nodes register --user user2 --key + // This simulates: headscale auth register --auth-id --user user2 node, _, err := app.state.HandleNodeFromAuthPath( regID, types.UserID(user2.ID), // Register to user2, not user1! @@ -3583,8 +3569,8 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) { nodeKey := key.NewNode() // Simulate a registration cache entry (as would be created during web auth) - registrationID := types.MustRegistrationID() - regEntry := types.NewRegisterNode(types.Node{ + registrationID := types.MustAuthID() + regEntry := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "webauth-tags-node", @@ -3593,7 +3579,7 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) { RequestTags: []string{"tag:unauthorized"}, // This tag is not in policy }, }) - app.state.SetRegistrationCacheEntry(registrationID, regEntry) + app.state.SetAuthCacheEntry(registrationID, regEntry) // Complete the web auth - should fail because tag is unauthorized _, _, err := app.state.HandleNodeFromAuthPath( @@ -3646,8 +3632,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { nodeKey1 := key.NewNode() // Step 1: Initial registration with tags - registrationID1 := types.MustRegistrationID() - regEntry1 := types.NewRegisterNode(types.Node{ + registrationID1 := types.MustAuthID() + regEntry1 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), NodeKey: nodeKey1.Public(), Hostname: "reauth-untag-node", @@ -3656,7 +3642,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { RequestTags: []string{"tag:valid-owned", "tag:second"}, }, }) - app.state.SetRegistrationCacheEntry(registrationID1, regEntry1) + app.state.SetAuthCacheEntry(registrationID1, regEntry1) // Complete initial registration with tags node, _, err := app.state.HandleNodeFromAuthPath( @@ -3673,8 +3659,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { // Step 2: Reauth with EMPTY tags to untag nodeKey2 := key.NewNode() // New node key for reauth - registrationID2 := types.MustRegistrationID() - regEntry2 := types.NewRegisterNode(types.Node{ + registrationID2 := types.MustAuthID() + regEntry2 := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), // Same machine key NodeKey: nodeKey2.Public(), // Different node key (rotation) Hostname: "reauth-untag-node", @@ -3683,7 +3669,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { RequestTags: []string{}, // EMPTY - should untag }, }) - app.state.SetRegistrationCacheEntry(registrationID2, regEntry2) + app.state.SetAuthCacheEntry(registrationID2, regEntry2) // Complete reauth with empty tags nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath( @@ -3759,8 +3745,8 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) { // Step 2: Reauth via web auth with EMPTY tags to transition to user-owned nodeKey2 := key.NewNode() // New node key for reauth - registrationID := types.MustRegistrationID() - regEntry := types.NewRegisterNode(types.Node{ + registrationID := types.MustAuthID() + regEntry := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), // Same machine key NodeKey: nodeKey2.Public(), // Different node key (rotation) Hostname: "authkey-tagged-node", @@ -3769,7 +3755,7 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) { RequestTags: []string{}, // EMPTY - should untag }, }) - app.state.SetRegistrationCacheEntry(registrationID, regEntry) + app.state.SetAuthCacheEntry(registrationID, regEntry) // Complete reauth with empty tags nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath( @@ -3956,10 +3942,10 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) { require.NotNil(t, alice, "Alice user should be created") // Step 4: Re-register the node to alice via HandleNodeFromAuthPath - // This is what happens when running: headscale nodes register --user alice --key ... + // This is what happens when running: headscale auth register --auth-id --user alice nodeKey2 := key.NewNode() - registrationID := types.MustRegistrationID() - regEntry := types.NewRegisterNode(types.Node{ + registrationID := types.MustAuthID() + regEntry := types.NewRegisterAuthRequest(types.Node{ MachineKey: machineKey.Public(), // Same machine key as the tagged node NodeKey: nodeKey2.Public(), Hostname: "tagged-orphan-node", @@ -3968,7 +3954,7 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) { RequestTags: []string{}, // Empty - transition to user-owned }, }) - app.state.SetRegistrationCacheEntry(registrationID, regEntry) + app.state.SetAuthCacheEntry(registrationID, regEntry) // This should NOT panic - before the fix, this would panic with: // panic: runtime error: invalid memory address or nil pointer dereference diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 6841f446..69f71e36 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -47,7 +47,7 @@ const ( type HSDatabase struct { DB *gorm.DB cfg *types.Config - regCache *zcache.Cache[types.RegistrationID, types.RegisterNode] + regCache *zcache.Cache[types.AuthID, types.AuthRequest] } // NewHeadscaleDatabase creates a new database connection and runs migrations. @@ -56,7 +56,7 @@ type HSDatabase struct { //nolint:gocyclo // complex database initialization with many migrations func NewHeadscaleDatabase( cfg *types.Config, - regCache *zcache.Cache[types.RegistrationID, types.RegisterNode], + regCache *zcache.Cache[types.AuthID, types.AuthRequest], ) (*HSDatabase, error) { dbConn, err := openDB(cfg.Database) if err != nil { diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 3c687b39..151d9966 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -162,8 +162,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) { } } -func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] { - return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour) +func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] { + return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour) } func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error { diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 073c6677..4c953454 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -247,7 +247,7 @@ func (api headscaleV1APIServer) RegisterNode( Str(zf.RegistrationKey, registrationKey). Msg("registering node") - registrationId, err := types.RegistrationIDFromString(request.GetKey()) + registrationId, err := types.AuthIDFromString(request.GetKey()) if err != nil { return nil, err } @@ -780,33 +780,32 @@ func (api headscaleV1APIServer) DebugCreateNode( Hostname: request.GetName(), } - registrationId, err := types.RegistrationIDFromString(request.GetKey()) + registrationId, err := types.AuthIDFromString(request.GetKey()) if err != nil { return nil, err } - newNode := types.NewRegisterNode( - types.Node{ - NodeKey: key.NewNode().Public(), - MachineKey: key.NewMachine().Public(), - Hostname: request.GetName(), - User: user, + newNode := types.Node{ + NodeKey: key.NewNode().Public(), + MachineKey: key.NewMachine().Public(), + Hostname: request.GetName(), + User: user, - Expiry: &time.Time{}, - LastSeen: &time.Time{}, + Expiry: &time.Time{}, + LastSeen: &time.Time{}, - Hostinfo: &hostinfo, - }, - ) + Hostinfo: &hostinfo, + } log.Debug(). Caller(). Str("registration_id", registrationId.String()). Msg("adding debug machine via CLI, appending to registration cache") - api.h.state.SetRegistrationCacheEntry(registrationId, newNode) + authRegReq := types.NewRegisterAuthRequest(newNode) + api.h.state.SetAuthCacheEntry(registrationId, authRegReq) - return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil + return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil } func (api headscaleV1APIServer) Health( @@ -829,4 +828,38 @@ func (api headscaleV1APIServer) Health( return response, healthErr } +func (api headscaleV1APIServer) AuthRegister( + ctx context.Context, + request *v1.AuthRegisterRequest, +) (*v1.AuthRegisterResponse, error) { + resp, err := api.RegisterNode(ctx, &v1.RegisterNodeRequest{ + Key: request.GetAuthId(), + User: request.GetUser(), + }) + if err != nil { + return nil, err + } + + return &v1.AuthRegisterResponse{Node: resp.GetNode()}, nil +} + +func (api headscaleV1APIServer) AuthApprove( + ctx context.Context, + request *v1.AuthApproveRequest, +) (*v1.AuthApproveResponse, error) { + authID, err := types.AuthIDFromString(request.GetAuthId()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid auth_id: %v", err) + } + + authReq, ok := api.h.state.GetAuthCacheEntry(authID) + if !ok { + return nil, status.Errorf(codes.NotFound, "no pending auth session for auth_id %s", authID) + } + + authReq.FinishAuth(types.AuthVerdict{}) + + return &v1.AuthApproveResponse{}, nil +} + func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {} diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 7c45f1ec..57469ce0 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -11,7 +11,6 @@ import ( "strings" "time" - "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/assets" "github.com/juanfont/headscale/hscontrol/templates" "github.com/juanfont/headscale/hscontrol/types" @@ -245,11 +244,58 @@ func NewAuthProviderWeb(serverURL string) *AuthProviderWeb { } } -func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string { +func (a *AuthProviderWeb) RegisterURL(authID types.AuthID) string { return fmt.Sprintf( "%s/register/%s", strings.TrimSuffix(a.serverURL, "/"), - registrationId.String()) + authID.String()) +} + +func (a *AuthProviderWeb) AuthURL(authID types.AuthID) string { + return fmt.Sprintf( + "%s/auth/%s", + strings.TrimSuffix(a.serverURL, "/"), + authID.String()) +} + +func (a *AuthProviderWeb) AuthHandler( + writer http.ResponseWriter, + req *http.Request, +) { + authID, err := authIDFromRequest(req) + if err != nil { + httpError(writer, err) + return + } + + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + writer.WriteHeader(http.StatusOK) + + _, err = writer.Write([]byte(templates.AuthWeb( + "Authentication check", + "Run the command below in the headscale server to approve this authentication request:", + "headscale auth approve --auth-id "+authID.String(), + ).Render())) + if err != nil { + log.Error().Err(err).Msg("failed to write auth response") + } +} + +func authIDFromRequest(req *http.Request) (types.AuthID, error) { + raw, err := urlParam[string](req, "auth_id") + if err != nil { + return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err)) + } + + // We need to make sure we dont open for XSS style injections, if the parameter that + // is passed as a key is not parsable/validated as a NodePublic key, then fail to render + // the template and log an error. + registrationId, err := types.AuthIDFromString(raw) + if err != nil { + return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err)) + } + + return registrationId, nil } // RegisterHandler shows a simple message in the browser to point to the CLI @@ -261,22 +307,20 @@ func (a *AuthProviderWeb) RegisterHandler( writer http.ResponseWriter, req *http.Request, ) { - vars := mux.Vars(req) - registrationIdStr := vars["registration_id"] - - // We need to make sure we dont open for XSS style injections, if the parameter that - // is passed as a key is not parsable/validated as a NodePublic key, then fail to render - // the template and log an error. - registrationId, err := types.RegistrationIDFromString(registrationIdStr) + registrationId, err := authIDFromRequest(req) if err != nil { - httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err)) + httpError(writer, err) return } writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - _, err = writer.Write([]byte(templates.RegisterWeb(registrationId).Render())) + _, err = writer.Write([]byte(templates.AuthWeb( + "Node registration", + "Run the command below in the headscale server to add this node to your network:", + fmt.Sprintf("headscale auth register --auth-id %s --user USERNAME", registrationId.String()), + ).Render())) if err != nil { log.Error().Err(err).Msg("failed to write register response") } diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 9e544633..6f3fbccb 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -95,8 +95,8 @@ var allBatcherFunctions = []batcherTestCase{ } // emptyCache creates an empty registration cache for testing. -func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] { - return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour) +func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] { + return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour) } // Test configuration constants. diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 2880f33a..c232d5d2 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -7,8 +7,11 @@ import ( "fmt" "io" "net/http" + "net/url" - "github.com/gorilla/mux" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/metrics" "github.com/juanfont/headscale/hscontrol/capver" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" @@ -22,6 +25,12 @@ import ( // ErrUnsupportedClientVersion is returned when a client connects with an unsupported protocol version. var ErrUnsupportedClientVersion = errors.New("unsupported client version") +// ErrMissingURLParameter is returned when a required URL parameter is not provided. +var ErrMissingURLParameter = errors.New("missing URL parameter") + +// ErrUnsupportedURLParameterType is returned when a URL parameter has an unsupported type. +var ErrUnsupportedURLParameterType = errors.New("unsupported URL parameter type") + const ( // ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade. ts2021UpgradePath = "/ts2021" @@ -69,7 +78,7 @@ func (h *Headscale) NoiseUpgradeHandler( return } - noiseServer := noiseServer{ + ns := noiseServer{ headscale: h, challenge: key.NewChallenge(), } @@ -79,42 +88,89 @@ func (h *Headscale) NoiseUpgradeHandler( writer, req, *h.noisePrivateKey, - noiseServer.earlyNoise, + ns.earlyNoise, ) if err != nil { httpError(writer, fmt.Errorf("upgrading noise connection: %w", err)) return } - noiseServer.conn = noiseConn - noiseServer.machineKey = noiseServer.conn.Peer() - noiseServer.protocolVersion = noiseServer.conn.ProtocolVersion() + ns.conn = noiseConn + ns.machineKey = ns.conn.Peer() + ns.protocolVersion = ns.conn.ProtocolVersion() // This router is served only over the Noise connection, and exposes only the new API. // // The HTTP2 server that exposes this router is created for // a single hijacked connection from /ts2021, using netutil.NewOneConnListener - router := mux.NewRouter() - router.Use(prometheusMiddleware) - router.HandleFunc("/machine/register", noiseServer.NoiseRegistrationHandler). - Methods(http.MethodPost) + r := chi.NewRouter() + r.Use(metrics.Collector(metrics.CollectorOpts{ + Host: false, + Proto: true, + Skip: func(r *http.Request) bool { + return r.Method != http.MethodOptions + }, + })) + r.Use(middleware.RequestID) + r.Use(middleware.RealIP) + r.Use(middleware.Logger) + r.Use(middleware.Recoverer) - // Endpoints outside of the register endpoint must use getAndValidateNode to - // get the node to ensure that the MachineKey matches the Node setting up the - // connection. - router.HandleFunc("/machine/map", noiseServer.NoisePollNetMapHandler) + r.Handle("/metrics", metrics.Handler()) - noiseServer.httpBaseConfig = &http.Server{ - Handler: router, + r.Route("/machine", func(r chi.Router) { + r.Post("/register", ns.RegistrationHandler) + r.Post("/map", ns.PollNetMapHandler) + + // SSH Check mode endpoint, consulted to validate if a given SSH connection should be accepted or rejected. + r.Get("/ssh/action/from/{src_node_id}/to/{dst_node_id}", ns.SSHActionHandler) + + // Not implemented yet + // + // /whoami is a debug endpoint to validate that the client can communicate over the connection, + // not clear if there is a specific response, it looks like it is just logged. + // https://github.com/tailscale/tailscale/blob/dfba01ca9bd8c4df02c3c32f400d9aeb897c5fc7/cmd/tailscale/cli/debug.go#L1138 + r.Get("/whoami", ns.NotImplementedHandler) + + // client sends a [tailcfg.SetDNSRequest] to this endpoints and expect + // the server to create or update this DNS record "somewhere". + // It is typically a TXT record for an ACME challenge. + r.Post("/set-dns", ns.NotImplementedHandler) + + // A patch of [tailcfg.SetDeviceAttributesRequest] to update device attributes. + // We currently do not support device attributes. + r.Patch("/set-device-attr", ns.NotImplementedHandler) + + // A [tailcfg.AuditLogRequest] to send audit log entries to the server. + // The server is expected to store them "somewhere". + // We currently do not support device attributes. + r.Post("/audit-log", ns.NotImplementedHandler) + + // handles requests to get an OIDC ID token. Receives a [tailcfg.TokenRequest]. + r.Post("/id-token", ns.NotImplementedHandler) + + // Asks the server if a feature is available and receive information about how to enable it. + // Gets a [tailcfg.QueryFeatureRequest] and returns a [tailcfg.QueryFeatureResponse]. + r.Post("/feature/query", ns.NotImplementedHandler) + + r.Post("/update-health", ns.NotImplementedHandler) + + r.Route("/webclient", func(r chi.Router) {}) + + r.Post("/c2n", ns.NotImplementedHandler) + }) + + ns.httpBaseConfig = &http.Server{ + Handler: r, ReadHeaderTimeout: types.HTTPTimeout, } - noiseServer.http2Server = &http2.Server{} + ns.http2Server = &http2.Server{} - noiseServer.http2Server.ServeConn( + ns.http2Server.ServeConn( noiseConn, &http2.ServeConnOpts{ - BaseConfig: noiseServer.httpBaseConfig, + BaseConfig: ns.httpBaseConfig, }, ) } @@ -189,7 +245,143 @@ func rejectUnsupported( return false } -// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol +func (ns *noiseServer) NotImplementedHandler(writer http.ResponseWriter, req *http.Request) { + d, _ := io.ReadAll(req.Body) + log.Trace().Caller().Str("path", req.URL.String()).Bytes("body", d).Msgf("not implemented handler hit") + http.Error(writer, "Not implemented yet", http.StatusNotImplemented) +} + +func urlParam[T any](req *http.Request, key string) (T, error) { + var zero T + + param := chi.URLParam(req, key) + if param == "" { + return zero, fmt.Errorf("%w: %s", ErrMissingURLParameter, key) + } + + var value T + switch any(value).(type) { + case string: + v, ok := any(param).(T) + if !ok { + return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value) + } + + value = v + default: + return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value) + } + + return value, nil +} + +// SSHActionHandler handles the /ssh-action endpoint, it returns a [tailcfg.SSHActionHandler] +// to the client with the verdict of an SSH access request. +func (ns *noiseServer) SSHActionHandler(writer http.ResponseWriter, req *http.Request) { + srcNodeID, _ := urlParam[types.NodeID](req, "src_node_id") + dstNodeID, _ := urlParam[types.NodeID](req, "dst_node_id") + + sshUser := req.URL.Query().Get("ssh_user") + localUser := req.URL.Query().Get("local_user") + + // Set if this is a follow up request. + authIDStr := req.URL.Query().Get("auth_id") + + log.Trace().Caller(). + Str("path", req.URL.String()). + Uint64("src_node_id", srcNodeID.Uint64()). + Uint64("dst_node_id", dstNodeID.Uint64()). + Str("ssh_user", sshUser). + Str("local_user", localUser). + Str("auth_id", authIDStr). + Msg("got SSH action request") + + var action tailcfg.SSHAction + + action.AllowAgentForwarding = true + action.AllowLocalPortForwarding = true + action.AllowRemotePortForwarding = true + + if authIDStr == "" { + holdURL, err := url.Parse(ns.headscale.cfg.ServerURL + "/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER") + if err != nil { + log.Error().Caller().Err(err).Msg("failed to parse SSH action URL") + http.Error(writer, "Internal error", http.StatusInternalServerError) + + return + } + + authID, err := types.NewAuthID() + if err != nil { + log.Error().Caller().Err(err).Msg("failed to generate auth ID for SSH action") + http.Error(writer, "Internal error", http.StatusInternalServerError) + + return + } + + ns.headscale.state.SetAuthCacheEntry(authID, types.NewAuthRequest()) + + authURL := ns.headscale.authProvider.AuthURL(authID) + + q := holdURL.Query() + q.Set("auth_id", authID.String()) + holdURL.RawQuery = q.Encode() + + action.HoldAndDelegate = holdURL.String() + // TODO(kradalby): here we can also send a very tiny mapresponse + // "popping" the url and opening it for the user. + action.Message = fmt.Sprintf(`# Headscale SSH requires an additional check. +# To authenticate, visit: %s +# Authentication checked with Headscale SSH. +`, authURL) + } else { + authID, err := types.AuthIDFromString(authIDStr) + if err != nil { + log.Error().Caller().Err(err).Str("auth_id", authIDStr).Msg("invalid auth_id in SSH action request") + http.Error(writer, "Invalid auth_id", http.StatusBadRequest) + + return + } + + log.Trace().Caller().Str("auth_id", authID.String()).Msg("SSH action follow-up request with auth_id") + + auth, ok := ns.headscale.state.GetAuthCacheEntry(authID) + if !ok { + log.Error().Caller().Str("auth_id", authID.String()).Msg("no auth session found for auth_id in SSH action request") + http.Error(writer, "Invalid auth_id", http.StatusBadRequest) + + return + } + + verdict := <-auth.WaitForAuth() + + if verdict.Accept() { + action.Reject = false + action.Accept = true + } else { + action.Reject = true + action.Accept = false + + log.Trace().Caller().Str("auth_id", authID.String()).Err(verdict.Err).Msg("SSH action authentication rejected") + } + } + + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + + err := json.NewEncoder(writer).Encode(action) + if err != nil { + log.Error().Caller().Err(err).Msg("failed to encode SSH action response") + return + } + + // Ensure response is flushed to client + if flusher, ok := writer.(http.Flusher); ok { + flusher.Flush() + } +} + +// PollNetMapHandler takes care of /machine/:id/map using the Noise protocol // // This is the busiest endpoint, as it keeps the HTTP long poll that updates // the clients when something in the network changes. @@ -198,7 +390,7 @@ func rejectUnsupported( // only after their first request (marked with the ReadOnly field). // // At this moment the updates are sent in a quite horrendous way, but they kinda work. -func (ns *noiseServer) NoisePollNetMapHandler( +func (ns *noiseServer) PollNetMapHandler( writer http.ResponseWriter, req *http.Request, ) { @@ -237,8 +429,8 @@ func regErr(err error) *tailcfg.RegisterResponse { return &tailcfg.RegisterResponse{Error: err.Error()} } -// NoiseRegistrationHandler handles the actual registration process of a node. -func (ns *noiseServer) NoiseRegistrationHandler( +// RegistrationHandler handles the actual registration process of a node. +func (ns *noiseServer) RegistrationHandler( writer http.ResponseWriter, req *http.Request, ) { diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 9d284921..ee6dbeb9 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -12,7 +12,6 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/templates" "github.com/juanfont/headscale/hscontrol/types" @@ -26,8 +25,8 @@ import ( const ( randomByteSize = 16 defaultOAuthOptionsCount = 3 - registerCacheExpiration = time.Minute * 15 - registerCacheCleanup = time.Minute * 20 + authCacheExpiration = time.Minute * 15 + authCacheCleanup = time.Minute * 20 ) var ( @@ -44,17 +43,21 @@ var ( errOIDCUnverifiedEmail = errors.New("authenticated principal has an unverified email") ) -// RegistrationInfo contains both machine key and verifier information for OIDC validation. -type RegistrationInfo struct { - RegistrationID types.RegistrationID - Verifier *string +// AuthInfo contains both auth ID and verifier information for OIDC validation. +type AuthInfo struct { + AuthID types.AuthID + Verifier *string + Registration bool } type AuthProviderOIDC struct { - h *Headscale - serverURL string - cfg *types.OIDCConfig - registrationCache *zcache.Cache[string, RegistrationInfo] + h *Headscale + serverURL string + cfg *types.OIDCConfig + + // authCache holds auth information between + // the auth and the callback steps. + authCache *zcache.Cache[string, AuthInfo] oidcProvider *oidc.Provider oauth2Config *oauth2.Config @@ -81,45 +84,63 @@ func NewAuthProviderOIDC( Scopes: cfg.Scope, } - registrationCache := zcache.New[string, RegistrationInfo]( - registerCacheExpiration, - registerCacheCleanup, + authCache := zcache.New[string, AuthInfo]( + authCacheExpiration, + authCacheCleanup, ) return &AuthProviderOIDC{ - h: h, - serverURL: serverURL, - cfg: cfg, - registrationCache: registrationCache, + h: h, + serverURL: serverURL, + cfg: cfg, + authCache: authCache, oidcProvider: oidcProvider, oauth2Config: oauth2Config, }, nil } -func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string { +func (a *AuthProviderOIDC) AuthURL(authID types.AuthID) string { + return fmt.Sprintf( + "%s/auth/%s", + strings.TrimSuffix(a.serverURL, "/"), + authID.String()) +} + +func (a *AuthProviderOIDC) AuthHandler( + writer http.ResponseWriter, + req *http.Request, +) { + a.authHandler(writer, req, false) +} + +func (a *AuthProviderOIDC) RegisterURL(authID types.AuthID) string { return fmt.Sprintf( "%s/register/%s", strings.TrimSuffix(a.serverURL, "/"), - registrationID.String()) + authID.String()) } // RegisterHandler registers the OIDC callback handler with the given router. // It puts NodeKey in cache so the callback can retrieve it using the oidc state param. -// Listens in /register/:registration_id. +// Listens in /register/:auth_id. func (a *AuthProviderOIDC) RegisterHandler( writer http.ResponseWriter, req *http.Request, ) { - vars := mux.Vars(req) - registrationIdStr := vars["registration_id"] + a.authHandler(writer, req, true) +} - // We need to make sure we dont open for XSS style injections, if the parameter that - // is passed as a key is not parsable/validated as a NodePublic key, then fail to render - // the template and log an error. - registrationId, err := types.RegistrationIDFromString(registrationIdStr) +// authHandler takes an incoming request that needs to be authenticated and +// validates and prepares it for the OIDC flow. +func (a *AuthProviderOIDC) authHandler( + writer http.ResponseWriter, + req *http.Request, + registration bool, +) { + authID, err := authIDFromRequest(req) if err != nil { - httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err)) + httpError(writer, err) return } @@ -137,9 +158,9 @@ func (a *AuthProviderOIDC) RegisterHandler( return } - // Initialize registration info with machine key - registrationInfo := RegistrationInfo{ - RegistrationID: registrationId, + registrationInfo := AuthInfo{ + AuthID: authID, + Registration: registration, } extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount) @@ -167,7 +188,7 @@ func (a *AuthProviderOIDC) RegisterHandler( extras = append(extras, oidc.Nonce(nonce)) // Cache the registration info - a.registrationCache.Set(state, registrationInfo) + a.authCache.Set(state, registrationInfo) authURL := a.oauth2Config.AuthCodeURL(state, extras...) log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL) @@ -302,16 +323,20 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // If the node exists, then the node should be reauthenticated, // if the node does not exist, and the machine key exists, then // this is a new node that should be registered. - registrationId := a.getRegistrationIDFromState(state) + authInfo := a.getAuthInfoFromState(state) + if authInfo == nil { + log.Debug().Caller().Str("state", state).Msg("state not found in cache, login session may have expired") + httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) - // Register the node if it does not exist. - if registrationId != nil { - verb := "Reauthenticated" + return + } - newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) + // If this is a registration flow, then we need to register the node. + if authInfo.Registration { + newNode, err := a.handleRegistration(user, authInfo.AuthID, nodeExpiry) if err != nil { if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { - log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed") + log.Debug().Caller().Str("registration_id", authInfo.AuthID.String()).Msg("registration session expired before authorization completed") httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err)) return @@ -322,12 +347,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - if newNode { - verb = "Authenticated" - } - - // TODO(kradalby): replace with go-elem - content := renderOIDCCallbackTemplate(user, verb) + content := renderRegistrationSuccessTemplate(user, newNode) writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) @@ -339,9 +359,28 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - // Neither node nor machine key was found in the state cache meaning - // that we could not reauth nor register the node. - httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) + // If this is not a registration callback, then its a regular authentication callback + // and we need to send a response and confirm that the access was allowed. + + authReq, ok := a.h.state.GetAuthCacheEntry(authInfo.AuthID) + if !ok { + log.Debug().Caller().Str("auth_id", authInfo.AuthID.String()).Msg("auth session expired before authorization completed") + httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) + + return + } + + // Send a finish auth verdict with no errors to let the CLI know that the authentication was successful. + authReq.FinishAuth(types.AuthVerdict{}) + + content := renderAuthSuccessTemplate(user) + + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + writer.WriteHeader(http.StatusOK) + + if _, err := writer.Write(content.Bytes()); err != nil { //nolint:noinlineerr + util.LogErr(err, "Failed to write HTTP response") + } } func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time { @@ -374,7 +413,7 @@ func (a *AuthProviderOIDC) getOauth2Token( var exchangeOpts []oauth2.AuthCodeOption if a.cfg.PKCE.Enabled { - regInfo, ok := a.registrationCache.Get(state) + regInfo, ok := a.authCache.Get(state) if !ok { return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo) } @@ -507,14 +546,14 @@ func doOIDCAuthorization( return nil } -// getRegistrationIDFromState retrieves the registration ID from the state. -func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID { - regInfo, ok := a.registrationCache.Get(state) +// getAuthInfoFromState retrieves the registration ID from the state. +func (a *AuthProviderOIDC) getAuthInfoFromState(state string) *AuthInfo { + authInfo, ok := a.authCache.Get(state) if !ok { return nil } - return ®Info.RegistrationID + return &authInfo } func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( @@ -562,7 +601,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( func (a *AuthProviderOIDC) handleRegistration( user *types.User, - registrationID types.RegistrationID, + registrationID types.AuthID, expiry time.Time, ) (bool, error) { node, nodeChange, err := a.h.state.HandleNodeFromAuthPath( @@ -597,12 +636,38 @@ func (a *AuthProviderOIDC) handleRegistration( return !nodeChange.IsEmpty(), nil } -func renderOIDCCallbackTemplate( +func renderRegistrationSuccessTemplate( user *types.User, - verb string, + newNode bool, ) *bytes.Buffer { - html := templates.OIDCCallback(user.Display(), verb).Render() - return bytes.NewBufferString(html) + result := templates.AuthSuccessResult{ + Title: "Headscale - Node Reauthenticated", + Heading: "Node reauthenticated", + Verb: "Reauthenticated", + User: user.Display(), + Message: "You can now close this window.", + } + if newNode { + result.Title = "Headscale - Node Registered" + result.Heading = "Node registered" + result.Verb = "Registered" + } + + return bytes.NewBufferString(templates.AuthSuccess(result).Render()) +} + +func renderAuthSuccessTemplate( + user *types.User, +) *bytes.Buffer { + result := templates.AuthSuccessResult{ + Title: "Headscale - SSH Session Authorized", + Heading: "SSH session authorized", + Verb: "Authorized", + User: user.Display(), + Message: "You may return to your terminal.", + } + + return bytes.NewBufferString(templates.AuthSuccess(result).Render()) } // getCookieName generates a unique cookie name based on a cookie value. diff --git a/hscontrol/oidc_template_test.go b/hscontrol/oidc_template_test.go index 367451b1..24dfc0b0 100644 --- a/hscontrol/oidc_template_test.go +++ b/hscontrol/oidc_template_test.go @@ -7,35 +7,54 @@ import ( "github.com/stretchr/testify/assert" ) -func TestOIDCCallbackTemplate(t *testing.T) { +func TestAuthSuccessTemplate(t *testing.T) { tests := []struct { - name string - userName string - verb string + name string + result templates.AuthSuccessResult }{ { - name: "logged_in_user", - userName: "test@example.com", - verb: "Logged in", + name: "node_registered", + result: templates.AuthSuccessResult{ + Title: "Headscale - Node Registered", + Heading: "Node registered", + Verb: "Registered", + User: "newuser@example.com", + Message: "You can now close this window.", + }, }, { - name: "registered_user", - userName: "newuser@example.com", - verb: "Registered", + name: "node_reauthenticated", + result: templates.AuthSuccessResult{ + Title: "Headscale - Node Reauthenticated", + Heading: "Node reauthenticated", + Verb: "Reauthenticated", + User: "test@example.com", + Message: "You can now close this window.", + }, + }, + { + name: "ssh_session_authorized", + result: templates.AuthSuccessResult{ + Title: "Headscale - SSH Session Authorized", + Heading: "SSH session authorized", + Verb: "Authorized", + User: "test@example.com", + Message: "You may return to your terminal.", + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Render using the elem-go template - html := templates.OIDCCallback(tt.userName, tt.verb).Render() + html := templates.AuthSuccess(tt.result).Render() - // Verify the HTML contains expected elements + // Verify the HTML contains expected structural elements assert.Contains(t, html, "") - assert.Contains(t, html, "Headscale Authentication Succeeded") - assert.Contains(t, html, tt.verb) - assert.Contains(t, html, tt.userName) - assert.Contains(t, html, "You can now close this window") + assert.Contains(t, html, ""+tt.result.Title+"") + assert.Contains(t, html, tt.result.Heading) + assert.Contains(t, html, tt.result.Verb+" as ") + assert.Contains(t, html, tt.result.User) + assert.Contains(t, html, tt.result.Message) // Verify Material for MkDocs design system CSS is present assert.Contains(t, html, "Material for MkDocs") diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 6dfacd91..2de2e8dd 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -19,7 +19,7 @@ type PolicyManager interface { MatchersForNode(node types.NodeView) ([]matcher.Match, error) // BuildPeerMap constructs peer relationship maps for the given nodes BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView - SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) + SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error) SetPolicy(pol []byte) (bool, error) SetUsers(users []types.User) (bool, error) SetNodes(nodes views.Slice[types.NodeView]) (bool, error) diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index 9c97e39c..536c86f3 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -1188,8 +1188,9 @@ func TestSSHPolicyRules(t *testing.T) { "root": "", }, Action: &tailcfg.SSHAction{ - Accept: true, + Accept: false, SessionDuration: 24 * time.Hour, + HoldAndDelegate: "unused-url/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER", AllowAgentForwarding: true, AllowLocalPortForwarding: true, AllowRemotePortForwarding: true, @@ -1476,7 +1477,7 @@ func TestSSHPolicyRules(t *testing.T) { require.NoError(t, err) - got, err := pm.SSHPolicy(tt.targetNode.View()) + got, err := pm.SSHPolicy("unused-url", tt.targetNode.View()) require.NoError(t, err) if diff := cmp.Diff(tt.wantSSH, got); diff != "" { diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 9c2c5f17..f15093aa 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -319,11 +319,27 @@ func (pol *Policy) compileACLWithAutogroupSelf( return rules, nil } -func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction { +var sshAccept = tailcfg.SSHAction{ + Reject: false, + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, +} + +func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction { return tailcfg.SSHAction{ - Reject: !accept, - Accept: accept, - SessionDuration: duration, + Reject: false, + Accept: false, + SessionDuration: duration, + // Replaced in the client: + // * $SRC_NODE_IP (URL escaped) + // * $SRC_NODE_ID (Node.ID as int64 string) + // * $DST_NODE_IP (URL escaped) + // * $DST_NODE_ID (Node.ID as int64 string) + // * $SSH_USER (URL escaped, ssh user requested) + // * $LOCAL_USER (URL escaped, local user mapped) + HoldAndDelegate: baseURL + "/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER", AllowAgentForwarding: true, AllowLocalPortForwarding: true, AllowRemotePortForwarding: true, @@ -332,6 +348,7 @@ func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction { //nolint:gocyclo // complex SSH policy compilation logic func (pol *Policy) compileSSHPolicy( + baseURL string, users types.Users, node types.NodeView, nodes views.Slice[types.NodeView], @@ -377,9 +394,9 @@ func (pol *Policy) compileSSHPolicy( switch rule.Action { case SSHActionAccept: - action = sshAction(true, 0) + action = sshAccept case SSHActionCheck: - action = sshAction(true, time.Duration(rule.CheckPeriod)) + action = sshCheck(baseURL, time.Duration(rule.CheckPeriod)) default: return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) } diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index cdf7c131..1c15f732 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -615,7 +615,7 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { require.NoError(t, err) // Compile SSH policy - sshPolicy, err := tt.policy.compileSSHPolicy(users, tt.targetNode.View(), nodes.ViewSlice()) + sshPolicy, err := tt.policy.compileSSHPolicy("unused-server-url", users, tt.targetNode.View(), nodes.ViewSlice()) require.NoError(t, err) if tt.wantEmpty { @@ -691,7 +691,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { err := policy.validate() require.NoError(t, err) - sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, nodeTaggedServer.View(), nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -704,8 +704,11 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { } assert.Equal(t, expectedUsers, rule.SSHUsers) - // Verify check action with session duration - assert.True(t, rule.Action.Accept) + // Verify check action: Accept is false, HoldAndDelegate is set + assert.False(t, rule.Action.Accept) + assert.False(t, rule.Action.Reject) + assert.NotEmpty(t, rule.Action.HoldAndDelegate) + assert.Contains(t, rule.Action.HoldAndDelegate, "/machine/ssh/action/") assert.Equal(t, 24*time.Hour, rule.Action.SessionDuration) } @@ -756,7 +759,7 @@ func TestSSHIntegrationReproduction(t *testing.T) { require.NoError(t, err) // Test SSH policy compilation for node2 (owned by user2, who is in the group) - sshPolicy, err := policy.compileSSHPolicy(users, node2.View(), nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node2.View(), nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -806,7 +809,7 @@ func TestSSHJSONSerialization(t *testing.T) { err := policy.validate() require.NoError(t, err) - sshPolicy, err := policy.compileSSHPolicy(users, node.View(), nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node.View(), nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) @@ -1413,7 +1416,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { // Test for user1's first node node1 := nodes[0].View() - sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -1432,7 +1435,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { // Test for user2's first node node3 := nodes[2].View() - sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) + sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy2) require.Len(t, sshPolicy2.Rules, 1) @@ -1451,7 +1454,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { // Test for tagged node (should have no SSH rules) node5 := nodes[4].View() - sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) + sshPolicy3, err := policy.compileSSHPolicy("unused-server-url", users, node5, nodes.ViewSlice()) require.NoError(t, err) if sshPolicy3 != nil { @@ -1491,7 +1494,7 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { // For user1's node: should allow SSH from user1's devices node1 := nodes[0].View() - sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -1508,7 +1511,7 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { // For user2's node: should have no rules (user1's devices can't match user2's self) node3 := nodes[2].View() - sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) + sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice()) require.NoError(t, err) if sshPolicy2 != nil { @@ -1551,7 +1554,7 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { // For user1's node: should allow SSH from user1's devices only (not user2's) node1 := nodes[0].View() - sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -1568,7 +1571,7 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { // For user3's node: should have no rules (not in group:admins) node5 := nodes[4].View() - sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) + sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node5, nodes.ViewSlice()) require.NoError(t, err) if sshPolicy2 != nil { @@ -1610,7 +1613,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { // For untagged node: should only get principals from other untagged nodes node1 := nodes[0].View() - sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -1628,7 +1631,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { // For tagged node: should get no SSH rules node3 := nodes[2].View() - sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) + sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice()) require.NoError(t, err) if sshPolicy2 != nil { @@ -1671,7 +1674,7 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { // Test 1: Compile for user1's device (should only match autogroup:self destination) node1 := nodes[0].View() - sshPolicy1, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy1, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy1) require.Len(t, sshPolicy1.Rules, 1, "user1's device should have 1 SSH rule (autogroup:self)") @@ -1690,7 +1693,7 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { // Test 2: Compile for router (should only match tag:router destination) routerNode := nodes[3].View() // user2-router - sshPolicyRouter, err := policy.compileSSHPolicy(users, routerNode, nodes.ViewSlice()) + sshPolicyRouter, err := policy.compileSSHPolicy("unused-server-url", users, routerNode, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicyRouter) require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)") diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 74b7ba6a..744f52c7 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -222,7 +222,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { return true, nil } -func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { +func (pm *PolicyManager) SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error) { pm.mu.Lock() defer pm.mu.Unlock() @@ -230,7 +230,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err return sshPol, nil } - sshPol, err := pm.pol.compileSSHPolicy(pm.users, node, pm.nodes) + sshPol, err := pm.pol.compileSSHPolicy(baseURL, pm.users, node, pm.nodes) if err != nil { return nil, fmt.Errorf("compiling SSH policy: %w", err) } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index e421d5bd..1ec3eedf 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -64,6 +64,9 @@ var ErrNodeNotInNodeStore = errors.New("node no longer exists in NodeStore") // ErrNodeNameNotUnique is returned when a node name is not unique. var ErrNodeNameNotUnique = errors.New("node name is not unique") +// ErrRegistrationExpired is returned when a registration has expired. +var ErrRegistrationExpired = errors.New("registration expired") + // State manages Headscale's core state, coordinating between database, policy management, // IP allocation, and DERP routing. All methods are thread-safe. type State struct { @@ -82,8 +85,10 @@ type State struct { derpMap atomic.Pointer[tailcfg.DERPMap] // polMan handles policy evaluation and management polMan policy.PolicyManager - // registrationCache caches node registration data to reduce database load - registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode] + + // authCache caches any pending authentication requests, from either auth type (Web and OIDC). + authCache *zcache.Cache[types.AuthID, types.AuthRequest] + // primaryRoutes tracks primary route assignments for nodes primaryRoutes *routes.PrimaryRoutes } @@ -101,20 +106,20 @@ func NewState(cfg *types.Config) (*State, error) { cacheCleanup = cfg.Tuning.RegisterCacheCleanup } - registrationCache := zcache.New[types.RegistrationID, types.RegisterNode]( + authCache := zcache.New[types.AuthID, types.AuthRequest]( cacheExpiration, cacheCleanup, ) - registrationCache.OnEvicted( - func(id types.RegistrationID, rn types.RegisterNode) { - rn.SendAndClose(nil) + authCache.OnEvicted( + func(id types.AuthID, rn types.AuthRequest) { + rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired}) }, ) db, err := hsdb.NewHeadscaleDatabase( cfg, - registrationCache, + authCache, ) if err != nil { return nil, fmt.Errorf("initializing database: %w", err) @@ -178,12 +183,12 @@ func NewState(cfg *types.Config) (*State, error) { return &State{ cfg: cfg, - db: db, - ipAlloc: ipAlloc, - polMan: polMan, - registrationCache: registrationCache, - primaryRoutes: routes.New(), - nodeStore: nodeStore, + db: db, + ipAlloc: ipAlloc, + polMan: polMan, + authCache: authCache, + primaryRoutes: routes.New(), + nodeStore: nodeStore, }, nil } @@ -851,7 +856,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha // SSHPolicy returns the SSH access policy for a node. func (s *State) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { - return s.polMan.SSHPolicy(node) + return s.polMan.SSHPolicy(s.cfg.ServerURL, node) } // Filter returns the current network filter rules and matches. @@ -1042,9 +1047,9 @@ func (s *State) DeletePreAuthKey(id uint64) error { return s.db.DeletePreAuthKey(id) } -// GetRegistrationCacheEntry retrieves a node registration from cache. -func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.RegisterNode, bool) { - entry, found := s.registrationCache.Get(id) +// GetAuthCacheEntry retrieves a node registration from cache. +func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) { + entry, found := s.authCache.Get(id) if !found { return nil, false } @@ -1052,26 +1057,24 @@ func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.Regis return &entry, true } -// SetRegistrationCacheEntry stores a node registration in cache. -func (s *State) SetRegistrationCacheEntry(id types.RegistrationID, entry types.RegisterNode) { - s.registrationCache.Set(id, entry) +// SetAuthCacheEntry stores a node registration in cache. +func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) { + s.authCache.Set(id, entry) } // logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname. -func logHostinfoValidation(machineKey, nodeKey, username, hostname string, hostinfo *tailcfg.Hostinfo) { - if hostinfo == nil { +func logHostinfoValidation(nv types.NodeView, username, hostname string) { + if !nv.Hostinfo().Valid() { log.Warn(). Caller(). - Str(zf.MachineKey, machineKey). - Str(zf.NodeKey, nodeKey). + EmbedObject(nv). Str(zf.UserName, username). Str(zf.GeneratedHostname, hostname). Msg("Registration had nil hostinfo, generated default hostname") - } else if hostinfo.Hostname == "" { + } else if nv.Hostinfo().Hostname() == "" { log.Warn(). Caller(). - Str(zf.MachineKey, machineKey). - Str(zf.NodeKey, nodeKey). + EmbedObject(nv). Str(zf.UserName, username). Str(zf.GeneratedHostname, hostname). Msg("Registration had empty hostname, generated default") @@ -1113,7 +1116,7 @@ type authNodeUpdateParams struct { // Node to update; must be valid and in NodeStore. ExistingNode types.NodeView // Client data: keys, hostinfo, endpoints. - RegEntry *types.RegisterNode + RegEntry *types.AuthRequest // Pre-validated hostinfo; NetInfo preserved from ExistingNode. ValidHostinfo *tailcfg.Hostinfo // Hostname from hostinfo, or generated from keys if client omits it. @@ -1132,6 +1135,7 @@ type authNodeUpdateParams struct { // an existing node. It updates the node in NodeStore, processes RequestTags, and // persists changes to the database. func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) { + regNv := params.RegEntry.Node() // Log the operation type if params.IsConvertFromTag { log.Info(). @@ -1140,16 +1144,16 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView Msg("Converting tagged node to user-owned node") } else { log.Info(). - EmbedObject(params.ExistingNode). - Interface("hostinfo", params.RegEntry.Node.Hostinfo). + Object("existing", params.ExistingNode). + Object("incoming", regNv). Msg("Updating existing node registration via reauth") } // Process RequestTags during reauth (#2979) // Due to json:",omitempty", we treat empty/nil as "clear tags" var requestTags []string - if params.RegEntry.Node.Hostinfo != nil { - requestTags = params.RegEntry.Node.Hostinfo.RequestTags + if regNv.Hostinfo().Valid() { + requestTags = regNv.Hostinfo().RequestTags().AsSlice() } oldTags := params.ExistingNode.Tags().AsSlice() @@ -1167,8 +1171,8 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView // Update existing node in NodeStore - validation passed, safe to mutate updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) { - node.NodeKey = params.RegEntry.Node.NodeKey - node.DiscoKey = params.RegEntry.Node.DiscoKey + node.NodeKey = regNv.NodeKey() + node.DiscoKey = regNv.DiscoKey() node.Hostname = params.Hostname // Preserve NetInfo from existing node when re-registering @@ -1179,7 +1183,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView params.ValidHostinfo, ) - node.Endpoints = params.RegEntry.Node.Endpoints + node.Endpoints = regNv.Endpoints().AsSlice() node.IsOnline = new(false) node.LastSeen = new(time.Now()) @@ -1188,7 +1192,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView if params.IsConvertFromTag { node.RegisterMethod = params.RegisterMethod } else { - node.RegisterMethod = params.RegEntry.Node.RegisterMethod + node.RegisterMethod = regNv.RegisterMethod() } // Track tagged status BEFORE processing tags @@ -1208,7 +1212,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView if params.Expiry != nil { node.Expiry = params.Expiry } else { - node.Expiry = params.RegEntry.Node.Expiry + node.Expiry = regNv.Expiry().Clone() } case !wasTagged && isTagged: // Personal → Tagged: clear expiry (tagged nodes don't expire) @@ -1218,14 +1222,14 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView if params.Expiry != nil { node.Expiry = params.Expiry } else { - node.Expiry = params.RegEntry.Node.Expiry + node.Expiry = regNv.Expiry().Clone() } case !isTagged: // Personal → Personal: update expiry from client if params.Expiry != nil { node.Expiry = params.Expiry } else { - node.Expiry = params.RegEntry.Node.Expiry + node.Expiry = regNv.Expiry().Clone() } } // Tagged → Tagged: keep existing expiry (nil) - no action needed @@ -1511,13 +1515,13 @@ func (s *State) processReauthTags( // HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC). func (s *State) HandleNodeFromAuthPath( - registrationID types.RegistrationID, + authID types.AuthID, userID types.UserID, expiry *time.Time, registrationMethod string, ) (types.NodeView, change.Change, error) { // Get the registration entry from cache - regEntry, ok := s.GetRegistrationCacheEntry(registrationID) + regEntry, ok := s.GetAuthCacheEntry(authID) if !ok { return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache } @@ -1530,25 +1534,27 @@ func (s *State) HandleNodeFromAuthPath( // Ensure we have a valid hostname from the registration cache entry hostname := util.EnsureHostname( - regEntry.Node.Hostinfo, - regEntry.Node.MachineKey.String(), - regEntry.Node.NodeKey.String(), + regEntry.Node().Hostinfo(), + regEntry.Node().MachineKey().String(), + regEntry.Node().NodeKey().String(), ) // Ensure we have valid hostinfo - validHostinfo := cmp.Or(regEntry.Node.Hostinfo, &tailcfg.Hostinfo{}) - validHostinfo.Hostname = hostname + hostinfo := &tailcfg.Hostinfo{} + if regEntry.Node().Hostinfo().Valid() { + hostinfo = regEntry.Node().Hostinfo().AsStruct() + } + + hostinfo.Hostname = hostname logHostinfoValidation( - regEntry.Node.MachineKey.ShortString(), - regEntry.Node.NodeKey.String(), + regEntry.Node(), user.Name, hostname, - regEntry.Node.Hostinfo, ) // Lookup existing nodes - machineKey := regEntry.Node.MachineKey + machineKey := regEntry.Node().MachineKey() existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID)) existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey) @@ -1562,7 +1568,7 @@ func (s *State) HandleNodeFromAuthPath( // Create logger with common fields for all auth operations logger := log.With(). - Str(zf.RegistrationID, registrationID.String()). + Str(zf.RegistrationID, authID.String()). Str(zf.UserName, user.Name). Str(zf.MachineKey, machineKey.ShortString()). Str(zf.Method, registrationMethod). @@ -1571,7 +1577,7 @@ func (s *State) HandleNodeFromAuthPath( // Common params for update operations updateParams := authNodeUpdateParams{ RegEntry: regEntry, - ValidHostinfo: validHostinfo, + ValidHostinfo: hostinfo, Hostname: hostname, User: user, Expiry: expiry, @@ -1605,7 +1611,7 @@ func (s *State) HandleNodeFromAuthPath( Msg("Creating new node for different user (same machine key exists for another user)") finalNode, err = s.createNewNodeFromAuth( - logger, user, regEntry, hostname, validHostinfo, + logger, user, regEntry, hostname, hostinfo, expiry, registrationMethod, existingNodeAnyUser, ) if err != nil { @@ -1613,7 +1619,7 @@ func (s *State) HandleNodeFromAuthPath( } } else { finalNode, err = s.createNewNodeFromAuth( - logger, user, regEntry, hostname, validHostinfo, + logger, user, regEntry, hostname, hostinfo, expiry, registrationMethod, types.NodeView{}, ) if err != nil { @@ -1622,10 +1628,10 @@ func (s *State) HandleNodeFromAuthPath( } // Signal to waiting clients - regEntry.SendAndClose(finalNode.AsStruct()) + regEntry.FinishAuth(types.AuthVerdict{Node: finalNode}) // Delete from registration cache - s.registrationCache.Delete(registrationID) + s.authCache.Delete(authID) // Update policy managers usersChange, err := s.updatePolicyManagerUsers() @@ -1654,7 +1660,7 @@ func (s *State) HandleNodeFromAuthPath( func (s *State) createNewNodeFromAuth( logger zerolog.Logger, user *types.User, - regEntry *types.RegisterNode, + regEntry *types.AuthRequest, hostname string, validHostinfo *tailcfg.Hostinfo, expiry *time.Time, @@ -1667,13 +1673,13 @@ func (s *State) createNewNodeFromAuth( return s.createAndSaveNewNode(newNodeParams{ User: *user, - MachineKey: regEntry.Node.MachineKey, - NodeKey: regEntry.Node.NodeKey, - DiscoKey: regEntry.Node.DiscoKey, + MachineKey: regEntry.Node().MachineKey(), + NodeKey: regEntry.Node().NodeKey(), + DiscoKey: regEntry.Node().DiscoKey(), Hostname: hostname, Hostinfo: validHostinfo, - Endpoints: regEntry.Node.Endpoints, - Expiry: cmp.Or(expiry, regEntry.Node.Expiry), + Endpoints: regEntry.Node().Endpoints().AsSlice(), + Expiry: cmp.Or(expiry, regEntry.Node().Expiry().Clone()), RegisterMethod: registrationMethod, ExistingNodeForNetinfo: existingNodeForNetinfo, }) @@ -1759,7 +1765,7 @@ func (s *State) HandleNodeFromPreAuthKey( // Ensure we have a valid hostname - handle nil/empty cases hostname := util.EnsureHostname( - regReq.Hostinfo, + regReq.Hostinfo.View(), machineKey.String(), regReq.NodeKey.String(), ) @@ -1768,14 +1774,6 @@ func (s *State) HandleNodeFromPreAuthKey( validHostinfo := cmp.Or(regReq.Hostinfo, &tailcfg.Hostinfo{}) validHostinfo.Hostname = hostname - logHostinfoValidation( - machineKey.ShortString(), - regReq.NodeKey.ShortString(), - pakUsername(), - hostname, - regReq.Hostinfo, - ) - log.Debug(). Caller(). Str(zf.NodeName, hostname). diff --git a/hscontrol/templates/auth_success.go b/hscontrol/templates/auth_success.go new file mode 100644 index 00000000..1a212b6e --- /dev/null +++ b/hscontrol/templates/auth_success.go @@ -0,0 +1,62 @@ +package templates + +import ( + "github.com/chasefleming/elem-go" +) + +// AuthSuccessResult contains the text content for an authentication success page. +// Each field controls a distinct piece of user-facing text so that every auth +// flow (node registration, reauthentication, SSH check, …) can clearly +// communicate what just happened. +type AuthSuccessResult struct { + // Title is the browser tab / page title, + // e.g. "Headscale - Node Registered". + Title string + + // Heading is the bold green text inside the success box, + // e.g. "Node registered". + Heading string + + // Verb is the action prefix in the body text before "as ", + // e.g. "Registered", "Reauthenticated", "Authorized". + Verb string + + // User is the display name shown in bold in the body text, + // e.g. "user@example.com". + User string + + // Message is the follow-up instruction shown after the user name, + // e.g. "You can now close this window." + Message string +} + +// AuthSuccess renders an authentication / authorisation success page. +// The caller controls every user-visible string via [AuthSuccessResult] so the +// page clearly describes what succeeded (registration, reauth, SSH check, …). +func AuthSuccess(result AuthSuccessResult) *elem.Element { + box := successBox( + result.Heading, + elem.Text(result.Verb+" as "), + elem.Strong(nil, elem.Text(result.User)), + elem.Text(". "+result.Message), + ) + + return HtmlStructure( + elem.Title(nil, elem.Text(result.Title)), + mdTypesetBody( + headscaleLogo(), + box, + H2(elem.Text("Getting started")), + P(elem.Text("Check out the documentation to learn more about headscale and Tailscale:")), + Ul( + elem.Li(nil, + externalLink("https://headscale.net/stable/", "Headscale documentation"), + ), + elem.Li(nil, + externalLink("https://tailscale.com/kb/", "Tailscale knowledge base"), + ), + ), + pageFooter(), + ), + ) +} diff --git a/hscontrol/templates/auth_web.go b/hscontrol/templates/auth_web.go new file mode 100644 index 00000000..8b6d6f97 --- /dev/null +++ b/hscontrol/templates/auth_web.go @@ -0,0 +1,21 @@ +package templates + +import ( + "github.com/chasefleming/elem-go" +) + +// AuthWeb renders a page that instructs an administrator to run a CLI command +// to complete an authentication or registration flow. +// It is used by both the registration and auth-approve web handlers. +func AuthWeb(title, description, command string) *elem.Element { + return HtmlStructure( + elem.Title(nil, elem.Text(title+" - Headscale")), + mdTypesetBody( + headscaleLogo(), + H1(elem.Text(title)), + P(elem.Text(description)), + Pre(PreCode(command)), + pageFooter(), + ), + ) +} diff --git a/hscontrol/templates/design.go b/hscontrol/templates/design.go index 615c0e41..221eaf11 100644 --- a/hscontrol/templates/design.go +++ b/hscontrol/templates/design.go @@ -365,6 +365,47 @@ func orDivider() *elem.Element { ) } +// successBox creates a green success feedback box with a checkmark icon. +// The heading is displayed as bold green text, and children are rendered below it. +// Pairs with warningBox for consistent feedback styling. +// +//nolint:unused // Used in auth_success.go template. +func successBox(heading string, children ...elem.Node) *elem.Element { + return elem.Div(attrs.Props{ + attrs.Style: styles.Props{ + styles.Display: "flex", + styles.AlignItems: "center", + styles.Gap: spaceM, + styles.Padding: spaceL, + styles.BackgroundColor: colorSuccessLight, + styles.Border: "1px solid " + colorSuccess, + styles.BorderRadius: "0.5rem", + styles.MarginBottom: spaceXL, + }.ToInline(), + }, + checkboxIcon(), + elem.Div(nil, + append([]elem.Node{ + elem.Strong(attrs.Props{ + attrs.Style: styles.Props{ + styles.Display: "block", + styles.Color: colorSuccess, + styles.FontSize: fontSizeH3, + styles.MarginBottom: spaceXS, + }.ToInline(), + }, elem.Text(heading)), + }, children...)..., + ), + ) +} + +// checkboxIcon returns the success checkbox SVG icon as raw HTML. +func checkboxIcon() elem.Node { + return elem.Raw(``) +} + // warningBox creates a warning message box with icon and content. // //nolint:unused // Used in apple.go template. diff --git a/hscontrol/templates/oidc_callback.go b/hscontrol/templates/oidc_callback.go deleted file mode 100644 index 16c08fde..00000000 --- a/hscontrol/templates/oidc_callback.go +++ /dev/null @@ -1,69 +0,0 @@ -package templates - -import ( - "github.com/chasefleming/elem-go" - "github.com/chasefleming/elem-go/attrs" - "github.com/chasefleming/elem-go/styles" -) - -// checkboxIcon returns the success checkbox SVG icon as raw HTML. -func checkboxIcon() elem.Node { - return elem.Raw(``) -} - -// OIDCCallback renders the OIDC authentication success callback page. -func OIDCCallback(user, verb string) *elem.Element { - // Success message box - successBox := elem.Div(attrs.Props{ - attrs.Style: styles.Props{ - styles.Display: "flex", - styles.AlignItems: "center", - styles.Gap: spaceM, - styles.Padding: spaceL, - styles.BackgroundColor: colorSuccessLight, - styles.Border: "1px solid " + colorSuccess, - styles.BorderRadius: "0.5rem", - styles.MarginBottom: spaceXL, - }.ToInline(), - }, - checkboxIcon(), - elem.Div(nil, - elem.Strong(attrs.Props{ - attrs.Style: styles.Props{ - styles.Display: "block", - styles.Color: colorSuccess, - styles.FontSize: fontSizeH3, - styles.MarginBottom: spaceXS, - }.ToInline(), - }, elem.Text("Signed in successfully")), - elem.P(attrs.Props{ - attrs.Style: styles.Props{ - styles.Margin: "0", - styles.Color: colorTextPrimary, - styles.FontSize: fontSizeBase, - }.ToInline(), - }, elem.Text(verb), elem.Text(" as "), elem.Strong(nil, elem.Text(user)), elem.Text(". You can now close this window.")), - ), - ) - - return HtmlStructure( - elem.Title(nil, elem.Text("Headscale Authentication Succeeded")), - mdTypesetBody( - headscaleLogo(), - successBox, - H2(elem.Text("Getting started")), - P(elem.Text("Check out the documentation to learn more about headscale and Tailscale:")), - Ul( - elem.Li(nil, - externalLink("https://headscale.net/stable/", "Headscale documentation"), - ), - elem.Li(nil, - externalLink("https://tailscale.com/kb/", "Tailscale knowledge base"), - ), - ), - pageFooter(), - ), - ) -} diff --git a/hscontrol/templates/register_web.go b/hscontrol/templates/register_web.go deleted file mode 100644 index 829af7fb..00000000 --- a/hscontrol/templates/register_web.go +++ /dev/null @@ -1,21 +0,0 @@ -package templates - -import ( - "fmt" - - "github.com/chasefleming/elem-go" - "github.com/juanfont/headscale/hscontrol/types" -) - -func RegisterWeb(registrationID types.RegistrationID) *elem.Element { - return HtmlStructure( - elem.Title(nil, elem.Text("Registration - Headscale")), - mdTypesetBody( - headscaleLogo(), - H1(elem.Text("Machine registration")), - P(elem.Text("Run the command below in the headscale server to add this machine to your network:")), - Pre(PreCode(fmt.Sprintf("headscale nodes register --key %s --user USERNAME", registrationID.String()))), - pageFooter(), - ), - ) -} diff --git a/hscontrol/templates_consistency_test.go b/hscontrol/templates_consistency_test.go index 369639cc..4836c1d1 100644 --- a/hscontrol/templates_consistency_test.go +++ b/hscontrol/templates_consistency_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/juanfont/headscale/hscontrol/templates" - "github.com/juanfont/headscale/hscontrol/types" "github.com/stretchr/testify/assert" ) @@ -16,12 +15,30 @@ func TestTemplateHTMLConsistency(t *testing.T) { html string }{ { - name: "OIDC Callback", - html: templates.OIDCCallback("test@example.com", "Logged in").Render(), + name: "Auth Success", + html: templates.AuthSuccess(templates.AuthSuccessResult{ + Title: "Headscale - Node Registered", + Heading: "Node registered", + Verb: "Registered", + User: "test@example.com", + Message: "You can now close this window.", + }).Render(), }, { - name: "Register Web", - html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + name: "Auth Web Register", + html: templates.AuthWeb( + "Machine registration", + "Run the command below in the headscale server to add this machine to your network:", + "headscale auth register --auth-id test-key-123 --user USERNAME", + ).Render(), + }, + { + name: "Auth Web Approve", + html: templates.AuthWeb( + "Authentication check", + "Run the command below in the headscale server to approve this authentication request:", + "headscale auth approve --auth-id test-key-123", + ).Render(), }, { name: "Windows Config", @@ -72,12 +89,30 @@ func TestTemplateModernHTMLFeatures(t *testing.T) { html string }{ { - name: "OIDC Callback", - html: templates.OIDCCallback("test@example.com", "Logged in").Render(), + name: "Auth Success", + html: templates.AuthSuccess(templates.AuthSuccessResult{ + Title: "Headscale - Node Registered", + Heading: "Node registered", + Verb: "Registered", + User: "test@example.com", + Message: "You can now close this window.", + }).Render(), }, { - name: "Register Web", - html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + name: "Auth Web Register", + html: templates.AuthWeb( + "Machine registration", + "Run the command below in the headscale server to add this machine to your network:", + "headscale auth register --auth-id test-key-123 --user USERNAME", + ).Render(), + }, + { + name: "Auth Web Approve", + html: templates.AuthWeb( + "Authentication check", + "Run the command below in the headscale server to approve this authentication request:", + "headscale auth approve --auth-id test-key-123", + ).Render(), }, { name: "Windows Config", @@ -116,16 +151,35 @@ func TestTemplateExternalLinkSecurity(t *testing.T) { externalURLs []string // URLs that should have security attributes }{ { - name: "OIDC Callback", - html: templates.OIDCCallback("test@example.com", "Logged in").Render(), + name: "Auth Success", + html: templates.AuthSuccess(templates.AuthSuccessResult{ + Title: "Headscale - Node Registered", + Heading: "Node registered", + Verb: "Registered", + User: "test@example.com", + Message: "You can now close this window.", + }).Render(), externalURLs: []string{ "https://headscale.net/stable/", "https://tailscale.com/kb/", }, }, { - name: "Register Web", - html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + name: "Auth Web Register", + html: templates.AuthWeb( + "Machine registration", + "Run the command below in the headscale server to add this machine to your network:", + "headscale auth register --auth-id test-key-123 --user USERNAME", + ).Render(), + externalURLs: []string{}, // No external links + }, + { + name: "Auth Web Approve", + html: templates.AuthWeb( + "Authentication check", + "Run the command below in the headscale server to approve this authentication request:", + "headscale auth approve --auth-id test-key-123", + ).Render(), externalURLs: []string{}, // No external links }, { @@ -185,12 +239,30 @@ func TestTemplateAccessibilityAttributes(t *testing.T) { html string }{ { - name: "OIDC Callback", - html: templates.OIDCCallback("test@example.com", "Logged in").Render(), + name: "Auth Success", + html: templates.AuthSuccess(templates.AuthSuccessResult{ + Title: "Headscale - Node Registered", + Heading: "Node registered", + Verb: "Registered", + User: "test@example.com", + Message: "You can now close this window.", + }).Render(), }, { - name: "Register Web", - html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + name: "Auth Web Register", + html: templates.AuthWeb( + "Machine registration", + "Run the command below in the headscale server to add this machine to your network:", + "headscale auth register --auth-id test-key-123 --user USERNAME", + ).Render(), + }, + { + name: "Auth Web Approve", + html: templates.AuthWeb( + "Authentication check", + "Run the command below in the headscale server to approve this authentication request:", + "headscale auth approve --auth-id test-key-123", + ).Render(), }, { name: "Windows Config", diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index d852753e..01429dc9 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -22,8 +22,8 @@ const ( // Common errors. var ( - ErrCannotParsePrefix = errors.New("cannot parse prefix") - ErrInvalidRegistrationIDLength = errors.New("registration ID has invalid length") + ErrCannotParsePrefix = errors.New("cannot parse prefix") + ErrInvalidAuthIDLength = errors.New("registration ID has invalid length") ) type StateUpdateType int @@ -159,21 +159,21 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate { } } -const RegistrationIDLength = 24 +const AuthIDLength = 24 -type RegistrationID string +type AuthID string -func NewRegistrationID() (RegistrationID, error) { - rid, err := util.GenerateRandomStringURLSafe(RegistrationIDLength) +func NewAuthID() (AuthID, error) { + rid, err := util.GenerateRandomStringURLSafe(AuthIDLength) if err != nil { return "", err } - return RegistrationID(rid), nil + return AuthID(rid), nil } -func MustRegistrationID() RegistrationID { - rid, err := NewRegistrationID() +func MustAuthID() AuthID { + rid, err := NewAuthID() if err != nil { panic(err) } @@ -181,43 +181,96 @@ func MustRegistrationID() RegistrationID { return rid } -func RegistrationIDFromString(str string) (RegistrationID, error) { - if len(str) != RegistrationIDLength { - return "", fmt.Errorf("%w: expected %d, got %d", ErrInvalidRegistrationIDLength, RegistrationIDLength, len(str)) +func AuthIDFromString(str string) (AuthID, error) { + r := AuthID(str) + + err := r.Validate() + if err != nil { + return "", err } - return RegistrationID(str), nil + return r, nil } -func (r RegistrationID) String() string { +func (r AuthID) String() string { return string(r) } -type RegisterNode struct { - Node Node - Registered chan *Node - closed *atomic.Bool +func (r AuthID) Validate() error { + if len(r) != AuthIDLength { + return fmt.Errorf("%w: expected %d, got %d", ErrInvalidAuthIDLength, AuthIDLength, len(r)) + } + + return nil } -func NewRegisterNode(node Node) RegisterNode { - return RegisterNode{ - Node: node, - Registered: make(chan *Node), - closed: &atomic.Bool{}, +// AuthRequest represent a pending authentication request from a user or a node. +// If it is a registration request, the node field will be populate with the node that is trying to register. +// When the authentication process is finished, the node that has been authenticated will be sent through the Finished channel. +// The closed field is used to ensure that the Finished channel is only closed once, and that no more nodes are sent through it after it has been closed. +type AuthRequest struct { + node *Node + finished chan AuthVerdict + closed *atomic.Bool +} + +func NewAuthRequest() AuthRequest { + return AuthRequest{ + finished: make(chan AuthVerdict), + closed: &atomic.Bool{}, } } -func (rn *RegisterNode) SendAndClose(node *Node) { +func NewRegisterAuthRequest(node Node) AuthRequest { + return AuthRequest{ + node: &node, + finished: make(chan AuthVerdict), + closed: &atomic.Bool{}, + } +} + +// Node returns the node that is trying to register. +// It will panic if the AuthRequest is not a registration request. +// Can _only_ be used in the registration path. +func (rn *AuthRequest) Node() NodeView { + if rn.node == nil { + panic("Node can only be used in registration requests") + } + + return rn.node.View() +} + +func (rn *AuthRequest) FinishAuth(verdict AuthVerdict) { if rn.closed.Swap(true) { return } select { - case rn.Registered <- node: + case rn.finished <- verdict: default: } - close(rn.Registered) + close(rn.finished) +} + +func (rn *AuthRequest) WaitForAuth() <-chan AuthVerdict { + return rn.finished +} + +type AuthVerdict struct { + // Err is the error that occurred during the authentication process, if any. + // If Err is nil, the authentication process has succeeded. + // If Err is not nil, the authentication process has failed and the node should not be authenticated. + Err error + + // Node is the node that has been authenticated. + // Node is only valid if the auth request was a registration request + // and the authentication process has succeeded. + Node NodeView +} + +func (v AuthVerdict) Accept() bool { + return v.Err == nil } // DefaultBatcherWorkers returns the default number of batcher workers. diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index cbce663b..034779b5 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -295,8 +295,8 @@ func IsCI() bool { // 3. If normalisation fails → generate invalid- replacement // // Returns the guaranteed-valid hostname to use. -func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) string { - if hostinfo == nil || hostinfo.Hostname == "" { +func EnsureHostname(hostinfo tailcfg.HostinfoView, machineKey, nodeKey string) string { + if !hostinfo.Valid() || hostinfo.Hostname() == "" { key := cmp.Or(machineKey, nodeKey) if key == "" { return "unknown-node" @@ -310,7 +310,7 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri return "node-" + keyPrefix } - lowercased := strings.ToLower(hostinfo.Hostname) + lowercased := strings.ToLower(hostinfo.Hostname()) err := ValidateHostname(lowercased) if err == nil { diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index 5cca4990..6e7a0630 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -1070,7 +1070,7 @@ func TestEnsureHostname(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) + got := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey) // For invalid hostnames, we just check the prefix since the random part varies if strings.HasPrefix(tt.want, "invalid-") { if !strings.HasPrefix(got, "invalid-") { @@ -1255,7 +1255,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) + gotHostname := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey) // For invalid hostnames, we just check the prefix since the random part varies if strings.HasPrefix(tt.wantHostname, "invalid-") { if !strings.HasPrefix(gotHostname, "invalid-") { @@ -1284,7 +1284,7 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) { hostinfo := &tailcfg.Hostinfo{Hostname: hostname} - result := EnsureHostname(hostinfo, "mkey", "nkey") + result := EnsureHostname(hostinfo.View(), "mkey", "nkey") if len(result) > 63 { t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result)) } @@ -1300,8 +1300,8 @@ func TestEnsureHostname_Idempotent(t *testing.T) { OS: "linux", } - hostname1 := EnsureHostname(originalHostinfo, "mkey", "nkey") - hostname2 := EnsureHostname(originalHostinfo, "mkey", "nkey") + hostname1 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey") + hostname2 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey") if hostname1 != hostname2 { t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2) diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index eba2ebbf..d00c5fdd 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -312,7 +312,7 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { } // Register all clients as user1 (this is where cross-user registration happens) - // This simulates: headscale nodes register --user user1 --key + // This simulates: headscale auth register --auth-id --user user1 _ = scenario.runHeadscaleRegister("user1", body) } diff --git a/integration/cli_test.go b/integration/cli_test.go index a1174277..a7696bb4 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -1065,11 +1065,11 @@ func TestNodeCommand(t *testing.T) { require.NoError(t, err) regIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), } nodes := make([]*v1.Node, len(regIDs)) @@ -1100,11 +1100,11 @@ func TestNodeCommand(t *testing.T) { headscale, []string{ "headscale", - "nodes", + "auth", + "register", "--user", "node-user", - "register", - "--key", + "--auth-id", regID, "--output", "json", @@ -1153,8 +1153,8 @@ func TestNodeCommand(t *testing.T) { assert.Equal(t, "node-5", listAll[4].GetName()) otherUserRegIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), } otherUserMachines := make([]*v1.Node, len(otherUserRegIDs)) @@ -1185,11 +1185,11 @@ func TestNodeCommand(t *testing.T) { headscale, []string{ "headscale", - "nodes", + "auth", + "register", "--user", "other-user", - "register", - "--key", + "--auth-id", regID, "--output", "json", @@ -1326,11 +1326,11 @@ func TestNodeExpireCommand(t *testing.T) { require.NoError(t, err) regIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), } nodes := make([]*v1.Node, len(regIDs)) @@ -1359,11 +1359,11 @@ func TestNodeExpireCommand(t *testing.T) { headscale, []string{ "headscale", - "nodes", + "auth", + "register", "--user", "node-expire-user", - "register", - "--key", + "--auth-id", regID, "--output", "json", @@ -1461,11 +1461,11 @@ func TestNodeRenameCommand(t *testing.T) { require.NoError(t, err) regIDs := []string{ - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), - types.MustRegistrationID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), + types.MustAuthID().String(), } nodes := make([]*v1.Node, len(regIDs)) @@ -1496,11 +1496,11 @@ func TestNodeRenameCommand(t *testing.T) { headscale, []string{ "headscale", - "nodes", + "auth", + "register", "--user", "node-rename-command", - "register", - "--key", + "--auth-id", regID, "--output", "json", diff --git a/integration/control.go b/integration/control.go index f390d080..d9273ae6 100644 --- a/integration/control.go +++ b/integration/control.go @@ -16,6 +16,7 @@ import ( type ControlServer interface { Shutdown() (string, string, error) SaveLog(path string) (string, string, error) + ReadLog() (string, string, error) SaveProfile(path string) error Execute(command []string) (string, error) WriteFile(path string, content []byte) error diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 3ef4d5d4..cd60c20d 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -699,6 +699,18 @@ func (t *HeadscaleInContainer) WriteLogs(stdout, stderr io.Writer) error { return dockertestutil.WriteLog(t.pool, t.container, stdout, stderr) } +// ReadLog returns the current stdout and stderr logs from the headscale container. +func (t *HeadscaleInContainer) ReadLog() (string, string, error) { + var stdout, stderr bytes.Buffer + + err := dockertestutil.WriteLog(t.pool, t.container, &stdout, &stderr) + if err != nil { + return "", "", fmt.Errorf("reading container logs: %w", err) + } + + return stdout.String(), stderr.String(), nil +} + // SaveLog saves the current stdout log of the container to a path // on the host system. func (t *HeadscaleInContainer) SaveLog(path string) (string, string, error) { diff --git a/integration/scenario.go b/integration/scenario.go index cd43b78f..ba99a392 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -141,6 +141,12 @@ type ScenarioSpec struct { // Versions is specific list of versions to use for the test. Versions []string + // OIDCSkipUserCreation, if true, skips creating users via headscale CLI + // during environment setup. Useful for OIDC tests where the SSH policy + // references users by name, since OIDC login creates users automatically + // and pre-creating them via CLI causes duplicate user records. + OIDCSkipUserCreation bool + // OIDCUsers, if populated, will start a Mock OIDC server and populate // the user login stack with the given users. // If the NodesPerUser is set, it should align with this list to ensure @@ -866,9 +872,18 @@ func (s *Scenario) createHeadscaleEnvWithTags( } for _, user := range s.spec.Users { - u, err := s.CreateUser(user) - if err != nil { - return err + var u *v1.User + + if s.spec.OIDCSkipUserCreation { + // Only register locally — OIDC login will create the headscale user. + s.mu.Lock() + s.users[user] = &User{Clients: make(map[string]TailscaleClient)} + s.mu.Unlock() + } else { + u, err = s.CreateUser(user) + if err != nil { + return err + } } var userOpts []tsic.Option @@ -1169,7 +1184,7 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { return errParseAuthPage } - keySep := strings.Split(codeSep[0], "key ") + keySep := strings.Split(codeSep[0], "--auth-id ") if len(keySep) != 2 { return errParseAuthPage } @@ -1180,7 +1195,7 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { if headscale, err := s.Headscale(); err == nil { //nolint:noinlineerr _, err = headscale.Execute( - []string{"headscale", "nodes", "register", "--user", userStr, "--key", key}, + []string{"headscale", "auth", "register", "--user", userStr, "--auth-id", key}, ) if err != nil { log.Printf("registering node: %s", err) diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 45bc2dc7..5a46f598 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -3,13 +3,17 @@ package integration import ( "fmt" "log" + "net/url" "strings" "testing" "time" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" + "github.com/oauth2-proxy/mockoidc" + "github.com/prometheus/common/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" @@ -579,3 +583,558 @@ func TestSSHAutogroupSelf(t *testing.T) { } } } + +type sshCheckResult struct { + stdout string + stderr string + err error +} + +// doSSHCheck runs SSH in a goroutine with a longer timeout, returning a channel +// for the result. The SSH command will block while waiting for auth approval in +// check mode. +func doSSHCheck( + t *testing.T, + client TailscaleClient, + peer TailscaleClient, +) chan sshCheckResult { + t.Helper() + + peerFQDN, _ := peer.FQDN() + + command := []string{ + "/usr/bin/ssh", "-o StrictHostKeyChecking=no", "-o ConnectTimeout=30", + fmt.Sprintf("%s@%s", "ssh-it-user", peerFQDN), + "'hostname'", + } + + log.Printf( + "[SSH check] Running from %s to %s", + client.Hostname(), + peer.Hostname(), + ) + + ch := make(chan sshCheckResult, 1) + + go func() { + stdout, stderr, err := client.Execute( + command, + dockertestutil.ExecuteCommandTimeout(60*time.Second), + ) + ch <- sshCheckResult{stdout, stderr, err} + }() + + return ch +} + +// findSSHCheckAuthID polls headscale container logs for the SSH action auth-id. +// The SSH action handler logs "SSH action follow-up" with the auth_id on the +// follow-up request (where auth_id is non-empty). +func findSSHCheckAuthID(t *testing.T, headscale ControlServer) string { + t.Helper() + + var authID string + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + _, stderr, err := headscale.ReadLog() + assert.NoError(c, err) + + for line := range strings.SplitSeq(stderr, "\n") { + if !strings.Contains(line, "SSH action follow-up") { + continue + } + + if idx := strings.Index(line, "auth_id="); idx != -1 { + start := idx + len("auth_id=") + + end := strings.IndexByte(line[start:], ' ') + if end == -1 { + end = len(line[start:]) + } + + authID = line[start : start+end] + } + } + + assert.NotEmpty(c, authID, "auth-id not found in headscale logs") + }, 10*time.Second, 500*time.Millisecond, "waiting for SSH check auth-id in headscale logs") + + return authID +} + +// sshCheckPolicy returns a policy with SSH "check" mode for group:integration-test +// targeting autogroup:member and autogroup:tagged destinations. +func sshCheckPolicy() *policyv2.Policy { + return &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-test"): []policyv2.Username{ + policyv2.Username("user1@"), + }, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + SSHs: []policyv2.SSH{ + { + Action: "check", + Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")}, + Destinations: policyv2.SSHDstAliases{ + new(policyv2.AutoGroupMember), + new(policyv2.AutoGroupTagged), + }, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, + }, + }, + } +} + +// sshCheckPolicyWithPeriod returns a policy with SSH "check" mode and a +// specified checkPeriod for session duration. +func sshCheckPolicyWithPeriod(period time.Duration) *policyv2.Policy { + return &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-test"): []policyv2.Username{ + policyv2.Username("user1@"), + }, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + SSHs: []policyv2.SSH{ + { + Action: "check", + Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")}, + Destinations: policyv2.SSHDstAliases{ + new(policyv2.AutoGroupMember), + new(policyv2.AutoGroupTagged), + }, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, + CheckPeriod: model.Duration(period), + }, + }, + } +} + +// findNewSSHCheckAuthID polls headscale logs for an SSH check auth-id +// that differs from excludeID. Used to verify re-authentication after +// session expiry. +func findNewSSHCheckAuthID( + t *testing.T, + headscale ControlServer, + excludeID string, +) string { + t.Helper() + + var authID string + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + _, stderr, err := headscale.ReadLog() + assert.NoError(c, err) + + for line := range strings.SplitSeq(stderr, "\n") { + if !strings.Contains(line, "SSH action follow-up") { + continue + } + + if idx := strings.Index(line, "auth_id="); idx != -1 { + start := idx + len("auth_id=") + + end := strings.IndexByte(line[start:], ' ') + if end == -1 { + end = len(line[start:]) + } + + id := line[start : start+end] + if id != excludeID { + authID = id + } + } + } + + assert.NotEmpty(c, authID, "new auth-id not found in headscale logs") + }, 10*time.Second, 500*time.Millisecond, "waiting for new SSH check auth-id") + + return authID +} + +func TestSSHOneUserToOneCheckModeCLI(t *testing.T) { + IntegrationSkip(t) + + scenario := sshScenario(t, sshCheckPolicy(), 1) + // defer scenario.ShutdownAssertNoPanics(t) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + requireNoErrListClients(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + requireNoErrListClients(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + requireNoErrListFQDN(t, err) + + // user1 can SSH (via check) to all peers + for _, client := range user1Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + // Start SSH — will block waiting for check auth + sshResult := doSSHCheck(t, client, peer) + + // Find the auth-id from headscale logs + authID := findSSHCheckAuthID(t, headscale) + + // Approve via CLI + _, err := headscale.Execute( + []string{ + "headscale", "auth", "approve", + "--auth-id", authID, + }, + ) + require.NoError(t, err) + + // Wait for SSH to complete + select { + case result := <-sshResult: + require.NoError(t, result.err) + require.Contains( + t, + peer.ContainerID(), + strings.ReplaceAll(result.stdout, "\n", ""), + ) + case <-time.After(30 * time.Second): + t.Fatal("SSH did not complete after auth approval") + } + } + } + + // user2 cannot SSH — not in the check policy group + for _, client := range user2Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + assertSSHPermissionDenied(t, client, peer) + } + } +} + +func TestSSHOneUserToOneCheckModeOIDC(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + OIDCSkipUserCreation: true, + OIDCUsers: []mockoidc.MockUser{ + // First 2: consumed during node registration + oidcMockUser("user1", true), + oidcMockUser("user2", true), + // Extra: consumed during SSH check auth flows. + // Each SSH check pops one user from the queue. + oidcMockUser("user1", true), + }, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + // defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + } + + err = scenario.CreateHeadscaleEnvWithLoginURL( + []tsic.Option{ + tsic.WithSSH(), + tsic.WithNetfilter("off"), + tsic.WithPackages("openssh"), + tsic.WithExtraCommands("adduser ssh-it-user"), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithACLPolicy(sshCheckPolicy()), + hsic.WithTestName("sshcheckoidc"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer( + "/tmp/hs_client_oidc_secret", + []byte(scenario.mockOIDC.ClientSecret()), + ), + ) + require.NoError(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + requireNoErrListClients(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + requireNoErrListClients(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + requireNoErrListFQDN(t, err) + + // user1 can SSH (via check) to all peers + for _, client := range user1Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + // Start SSH — will block waiting for check auth + sshResult := doSSHCheck(t, client, peer) + + // Find the auth-id from headscale logs + authID := findSSHCheckAuthID(t, headscale) + + // Build auth URL and visit it to trigger OIDC flow. + // The mock OIDC server auto-authenticates from the user queue. + authURL := headscale.GetEndpoint() + "/auth/" + authID + parsedURL, err := url.Parse(authURL) + require.NoError(t, err) + + _, err = doLoginURL("ssh-check-oidc", parsedURL) + require.NoError(t, err) + + // Wait for SSH to complete + select { + case result := <-sshResult: + require.NoError(t, result.err) + require.Contains( + t, + peer.ContainerID(), + strings.ReplaceAll(result.stdout, "\n", ""), + ) + case <-time.After(30 * time.Second): + t.Fatal("SSH did not complete after OIDC auth") + } + } + } + + // user2 cannot SSH — not in the check policy group + for _, client := range user2Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + assertSSHPermissionDenied(t, client, peer) + } + } +} + +// TestSSHCheckModeUnapprovedTimeout verifies that SSH in check mode is rejected +// when nobody approves the auth request and the registration cache entry expires. +func TestSSHCheckModeUnapprovedTimeout(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{ + tsic.WithSSH(), + tsic.WithNetfilter("off"), + tsic.WithPackages("openssh"), + tsic.WithExtraCommands("adduser ssh-it-user"), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithACLPolicy(sshCheckPolicy()), + hsic.WithTestName("sshchecktimeout"), + hsic.WithConfigEnv(map[string]string{ + "HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "15s", + "HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "5s", + }), + ) + require.NoError(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + requireNoErrListClients(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + requireNoErrListClients(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + requireNoErrListFQDN(t, err) + + // user1 attempts SSH — enters check flow, but nobody approves + for _, client := range user1Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + sshResult := doSSHCheck(t, client, peer) + + // Confirm the check flow was entered + _ = findSSHCheckAuthID(t, headscale) + + // Do NOT approve — wait for cache expiry and SSH rejection + select { + case result := <-sshResult: + require.Error(t, result.err, "SSH should be rejected when unapproved") + assert.Empty(t, result.stdout, "no command output expected on rejection") + case <-time.After(60 * time.Second): + t.Fatal("SSH did not complete after cache expiry timeout") + } + } + } + + // user2 still gets immediate Permission Denied + for _, client := range user2Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + assertSSHPermissionDenied(t, client, peer) + } + } +} + +// TestSSHCheckModeCheckPeriodCLI verifies that after approval with a short +// checkPeriod, the session expires and the next SSH connection requires +// re-authentication via a new check flow. +func TestSSHCheckModeCheckPeriodCLI(t *testing.T) { + IntegrationSkip(t) + + // 1 minute is the documented minimum checkPeriod + scenario := sshScenario(t, sshCheckPolicyWithPeriod(time.Minute), 1) + defer scenario.ShutdownAssertNoPanics(t) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + requireNoErrListClients(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + requireNoErrListFQDN(t, err) + + // === Phase 1: First SSH check — approve, verify success === + for _, client := range user1Clients { + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + sshResult := doSSHCheck(t, client, peer) + firstAuthID := findSSHCheckAuthID(t, headscale) + + _, err := headscale.Execute( + []string{ + "headscale", "auth", "approve", + "--auth-id", firstAuthID, + }, + ) + require.NoError(t, err) + + select { + case result := <-sshResult: + require.NoError(t, result.err, "first SSH should succeed after approval") + require.Contains( + t, + peer.ContainerID(), + strings.ReplaceAll(result.stdout, "\n", ""), + ) + case <-time.After(30 * time.Second): + t.Fatal("first SSH did not complete after auth approval") + } + + // === Phase 2: Wait for checkPeriod to expire === + //nolint:forbidigo // Intentional sleep: waiting for the check period session + // to expire. This is a time-based expiry, not a pollable condition — the + // Tailscale client caches the approval for SessionDuration and only + // re-triggers the check flow after it elapses. + time.Sleep(70 * time.Second) + + // === Phase 3: Second SSH — must re-authenticate === + sshResult2 := doSSHCheck(t, client, peer) + secondAuthID := findNewSSHCheckAuthID(t, headscale, firstAuthID) + + require.NotEqual( + t, + firstAuthID, + secondAuthID, + "second SSH should trigger a new auth flow after checkPeriod expiry", + ) + + _, err = headscale.Execute( + []string{ + "headscale", "auth", "approve", + "--auth-id", secondAuthID, + }, + ) + require.NoError(t, err) + + select { + case result := <-sshResult2: + require.NoError(t, result.err, "second SSH should succeed after re-approval") + require.Contains( + t, + peer.ContainerID(), + strings.ReplaceAll(result.stdout, "\n", ""), + ) + case <-time.After(30 * time.Second): + t.Fatal("second SSH did not complete after re-auth approval") + } + } + } +} diff --git a/integration/tags_test.go b/integration/tags_test.go index b4fe678b..617f688d 100644 --- a/integration/tags_test.go +++ b/integration/tags_test.go @@ -3122,7 +3122,7 @@ func TestTagsAuthKeyWithoutUserRejectsAdvertisedTags(t *testing.T) { // TestTagsAuthKeyConvertToUserViaCLIRegister reproduces the panic from // issue #3038: register a node with a tags-only preauthkey (no user), then -// convert it to a user-owned node via "headscale nodes register --user --key ...". +// convert it to a user-owned node via "headscale auth register --auth-id --user ". // The crash happens in the mapper's generateUserProfiles when node.User is nil // after the tag→user conversion in processReauthTags. // diff --git a/proto/headscale/v1/auth.proto b/proto/headscale/v1/auth.proto new file mode 100644 index 00000000..8292400e --- /dev/null +++ b/proto/headscale/v1/auth.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; +package headscale.v1; +option go_package = "github.com/juanfont/headscale/gen/go/v1"; + +import "headscale/v1/node.proto"; + +message AuthRegisterRequest { + string user = 1; + string auth_id = 2; +} + +message AuthRegisterResponse { + Node node = 1; +} + +message AuthApproveRequest { + string auth_id = 1; +} + +message AuthApproveResponse {} diff --git a/proto/headscale/v1/headscale.proto b/proto/headscale/v1/headscale.proto index 5e556255..5a0dd288 100644 --- a/proto/headscale/v1/headscale.proto +++ b/proto/headscale/v1/headscale.proto @@ -8,6 +8,7 @@ import "headscale/v1/user.proto"; import "headscale/v1/preauthkey.proto"; import "headscale/v1/node.proto"; import "headscale/v1/apikey.proto"; +import "headscale/v1/auth.proto"; import "headscale/v1/policy.proto"; service HeadscaleService { @@ -139,6 +140,22 @@ service HeadscaleService { // --- Node end --- + // --- Auth start --- + rpc AuthRegister(AuthRegisterRequest) returns (AuthRegisterResponse) { + option (google.api.http) = { + post : "/api/v1/auth/register" + body : "*" + }; + } + + rpc AuthApprove(AuthApproveRequest) returns (AuthApproveResponse) { + option (google.api.http) = { + post : "/api/v1/auth/approve" + body : "*" + }; + } + // --- Auth end --- + // --- ApiKeys start --- rpc CreateApiKey(CreateApiKeyRequest) returns (CreateApiKeyResponse) { option (google.api.http) = {