From c923f461abbcdc81b360ce59d9d251e9f83046bb Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 1 May 2025 15:30:52 +0300 Subject: [PATCH] error on undefined host in policy (#2490) * add testcases Signed-off-by: Kristoffer Dalby * policy/v2: add validate to do post marshal validation Signed-off-by: Kristoffer Dalby --------- Signed-off-by: Kristoffer Dalby --- hscontrol/policy/v2/filter_test.go | 2 +- hscontrol/policy/v2/policy.go | 4 +- hscontrol/policy/v2/types.go | 56 ++++++++++++++++++++++++--- hscontrol/policy/v2/types_test.go | 61 +++++++++++++++++++++++++++++- 4 files changed, 113 insertions(+), 10 deletions(-) diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index e0b12520..b5f08164 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -336,7 +336,7 @@ func TestParsing(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - pol, err := policyFromBytes([]byte(tt.acl)) + pol, err := unmarshalPolicy([]byte(tt.acl)) if tt.wantErr && err == nil { t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 2bc04dbc..ec4b7737 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -42,7 +42,7 @@ type PolicyManager struct { // It returns an error if the policy file is invalid. // The policy manager will update the filter rules based on the users and nodes. func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) { - policy, err := policyFromBytes(b) + policy, err := unmarshalPolicy(b) if err != nil { return nil, fmt.Errorf("parsing policy: %w", err) } @@ -137,7 +137,7 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) { return false, nil } - pol, err := policyFromBytes(polB) + pol, err := unmarshalPolicy(polB) if err != nil { return false, fmt.Errorf("parsing policy: %w", err) } diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 55376b97..0e292f3a 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -532,7 +532,7 @@ Please check the format and try again.`, vs) type AliasEnc struct{ Alias } func (ve *AliasEnc) UnmarshalJSON(b []byte) error { - ptr, err := unmarshalPointer[Alias]( + ptr, err := unmarshalPointer( b, parseAlias, ) @@ -639,7 +639,7 @@ Please check the format and try again.`, s) type AutoApproverEnc struct{ AutoApprover } func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error { - ptr, err := unmarshalPointer[AutoApprover]( + ptr, err := unmarshalPointer( b, parseAutoApprover, ) @@ -659,7 +659,7 @@ type Owner interface { type OwnerEnc struct{ Owner } func (ve *OwnerEnc) UnmarshalJSON(b []byte) error { - ptr, err := unmarshalPointer[Owner]( + ptr, err := unmarshalPointer( b, parseOwner, ) @@ -769,6 +769,11 @@ func (h *Hosts) UnmarshalJSON(b []byte) error { return nil } +func (h Hosts) exist(name Host) bool { + _, ok := h[name] + return ok +} + // TagOwners are a map of Tag to a list of the UserEntities that own the tag. type TagOwners map[Tag]Owners @@ -902,6 +907,39 @@ type Policy struct { SSHs []SSH `json:"ssh"` } +// validate reports if there are any errors in a policy after +// the unmarshaling process. +// It runs through all rules and checks if there are any inconsistencies +// in the policy that needs to be addressed before it can be used. +func (p *Policy) validate() error { + if p == nil { + panic("passed nil policy") + } + + // All errors are collected and presented to the user, + // when adding more validation, please add to the list of errors. + var errs []error + + for _, acl := range p.ACLs { + for _, src := range acl.Sources { + switch src.(type) { + case *Host: + h := src.(*Host) + if !p.Hosts.exist(*h) { + errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h)) + } + } + } + } + + if len(errs) > 0 { + return multierr.New(errs...) + } + + p.validated = true + return nil +} + // SSH controls who can ssh into which machines. type SSH struct { Action string `json:"action"` // TODO(kradalby): add strict type @@ -986,7 +1024,10 @@ func (u SSHUser) String() string { return string(u) } -func policyFromBytes(b []byte) (*Policy, error) { +// unmarshalPolicy takes a byte slice and unmarshals it into a Policy struct. +// In addition to unmarshalling, it will also validate the policy. +// This is the only entrypoint of reading a policy from a file or other source. +func unmarshalPolicy(b []byte) (*Policy, error) { if b == nil || len(b) == 0 { return nil, nil } @@ -1000,11 +1041,14 @@ func policyFromBytes(b []byte) (*Policy, error) { ast.Standardize() acl := ast.Pack() - err = json.Unmarshal(acl, &policy) - if err != nil { + if err = json.Unmarshal(acl, &policy); err != nil { return nil, fmt.Errorf("parsing policy from bytes: %w", err) } + if err := policy.validate(); err != nil { + return nil, err + } + return &policy, nil } diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 2218685e..6a89efd3 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -361,6 +361,65 @@ func TestUnmarshalPolicy(t *testing.T) { `, wantErr: `AutoGroup is invalid, got: "autogroup:invalid", must be one of [autogroup:internet]`, }, + { + name: "undefined-hostname-errors-2490", + input: ` +{ + "acls": [ + { + "action": "accept", + "src": [ + "user1" + ], + "dst": [ + "user1:*" + ] + } + ] +} +`, + wantErr: `Host "user1" is not defined in the Policy, please define or remove the reference to it`, + }, + { + name: "defined-hostname-does-not-err-2490", + input: ` +{ + "hosts": { + "user1": "100.100.100.100", + }, + "acls": [ + { + "action": "accept", + "src": [ + "user1" + ], + "dst": [ + "user1:*" + ] + } + ] +} +`, + want: &Policy{ + Hosts: Hosts{ + "user1": Prefix(mp("100.100.100.100/32")), + }, + ACLs: []ACL{ + { + Action: "accept", + Sources: Aliases{ + hp("user1"), + }, + Destinations: []AliasWithPorts{ + { + Alias: hp("user1"), + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + }, + }, + }, } cmps := append(util.Comparers, cmp.Comparer(func(x, y Prefix) bool { @@ -370,7 +429,7 @@ func TestUnmarshalPolicy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - policy, err := policyFromBytes([]byte(tt.input)) + policy, err := unmarshalPolicy([]byte(tt.input)) if tt.wantErr == "" { if err != nil { t.Fatalf("got %v; want no error", err)