diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 6dfacd91..2de2e8dd 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -19,7 +19,7 @@ type PolicyManager interface { MatchersForNode(node types.NodeView) ([]matcher.Match, error) // BuildPeerMap constructs peer relationship maps for the given nodes BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView - SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) + SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error) SetPolicy(pol []byte) (bool, error) SetUsers(users []types.User) (bool, error) SetNodes(nodes views.Slice[types.NodeView]) (bool, error) diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index 9c97e39c..536c86f3 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -1188,8 +1188,9 @@ func TestSSHPolicyRules(t *testing.T) { "root": "", }, Action: &tailcfg.SSHAction{ - Accept: true, + Accept: false, SessionDuration: 24 * time.Hour, + HoldAndDelegate: "unused-url/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER", AllowAgentForwarding: true, AllowLocalPortForwarding: true, AllowRemotePortForwarding: true, @@ -1476,7 +1477,7 @@ func TestSSHPolicyRules(t *testing.T) { require.NoError(t, err) - got, err := pm.SSHPolicy(tt.targetNode.View()) + got, err := pm.SSHPolicy("unused-url", tt.targetNode.View()) require.NoError(t, err) if diff := cmp.Diff(tt.wantSSH, got); diff != "" { diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index d75e1914..526a0cb1 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -12,7 +12,6 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" - "github.com/spf13/viper" "go4.org/netipx" "tailscale.com/tailcfg" "tailscale.com/types/views" @@ -349,6 +348,7 @@ func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction { //nolint:gocyclo // complex SSH policy compilation logic func (pol *Policy) compileSSHPolicy( + baseURL string, users types.Users, node types.NodeView, nodes views.Slice[types.NodeView], @@ -392,14 +392,11 @@ func (pol *Policy) compileSSHPolicy( var action tailcfg.SSHAction - // HACK HACK HACK - serverURL := viper.GetString("server_url") - switch rule.Action { case SSHActionAccept: action = sshAccept case SSHActionCheck: - action = sshCheck(serverURL, time.Duration(rule.CheckPeriod)) + action = sshCheck(baseURL, time.Duration(rule.CheckPeriod)) default: return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) } diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index cdf7c131..1c15f732 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -615,7 +615,7 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { require.NoError(t, err) // Compile SSH policy - sshPolicy, err := tt.policy.compileSSHPolicy(users, tt.targetNode.View(), nodes.ViewSlice()) + sshPolicy, err := tt.policy.compileSSHPolicy("unused-server-url", users, tt.targetNode.View(), nodes.ViewSlice()) require.NoError(t, err) if tt.wantEmpty { @@ -691,7 +691,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { err := policy.validate() require.NoError(t, err) - sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, nodeTaggedServer.View(), nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -704,8 +704,11 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { } assert.Equal(t, expectedUsers, rule.SSHUsers) - // Verify check action with session duration - assert.True(t, rule.Action.Accept) + // Verify check action: Accept is false, HoldAndDelegate is set + assert.False(t, rule.Action.Accept) + assert.False(t, rule.Action.Reject) + assert.NotEmpty(t, rule.Action.HoldAndDelegate) + assert.Contains(t, rule.Action.HoldAndDelegate, "/machine/ssh/action/") assert.Equal(t, 24*time.Hour, rule.Action.SessionDuration) } @@ -756,7 +759,7 @@ func TestSSHIntegrationReproduction(t *testing.T) { require.NoError(t, err) // Test SSH policy compilation for node2 (owned by user2, who is in the group) - sshPolicy, err := policy.compileSSHPolicy(users, node2.View(), nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node2.View(), nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -806,7 +809,7 @@ func TestSSHJSONSerialization(t *testing.T) { err := policy.validate() require.NoError(t, err) - sshPolicy, err := policy.compileSSHPolicy(users, node.View(), nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node.View(), nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) @@ -1413,7 +1416,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { // Test for user1's first node node1 := nodes[0].View() - sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -1432,7 +1435,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { // Test for user2's first node node3 := nodes[2].View() - sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) + sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy2) require.Len(t, sshPolicy2.Rules, 1) @@ -1451,7 +1454,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { // Test for tagged node (should have no SSH rules) node5 := nodes[4].View() - sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) + sshPolicy3, err := policy.compileSSHPolicy("unused-server-url", users, node5, nodes.ViewSlice()) require.NoError(t, err) if sshPolicy3 != nil { @@ -1491,7 +1494,7 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { // For user1's node: should allow SSH from user1's devices node1 := nodes[0].View() - sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -1508,7 +1511,7 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { // For user2's node: should have no rules (user1's devices can't match user2's self) node3 := nodes[2].View() - sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) + sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice()) require.NoError(t, err) if sshPolicy2 != nil { @@ -1551,7 +1554,7 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { // For user1's node: should allow SSH from user1's devices only (not user2's) node1 := nodes[0].View() - sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -1568,7 +1571,7 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { // For user3's node: should have no rules (not in group:admins) node5 := nodes[4].View() - sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) + sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node5, nodes.ViewSlice()) require.NoError(t, err) if sshPolicy2 != nil { @@ -1610,7 +1613,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { // For untagged node: should only get principals from other untagged nodes node1 := nodes[0].View() - sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy) require.Len(t, sshPolicy.Rules, 1) @@ -1628,7 +1631,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { // For tagged node: should get no SSH rules node3 := nodes[2].View() - sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) + sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice()) require.NoError(t, err) if sshPolicy2 != nil { @@ -1671,7 +1674,7 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { // Test 1: Compile for user1's device (should only match autogroup:self destination) node1 := nodes[0].View() - sshPolicy1, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + sshPolicy1, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicy1) require.Len(t, sshPolicy1.Rules, 1, "user1's device should have 1 SSH rule (autogroup:self)") @@ -1690,7 +1693,7 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { // Test 2: Compile for router (should only match tag:router destination) routerNode := nodes[3].View() // user2-router - sshPolicyRouter, err := policy.compileSSHPolicy(users, routerNode, nodes.ViewSlice()) + sshPolicyRouter, err := policy.compileSSHPolicy("unused-server-url", users, routerNode, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, sshPolicyRouter) require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)") diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 74b7ba6a..744f52c7 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -222,7 +222,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { return true, nil } -func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { +func (pm *PolicyManager) SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error) { pm.mu.Lock() defer pm.mu.Unlock() @@ -230,7 +230,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err return sshPol, nil } - sshPol, err := pm.pol.compileSSHPolicy(pm.users, node, pm.nodes) + sshPol, err := pm.pol.compileSSHPolicy(baseURL, pm.users, node, pm.nodes) if err != nil { return nil, fmt.Errorf("compiling SSH policy: %w", err) } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index e421d5bd..f546f7a4 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -851,7 +851,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha // SSHPolicy returns the SSH access policy for a node. func (s *State) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { - return s.polMan.SSHPolicy(node) + return s.polMan.SSHPolicy(s.cfg.ServerURL, node) } // Filter returns the current network filter rules and matches.