1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-04-30 01:19:47 +02:00

Prevent race condition between rules and matchers

This commit is contained in:
Aras Ergus 2025-04-02 08:16:14 +02:00
parent 42f71b9c06
commit d1812eeec9
No known key found for this signature in database
GPG Key ID: 06334F046D945E11
8 changed files with 13 additions and 25 deletions

View File

@ -40,7 +40,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
w.Write(pol) w.Write(pol)
})) }))
debug.Handle("filter", "Current filter", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { debug.Handle("filter", "Current filter", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
filter := h.polMan.Filter() filter, _ := h.polMan.Filter()
filterJSON, err := json.MarshalIndent(filter, "", " ") filterJSON, err := json.MarshalIndent(filter, "", " ")
if err != nil { if err != nil {

View File

@ -519,7 +519,7 @@ func appendPeerChanges(
changed types.Nodes, changed types.Nodes,
cfg *types.Config, cfg *types.Config,
) error { ) error {
filter := polMan.Filter() filter, matchers := polMan.Filter()
sshPolicy, err := polMan.SSHPolicy(node) sshPolicy, err := polMan.SSHPolicy(node)
if err != nil { if err != nil {
@ -529,7 +529,6 @@ func appendPeerChanges(
// If there are filter rules present, see if there are any nodes that cannot // If there are filter rules present, see if there are any nodes that cannot
// access each-other at all and remove them from the peers. // access each-other at all and remove them from the peers.
if len(filter) > 0 { if len(filter) > 0 {
matchers := polMan.Matchers()
changed = policy.FilterNodesByACL(node, changed, matchers) changed = policy.FilterNodesByACL(node, changed, matchers)
} }

View File

@ -16,9 +16,8 @@ var (
) )
type PolicyManager interface { type PolicyManager interface {
Filter() []tailcfg.FilterRule // Filter returns the current filter rules for the entire tailnet and the associated matchers.
// Matchers returns the matchers for the current filter rules. Filter() ([]tailcfg.FilterRule, []matcher.Match)
Matchers() []matcher.Match
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error) SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
SetPolicy([]byte) (bool, error) SetPolicy([]byte) (bool, error)
SetUsers(users []types.User) (bool, error) SetUsers(users []types.User) (bool, error)

View File

@ -770,7 +770,7 @@ func TestReduceFilterRules(t *testing.T) {
var err error var err error
pm, err = pmf(users, append(tt.peers, tt.node)) pm, err = pmf(users, append(tt.peers, tt.node))
require.NoError(t, err) require.NoError(t, err)
got := pm.Filter() got, _ := pm.Filter()
got = ReduceFilterRules(tt.node, got) got = ReduceFilterRules(tt.node, got)
if diff := cmp.Diff(tt.want, got); diff != "" { if diff := cmp.Diff(tt.want, got); diff != "" {

View File

@ -87,15 +87,10 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
return true, nil return true, nil
} }
func (pm *PolicyManager) Filter() []tailcfg.FilterRule { func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
return pm.filter return pm.filter, matcher.MatchesFromFilterRules(pm.filter)
}
func (pm *PolicyManager) Matchers() []matcher.Match {
filter := pm.Filter()
return matcher.MatchesFromFilterRules(filter)
} }
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {

View File

@ -150,7 +150,8 @@ func TestPolicySetChange(t *testing.T) {
assert.Equal(t, tt.wantNodesChange, change) assert.Equal(t, tt.wantNodesChange, change)
} }
if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" { filter, _ := pm.Filter()
if diff := cmp.Diff(tt.wantFilter, filter); diff != "" {
t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff) t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff)
} }
}) })

View File

@ -147,17 +147,11 @@ func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
return pm.updateLocked() return pm.updateLocked()
} }
// Filter returns the current filter rules for the entire tailnet. // Filter returns the current filter rules for the entire tailnet and the associated matchers.
func (pm *PolicyManager) Filter() []tailcfg.FilterRule { func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
return pm.filter return pm.filter, pm.matchers
}
func (pm *PolicyManager) Matchers() []matcher.Match {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.matchers
} }
// SetUsers updates the users in the policy manager and updates the filter rules. // SetUsers updates the users in the policy manager and updates the filter rules.

View File

@ -47,7 +47,7 @@ func TestPolicyManager(t *testing.T) {
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes) pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
require.NoError(t, err) require.NoError(t, err)
filter := pm.Filter() filter, _ := pm.Filter()
if diff := cmp.Diff(filter, tt.wantFilter); diff != "" { if diff := cmp.Diff(filter, tt.wantFilter); diff != "" {
t.Errorf("Filter() mismatch (-want +got):\n%s", diff) t.Errorf("Filter() mismatch (-want +got):\n%s", diff)
} }