1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-09-25 17:51:11 +02:00

policy: make SSH action strict types

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-09-10 11:20:49 +02:00
parent 64ad05f1c5
commit 9a37717726
No known key found for this signature in database
3 changed files with 36 additions and 8 deletions

View File

@ -1875,7 +1875,7 @@ func TestSSHPolicyRules(t *testing.T) {
] ]
}`, }`,
expectErr: true, expectErr: true,
errorMessage: `SSH action "invalid" is not valid, must be accept or check`, errorMessage: `invalid SSH action "invalid", must be one of: accept, check`,
}, },
{ {
name: "invalid-check-period", name: "invalid-check-period",

View File

@ -126,9 +126,9 @@ func (pol *Policy) compileSSHPolicy(
var action tailcfg.SSHAction var action tailcfg.SSHAction
switch rule.Action { switch rule.Action {
case "accept": case SSHActionAccept:
action = sshAction(true, 0) action = sshAction(true, 0)
case "check": case SSHActionCheck:
action = sshAction(true, time.Duration(rule.CheckPeriod)) action = sshAction(true, time.Duration(rule.CheckPeriod))
default: default:
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)

View File

@ -1249,6 +1249,14 @@ const (
ActionAccept Action = "accept" ActionAccept Action = "accept"
) )
// SSHAction represents the action to take for an SSH rule.
type SSHAction string
const (
SSHActionAccept SSHAction = "accept"
SSHActionCheck SSHAction = "check"
)
// String returns the string representation of the Action. // String returns the string representation of the Action.
func (a Action) String() string { func (a Action) String() string {
return string(a) return string(a)
@ -1271,6 +1279,30 @@ func (a Action) MarshalJSON() ([]byte, error) {
return json.Marshal(string(a)) return json.Marshal(string(a))
} }
// String returns the string representation of the SSHAction.
func (a SSHAction) String() string {
return string(a)
}
// UnmarshalJSON implements JSON unmarshaling for SSHAction.
func (a *SSHAction) UnmarshalJSON(b []byte) error {
str := strings.Trim(string(b), `"`)
switch str {
case "accept":
*a = SSHActionAccept
case "check":
*a = SSHActionCheck
default:
return fmt.Errorf("invalid SSH action %q, must be one of: accept, check", str)
}
return nil
}
// MarshalJSON implements JSON marshaling for SSHAction.
func (a SSHAction) MarshalJSON() ([]byte, error) {
return json.Marshal(string(a))
}
// Protocol represents a network protocol with its IANA number and descriptions. // Protocol represents a network protocol with its IANA number and descriptions.
type Protocol string type Protocol string
@ -1691,10 +1723,6 @@ func (p *Policy) validate() error {
} }
for _, ssh := range p.SSHs { for _, ssh := range p.SSHs {
if ssh.Action != "accept" && ssh.Action != "check" {
errs = append(errs, fmt.Errorf("SSH action %q is not valid, must be accept or check", ssh.Action))
}
for _, user := range ssh.Users { for _, user := range ssh.Users {
if strings.HasPrefix(string(user), "autogroup:") { if strings.HasPrefix(string(user), "autogroup:") {
maybeAuto := AutoGroup(user) maybeAuto := AutoGroup(user)
@ -1808,7 +1836,7 @@ func (p *Policy) validate() error {
// SSH controls who can ssh into which machines. // SSH controls who can ssh into which machines.
type SSH struct { type SSH struct {
Action string `json:"action"` Action SSHAction `json:"action"`
Sources SSHSrcAliases `json:"src"` Sources SSHSrcAliases `json:"src"`
Destinations SSHDstAliases `json:"dst"` Destinations SSHDstAliases `json:"dst"`
Users []SSHUser `json:"users"` Users []SSHUser `json:"users"`