diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 96a2253b..511e19bb 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -383,6 +383,12 @@ type AutoGroup string const ( AutoGroupInternet AutoGroup = "autogroup:internet" + AutoGroupNonRoot AutoGroup = "autogroup:nonroot" + + // These are not yet implemented. + AutoGroupSelf AutoGroup = "autogroup:self" + AutoGroupMember AutoGroup = "autogroup:member" + AutoGroupTagged AutoGroup = "autogroup:tagged" ) var autogroups = []AutoGroup{AutoGroupInternet} @@ -915,6 +921,99 @@ type Policy struct { SSHs []SSH `json:"ssh"` } +var ( + autogroupForSrc = []AutoGroup{} + autogroupForDst = []AutoGroup{AutoGroupInternet} + autogroupForSSHSrc = []AutoGroup{} + autogroupForSSHDst = []AutoGroup{} + autogroupForSSHUser = []AutoGroup{AutoGroupNonRoot} + autogroupNotSupported = []AutoGroup{AutoGroupSelf, AutoGroupMember, AutoGroupTagged} +) + +func validateAutogroupSupported(ag *AutoGroup) error { + if ag == nil { + return nil + } + + if slices.Contains(autogroupNotSupported, *ag) { + return fmt.Errorf("autogroup %q is not supported in headscale", *ag) + } + + return nil +} + +func validateAutogroupForSrc(src *AutoGroup) error { + if src == nil { + return nil + } + + if src.Is(AutoGroupInternet) { + return fmt.Errorf(`"autogroup:internet" used in source, it can only be used in ACL destinations`) + } + + if !slices.Contains(autogroupForSrc, *src) { + return fmt.Errorf("autogroup %q is not supported for ACL sources, can be %v", *src, autogroupForSrc) + } + + return nil +} + +func validateAutogroupForDst(dst *AutoGroup) error { + if dst == nil { + return nil + } + + if !slices.Contains(autogroupForDst, *dst) { + return fmt.Errorf("autogroup %q is not supported for ACL destinations, can be %v", *dst, autogroupForDst) + } + + return nil +} + +func validateAutogroupForSSHSrc(src *AutoGroup) error { + if src == nil { + return nil + } + + if src.Is(AutoGroupInternet) { + return fmt.Errorf(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`) + } + + if !slices.Contains(autogroupForSSHSrc, *src) { + return fmt.Errorf("autogroup %q is not supported for SSH sources, can be %v", *src, autogroupForSSHSrc) + } + + return nil +} + +func validateAutogroupForSSHDst(dst *AutoGroup) error { + if dst == nil { + return nil + } + + if dst.Is(AutoGroupInternet) { + return fmt.Errorf(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`) + } + + if !slices.Contains(autogroupForSSHDst, *dst) { + return fmt.Errorf("autogroup %q is not supported for SSH sources, can be %v", *dst, autogroupForSSHDst) + } + + return nil +} + +func validateAutogroupForSSHUser(user *AutoGroup) error { + if user == nil { + return nil + } + + if !slices.Contains(autogroupForSSHUser, *user) { + return fmt.Errorf("autogroup %q is not supported for SSH user, can be %v", *user, autogroupForSSHUser) + } + + return nil +} + // validate reports if there are any errors in a policy after // the unmarshaling process. // It runs through all rules and checks if there are any inconsistencies @@ -938,20 +1037,70 @@ func (p *Policy) validate() error { } case *AutoGroup: ag := src.(*AutoGroup) - if ag.Is(AutoGroupInternet) { - errs = append(errs, fmt.Errorf(`"autogroup:internet" used in source, it can only be used in ACL destinations`)) + + if err := validateAutogroupSupported(ag); err != nil { + errs = append(errs, err) + continue + } + + if err := validateAutogroupForSrc(ag); err != nil { + errs = append(errs, err) + continue + } + } + } + + for _, dst := range acl.Destinations { + switch dst.Alias.(type) { + case *Host: + h := dst.Alias.(*Host) + if !p.Hosts.exist(*h) { + errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h)) + } + case *AutoGroup: + ag := dst.Alias.(*AutoGroup) + + if err := validateAutogroupSupported(ag); err != nil { + errs = append(errs, err) + continue + } + + if err := validateAutogroupForDst(ag); err != nil { + errs = append(errs, err) + continue } } } } for _, ssh := range p.SSHs { + if ssh.Action != "accept" && ssh.Action != "check" { + errs = append(errs, fmt.Errorf("SSH action %q is not valid, must be accept or check", ssh.Action)) + } + + for _, user := range ssh.Users { + if strings.HasPrefix(string(user), "autogroup:") { + maybeAuto := AutoGroup(user) + if err := validateAutogroupForSSHUser(&maybeAuto); err != nil { + errs = append(errs, err) + continue + } + } + } + for _, src := range ssh.Sources { switch src.(type) { case *AutoGroup: ag := src.(*AutoGroup) - if ag.Is(AutoGroupInternet) { - errs = append(errs, fmt.Errorf(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`)) + + if err := validateAutogroupSupported(ag); err != nil { + errs = append(errs, err) + continue + } + + if err := validateAutogroupForSSHSrc(ag); err != nil { + errs = append(errs, err) + continue } } } @@ -959,8 +1108,14 @@ func (p *Policy) validate() error { switch dst.(type) { case *AutoGroup: ag := dst.(*AutoGroup) - if ag.Is(AutoGroupInternet) { - errs = append(errs, fmt.Errorf(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`)) + if err := validateAutogroupSupported(ag); err != nil { + errs = append(errs, err) + continue + } + + if err := validateAutogroupForSSHDst(ag); err != nil { + errs = append(errs, err) + continue } } } @@ -997,7 +1152,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { *a = make([]Alias, len(aliases)) for i, alias := range aliases { switch alias.Alias.(type) { - case *Username, *Group, *Tag, *AutoGroup: + case *Group, *Tag, *AutoGroup: (*a)[i] = alias.Alias default: return fmt.Errorf("type %T not supported", alias.Alias) @@ -1042,8 +1197,8 @@ func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { // so we will leave it in as there is no other option // to dynamically give all access // https://tailscale.com/kb/1193/tailscale-ssh#dst - Asterix, - *Group: + // TODO(kradalby): remove this when we support autogroup:tagged and autogroup:member + Asterix: (*a)[i] = alias.Alias default: return fmt.Errorf("type %T not supported", alias.Alias)