1
0
mirror of https://github.com/juanfont/headscale.git synced 2026-02-07 20:04:00 +01:00

policy: patch serverURL into ssh policy

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2026-02-10 13:59:12 +01:00
parent 0291fa8644
commit d1364194ef
No known key found for this signature in database
6 changed files with 29 additions and 28 deletions

View File

@ -19,7 +19,7 @@ type PolicyManager interface {
MatchersForNode(node types.NodeView) ([]matcher.Match, error)
// BuildPeerMap constructs peer relationship maps for the given nodes
BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView
SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error)
SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error)
SetPolicy(pol []byte) (bool, error)
SetUsers(users []types.User) (bool, error)
SetNodes(nodes views.Slice[types.NodeView]) (bool, error)

View File

@ -1188,8 +1188,9 @@ func TestSSHPolicyRules(t *testing.T) {
"root": "",
},
Action: &tailcfg.SSHAction{
Accept: true,
Accept: false,
SessionDuration: 24 * time.Hour,
HoldAndDelegate: "unused-url/machine/ssh/action/from/$SRC_NODE_ID/to/$DST_NODE_ID?ssh_user=$SSH_USER&local_user=$LOCAL_USER",
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
AllowRemotePortForwarding: true,
@ -1476,7 +1477,7 @@ func TestSSHPolicyRules(t *testing.T) {
require.NoError(t, err)
got, err := pm.SSHPolicy(tt.targetNode.View())
got, err := pm.SSHPolicy("unused-url", tt.targetNode.View())
require.NoError(t, err)
if diff := cmp.Diff(tt.wantSSH, got); diff != "" {

View File

@ -12,7 +12,6 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"go4.org/netipx"
"tailscale.com/tailcfg"
"tailscale.com/types/views"
@ -349,6 +348,7 @@ func sshCheck(baseURL string, duration time.Duration) tailcfg.SSHAction {
//nolint:gocyclo // complex SSH policy compilation logic
func (pol *Policy) compileSSHPolicy(
baseURL string,
users types.Users,
node types.NodeView,
nodes views.Slice[types.NodeView],
@ -392,14 +392,11 @@ func (pol *Policy) compileSSHPolicy(
var action tailcfg.SSHAction
// HACK HACK HACK
serverURL := viper.GetString("server_url")
switch rule.Action {
case SSHActionAccept:
action = sshAccept
case SSHActionCheck:
action = sshCheck(serverURL, time.Duration(rule.CheckPeriod))
action = sshCheck(baseURL, time.Duration(rule.CheckPeriod))
default:
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
}

View File

@ -615,7 +615,7 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
require.NoError(t, err)
// Compile SSH policy
sshPolicy, err := tt.policy.compileSSHPolicy(users, tt.targetNode.View(), nodes.ViewSlice())
sshPolicy, err := tt.policy.compileSSHPolicy("unused-server-url", users, tt.targetNode.View(), nodes.ViewSlice())
require.NoError(t, err)
if tt.wantEmpty {
@ -691,7 +691,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
err := policy.validate()
require.NoError(t, err)
sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, nodeTaggedServer.View(), nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@ -704,8 +704,11 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
}
assert.Equal(t, expectedUsers, rule.SSHUsers)
// Verify check action with session duration
assert.True(t, rule.Action.Accept)
// Verify check action: Accept is false, HoldAndDelegate is set
assert.False(t, rule.Action.Accept)
assert.False(t, rule.Action.Reject)
assert.NotEmpty(t, rule.Action.HoldAndDelegate)
assert.Contains(t, rule.Action.HoldAndDelegate, "/machine/ssh/action/")
assert.Equal(t, 24*time.Hour, rule.Action.SessionDuration)
}
@ -756,7 +759,7 @@ func TestSSHIntegrationReproduction(t *testing.T) {
require.NoError(t, err)
// Test SSH policy compilation for node2 (owned by user2, who is in the group)
sshPolicy, err := policy.compileSSHPolicy(users, node2.View(), nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node2.View(), nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@ -806,7 +809,7 @@ func TestSSHJSONSerialization(t *testing.T) {
err := policy.validate()
require.NoError(t, err)
sshPolicy, err := policy.compileSSHPolicy(users, node.View(), nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node.View(), nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
@ -1413,7 +1416,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
// Test for user1's first node
node1 := nodes[0].View()
sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@ -1432,7 +1435,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
// Test for user2's first node
node3 := nodes[2].View()
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy2)
require.Len(t, sshPolicy2.Rules, 1)
@ -1451,7 +1454,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
// Test for tagged node (should have no SSH rules)
node5 := nodes[4].View()
sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
sshPolicy3, err := policy.compileSSHPolicy("unused-server-url", users, node5, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy3 != nil {
@ -1491,7 +1494,7 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
// For user1's node: should allow SSH from user1's devices
node1 := nodes[0].View()
sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@ -1508,7 +1511,7 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
// For user2's node: should have no rules (user1's devices can't match user2's self)
node3 := nodes[2].View()
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy2 != nil {
@ -1551,7 +1554,7 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
// For user1's node: should allow SSH from user1's devices only (not user2's)
node1 := nodes[0].View()
sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@ -1568,7 +1571,7 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
// For user3's node: should have no rules (not in group:admins)
node5 := nodes[4].View()
sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node5, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy2 != nil {
@ -1610,7 +1613,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
// For untagged node: should only get principals from other untagged nodes
node1 := nodes[0].View()
sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
sshPolicy, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy)
require.Len(t, sshPolicy.Rules, 1)
@ -1628,7 +1631,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
// For tagged node: should get no SSH rules
node3 := nodes[2].View()
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
sshPolicy2, err := policy.compileSSHPolicy("unused-server-url", users, node3, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy2 != nil {
@ -1671,7 +1674,7 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
// Test 1: Compile for user1's device (should only match autogroup:self destination)
node1 := nodes[0].View()
sshPolicy1, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice())
sshPolicy1, err := policy.compileSSHPolicy("unused-server-url", users, node1, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicy1)
require.Len(t, sshPolicy1.Rules, 1, "user1's device should have 1 SSH rule (autogroup:self)")
@ -1690,7 +1693,7 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
// Test 2: Compile for router (should only match tag:router destination)
routerNode := nodes[3].View() // user2-router
sshPolicyRouter, err := policy.compileSSHPolicy(users, routerNode, nodes.ViewSlice())
sshPolicyRouter, err := policy.compileSSHPolicy("unused-server-url", users, routerNode, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, sshPolicyRouter)
require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)")

View File

@ -222,7 +222,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
return true, nil
}
func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
func (pm *PolicyManager) SSHPolicy(baseURL string, node types.NodeView) (*tailcfg.SSHPolicy, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
@ -230,7 +230,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err
return sshPol, nil
}
sshPol, err := pm.pol.compileSSHPolicy(pm.users, node, pm.nodes)
sshPol, err := pm.pol.compileSSHPolicy(baseURL, pm.users, node, pm.nodes)
if err != nil {
return nil, fmt.Errorf("compiling SSH policy: %w", err)
}

View File

@ -851,7 +851,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha
// SSHPolicy returns the SSH access policy for a node.
func (s *State) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
return s.polMan.SSHPolicy(node)
return s.polMan.SSHPolicy(s.cfg.ServerURL, node)
}
// Filter returns the current network filter rules and matches.