diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 37aa792e..735c50bf 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -70,6 +70,7 @@ jobs: - TestTaildrop - TestUpdateHostnameFromClient - TestExpireNode + - TestSetNodeExpiryInFuture - TestNodeOnlineStatus - TestPingAllByIPManyUpDown - Test2118DeletingOnlineNodePanics diff --git a/CHANGELOG.md b/CHANGELOG.md index da547451..02986867 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ ## Next +### Changes + +- Expire nodes with a custom timestamp + [#2828](https://github.com/juanfont/headscale/pull/2828) + ## 0.27.0 (2025-10-27) **Minimum supported Tailscale client version: v1.64.0** diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index e1b8e7b3..e1b040f0 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -15,6 +15,7 @@ import ( "github.com/samber/lo" "github.com/spf13/cobra" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" "tailscale.com/types/key" ) @@ -51,6 +52,7 @@ func init() { nodeCmd.AddCommand(registerNodeCmd) expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") + expireNodeCmd.Flags().StringP("expiry", "e", "", "Set expire to (RFC3339 format, e.g. 2025-08-27T10:00:00Z), or leave empty to expire immediately.") err = expireNodeCmd.MarkFlagRequired("identifier") if err != nil { log.Fatal(err.Error()) @@ -289,12 +291,37 @@ var expireNodeCmd = &cobra.Command{ ) } + expiry, err := cmd.Flags().GetString("expiry") + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Error converting expiry to string: %s", err), + output, + ) + + return + } + expiryTime := time.Now() + if expiry != "" { + expiryTime, err = time.Parse(time.RFC3339, expiry) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Error converting expiry to string: %s", err), + output, + ) + + return + } + } + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() request := &v1.ExpireNodeRequest{ NodeId: identifier, + Expiry: timestamppb.New(expiryTime), } response, err := client.ExpireNode(ctx, request) diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go index b8a9a2ad..f99d5390 100644 --- a/cmd/headscale/cli/policy.go +++ b/cmd/headscale/cli/policy.go @@ -127,12 +127,6 @@ var setPolicy = &cobra.Command{ ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output) } - _, err = policy.NewPolicyManager(policyBytes, nil, views.Slice[types.NodeView]{}) - if err != nil { - ErrorOutput(err, fmt.Sprintf("Error parsing the policy file: %s", err), output) - return - } - if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass { confirm := false force, _ := cmd.Flags().GetBool("force") @@ -159,6 +153,17 @@ var setPolicy = &cobra.Command{ ErrorOutput(err, fmt.Sprintf("Failed to open database: %s", err), output) } + users, err := d.ListUsers() + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed to load users for policy validation: %s", err), output) + } + + _, err = policy.NewPolicyManager(policyBytes, users, views.Slice[types.NodeView]{}) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Error parsing the policy file: %s", err), output) + return + } + _, err = d.SetPolicy(string(policyBytes)) if err != nil { ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output) diff --git a/flake.nix b/flake.nix index c064c7fe..f8eb6dd1 100644 --- a/flake.nix +++ b/flake.nix @@ -19,7 +19,7 @@ overlay = _: prev: let pkgs = nixpkgs.legacyPackages.${prev.system}; buildGo = pkgs.buildGo125Module; - vendorHash = "sha256-GUIzlPRsyEq1uSTzRNds9p1uVu4pTeH5PAxrJ5Njhis="; + vendorHash = "sha256-VOi4PGZ8I+2MiwtzxpKc/4smsL5KcH/pHVkjJfAFPJ0="; in { headscale = buildGo { pname = "headscale"; diff --git a/gen/go/headscale/v1/apikey.pb.go b/gen/go/headscale/v1/apikey.pb.go index 38aaf55a..a9f6a7b8 100644 --- a/gen/go/headscale/v1/apikey.pb.go +++ b/gen/go/headscale/v1/apikey.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.8 +// protoc-gen-go v1.36.10 // protoc (unknown) // source: headscale/v1/apikey.proto diff --git a/gen/go/headscale/v1/device.pb.go b/gen/go/headscale/v1/device.pb.go index c31bd754..8b150f96 100644 --- a/gen/go/headscale/v1/device.pb.go +++ b/gen/go/headscale/v1/device.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.8 +// protoc-gen-go v1.36.10 // protoc (unknown) // source: headscale/v1/device.proto diff --git a/gen/go/headscale/v1/headscale.pb.go b/gen/go/headscale/v1/headscale.pb.go index e9fdfd7f..2c594f5a 100644 --- a/gen/go/headscale/v1/headscale.pb.go +++ b/gen/go/headscale/v1/headscale.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.8 +// protoc-gen-go v1.36.10 // protoc (unknown) // source: headscale/v1/headscale.proto diff --git a/gen/go/headscale/v1/headscale.pb.gw.go b/gen/go/headscale/v1/headscale.pb.gw.go index fcd7fa2b..2a8ac365 100644 --- a/gen/go/headscale/v1/headscale.pb.gw.go +++ b/gen/go/headscale/v1/headscale.pb.gw.go @@ -471,6 +471,8 @@ func local_request_HeadscaleService_DeleteNode_0(ctx context.Context, marshaler return msg, metadata, err } +var filter_HeadscaleService_ExpireNode_0 = &utilities.DoubleArray{Encoding: map[string]int{"node_id": 0}, Base: []int{1, 1, 0}, Check: []int{0, 1, 2}} + func request_HeadscaleService_ExpireNode_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { var ( protoReq ExpireNodeRequest @@ -485,6 +487,12 @@ func request_HeadscaleService_ExpireNode_0(ctx context.Context, marshaler runtim if err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) } + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_ExpireNode_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } msg, err := client.ExpireNode(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err } @@ -503,6 +511,12 @@ func local_request_HeadscaleService_ExpireNode_0(ctx context.Context, marshaler if err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) } + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_ExpireNode_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } msg, err := server.ExpireNode(ctx, &protoReq) return msg, metadata, err } diff --git a/gen/go/headscale/v1/node.pb.go b/gen/go/headscale/v1/node.pb.go index 60d8fb95..f04c7e2d 100644 --- a/gen/go/headscale/v1/node.pb.go +++ b/gen/go/headscale/v1/node.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.8 +// protoc-gen-go v1.36.10 // protoc (unknown) // source: headscale/v1/node.proto @@ -729,6 +729,7 @@ func (*DeleteNodeResponse) Descriptor() ([]byte, []int) { type ExpireNodeRequest struct { state protoimpl.MessageState `protogen:"open.v1"` NodeId uint64 `protobuf:"varint,1,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"` + Expiry *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=expiry,proto3" json:"expiry,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -770,6 +771,13 @@ func (x *ExpireNodeRequest) GetNodeId() uint64 { return 0 } +func (x *ExpireNodeRequest) GetExpiry() *timestamppb.Timestamp { + if x != nil { + return x.Expiry + } + return nil +} + type ExpireNodeResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` @@ -1349,9 +1357,10 @@ const file_headscale_v1_node_proto_rawDesc = "" + "\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\",\n" + "\x11DeleteNodeRequest\x12\x17\n" + "\anode_id\x18\x01 \x01(\x04R\x06nodeId\"\x14\n" + - "\x12DeleteNodeResponse\",\n" + + "\x12DeleteNodeResponse\"`\n" + "\x11ExpireNodeRequest\x12\x17\n" + - "\anode_id\x18\x01 \x01(\x04R\x06nodeId\"<\n" + + "\anode_id\x18\x01 \x01(\x04R\x06nodeId\x122\n" + + "\x06expiry\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\x06expiry\"<\n" + "\x12ExpireNodeResponse\x12&\n" + "\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\"G\n" + "\x11RenameNodeRequest\x12\x17\n" + @@ -1439,16 +1448,17 @@ var file_headscale_v1_node_proto_depIdxs = []int32{ 1, // 7: headscale.v1.GetNodeResponse.node:type_name -> headscale.v1.Node 1, // 8: headscale.v1.SetTagsResponse.node:type_name -> headscale.v1.Node 1, // 9: headscale.v1.SetApprovedRoutesResponse.node:type_name -> headscale.v1.Node - 1, // 10: headscale.v1.ExpireNodeResponse.node:type_name -> headscale.v1.Node - 1, // 11: headscale.v1.RenameNodeResponse.node:type_name -> headscale.v1.Node - 1, // 12: headscale.v1.ListNodesResponse.nodes:type_name -> headscale.v1.Node - 1, // 13: headscale.v1.MoveNodeResponse.node:type_name -> headscale.v1.Node - 1, // 14: headscale.v1.DebugCreateNodeResponse.node:type_name -> headscale.v1.Node - 15, // [15:15] is the sub-list for method output_type - 15, // [15:15] is the sub-list for method input_type - 15, // [15:15] is the sub-list for extension type_name - 15, // [15:15] is the sub-list for extension extendee - 0, // [0:15] is the sub-list for field type_name + 25, // 10: headscale.v1.ExpireNodeRequest.expiry:type_name -> google.protobuf.Timestamp + 1, // 11: headscale.v1.ExpireNodeResponse.node:type_name -> headscale.v1.Node + 1, // 12: headscale.v1.RenameNodeResponse.node:type_name -> headscale.v1.Node + 1, // 13: headscale.v1.ListNodesResponse.nodes:type_name -> headscale.v1.Node + 1, // 14: headscale.v1.MoveNodeResponse.node:type_name -> headscale.v1.Node + 1, // 15: headscale.v1.DebugCreateNodeResponse.node:type_name -> headscale.v1.Node + 16, // [16:16] is the sub-list for method output_type + 16, // [16:16] is the sub-list for method input_type + 16, // [16:16] is the sub-list for extension type_name + 16, // [16:16] is the sub-list for extension extendee + 0, // [0:16] is the sub-list for field type_name } func init() { file_headscale_v1_node_proto_init() } diff --git a/gen/go/headscale/v1/policy.pb.go b/gen/go/headscale/v1/policy.pb.go index 4ac6e3b2..fefcfb22 100644 --- a/gen/go/headscale/v1/policy.pb.go +++ b/gen/go/headscale/v1/policy.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.8 +// protoc-gen-go v1.36.10 // protoc (unknown) // source: headscale/v1/policy.proto diff --git a/gen/go/headscale/v1/preauthkey.pb.go b/gen/go/headscale/v1/preauthkey.pb.go index de7f3248..661f170d 100644 --- a/gen/go/headscale/v1/preauthkey.pb.go +++ b/gen/go/headscale/v1/preauthkey.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.8 +// protoc-gen-go v1.36.10 // protoc (unknown) // source: headscale/v1/preauthkey.proto diff --git a/gen/go/headscale/v1/user.pb.go b/gen/go/headscale/v1/user.pb.go index 97fcaff9..fa6d49bb 100644 --- a/gen/go/headscale/v1/user.pb.go +++ b/gen/go/headscale/v1/user.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.8 +// protoc-gen-go v1.36.10 // protoc (unknown) // source: headscale/v1/user.proto diff --git a/gen/openapiv2/headscale/v1/headscale.swagger.json b/gen/openapiv2/headscale/v1/headscale.swagger.json index 2900d65f..6a7b48ad 100644 --- a/gen/openapiv2/headscale/v1/headscale.swagger.json +++ b/gen/openapiv2/headscale/v1/headscale.swagger.json @@ -406,6 +406,13 @@ "required": true, "type": "string", "format": "uint64" + }, + { + "name": "expiry", + "in": "query", + "required": false, + "type": "string", + "format": "date-time" } ], "tags": [ diff --git a/go.mod b/go.mod index b96cedf1..67c6c089 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ require ( github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 github.com/tailscale/hujson v0.0.0-20250226034555-ec1d1c113d33 - github.com/tailscale/squibble v0.0.0-20250108170732-a4ca58afa694 + github.com/tailscale/squibble v0.0.0-20251030164342-4d5df9caa993 github.com/tailscale/tailsql v0.0.0-20250421235516-02f85f087b97 github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e go4.org/netipx v0.0.0-20231129151722-fdeea329fbba @@ -115,7 +115,7 @@ require ( github.com/containerd/errdefs v0.3.0 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 // indirect - github.com/creachadair/mds v0.25.2 // indirect + github.com/creachadair/mds v0.25.10 // indirect github.com/dblohm7/wingoes v0.0.0-20240123200102-b75a8a7d7eb0 // indirect github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e // indirect github.com/distribution/reference v0.6.0 // indirect @@ -159,7 +159,7 @@ require ( github.com/jinzhu/now v1.1.5 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/jsimonetti/rtnetlink v1.4.1 // indirect - github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/compress v1.18.1 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/lib/pq v1.10.9 // indirect diff --git a/go.sum b/go.sum index 1b09acc5..e78e9aff 100644 --- a/go.sum +++ b/go.sum @@ -126,6 +126,8 @@ github.com/creachadair/flax v0.0.5 h1:zt+CRuXQASxwQ68e9GHAOnEgAU29nF0zYMHOCrL5wz github.com/creachadair/flax v0.0.5/go.mod h1:F1PML0JZLXSNDMNiRGK2yjm5f+L9QCHchyHBldFymj8= github.com/creachadair/mds v0.25.2 h1:xc0S0AfDq5GX9KUR5sLvi5XjA61/P6S5e0xFs1vA18Q= github.com/creachadair/mds v0.25.2/go.mod h1:+s4CFteFRj4eq2KcGHW8Wei3u9NyzSPzNV32EvjyK/Q= +github.com/creachadair/mds v0.25.10 h1:9k9JB35D1xhOCFl0liBhagBBp8fWWkKZrA7UXsfoHtA= +github.com/creachadair/mds v0.25.10/go.mod h1:4hatI3hRM+qhzuAmqPRFvaBM8mONkS7nsLxkcuTYUIs= github.com/creachadair/taskgroup v0.13.2 h1:3KyqakBuFsm3KkXi/9XIb0QcA8tEzLHLgaoidf0MdVc= github.com/creachadair/taskgroup v0.13.2/go.mod h1:i3V1Zx7H8RjwljUEeUWYT30Lmb9poewSb2XI1yTwD0g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -278,6 +280,8 @@ github.com/jsimonetti/rtnetlink v1.4.1 h1:JfD4jthWBqZMEffc5RjgmlzpYttAVw1sdnmiNa github.com/jsimonetti/rtnetlink v1.4.1/go.mod h1:xJjT7t59UIZ62GLZbv6PLLo8VFrostJMPBAheR6OM8w= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co= +github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.10/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= @@ -461,6 +465,8 @@ github.com/tailscale/setec v0.0.0-20250305161714-445cadbbca3d h1:mnqtPWYyvNiPU9l github.com/tailscale/setec v0.0.0-20250305161714-445cadbbca3d/go.mod h1:9BzmlFc3OLqLzLTF/5AY+BMs+clxMqyhSGzgXIm8mNI= github.com/tailscale/squibble v0.0.0-20250108170732-a4ca58afa694 h1:95eIP97c88cqAFU/8nURjgI9xxPbD+Ci6mY/a79BI/w= github.com/tailscale/squibble v0.0.0-20250108170732-a4ca58afa694/go.mod h1:veguaG8tVg1H/JG5RfpoUW41I+O8ClPElo/fTYr8mMk= +github.com/tailscale/squibble v0.0.0-20251030164342-4d5df9caa993 h1:FyiiAvDAxpB0DrW2GW3KOVfi3YFOtsQUEeFWbf55JJU= +github.com/tailscale/squibble v0.0.0-20251030164342-4d5df9caa993/go.mod h1:xJkMmR3t+thnUQhA3Q4m2VSlS5pcOq+CIjmU/xfKKx4= github.com/tailscale/tailsql v0.0.0-20250421235516-02f85f087b97 h1:JJkDnrAhHvOCttk8z9xeZzcDlzzkRA7+Duxj9cwOyxk= github.com/tailscale/tailsql v0.0.0-20250421235516-02f85f087b97/go.mod h1:9jS8HxwsP2fU4ESZ7DZL+fpH/U66EVlVMzdgznH12RM= github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 h1:UBPHPtv8+nEAy2PD8RyAhOYvau1ek0HDJqLS/Pysi14= diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index e18f2e5d..04c6cc0a 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -932,6 +932,26 @@ AND auth_key_id NOT IN ( }, Rollback: func(db *gorm.DB) error { return nil }, }, + { + // Drop all tables that are no longer in use and has existed. + // They potentially still present from broken migrations in the past. + ID: "202510311551", + Migrate: func(tx *gorm.DB) error { + for _, oldTable := range []string{"namespaces", "machines", "shared_machines", "kvs", "pre_auth_key_acl_tags", "routes"} { + err := tx.Migrator().DropTable(oldTable) + if err != nil { + log.Trace().Str("table", oldTable). + Err(err). + Msg("Error dropping old table, continuing...") + } + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + return nil + }, + }, // From this point, the following rules must be followed: // - NEVER use gorm.AutoMigrate, write the exact migration steps needed // - AutoMigrate depends on the struct staying exactly the same, which it won't over time. @@ -962,7 +982,17 @@ AND auth_key_id NOT IN ( ctx, cancel := context.WithTimeout(context.Background(), contextTimeoutSecs*time.Second) defer cancel() - if err := squibble.Validate(ctx, sqlConn, dbSchema); err != nil { + opts := squibble.DigestOptions{ + IgnoreTables: []string{ + // Litestream tables, these are inserted by + // litestream and not part of our schema + // https://litestream.io/how-it-works + "_litestream_lock", + "_litestream_seq", + }, + } + + if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil { return nil, fmt.Errorf("validating schema: %w", err) } } diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 5493a55c..70d3afaf 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -18,6 +18,7 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "gorm.io/gorm" + "tailscale.com/net/tsaddr" "tailscale.com/types/key" "tailscale.com/types/ptr" ) @@ -27,9 +28,7 @@ const ( NodeGivenNameTrimSize = 2 ) -var ( - invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") -) +var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") var ( ErrNodeNotFound = errors.New("node not found") @@ -234,6 +233,17 @@ func SetApprovedRoutes( return nil } + // When approving exit routes, ensure both IPv4 and IPv6 are included + // If either 0.0.0.0/0 or ::/0 is being approved, both should be approved + hasIPv4Exit := slices.Contains(routes, tsaddr.AllIPv4()) + hasIPv6Exit := slices.Contains(routes, tsaddr.AllIPv6()) + + if hasIPv4Exit && !hasIPv6Exit { + routes = append(routes, tsaddr.AllIPv6()) + } else if hasIPv6Exit && !hasIPv4Exit { + routes = append(routes, tsaddr.AllIPv4()) + } + b, err := json.Marshal(routes) if err != nil { return err diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index b51dba1c..0efd0e8b 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -476,7 +476,7 @@ func TestAutoApproveRoutes(t *testing.T) { require.NoError(t, err) } - newRoutes2, changed2 := policy.ApproveRoutesWithPolicy(pm, nodeTagged.View(), node.ApprovedRoutes, tt.routes) + newRoutes2, changed2 := policy.ApproveRoutesWithPolicy(pm, nodeTagged.View(), nodeTagged.ApprovedRoutes, tt.routes) if changed2 { err = SetApprovedRoutes(adb.DB, nodeTagged.ID, newRoutes2) require.NoError(t, err) @@ -490,7 +490,7 @@ func TestAutoApproveRoutes(t *testing.T) { if len(expectedRoutes1) == 0 { expectedRoutes1 = nil } - if diff := cmp.Diff(expectedRoutes1, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" { + if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" { t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) } @@ -501,7 +501,7 @@ func TestAutoApproveRoutes(t *testing.T) { if len(expectedRoutes2) == 0 { expectedRoutes2 = nil } - if diff := cmp.Diff(expectedRoutes2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" { + if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" { t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) } }) diff --git a/hscontrol/db/testdata/sqlite/headscale_0.26.1_dump_schema-to-0.27.0-old-table-cleanup.sql b/hscontrol/db/testdata/sqlite/headscale_0.26.1_dump_schema-to-0.27.0-old-table-cleanup.sql new file mode 100644 index 00000000..388fefbc --- /dev/null +++ b/hscontrol/db/testdata/sqlite/headscale_0.26.1_dump_schema-to-0.27.0-old-table-cleanup.sql @@ -0,0 +1,40 @@ +PRAGMA foreign_keys=OFF; +BEGIN TRANSACTION; +CREATE TABLE `migrations` (`id` text,PRIMARY KEY (`id`)); +INSERT INTO migrations VALUES('202312101416'); +INSERT INTO migrations VALUES('202312101430'); +INSERT INTO migrations VALUES('202402151347'); +INSERT INTO migrations VALUES('2024041121742'); +INSERT INTO migrations VALUES('202406021630'); +INSERT INTO migrations VALUES('202409271400'); +INSERT INTO migrations VALUES('202407191627'); +INSERT INTO migrations VALUES('202408181235'); +INSERT INTO migrations VALUES('202501221827'); +INSERT INTO migrations VALUES('202501311657'); +INSERT INTO migrations VALUES('202502070949'); +INSERT INTO migrations VALUES('202502131714'); +INSERT INTO migrations VALUES('202502171819'); +INSERT INTO migrations VALUES('202505091439'); +INSERT INTO migrations VALUES('202505141324'); +CREATE TABLE `users` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`name` text,`display_name` text,`email` text,`provider_identifier` text,`provider` text,`profile_pic_url` text); +CREATE TABLE `pre_auth_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`key` text,`user_id` integer,`reusable` numeric,`ephemeral` numeric DEFAULT false,`used` numeric DEFAULT false,`tags` text,`created_at` datetime,`expiration` datetime,CONSTRAINT `fk_pre_auth_keys_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE SET NULL); +CREATE TABLE `api_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`prefix` text,`hash` blob,`created_at` datetime,`expiration` datetime,`last_seen` datetime); +CREATE TABLE IF NOT EXISTS "nodes" (`id` integer PRIMARY KEY AUTOINCREMENT,`machine_key` text,`node_key` text,`disco_key` text,`endpoints` text,`host_info` text,`ipv4` text,`ipv6` text,`hostname` text,`given_name` varchar(63),`user_id` integer,`register_method` text,`forced_tags` text,`auth_key_id` integer,`expiry` datetime,`last_seen` datetime,`approved_routes` text,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,CONSTRAINT `fk_nodes_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE CASCADE,CONSTRAINT `fk_nodes_auth_key` FOREIGN KEY (`auth_key_id`) REFERENCES `pre_auth_keys`(`id`)); +CREATE TABLE `policies` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`data` text); +DELETE FROM sqlite_sequence; +INSERT INTO sqlite_sequence VALUES('nodes',0); +CREATE INDEX `idx_users_deleted_at` ON `users`(`deleted_at`); +CREATE UNIQUE INDEX `idx_api_keys_prefix` ON `api_keys`(`prefix`); +CREATE INDEX `idx_policies_deleted_at` ON `policies`(`deleted_at`); +CREATE UNIQUE INDEX idx_provider_identifier ON users (provider_identifier) WHERE provider_identifier IS NOT NULL; +CREATE UNIQUE INDEX idx_name_provider_identifier ON users (name,provider_identifier); +CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users (name) WHERE provider_identifier IS NULL; + +-- Create all the old tables we have had and ensure they are clean up. +CREATE TABLE `namespaces` (`id` text,PRIMARY KEY (`id`)); +CREATE TABLE `machines` (`id` text,PRIMARY KEY (`id`)); +CREATE TABLE `kvs` (`id` text,PRIMARY KEY (`id`)); +CREATE TABLE `shared_machines` (`id` text,PRIMARY KEY (`id`)); +CREATE TABLE `pre_auth_key_acl_tags` (`id` text,PRIMARY KEY (`id`)); +CREATE TABLE `routes` (`id` text,PRIMARY KEY (`id`)); +COMMIT; diff --git a/hscontrol/db/testdata/sqlite/headscale_0.26.1_schema-litestream.sql b/hscontrol/db/testdata/sqlite/headscale_0.26.1_schema-litestream.sql new file mode 100644 index 00000000..3fc2b319 --- /dev/null +++ b/hscontrol/db/testdata/sqlite/headscale_0.26.1_schema-litestream.sql @@ -0,0 +1,14 @@ +CREATE TABLE `migrations` (`id` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`name` text,`display_name` text,`email` text,`provider_identifier` text,`provider` text,`profile_pic_url` text); +CREATE INDEX `idx_users_deleted_at` ON `users`(`deleted_at`); +CREATE TABLE `pre_auth_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`key` text,`user_id` integer,`reusable` numeric,`ephemeral` numeric DEFAULT false,`used` numeric DEFAULT false,`tags` text,`created_at` datetime,`expiration` datetime,CONSTRAINT `fk_pre_auth_keys_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE SET NULL); +CREATE TABLE `api_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`prefix` text,`hash` blob,`created_at` datetime,`expiration` datetime,`last_seen` datetime); +CREATE UNIQUE INDEX `idx_api_keys_prefix` ON `api_keys`(`prefix`); +CREATE TABLE IF NOT EXISTS "nodes" (`id` integer PRIMARY KEY AUTOINCREMENT,`machine_key` text,`node_key` text,`disco_key` text,`endpoints` text,`host_info` text,`ipv4` text,`ipv6` text,`hostname` text,`given_name` varchar(63),`user_id` integer,`register_method` text,`forced_tags` text,`auth_key_id` integer,`expiry` datetime,`last_seen` datetime,`approved_routes` text,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,CONSTRAINT `fk_nodes_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE CASCADE,CONSTRAINT `fk_nodes_auth_key` FOREIGN KEY (`auth_key_id`) REFERENCES `pre_auth_keys`(`id`)); +CREATE TABLE `policies` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`data` text); +CREATE INDEX `idx_policies_deleted_at` ON `policies`(`deleted_at`); +CREATE UNIQUE INDEX idx_provider_identifier ON users (provider_identifier) WHERE provider_identifier IS NOT NULL; +CREATE UNIQUE INDEX idx_name_provider_identifier ON users (name,provider_identifier); +CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users (name) WHERE provider_identifier IS NULL; +CREATE TABLE _litestream_seq (id INTEGER PRIMARY KEY, seq INTEGER); +CREATE TABLE _litestream_lock (id INTEGER); diff --git a/hscontrol/derp/derp.go b/hscontrol/derp/derp.go index 6c8244f5..42d74abe 100644 --- a/hscontrol/derp/derp.go +++ b/hscontrol/derp/derp.go @@ -12,6 +12,7 @@ import ( "net/url" "os" "reflect" + "slices" "sync" "time" @@ -126,7 +127,17 @@ func shuffleDERPMap(dm *tailcfg.DERPMap) { return } - for id, region := range dm.Regions { + // Collect region IDs and sort them to ensure deterministic iteration order. + // Map iteration order is non-deterministic in Go, which would cause the + // shuffle to be non-deterministic even with a fixed seed. + ids := make([]int, 0, len(dm.Regions)) + for id := range dm.Regions { + ids = append(ids, id) + } + slices.Sort(ids) + + for _, id := range ids { + region := dm.Regions[id] if len(region.Nodes) == 0 { continue } diff --git a/hscontrol/derp/derp_test.go b/hscontrol/derp/derp_test.go index 9334de05..91d605a6 100644 --- a/hscontrol/derp/derp_test.go +++ b/hscontrol/derp/derp_test.go @@ -83,9 +83,9 @@ func TestShuffleDERPMapDeterministic(t *testing.T) { RegionCode: "sea", RegionName: "Seattle", Nodes: []*tailcfg.DERPNode{ - {Name: "10b", RegionID: 10, HostName: "derp10b.tailscale.com"}, - {Name: "10c", RegionID: 10, HostName: "derp10c.tailscale.com"}, {Name: "10d", RegionID: 10, HostName: "derp10d.tailscale.com"}, + {Name: "10c", RegionID: 10, HostName: "derp10c.tailscale.com"}, + {Name: "10b", RegionID: 10, HostName: "derp10b.tailscale.com"}, }, }, 2: { @@ -93,9 +93,9 @@ func TestShuffleDERPMapDeterministic(t *testing.T) { RegionCode: "sfo", RegionName: "San Francisco", Nodes: []*tailcfg.DERPNode{ - {Name: "2f", RegionID: 2, HostName: "derp2f.tailscale.com"}, - {Name: "2e", RegionID: 2, HostName: "derp2e.tailscale.com"}, {Name: "2d", RegionID: 2, HostName: "derp2d.tailscale.com"}, + {Name: "2e", RegionID: 2, HostName: "derp2e.tailscale.com"}, + {Name: "2f", RegionID: 2, HostName: "derp2f.tailscale.com"}, }, }, }, @@ -169,6 +169,74 @@ func TestShuffleDERPMapDeterministic(t *testing.T) { }, }, }, + { + name: "same dataset with another base domain", + baseDomain: "another.example.com", + derpMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 4: { + RegionID: 4, + RegionCode: "fra", + RegionName: "Frankfurt", + Nodes: []*tailcfg.DERPNode{ + {Name: "4f", RegionID: 4, HostName: "derp4f.tailscale.com"}, + {Name: "4g", RegionID: 4, HostName: "derp4g.tailscale.com"}, + {Name: "4h", RegionID: 4, HostName: "derp4h.tailscale.com"}, + {Name: "4i", RegionID: 4, HostName: "derp4i.tailscale.com"}, + }, + }, + }, + }, + expected: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 4: { + RegionID: 4, + RegionCode: "fra", + RegionName: "Frankfurt", + Nodes: []*tailcfg.DERPNode{ + {Name: "4h", RegionID: 4, HostName: "derp4h.tailscale.com"}, + {Name: "4f", RegionID: 4, HostName: "derp4f.tailscale.com"}, + {Name: "4g", RegionID: 4, HostName: "derp4g.tailscale.com"}, + {Name: "4i", RegionID: 4, HostName: "derp4i.tailscale.com"}, + }, + }, + }, + }, + }, + { + name: "same dataset with yet another base domain", + baseDomain: "yetanother.example.com", + derpMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 4: { + RegionID: 4, + RegionCode: "fra", + RegionName: "Frankfurt", + Nodes: []*tailcfg.DERPNode{ + {Name: "4f", RegionID: 4, HostName: "derp4f.tailscale.com"}, + {Name: "4g", RegionID: 4, HostName: "derp4g.tailscale.com"}, + {Name: "4h", RegionID: 4, HostName: "derp4h.tailscale.com"}, + {Name: "4i", RegionID: 4, HostName: "derp4i.tailscale.com"}, + }, + }, + }, + }, + expected: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 4: { + RegionID: 4, + RegionCode: "fra", + RegionName: "Frankfurt", + Nodes: []*tailcfg.DERPNode{ + {Name: "4i", RegionID: 4, HostName: "derp4i.tailscale.com"}, + {Name: "4h", RegionID: 4, HostName: "derp4h.tailscale.com"}, + {Name: "4f", RegionID: 4, HostName: "derp4f.tailscale.com"}, + {Name: "4g", RegionID: 4, HostName: "derp4g.tailscale.com"}, + }, + }, + }, + }, + }, } for _, tt := range tests { diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 1d620ba6..6d5189b8 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -416,9 +416,12 @@ func (api headscaleV1APIServer) ExpireNode( ctx context.Context, request *v1.ExpireNodeRequest, ) (*v1.ExpireNodeResponse, error) { - now := time.Now() + expiry := time.Now() + if request.GetExpiry() != nil { + expiry = request.GetExpiry().AsTime() + } - node, nodeChange, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now) + node, nodeChange, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), expiry) if err != nil { return nil, err } diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index ac96028e..3a3b39d1 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -108,11 +108,12 @@ func TestTailNode(t *testing.T) { Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ tsaddr.AllIPv4(), + tsaddr.AllIPv6(), netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("172.0.0.0/10"), }, }, - ApprovedRoutes: []netip.Prefix{tsaddr.AllIPv4(), netip.MustParsePrefix("192.168.0.0/24")}, + ApprovedRoutes: []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6(), netip.MustParsePrefix("192.168.0.0/24")}, CreatedAt: created, }, dnsConfig: &tailcfg.DNSConfig{}, @@ -150,6 +151,7 @@ func TestTailNode(t *testing.T) { Hostinfo: hiview(tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ tsaddr.AllIPv4(), + tsaddr.AllIPv6(), netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("172.0.0.0/10"), }, diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go index aac5a5f3..afc3cf68 100644 --- a/hscontrol/policy/matcher/matcher.go +++ b/hscontrol/policy/matcher/matcher.go @@ -7,6 +7,7 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" + "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" ) @@ -91,3 +92,12 @@ func (m *Match) SrcsOverlapsPrefixes(prefixes ...netip.Prefix) bool { func (m *Match) DestsOverlapsPrefixes(prefixes ...netip.Prefix) bool { return slices.ContainsFunc(prefixes, m.dests.OverlapsPrefix) } + +// DestsIsTheInternet reports if the destination is equal to "the internet" +// which is a IPSet that represents "autogroup:internet" and is special +// cased for exit nodes. +func (m Match) DestsIsTheInternet() bool { + return m.dests.Equal(util.TheInternet()) || + m.dests.ContainsPrefix(tsaddr.AllIPv4()) || + m.dests.ContainsPrefix(tsaddr.AllIPv6()) +} diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index b849d470..c016fa58 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -10,6 +10,7 @@ import ( "github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" "tailscale.com/tailcfg" @@ -782,12 +783,287 @@ func TestReduceNodes(t *testing.T) { got = append(got, v.AsStruct()) } if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { - t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff) + t.Errorf("ReduceNodes() unexpected result (-want +got):\n%s", diff) + t.Log("Matchers: ") + for _, m := range matchers { + t.Log("\t+", m.DebugString()) + } } }) } } +func TestReduceNodesFromPolicy(t *testing.T) { + n := func(id types.NodeID, ip, hostname, username string, routess ...string) *types.Node { + var routes []netip.Prefix + for _, route := range routess { + routes = append(routes, netip.MustParsePrefix(route)) + } + + return &types.Node{ + ID: id, + IPv4: ap(ip), + Hostname: hostname, + User: types.User{Name: username}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: routes, + }, + ApprovedRoutes: routes, + } + } + + type args struct { + } + tests := []struct { + name string + nodes types.Nodes + policy string + node *types.Node + want types.Nodes + wantMatchers int + }{ + { + name: "2788-exit-node-too-visible", + nodes: types.Nodes{ + n(1, "100.64.0.1", "mobile", "mobile"), + n(2, "100.64.0.2", "server", "server"), + n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"), + }, + policy: ` +{ + "hosts": { + "mobile": "100.64.0.1/32", + "server": "100.64.0.2/32", + "exit": "100.64.0.3/32" + }, + + "acls": [ + { + "action": "accept", + "src": [ + "mobile" + ], + "dst": [ + "server:80" + ] + } + ] +}`, + node: n(1, "100.64.0.1", "mobile", "mobile"), + want: types.Nodes{ + n(2, "100.64.0.2", "server", "server"), + }, + wantMatchers: 1, + }, + { + name: "2788-exit-node-autogroup:internet", + nodes: types.Nodes{ + n(1, "100.64.0.1", "mobile", "mobile"), + n(2, "100.64.0.2", "server", "server"), + n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"), + }, + policy: ` +{ + "hosts": { + "mobile": "100.64.0.1/32", + "server": "100.64.0.2/32", + "exit": "100.64.0.3/32" + }, + + "acls": [ + { + "action": "accept", + "src": [ + "mobile" + ], + "dst": [ + "server:80" + ] + }, + { + "action": "accept", + "src": [ + "mobile" + ], + "dst": [ + "autogroup:internet:*" + ] + } + ] +}`, + node: n(1, "100.64.0.1", "mobile", "mobile"), + want: types.Nodes{ + n(2, "100.64.0.2", "server", "server"), + n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"), + }, + wantMatchers: 2, + }, + { + name: "2788-exit-node-0000-route", + nodes: types.Nodes{ + n(1, "100.64.0.1", "mobile", "mobile"), + n(2, "100.64.0.2", "server", "server"), + n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"), + }, + policy: ` +{ + "hosts": { + "mobile": "100.64.0.1/32", + "server": "100.64.0.2/32", + "exit": "100.64.0.3/32" + }, + + "acls": [ + { + "action": "accept", + "src": [ + "mobile" + ], + "dst": [ + "server:80" + ] + }, + { + "action": "accept", + "src": [ + "mobile" + ], + "dst": [ + "0.0.0.0/0:*" + ] + } + ] +}`, + node: n(1, "100.64.0.1", "mobile", "mobile"), + want: types.Nodes{ + n(2, "100.64.0.2", "server", "server"), + n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"), + }, + wantMatchers: 2, + }, + { + name: "2788-exit-node-::0-route", + nodes: types.Nodes{ + n(1, "100.64.0.1", "mobile", "mobile"), + n(2, "100.64.0.2", "server", "server"), + n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"), + }, + policy: ` +{ + "hosts": { + "mobile": "100.64.0.1/32", + "server": "100.64.0.2/32", + "exit": "100.64.0.3/32" + }, + + "acls": [ + { + "action": "accept", + "src": [ + "mobile" + ], + "dst": [ + "server:80" + ] + }, + { + "action": "accept", + "src": [ + "mobile" + ], + "dst": [ + "::0/0:*" + ] + } + ] +}`, + node: n(1, "100.64.0.1", "mobile", "mobile"), + want: types.Nodes{ + n(2, "100.64.0.2", "server", "server"), + n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"), + }, + wantMatchers: 2, + }, + { + name: "2784-split-exit-node-access", + nodes: types.Nodes{ + n(1, "100.64.0.1", "user", "user"), + n(2, "100.64.0.2", "exit1", "exit", "0.0.0.0/0", "::/0"), + n(3, "100.64.0.3", "exit2", "exit", "0.0.0.0/0", "::/0"), + n(4, "100.64.0.4", "otheruser", "otheruser"), + }, + policy: ` +{ + "hosts": { + "user": "100.64.0.1/32", + "exit1": "100.64.0.2/32", + "exit2": "100.64.0.3/32", + "otheruser": "100.64.0.4/32", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "user" + ], + "dst": [ + "exit1:*" + ] + }, + { + "action": "accept", + "src": [ + "otheruser" + ], + "dst": [ + "exit2:*" + ] + } + ] +}`, + node: n(1, "100.64.0.1", "user", "user"), + want: types.Nodes{ + n(2, "100.64.0.2", "exit1", "exit", "0.0.0.0/0", "::/0"), + }, + wantMatchers: 2, + }, + } + + for _, tt := range tests { + for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) { + t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { + var pm PolicyManager + var err error + pm, err = pmf(nil, tt.nodes.ViewSlice()) + require.NoError(t, err) + + matchers, err := pm.MatchersForNode(tt.node.View()) + require.NoError(t, err) + assert.Len(t, matchers, tt.wantMatchers) + + gotViews := ReduceNodes( + tt.node.View(), + tt.nodes.ViewSlice(), + matchers, + ) + // Convert views back to nodes for comparison in tests + var got types.Nodes + for _, v := range gotViews.All() { + got = append(got, v.AsStruct()) + } + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { + t.Errorf("TestReduceNodesFromPolicy() unexpected result (-want +got):\n%s", diff) + t.Log("Matchers: ") + for _, m := range matchers { + t.Log("\t+", m.DebugString()) + } + } + }) + } + } +} + func TestSSHPolicyRules(t *testing.T) { users := []types.User{ {Name: "user1", Model: gorm.Model{ID: 1}}, diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index abdd4ffb..bb7d089a 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -99,14 +99,16 @@ func (pol *Policy) compileFilterRulesForNode( return nil, ErrInvalidAction } - rule, err := pol.compileACLWithAutogroupSelf(acl, users, node, nodes) + aclRules, err := pol.compileACLWithAutogroupSelf(acl, users, node, nodes) if err != nil { log.Trace().Err(err).Msgf("compiling ACL") continue } - if rule != nil { - rules = append(rules, *rule) + for _, rule := range aclRules { + if rule != nil { + rules = append(rules, *rule) + } } } @@ -115,27 +117,32 @@ func (pol *Policy) compileFilterRulesForNode( // compileACLWithAutogroupSelf compiles a single ACL rule, handling // autogroup:self per-node while supporting all other alias types normally. +// It returns a slice of filter rules because when an ACL has both autogroup:self +// and other destinations, they need to be split into separate rules with different +// source filtering logic. func (pol *Policy) compileACLWithAutogroupSelf( acl ACL, users types.Users, node types.NodeView, nodes views.Slice[types.NodeView], -) (*tailcfg.FilterRule, error) { - // Check if any destination uses autogroup:self - hasAutogroupSelfInDst := false +) ([]*tailcfg.FilterRule, error) { + var autogroupSelfDests []AliasWithPorts + var otherDests []AliasWithPorts for _, dest := range acl.Destinations { if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { - hasAutogroupSelfInDst = true - break + autogroupSelfDests = append(autogroupSelfDests, dest) + } else { + otherDests = append(otherDests, dest) } } - var srcIPs netipx.IPSetBuilder + protocols, _ := acl.Protocol.parseProtocol() + var rules []*tailcfg.FilterRule + + var resolvedSrcIPs []*netipx.IPSet - // Resolve sources to only include devices from the same user as the target node. for _, src := range acl.Sources { - // autogroup:self is not allowed in sources if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { return nil, fmt.Errorf("autogroup:self cannot be used in sources") } @@ -147,92 +154,121 @@ func (pol *Policy) compileACLWithAutogroupSelf( } if ips != nil { - if hasAutogroupSelfInDst { - // Instead of iterating all addresses (which could be millions), - // check each node's IPs against the source set - for _, n := range nodes.All() { - if n.User().ID == node.User().ID && !n.IsTagged() { - // Check if any of this node's IPs are in the source set - for _, nodeIP := range n.IPs() { - if ips.Contains(nodeIP) { - n.AppendToIPSet(&srcIPs) - break // Found this node, move to next - } - } - } - } - } else { - // No autogroup:self in destination, use all resolved sources - srcIPs.AddSet(ips) - } + resolvedSrcIPs = append(resolvedSrcIPs, ips) } } - srcSet, err := srcIPs.IPSet() - if err != nil { - return nil, err + if len(resolvedSrcIPs) == 0 { + return rules, nil } - if srcSet == nil || len(srcSet.Prefixes()) == 0 { - // No sources resolved, skip this rule - return nil, nil //nolint:nilnil - } + // Handle autogroup:self destinations (if any) + if len(autogroupSelfDests) > 0 { + // Pre-filter to same-user untagged devices once - reuse for both sources and destinations + sameUserNodes := make([]types.NodeView, 0) + for _, n := range nodes.All() { + if n.User().ID == node.User().ID && !n.IsTagged() { + sameUserNodes = append(sameUserNodes, n) + } + } - protocols, _ := acl.Protocol.parseProtocol() - - var destPorts []tailcfg.NetPortRange - - for _, dest := range acl.Destinations { - if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { - for _, n := range nodes.All() { - if n.User().ID == node.User().ID && !n.IsTagged() { - for _, port := range dest.Ports { - for _, ip := range n.IPs() { - pr := tailcfg.NetPortRange{ - IP: ip.String(), - Ports: port, - } - destPorts = append(destPorts, pr) + if len(sameUserNodes) > 0 { + // Filter sources to only same-user untagged devices + var srcIPs netipx.IPSetBuilder + for _, ips := range resolvedSrcIPs { + for _, n := range sameUserNodes { + // Check if any of this node's IPs are in the source set + for _, nodeIP := range n.IPs() { + if ips.Contains(nodeIP) { + n.AppendToIPSet(&srcIPs) + break } } } } - } else { - ips, err := dest.Resolve(pol, users, nodes) + + srcSet, err := srcIPs.IPSet() if err != nil { - log.Trace().Err(err).Msgf("resolving destination ips") - continue + return nil, err } - if ips == nil { - log.Debug().Msgf("destination resolved to nil ips: %v", dest) - continue - } - - prefixes := ips.Prefixes() - - for _, pref := range prefixes { - for _, port := range dest.Ports { - pr := tailcfg.NetPortRange{ - IP: pref.String(), - Ports: port, + if srcSet != nil && len(srcSet.Prefixes()) > 0 { + var destPorts []tailcfg.NetPortRange + for _, dest := range autogroupSelfDests { + for _, n := range sameUserNodes { + for _, port := range dest.Ports { + for _, ip := range n.IPs() { + destPorts = append(destPorts, tailcfg.NetPortRange{ + IP: ip.String(), + Ports: port, + }) + } + } } - destPorts = append(destPorts, pr) + } + + if len(destPorts) > 0 { + rules = append(rules, &tailcfg.FilterRule{ + SrcIPs: ipSetToPrefixStringList(srcSet), + DstPorts: destPorts, + IPProto: protocols, + }) } } } } - if len(destPorts) == 0 { - // No destinations resolved, skip this rule - return nil, nil //nolint:nilnil + if len(otherDests) > 0 { + var srcIPs netipx.IPSetBuilder + + for _, ips := range resolvedSrcIPs { + srcIPs.AddSet(ips) + } + + srcSet, err := srcIPs.IPSet() + if err != nil { + return nil, err + } + + if srcSet != nil && len(srcSet.Prefixes()) > 0 { + var destPorts []tailcfg.NetPortRange + + for _, dest := range otherDests { + ips, err := dest.Resolve(pol, users, nodes) + if err != nil { + log.Trace().Err(err).Msgf("resolving destination ips") + continue + } + + if ips == nil { + log.Debug().Msgf("destination resolved to nil ips: %v", dest) + continue + } + + prefixes := ips.Prefixes() + + for _, pref := range prefixes { + for _, port := range dest.Ports { + pr := tailcfg.NetPortRange{ + IP: pref.String(), + Ports: port, + } + destPorts = append(destPorts, pr) + } + } + } + + if len(destPorts) > 0 { + rules = append(rules, &tailcfg.FilterRule{ + SrcIPs: ipSetToPrefixStringList(srcSet), + DstPorts: destPorts, + IPProto: protocols, + }) + } + } } - return &tailcfg.FilterRule{ - SrcIPs: ipSetToPrefixStringList(srcSet), - DstPorts: destPorts, - IPProto: protocols, - }, nil + return rules, nil } func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction { @@ -260,46 +296,30 @@ func (pol *Policy) compileSSHPolicy( var rules []*tailcfg.SSHRule for index, rule := range pol.SSHs { - // Check if any destination uses autogroup:self - hasAutogroupSelfInDst := false + // Separate destinations into autogroup:self and others + // This is needed because autogroup:self requires filtering sources to same-user only, + // while other destinations should use all resolved sources + var autogroupSelfDests []Alias + var otherDests []Alias + for _, dst := range rule.Destinations { if ag, ok := dst.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { - hasAutogroupSelfInDst = true - break - } - } - - // If autogroup:self is used, skip tagged nodes - if hasAutogroupSelfInDst && node.IsTagged() { - continue - } - - var dest netipx.IPSetBuilder - for _, src := range rule.Destinations { - // Handle autogroup:self specially - if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { - // For autogroup:self, only include the target user's untagged devices - for _, n := range nodes.All() { - if n.User().ID == node.User().ID && !n.IsTagged() { - n.AppendToIPSet(&dest) - } - } + autogroupSelfDests = append(autogroupSelfDests, dst) } else { - ips, err := src.Resolve(pol, users, nodes) - if err != nil { - log.Trace().Caller().Err(err).Msgf("resolving destination ips") - continue - } - dest.AddSet(ips) + otherDests = append(otherDests, dst) } } - destSet, err := dest.IPSet() + // Note: Tagged nodes can't match autogroup:self destinations, but can still match other destinations + + // Resolve sources once - we'll use them differently for each destination type + srcIPs, err := rule.Sources.Resolve(pol, users, nodes) if err != nil { - return nil, err + log.Trace().Caller().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule) + continue // Skip this rule if we can't resolve sources } - if !node.InIPSet(destSet) { + if srcIPs == nil || len(srcIPs.Prefixes()) == 0 { continue } @@ -313,50 +333,9 @@ func (pol *Policy) compileSSHPolicy( return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) } - var principals []*tailcfg.SSHPrincipal - srcIPs, err := rule.Sources.Resolve(pol, users, nodes) - if err != nil { - log.Trace().Caller().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule) - continue // Skip this rule if we can't resolve sources - } - - // If autogroup:self is in destinations, filter sources to same user only - if hasAutogroupSelfInDst { - var filteredSrcIPs netipx.IPSetBuilder - // Instead of iterating all addresses, check each node's IPs - for _, n := range nodes.All() { - if n.User().ID == node.User().ID && !n.IsTagged() { - // Check if any of this node's IPs are in the source set - for _, nodeIP := range n.IPs() { - if srcIPs.Contains(nodeIP) { - n.AppendToIPSet(&filteredSrcIPs) - break // Found this node, move to next - } - } - } - } - - srcIPs, err = filteredSrcIPs.IPSet() - if err != nil { - return nil, err - } - - if srcIPs == nil || len(srcIPs.Prefixes()) == 0 { - // No valid sources after filtering, skip this rule - continue - } - } - - for addr := range util.IPSetAddrIter(srcIPs) { - principals = append(principals, &tailcfg.SSHPrincipal{ - NodeIP: addr.String(), - }) - } - userMap := make(map[string]string, len(rule.Users)) if rule.Users.ContainsNonRoot() { userMap["*"] = "=" - // by default, we do not allow root unless explicitly stated userMap["root"] = "" } @@ -366,11 +345,108 @@ func (pol *Policy) compileSSHPolicy( for _, u := range rule.Users.NormalUsers() { userMap[u.String()] = u.String() } - rules = append(rules, &tailcfg.SSHRule{ - Principals: principals, - SSHUsers: userMap, - Action: &action, - }) + + // Handle autogroup:self destinations (if any) + // Note: Tagged nodes can't match autogroup:self, so skip this block for tagged nodes + if len(autogroupSelfDests) > 0 && !node.IsTagged() { + // Build destination set for autogroup:self (same-user untagged devices only) + var dest netipx.IPSetBuilder + for _, n := range nodes.All() { + if n.User().ID == node.User().ID && !n.IsTagged() { + n.AppendToIPSet(&dest) + } + } + + destSet, err := dest.IPSet() + if err != nil { + return nil, err + } + + // Only create rule if this node is in the destination set + if node.InIPSet(destSet) { + // Filter sources to only same-user untagged devices + // Pre-filter to same-user untagged devices for efficiency + sameUserNodes := make([]types.NodeView, 0) + for _, n := range nodes.All() { + if n.User().ID == node.User().ID && !n.IsTagged() { + sameUserNodes = append(sameUserNodes, n) + } + } + + var filteredSrcIPs netipx.IPSetBuilder + for _, n := range sameUserNodes { + // Check if any of this node's IPs are in the source set + for _, nodeIP := range n.IPs() { + if srcIPs.Contains(nodeIP) { + n.AppendToIPSet(&filteredSrcIPs) + break // Found this node, move to next + } + } + } + + filteredSrcSet, err := filteredSrcIPs.IPSet() + if err != nil { + return nil, err + } + + if filteredSrcSet != nil && len(filteredSrcSet.Prefixes()) > 0 { + var principals []*tailcfg.SSHPrincipal + for addr := range util.IPSetAddrIter(filteredSrcSet) { + principals = append(principals, &tailcfg.SSHPrincipal{ + NodeIP: addr.String(), + }) + } + + if len(principals) > 0 { + rules = append(rules, &tailcfg.SSHRule{ + Principals: principals, + SSHUsers: userMap, + Action: &action, + }) + } + } + } + } + + // Handle other destinations (if any) + if len(otherDests) > 0 { + // Build destination set for other destinations + var dest netipx.IPSetBuilder + for _, dst := range otherDests { + ips, err := dst.Resolve(pol, users, nodes) + if err != nil { + log.Trace().Caller().Err(err).Msgf("resolving destination ips") + continue + } + if ips != nil { + dest.AddSet(ips) + } + } + + destSet, err := dest.IPSet() + if err != nil { + return nil, err + } + + // Only create rule if this node is in the destination set + if node.InIPSet(destSet) { + // For non-autogroup:self destinations, use all resolved sources (no filtering) + var principals []*tailcfg.SSHPrincipal + for addr := range util.IPSetAddrIter(srcIPs) { + principals = append(principals, &tailcfg.SSHPrincipal{ + NodeIP: addr.String(), + }) + } + + if len(principals) > 0 { + rules = append(rules, &tailcfg.SSHRule{ + Principals: principals, + SSHUsers: userMap, + Action: &action, + }) + } + } + } } return &tailcfg.SSHPolicy{ diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index 9f2845ac..37ff8730 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -1339,3 +1339,70 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { assert.Empty(t, sshPolicy2.Rules, "tagged node should get no SSH rules with autogroup:self") } } + +// TestSSHWithAutogroupSelfAndMixedDestinations tests that SSH rules can have both +// autogroup:self and other destinations (like tag:router) in the same rule, and that +// autogroup:self filtering only applies to autogroup:self destinations, not others. +func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + } + + nodes := types.Nodes{ + {User: users[0], IPv4: ap("100.64.0.1"), Hostname: "user1-device"}, + {User: users[0], IPv4: ap("100.64.0.2"), Hostname: "user1-device2"}, + {User: users[1], IPv4: ap("100.64.0.3"), Hostname: "user2-device"}, + {User: users[1], IPv4: ap("100.64.0.4"), Hostname: "user2-router", ForcedTags: []string{"tag:router"}}, + } + + policy := &Policy{ + TagOwners: TagOwners{ + Tag("tag:router"): Owners{up("user2@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{agp("autogroup:self"), tp("tag:router")}, + Users: []SSHUser{"admin"}, + }, + }, + } + + err := policy.validate() + require.NoError(t, err) + + // Test 1: Compile for user1's device (should only match autogroup:self destination) + node1 := nodes[0].View() + sshPolicy1, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicy1) + require.Len(t, sshPolicy1.Rules, 1, "user1's device should have 1 SSH rule (autogroup:self)") + + // Verify autogroup:self rule has filtered sources (only same-user devices) + selfRule := sshPolicy1.Rules[0] + require.Len(t, selfRule.Principals, 2, "autogroup:self rule should only have user1's devices") + selfPrincipals := make([]string, len(selfRule.Principals)) + for i, p := range selfRule.Principals { + selfPrincipals[i] = p.NodeIP + } + require.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, selfPrincipals, + "autogroup:self rule should only include same-user untagged devices") + + // Test 2: Compile for router (should only match tag:router destination) + routerNode := nodes[3].View() // user2-router + sshPolicyRouter, err := policy.compileSSHPolicy(users, routerNode, nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicyRouter) + require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)") + + routerRule := sshPolicyRouter.Rules[0] + routerPrincipals := make([]string, len(routerRule.Principals)) + for i, p := range routerRule.Principals { + routerPrincipals[i] = p.NodeIP + } + require.Contains(t, routerPrincipals, "100.64.0.1", "router rule should include user1's device (unfiltered sources)") + require.Contains(t, routerPrincipals, "100.64.0.2", "router rule should include user1's other device (unfiltered sources)") + require.Contains(t, routerPrincipals, "100.64.0.3", "router rule should include user2's device (unfiltered sources)") +} diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index 5191368a..bbde136e 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -2,6 +2,7 @@ package v2 import ( "net/netip" + "slices" "testing" "github.com/google/go-cmp/cmp" @@ -439,3 +440,82 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) { require.Empty(t, peerMap[node1.ID], "node1 should have no peers (can only reach itself)") require.Empty(t, peerMap[node2.ID], "node2 should have no peers") } + +// When separate ACL rules exist (one with autogroup:self, one with tag:router), +// the autogroup:self rule should not prevent the tag:router rule from working. +// This ensures that autogroup:self doesn't interfere with other ACL rules. +func TestAutogroupSelfWithOtherRules(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "test-1", Email: "test-1@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "test-2", Email: "test-2@example.com"}, + } + + // test-1 has a regular device + test1Node := &types.Node{ + ID: 1, + Hostname: "test-1-device", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: users[0], + UserID: users[0].ID, + Hostinfo: &tailcfg.Hostinfo{}, + } + + // test-2 has a router device with tag:node-router + test2RouterNode := &types.Node{ + ID: 2, + Hostname: "test-2-router", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: users[1], + UserID: users[1].ID, + ForcedTags: []string{"tag:node-router"}, + Hostinfo: &tailcfg.Hostinfo{}, + } + + nodes := types.Nodes{test1Node, test2RouterNode} + + // This matches the exact policy from issue #2838: + // - First rule: autogroup:member -> autogroup:self (allows users to see their own devices) + // - Second rule: group:home -> tag:node-router (should allow group members to see router) + policy := `{ + "groups": { + "group:home": ["test-1@example.com", "test-2@example.com"] + }, + "tagOwners": { + "tag:node-router": ["group:home"] + }, + "acls": [ + { + "action": "accept", + "src": ["autogroup:member"], + "dst": ["autogroup:self:*"] + }, + { + "action": "accept", + "src": ["group:home"], + "dst": ["tag:node-router:*"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(policy), users, nodes.ViewSlice()) + require.NoError(t, err) + + peerMap := pm.BuildPeerMap(nodes.ViewSlice()) + + // test-1 (in group:home) should see: + // 1. Their own node (from autogroup:self rule) + // 2. The router node (from group:home -> tag:node-router rule) + test1Peers := peerMap[test1Node.ID] + + // Verify test-1 can see the router (group:home -> tag:node-router rule) + require.True(t, slices.ContainsFunc(test1Peers, func(n types.NodeView) bool { + return n.ID() == test2RouterNode.ID + }), "test-1 should see test-2's router via group:home -> tag:node-router rule, even when autogroup:self rule exists (issue #2838)") + + // Verify that test-1 has filter rules (including autogroup:self and tag:node-router access) + rules, err := pm.FilterForNode(test1Node.View()) + require.NoError(t, err) + require.NotEmpty(t, rules, "test-1 should have filter rules from both ACL rules") +} diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 1d450cb6..c340adc2 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -456,9 +456,9 @@ func (s *State) Connect(id types.NodeID) []change.ChangeSet { log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node connected") // Use the node's current routes for primary route update - // SubnetRoutes() returns only the intersection of announced AND approved routes - // We MUST use SubnetRoutes() to maintain the security model - routeChange := s.primaryRoutes.SetRoutes(id, node.SubnetRoutes()...) + // AllApprovedRoutes() returns only the intersection of announced AND approved routes + // We MUST use AllApprovedRoutes() to maintain the security model + routeChange := s.primaryRoutes.SetRoutes(id, node.AllApprovedRoutes()...) if routeChange { c = append(c, change.NodeAdded(id)) @@ -656,7 +656,7 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (t // Update primary routes table based on SubnetRoutes (intersection of announced and approved). // The primary routes table is what the mapper uses to generate network maps, so updating it // here ensures that route changes are distributed to peers. - routeChange := s.primaryRoutes.SetRoutes(nodeID, nodeView.SubnetRoutes()...) + routeChange := s.primaryRoutes.SetRoutes(nodeID, nodeView.AllApprovedRoutes()...) // If routes changed or the changeset isn't already a full update, trigger a policy change // to ensure all nodes get updated network maps @@ -1711,7 +1711,7 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest } if needsRouteUpdate { - // SetNodeRoutes sets the active/distributed routes, so we must use SubnetRoutes() + // SetNodeRoutes sets the active/distributed routes, so we must use AllApprovedRoutes() // which returns only the intersection of announced AND approved routes. // Using AnnouncedRoutes() would bypass the security model and auto-approve everything. log.Debug(). @@ -1719,9 +1719,9 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest Uint64("node.id", id.Uint64()). Strs("announcedRoutes", util.PrefixesToString(updatedNode.AnnouncedRoutes())). Strs("approvedRoutes", util.PrefixesToString(updatedNode.ApprovedRoutes().AsSlice())). - Strs("subnetRoutes", util.PrefixesToString(updatedNode.SubnetRoutes())). + Strs("allApprovedRoutes", util.PrefixesToString(updatedNode.AllApprovedRoutes())). Msg("updating node routes for distribution") - nodeRouteChange = s.SetNodeRoutes(id, updatedNode.SubnetRoutes()...) + nodeRouteChange = s.SetNodeRoutes(id, updatedNode.AllApprovedRoutes()...) } _, policyChange, err := s.persistNodeToDB(updatedNode) diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 8cf40ced..c6429669 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -269,11 +269,19 @@ func (node *Node) Prefixes() []netip.Prefix { // node has any exit routes enabled. // If none are enabled, it will return nil. func (node *Node) ExitRoutes() []netip.Prefix { - if slices.ContainsFunc(node.SubnetRoutes(), tsaddr.IsExitRoute) { - return tsaddr.ExitRoutes() + var routes []netip.Prefix + + for _, route := range node.AnnouncedRoutes() { + if tsaddr.IsExitRoute(route) && slices.Contains(node.ApprovedRoutes, route) { + routes = append(routes, route) + } } - return nil + return routes +} + +func (node *Node) IsExitNode() bool { + return len(node.ExitRoutes()) > 0 } func (node *Node) IPsAsString() []string { @@ -311,9 +319,16 @@ func (node *Node) CanAccess(matchers []matcher.Match, node2 *Node) bool { return true } + // Check if the node has access to routes that might be part of a + // smaller subnet that is served from node2 as a subnet router. if matcher.DestsOverlapsPrefixes(node2.SubnetRoutes()...) { return true } + + // If the dst is "the internet" and node2 is an exit node, allow access. + if matcher.DestsIsTheInternet() && node2.IsExitNode() { + return true + } } return false @@ -440,16 +455,22 @@ func (node *Node) AnnouncedRoutes() []netip.Prefix { return node.Hostinfo.RoutableIPs } -// SubnetRoutes returns the list of routes that the node announces and are approved. +// SubnetRoutes returns the list of routes (excluding exit routes) that the node +// announces and are approved. // -// IMPORTANT: This method is used for internal data structures and should NOT be used -// for the gRPC Proto conversion. For Proto, SubnetRoutes must be populated manually -// with PrimaryRoutes to ensure it includes only routes actively served by the node. -// See the comment in Proto() method and the implementation in grpcv1.go/nodesToProto. +// IMPORTANT: This method is used for internal data structures and should NOT be +// used for the gRPC Proto conversion. For Proto, SubnetRoutes must be populated +// manually with PrimaryRoutes to ensure it includes only routes actively served +// by the node. See the comment in Proto() method and the implementation in +// grpcv1.go/nodesToProto. func (node *Node) SubnetRoutes() []netip.Prefix { var routes []netip.Prefix for _, route := range node.AnnouncedRoutes() { + if tsaddr.IsExitRoute(route) { + continue + } + if slices.Contains(node.ApprovedRoutes, route) { routes = append(routes, route) } @@ -463,6 +484,11 @@ func (node *Node) IsSubnetRouter() bool { return len(node.SubnetRoutes()) > 0 } +// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes +func (node *Node) AllApprovedRoutes() []netip.Prefix { + return append(node.SubnetRoutes(), node.ExitRoutes()...) +} + func (node *Node) String() string { return node.Hostname } @@ -653,6 +679,7 @@ func (node Node) DebugString() string { fmt.Fprintf(&sb, "\tApprovedRoutes: %v\n", node.ApprovedRoutes) fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes()) fmt.Fprintf(&sb, "\tSubnetRoutes: %v\n", node.SubnetRoutes()) + fmt.Fprintf(&sb, "\tExitRoutes: %v\n", node.ExitRoutes()) sb.WriteString("\n") return sb.String() @@ -678,27 +705,11 @@ func (v NodeView) InIPSet(set *netipx.IPSet) bool { } func (v NodeView) CanAccess(matchers []matcher.Match, node2 NodeView) bool { - if !v.Valid() || !node2.Valid() { + if !v.Valid() { return false } - src := v.IPs() - allowedIPs := node2.IPs() - for _, matcher := range matchers { - if !matcher.SrcsContainsIPs(src...) { - continue - } - - if matcher.DestsContainsIP(allowedIPs...) { - return true - } - - if matcher.DestsOverlapsPrefixes(node2.SubnetRoutes()...) { - return true - } - } - - return false + return v.ж.CanAccess(matchers, node2.AsStruct()) } func (v NodeView) CanAccessRoute(matchers []matcher.Match, route netip.Prefix) bool { @@ -730,6 +741,13 @@ func (v NodeView) IsSubnetRouter() bool { return v.ж.IsSubnetRouter() } +func (v NodeView) AllApprovedRoutes() []netip.Prefix { + if !v.Valid() { + return nil + } + return v.ж.AllApprovedRoutes() +} + func (v NodeView) AppendToIPSet(build *netipx.IPSetBuilder) { if !v.Valid() { return @@ -808,6 +826,13 @@ func (v NodeView) ExitRoutes() []netip.Prefix { return v.ж.ExitRoutes() } +func (v NodeView) IsExitNode() bool { + if !v.Valid() { + return false + } + return v.ж.IsExitNode() +} + // RequestTags returns the ACL tags that the node is requesting. func (v NodeView) RequestTags() []string { if !v.Valid() || !v.Hostinfo().Valid() { diff --git a/integration/acl_test.go b/integration/acl_test.go index 122eeea7..50924891 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -1611,37 +1611,170 @@ func TestACLAutogroupTagged(t *testing.T) { } // Test that only devices owned by the same user can access each other and cannot access devices of other users +// Test structure: +// - user1: 2 regular nodes (tests autogroup:self for same-user access) +// - user2: 2 regular nodes (tests autogroup:self for same-user access and cross-user isolation) +// - user-router: 1 node with tag:router-node (tests that autogroup:self doesn't interfere with other rules) func TestACLAutogroupSelf(t *testing.T) { IntegrationSkip(t) - scenario := aclScenario(t, - &policyv2.Policy{ - ACLs: []policyv2.ACL{ - { - Action: "accept", - Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupMember)}, - Destinations: []policyv2.AliasWithPorts{ - aliasWithPorts(ptr.To(policyv2.AutoGroupSelf), tailcfg.PortRangeAny), - }, + // Policy with TWO separate ACL rules: + // 1. autogroup:member -> autogroup:self (same-user access) + // 2. group:home -> tag:router-node (router access) + // This tests that autogroup:self doesn't prevent other rules from working + policy := &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:home"): []policyv2.Username{ + policyv2.Username("user1@"), + policyv2.Username("user2@"), + }, + }, + TagOwners: policyv2.TagOwners{ + policyv2.Tag("tag:router-node"): policyv2.Owners{ + usernameOwner("user-router@"), + }, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupMember)}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(ptr.To(policyv2.AutoGroupSelf), tailcfg.PortRangeAny), + }, + }, + { + Action: "accept", + Sources: []policyv2.Alias{groupp("group:home")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(tagp("tag:router-node"), tailcfg.PortRangeAny), + }, + }, + { + Action: "accept", + Sources: []policyv2.Alias{tagp("tag:router-node")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(groupp("group:home"), tailcfg.PortRangeAny), }, }, }, - 2, - ) + } + + // Create custom scenario: user1 and user2 with regular nodes, plus user-router with tagged node + spec := ScenarioSpec{ + NodesPerUser: 2, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) - err := scenario.WaitForTailscaleSyncWithPeerCount(1, integrationutil.PeerSyncTimeout(), integrationutil.PeerSyncRetryInterval()) + err = scenario.CreateHeadscaleEnv( + []tsic.Option{ + tsic.WithNetfilter("off"), + tsic.WithDockerEntrypoint([]string{ + "/bin/sh", + "-c", + "/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev", + }), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithACLPolicy(policy), + hsic.WithTestName("acl-autogroup-self"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + ) require.NoError(t, err) + // Add router node for user-router (single shared router node) + networks := scenario.Networks() + var network *dockertest.Network + if len(networks) > 0 { + network = networks[0] + } + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + routerUser, err := scenario.CreateUser("user-router") + require.NoError(t, err) + + authKey, err := scenario.CreatePreAuthKey(routerUser.GetId(), true, false) + require.NoError(t, err) + + // Create router node (tagged with tag:router-node) + routerClient, err := tsic.New( + scenario.Pool(), + "unstable", + tsic.WithCACert(headscale.GetCert()), + tsic.WithHeadscaleName(headscale.GetHostname()), + tsic.WithNetwork(network), + tsic.WithTags([]string{"tag:router-node"}), + tsic.WithNetfilter("off"), + tsic.WithDockerEntrypoint([]string{ + "/bin/sh", + "-c", + "/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev", + }), + tsic.WithDockerWorkdir("/"), + ) + require.NoError(t, err) + + err = routerClient.WaitForNeedsLogin(integrationutil.PeerSyncTimeout()) + require.NoError(t, err) + + err = routerClient.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + err = routerClient.WaitForRunning(integrationutil.PeerSyncTimeout()) + require.NoError(t, err) + + userRouterObj := scenario.GetOrCreateUser("user-router") + userRouterObj.Clients[routerClient.Hostname()] = routerClient + user1Clients, err := scenario.GetClients("user1") require.NoError(t, err) - user2Clients, err := scenario.GetClients("user2") require.NoError(t, err) - // Test that user1's devices can access each other + var user1Regular, user2Regular []TailscaleClient for _, client := range user1Clients { - for _, peer := range user1Clients { + status, err := client.Status() + require.NoError(t, err) + if status.Self != nil && (status.Self.Tags == nil || status.Self.Tags.Len() == 0) { + user1Regular = append(user1Regular, client) + } + } + for _, client := range user2Clients { + status, err := client.Status() + require.NoError(t, err) + if status.Self != nil && (status.Self.Tags == nil || status.Self.Tags.Len() == 0) { + user2Regular = append(user2Regular, client) + } + } + + require.NotEmpty(t, user1Regular, "user1 should have regular (untagged) devices") + require.NotEmpty(t, user2Regular, "user2 should have regular (untagged) devices") + require.NotNil(t, routerClient, "router node should exist") + + // Wait for all nodes to sync with their expected peer counts + // With our ACL policy: + // - Regular nodes (user1/user2): 1 same-user regular peer + 1 router-node = 2 peers + // - Router node: 2 user1 regular + 2 user2 regular = 4 peers + for _, client := range user1Regular { + err := client.WaitForPeers(2, integrationutil.PeerSyncTimeout(), integrationutil.PeerSyncRetryInterval()) + require.NoError(t, err, "user1 regular device %s should see 2 peers (1 same-user peer + 1 router)", client.Hostname()) + } + for _, client := range user2Regular { + err := client.WaitForPeers(2, integrationutil.PeerSyncTimeout(), integrationutil.PeerSyncRetryInterval()) + require.NoError(t, err, "user2 regular device %s should see 2 peers (1 same-user peer + 1 router)", client.Hostname()) + } + err = routerClient.WaitForPeers(4, integrationutil.PeerSyncTimeout(), integrationutil.PeerSyncRetryInterval()) + require.NoError(t, err, "router should see 4 peers (all group:home regular nodes)") + + // Test that user1's regular devices can access each other + for _, client := range user1Regular { + for _, peer := range user1Regular { if client.Hostname() == peer.Hostname() { continue } @@ -1656,13 +1789,13 @@ func TestACLAutogroupSelf(t *testing.T) { result, err := client.Curl(url) assert.NoError(c, err) assert.Len(c, result, 13) - }, 10*time.Second, 200*time.Millisecond, "user1 device should reach other user1 device") + }, 10*time.Second, 200*time.Millisecond, "user1 device should reach other user1 device via autogroup:self") } } - // Test that user2's devices can access each other - for _, client := range user2Clients { - for _, peer := range user2Clients { + // Test that user2's regular devices can access each other + for _, client := range user2Regular { + for _, peer := range user2Regular { if client.Hostname() == peer.Hostname() { continue } @@ -1677,36 +1810,64 @@ func TestACLAutogroupSelf(t *testing.T) { result, err := client.Curl(url) assert.NoError(c, err) assert.Len(c, result, 13) - }, 10*time.Second, 200*time.Millisecond, "user2 device should reach other user2 device") + }, 10*time.Second, 200*time.Millisecond, "user2 device should reach other user2 device via autogroup:self") } } - // Test that devices from different users cannot access each other - for _, client := range user1Clients { - for _, peer := range user2Clients { + // Test that user1's regular devices can access router-node + for _, client := range user1Regular { + fqdn, err := routerClient.FQDN() + require.NoError(t, err) + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + t.Logf("url from %s (user1) to %s (router-node) - should SUCCEED", client.Hostname(), fqdn) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.NotEmpty(c, result, "user1 should be able to access router-node via group:home -> tag:router-node rule") + }, 10*time.Second, 200*time.Millisecond, "user1 device should reach router-node (proves autogroup:self doesn't interfere)") + } + + // Test that user2's regular devices can access router-node + for _, client := range user2Regular { + fqdn, err := routerClient.FQDN() + require.NoError(t, err) + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + t.Logf("url from %s (user2) to %s (router-node) - should SUCCEED", client.Hostname(), fqdn) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.NotEmpty(c, result, "user2 should be able to access router-node via group:home -> tag:router-node rule") + }, 10*time.Second, 200*time.Millisecond, "user2 device should reach router-node (proves autogroup:self doesn't interfere)") + } + + // Test that devices from different users cannot access each other's regular devices + for _, client := range user1Regular { + for _, peer := range user2Regular { fqdn, err := peer.FQDN() require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) - t.Logf("url from %s (user1) to %s (user2) - should FAIL", client.Hostname(), fqdn) + t.Logf("url from %s (user1) to %s (user2 regular) - should FAIL", client.Hostname(), fqdn) result, err := client.Curl(url) - assert.Empty(t, result, "user1 should not be able to access user2's devices with autogroup:self") - assert.Error(t, err, "connection from user1 to user2 should fail") + assert.Empty(t, result, "user1 should not be able to access user2's regular devices (autogroup:self isolation)") + assert.Error(t, err, "connection from user1 to user2 regular device should fail") } } - for _, client := range user2Clients { - for _, peer := range user1Clients { + for _, client := range user2Regular { + for _, peer := range user1Regular { fqdn, err := peer.FQDN() require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) - t.Logf("url from %s (user2) to %s (user1) - should FAIL", client.Hostname(), fqdn) + t.Logf("url from %s (user2) to %s (user1 regular) - should FAIL", client.Hostname(), fqdn) result, err := client.Curl(url) - assert.Empty(t, result, "user2 should not be able to access user1's devices with autogroup:self") - assert.Error(t, err, "connection from user2 to user1 should fail") + assert.Empty(t, result, "user2 should not be able to access user1's regular devices (autogroup:self isolation)") + assert.Error(t, err, "connection from user2 to user1 regular device should fail") } } } diff --git a/integration/general_test.go b/integration/general_test.go index 2432db9c..c68768f7 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -819,6 +819,104 @@ func TestExpireNode(t *testing.T) { } } +// TestSetNodeExpiryInFuture tests setting arbitrary expiration date +// New expiration date should be stored in the db and propagated to all peers +func TestSetNodeExpiryInFuture(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("expirenodefuture")) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + targetExpiry := time.Now().Add(2 * time.Hour).Round(time.Second).UTC() + + result, err := headscale.Execute( + []string{ + "headscale", "nodes", "expire", + "--identifier", "1", + "--output", "json", + "--expiry", targetExpiry.Format(time.RFC3339), + }, + ) + require.NoError(t, err) + + var node v1.Node + err = json.Unmarshal([]byte(result), &node) + require.NoError(t, err) + + require.True(t, node.GetExpiry().AsTime().After(time.Now())) + require.WithinDuration(t, targetExpiry, node.GetExpiry().AsTime(), 2*time.Second) + + var nodeKey key.NodePublic + err = nodeKey.UnmarshalText([]byte(node.GetNodeKey())) + require.NoError(t, err) + + for _, client := range allClients { + if client.Hostname() == node.GetName() { + continue + } + + assert.EventuallyWithT( + t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err) + + peerStatus, ok := status.Peer[nodeKey] + assert.True(ct, ok, "node key should be present in peer list") + + if !ok { + return + } + + assert.NotNil(ct, peerStatus.KeyExpiry) + assert.NotNil(ct, peerStatus.Expired) + + if peerStatus.KeyExpiry != nil { + assert.WithinDuration( + ct, + targetExpiry, + *peerStatus.KeyExpiry, + 5*time.Second, + "node %q should have key expiry near the requested future time", + peerStatus.HostName, + ) + + assert.Truef( + ct, + peerStatus.KeyExpiry.After(time.Now()), + "node %q should have a key expiry timestamp in the future", + peerStatus.HostName, + ) + } + + assert.Falsef( + ct, + peerStatus.Expired, + "node %q should not be marked as expired", + peerStatus.HostName, + ) + }, 3*time.Minute, 5*time.Second, "Waiting for future expiry to propagate", + ) + } +} + func TestNodeOnlineStatus(t *testing.T) { IntegrationSkip(t) diff --git a/proto/headscale/v1/node.proto b/proto/headscale/v1/node.proto index 89d2c347..fb074008 100644 --- a/proto/headscale/v1/node.proto +++ b/proto/headscale/v1/node.proto @@ -82,7 +82,10 @@ message DeleteNodeRequest { uint64 node_id = 1; } message DeleteNodeResponse {} -message ExpireNodeRequest { uint64 node_id = 1; } +message ExpireNodeRequest { + uint64 node_id = 1; + google.protobuf.Timestamp expiry = 2; +} message ExpireNodeResponse { Node node = 1; }