diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index f19ac3d3..353fd2c1 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -1875,7 +1875,7 @@ func TestSSHPolicyRules(t *testing.T) { ] }`, expectErr: true, - errorMessage: `SSH action "invalid" is not valid, must be accept or check`, + errorMessage: `invalid SSH action "invalid", must be one of: accept, check`, }, { name: "invalid-check-period", diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 8dac82ae..17d4c16e 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -126,9 +126,9 @@ func (pol *Policy) compileSSHPolicy( var action tailcfg.SSHAction switch rule.Action { - case "accept": + case SSHActionAccept: action = sshAction(true, 0) - case "check": + case SSHActionCheck: action = sshAction(true, time.Duration(rule.CheckPeriod)) default: return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index f400b9c5..c16c1349 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -1249,6 +1249,14 @@ const ( ActionAccept Action = "accept" ) +// SSHAction represents the action to take for an SSH rule. +type SSHAction string + +const ( + SSHActionAccept SSHAction = "accept" + SSHActionCheck SSHAction = "check" +) + // String returns the string representation of the Action. func (a Action) String() string { return string(a) @@ -1271,6 +1279,30 @@ func (a Action) MarshalJSON() ([]byte, error) { return json.Marshal(string(a)) } +// String returns the string representation of the SSHAction. +func (a SSHAction) String() string { + return string(a) +} + +// UnmarshalJSON implements JSON unmarshaling for SSHAction. +func (a *SSHAction) UnmarshalJSON(b []byte) error { + str := strings.Trim(string(b), `"`) + switch str { + case "accept": + *a = SSHActionAccept + case "check": + *a = SSHActionCheck + default: + return fmt.Errorf("invalid SSH action %q, must be one of: accept, check", str) + } + return nil +} + +// MarshalJSON implements JSON marshaling for SSHAction. +func (a SSHAction) MarshalJSON() ([]byte, error) { + return json.Marshal(string(a)) +} + // Protocol represents a network protocol with its IANA number and descriptions. type Protocol string @@ -1691,10 +1723,6 @@ func (p *Policy) validate() error { } for _, ssh := range p.SSHs { - if ssh.Action != "accept" && ssh.Action != "check" { - errs = append(errs, fmt.Errorf("SSH action %q is not valid, must be accept or check", ssh.Action)) - } - for _, user := range ssh.Users { if strings.HasPrefix(string(user), "autogroup:") { maybeAuto := AutoGroup(user) @@ -1808,7 +1836,7 @@ func (p *Policy) validate() error { // SSH controls who can ssh into which machines. type SSH struct { - Action string `json:"action"` + Action SSHAction `json:"action"` Sources SSHSrcAliases `json:"src"` Destinations SSHDstAliases `json:"dst"` Users []SSHUser `json:"users"`