diff --git a/hscontrol/auth.go b/hscontrol/auth.go index fd1b231b..ee301242 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -272,13 +272,15 @@ func (h *Headscale) waitForFollowup( select { case <-ctx.Done(): return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err) - case node := <-reg.WaitForRegistration(): - if !node.Valid() { - // registration is expired in the cache, instruct the client to try a new registration - return h.reqToNewRegisterResponse(req, machineKey) - } + case verdict := <-reg.WaitForAuth(): + if verdict.Accept() { + if !verdict.Node.Valid() { + // registration is expired in the cache, instruct the client to try a new registration + return h.reqToNewRegisterResponse(req, machineKey) + } - return nodeToRegisterResponse(node), nil + return nodeToRegisterResponse(verdict.Node), nil + } } } diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index 8215b07c..321b55fa 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -692,7 +692,7 @@ func TestAuthenticationFlows(t *testing.T) { user := app.state.CreateUserForTest("followup-user") node := app.state.CreateNodeForTest(user, "followup-success-node") - nodeToRegister.FinishRegistration(node.View()) + nodeToRegister.FinishAuth(types.AuthVerdict{Node: node.View()}) }() return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil @@ -1348,7 +1348,7 @@ func TestAuthenticationFlows(t *testing.T) { // Simulate registration that returns empty NodeView (cache expired during auth) go func() { - nodeToRegister.FinishRegistration(types.NodeView{}) // Empty view indicates cache expiry + nodeToRegister.FinishAuth(types.AuthVerdict{Node: types.NodeView{}}) // Empty view indicates cache expiry }() return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index eb927750..1ec3eedf 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -64,6 +64,9 @@ var ErrNodeNotInNodeStore = errors.New("node no longer exists in NodeStore") // ErrNodeNameNotUnique is returned when a node name is not unique. var ErrNodeNameNotUnique = errors.New("node name is not unique") +// ErrRegistrationExpired is returned when a registration has expired. +var ErrRegistrationExpired = errors.New("registration expired") + // State manages Headscale's core state, coordinating between database, policy management, // IP allocation, and DERP routing. All methods are thread-safe. type State struct { @@ -110,7 +113,7 @@ func NewState(cfg *types.Config) (*State, error) { authCache.OnEvicted( func(id types.AuthID, rn types.AuthRequest) { - rn.FinishRegistration(types.NodeView{}) + rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired}) }, ) @@ -1625,7 +1628,7 @@ func (s *State) HandleNodeFromAuthPath( } // Signal to waiting clients - regEntry.FinishRegistration(finalNode) + regEntry.FinishAuth(types.AuthVerdict{Node: finalNode}) // Delete from registration cache s.authCache.Delete(authID) diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 66bbf619..891969d3 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -210,14 +210,14 @@ func (r AuthID) Validate() error { // The closed field is used to ensure that the Finished channel is only closed once, and that no more nodes are sent through it after it has been closed. type AuthRequest struct { node *Node - finished chan NodeView + finished chan AuthVerdict closed *atomic.Bool } func NewRegisterAuthRequest(node Node) AuthRequest { return AuthRequest{ node: &node, - finished: make(chan NodeView), + finished: make(chan AuthVerdict), closed: &atomic.Bool{}, } } @@ -233,35 +233,37 @@ func (rn *AuthRequest) Node() NodeView { return rn.node.View() } -func (rn *AuthRequest) FinishAuth() { - rn.FinishRegistration(NodeView{}) -} - -func (rn *AuthRequest) FinishRegistration(node NodeView) { +func (rn *AuthRequest) FinishAuth(verdict AuthVerdict) { if rn.closed.Swap(true) { return } - if node.Valid() { - select { - case rn.finished <- node: - default: - } + select { + case rn.finished <- verdict: + default: } close(rn.finished) } -// WaitForRegistration waits for the authentication process to finish -// and returns the authenticated node. -// Can _only_ be used in the registration path. -func (rn *AuthRequest) WaitForRegistration() <-chan NodeView { +func (rn *AuthRequest) WaitForAuth() <-chan AuthVerdict { return rn.finished } -// WaitForAuth waits until a authentication request has been finished. -func (rn *AuthRequest) WaitForAuth() { - <-rn.WaitForRegistration() +type AuthVerdict struct { + // Err is the error that occurred during the authentication process, if any. + // If Err is nil, the authentication process has succeeded. + // If Err is not nil, the authentication process has failed and the node should not be authenticated. + Err error + + // Node is the node that has been authenticated. + // Node is only valid if the auth request was a registration request + // and the authentication process has succeeded. + Node NodeView +} + +func (v AuthVerdict) Accept() bool { + return v.Err == nil } // DefaultBatcherWorkers returns the default number of batcher workers.