mirror of
https://github.com/juanfont/headscale.git
synced 2026-02-07 20:04:00 +01:00
auth: add /auth dummy, tighten AuthRequest, generalise
This commit generalise the "Registration" pipeline to a more general auth pipeline supporting both registrations and general auth requests. This means we have renamed the RegistrationID to AuthID. Fields from AuthRequest has been unexported and made read only. Added dummy /auth endpoints to be filled. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
d1364194ef
commit
e45cf30867
@ -479,7 +479,8 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux {
|
||||
r.Get("/health", h.HealthHandler)
|
||||
r.Get("/version", h.VersionHandler)
|
||||
r.Get("/key", h.KeyHandler)
|
||||
r.Get("/register/{registration_id}", h.authProvider.RegisterHandler)
|
||||
r.Get("/register/{auth_id}", h.authProvider.RegisterHandler)
|
||||
r.Get("/auth/{auth_id}", h.authProvider.AuthHandler)
|
||||
|
||||
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
|
||||
r.Get("/oidc/callback", provider.OIDCCallbackHandler)
|
||||
|
||||
@ -20,7 +20,9 @@ import (
|
||||
|
||||
type AuthProvider interface {
|
||||
RegisterHandler(w http.ResponseWriter, r *http.Request)
|
||||
AuthURL(regID types.RegistrationID) string
|
||||
AuthHandler(w http.ResponseWriter, r *http.Request)
|
||||
RegisterURL(authID types.AuthID) string
|
||||
AuthURL(authID types.AuthID) string
|
||||
}
|
||||
|
||||
func (h *Headscale) handleRegister(
|
||||
@ -261,22 +263,22 @@ func (h *Headscale) waitForFollowup(
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err)
|
||||
}
|
||||
|
||||
followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
|
||||
followupReg, err := types.AuthIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
|
||||
if err != nil {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "invalid registration ID", err)
|
||||
}
|
||||
|
||||
if reg, ok := h.state.GetRegistrationCacheEntry(followupReg); ok {
|
||||
if reg, ok := h.state.GetAuthCacheEntry(followupReg); ok {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
|
||||
case node := <-reg.Registered:
|
||||
if node == nil {
|
||||
case node := <-reg.WaitForRegistration():
|
||||
if !node.Valid() {
|
||||
// registration is expired in the cache, instruct the client to try a new registration
|
||||
return h.reqToNewRegisterResponse(req, machineKey)
|
||||
}
|
||||
|
||||
return nodeToRegisterResponse(node.View()), nil
|
||||
return nodeToRegisterResponse(node), nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -291,14 +293,14 @@ func (h *Headscale) reqToNewRegisterResponse(
|
||||
req tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
newRegID, err := types.NewRegistrationID()
|
||||
newAuthID, err := types.NewAuthID()
|
||||
if err != nil {
|
||||
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
|
||||
}
|
||||
|
||||
// Ensure we have a valid hostname
|
||||
hostname := util.EnsureHostname(
|
||||
req.Hostinfo,
|
||||
req.Hostinfo.View(),
|
||||
machineKey.String(),
|
||||
req.NodeKey.String(),
|
||||
)
|
||||
@ -307,25 +309,25 @@ func (h *Headscale) reqToNewRegisterResponse(
|
||||
hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{})
|
||||
hostinfo.Hostname = hostname
|
||||
|
||||
nodeToRegister := types.NewRegisterNode(
|
||||
types.Node{
|
||||
Hostname: hostname,
|
||||
MachineKey: machineKey,
|
||||
NodeKey: req.NodeKey,
|
||||
Hostinfo: hostinfo,
|
||||
LastSeen: new(time.Now()),
|
||||
},
|
||||
)
|
||||
|
||||
if !req.Expiry.IsZero() {
|
||||
nodeToRegister.Node.Expiry = &req.Expiry
|
||||
nodeToRegister := types.Node{
|
||||
Hostname: hostname,
|
||||
MachineKey: machineKey,
|
||||
NodeKey: req.NodeKey,
|
||||
Hostinfo: hostinfo,
|
||||
LastSeen: new(time.Now()),
|
||||
}
|
||||
|
||||
log.Info().Msgf("new followup node registration using key: %s", newRegID)
|
||||
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
|
||||
if !req.Expiry.IsZero() {
|
||||
nodeToRegister.Expiry = &req.Expiry
|
||||
}
|
||||
|
||||
authRegReq := types.NewRegisterAuthRequest(nodeToRegister)
|
||||
|
||||
log.Info().Msgf("new followup node registration using key: %s", newAuthID)
|
||||
h.state.SetAuthCacheEntry(newAuthID, authRegReq)
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
AuthURL: h.authProvider.AuthURL(newRegID),
|
||||
AuthURL: h.authProvider.RegisterURL(newAuthID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -376,13 +378,6 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
// Send both changes. Empty changes are ignored by Change().
|
||||
h.Change(changed, routesChange)
|
||||
|
||||
// TODO(kradalby): I think this is covered above, but we need to validate that.
|
||||
// // If policy changed due to node registration, send a separate policy change
|
||||
// if policyChanged {
|
||||
// policyChange := change.PolicyChange()
|
||||
// h.Change(policyChange)
|
||||
// }
|
||||
|
||||
resp := &tailcfg.RegisterResponse{
|
||||
MachineAuthorized: true,
|
||||
NodeKeyExpired: node.IsExpired(),
|
||||
@ -404,14 +399,14 @@ func (h *Headscale) handleRegisterInteractive(
|
||||
req tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
registrationId, err := types.NewRegistrationID()
|
||||
authID, err := types.NewAuthID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating registration ID: %w", err)
|
||||
}
|
||||
|
||||
// Ensure we have a valid hostname
|
||||
hostname := util.EnsureHostname(
|
||||
req.Hostinfo,
|
||||
req.Hostinfo.View(),
|
||||
machineKey.String(),
|
||||
req.NodeKey.String(),
|
||||
)
|
||||
@ -434,28 +429,28 @@ func (h *Headscale) handleRegisterInteractive(
|
||||
|
||||
hostinfo.Hostname = hostname
|
||||
|
||||
nodeToRegister := types.NewRegisterNode(
|
||||
types.Node{
|
||||
Hostname: hostname,
|
||||
MachineKey: machineKey,
|
||||
NodeKey: req.NodeKey,
|
||||
Hostinfo: hostinfo,
|
||||
LastSeen: new(time.Now()),
|
||||
},
|
||||
)
|
||||
|
||||
if !req.Expiry.IsZero() {
|
||||
nodeToRegister.Node.Expiry = &req.Expiry
|
||||
nodeToRegister := types.Node{
|
||||
Hostname: hostname,
|
||||
MachineKey: machineKey,
|
||||
NodeKey: req.NodeKey,
|
||||
Hostinfo: hostinfo,
|
||||
LastSeen: new(time.Now()),
|
||||
}
|
||||
|
||||
h.state.SetRegistrationCacheEntry(
|
||||
registrationId,
|
||||
nodeToRegister,
|
||||
if !req.Expiry.IsZero() {
|
||||
nodeToRegister.Expiry = &req.Expiry
|
||||
}
|
||||
|
||||
authRegReq := types.NewRegisterAuthRequest(nodeToRegister)
|
||||
|
||||
h.state.SetAuthCacheEntry(
|
||||
authID,
|
||||
authRegReq,
|
||||
)
|
||||
|
||||
log.Info().Msgf("starting node registration using key: %s", registrationId)
|
||||
log.Info().Msgf("starting node registration using key: %s", authID)
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
AuthURL: h.authProvider.AuthURL(registrationId),
|
||||
AuthURL: h.authProvider.RegisterURL(authID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -651,8 +651,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
|
||||
|
||||
// Step 1: Create user-owned node WITH expiry set
|
||||
clientExpiry := time.Now().Add(24 * time.Hour)
|
||||
registrationID1 := types.MustRegistrationID()
|
||||
regEntry1 := types.NewRegisterNode(types.Node{
|
||||
registrationID1 := types.MustAuthID()
|
||||
regEntry1 := types.NewRegisterAuthRequest(types.Node{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey1.Public(),
|
||||
Hostname: "personal-to-tagged",
|
||||
@ -662,7 +662,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
|
||||
},
|
||||
Expiry: &clientExpiry,
|
||||
})
|
||||
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1)
|
||||
app.state.SetAuthCacheEntry(registrationID1, regEntry1)
|
||||
|
||||
node, _, err := app.state.HandleNodeFromAuthPath(
|
||||
registrationID1, types.UserID(user.ID), nil, "webauth",
|
||||
@ -673,8 +673,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
|
||||
|
||||
// Step 2: Re-auth with tags (Personal → Tagged conversion)
|
||||
nodeKey2 := key.NewNode()
|
||||
registrationID2 := types.MustRegistrationID()
|
||||
regEntry2 := types.NewRegisterNode(types.Node{
|
||||
registrationID2 := types.MustAuthID()
|
||||
regEntry2 := types.NewRegisterAuthRequest(types.Node{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey2.Public(),
|
||||
Hostname: "personal-to-tagged",
|
||||
@ -684,7 +684,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
|
||||
},
|
||||
Expiry: &clientExpiry, // Client still sends expiry
|
||||
})
|
||||
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2)
|
||||
app.state.SetAuthCacheEntry(registrationID2, regEntry2)
|
||||
|
||||
nodeAfter, _, err := app.state.HandleNodeFromAuthPath(
|
||||
registrationID2, types.UserID(user.ID), nil, "webauth",
|
||||
@ -723,8 +723,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
|
||||
nodeKey1 := key.NewNode()
|
||||
|
||||
// Step 1: Create tagged node (expiry should be nil)
|
||||
registrationID1 := types.MustRegistrationID()
|
||||
regEntry1 := types.NewRegisterNode(types.Node{
|
||||
registrationID1 := types.MustAuthID()
|
||||
regEntry1 := types.NewRegisterAuthRequest(types.Node{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey1.Public(),
|
||||
Hostname: "tagged-to-personal",
|
||||
@ -733,7 +733,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
|
||||
RequestTags: []string{"tag:server"}, // Tagged node
|
||||
},
|
||||
})
|
||||
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1)
|
||||
app.state.SetAuthCacheEntry(registrationID1, regEntry1)
|
||||
|
||||
node, _, err := app.state.HandleNodeFromAuthPath(
|
||||
registrationID1, types.UserID(user.ID), nil, "webauth",
|
||||
@ -745,8 +745,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
|
||||
// Step 2: Re-auth with empty tags (Tagged → Personal conversion)
|
||||
nodeKey2 := key.NewNode()
|
||||
clientExpiry := time.Now().Add(48 * time.Hour)
|
||||
registrationID2 := types.MustRegistrationID()
|
||||
regEntry2 := types.NewRegisterNode(types.Node{
|
||||
registrationID2 := types.MustAuthID()
|
||||
regEntry2 := types.NewRegisterAuthRequest(types.Node{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey2.Public(),
|
||||
Hostname: "tagged-to-personal",
|
||||
@ -756,7 +756,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
|
||||
},
|
||||
Expiry: &clientExpiry, // Client requests expiry
|
||||
})
|
||||
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2)
|
||||
app.state.SetAuthCacheEntry(registrationID2, regEntry2)
|
||||
|
||||
nodeAfter, _, err := app.state.HandleNodeFromAuthPath(
|
||||
registrationID2, types.UserID(user.ID), nil, "webauth",
|
||||
|
||||
@ -676,28 +676,23 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
{
|
||||
name: "followup_registration_success",
|
||||
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper
|
||||
regID, err := types.NewRegistrationID()
|
||||
regID, err := types.NewAuthID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
registered := make(chan *types.Node, 1)
|
||||
nodeToRegister := types.RegisterNode{
|
||||
Node: types.Node{
|
||||
Hostname: "followup-success-node",
|
||||
},
|
||||
Registered: registered,
|
||||
}
|
||||
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
|
||||
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
|
||||
Hostname: "followup-success-node",
|
||||
})
|
||||
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
||||
|
||||
// Simulate successful registration - send to buffered channel
|
||||
// The channel is buffered (size 1), so this can complete immediately
|
||||
// and handleRegister will receive the value when it starts waiting
|
||||
// Simulate successful registration
|
||||
// handleRegister will receive the value when it starts waiting
|
||||
go func() {
|
||||
user := app.state.CreateUserForTest("followup-user")
|
||||
|
||||
node := app.state.CreateNodeForTest(user, "followup-success-node")
|
||||
registered <- node
|
||||
nodeToRegister.FinishRegistration(node.View())
|
||||
}()
|
||||
|
||||
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
|
||||
@ -723,20 +718,16 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
{
|
||||
name: "followup_registration_timeout",
|
||||
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper
|
||||
regID, err := types.NewRegistrationID()
|
||||
regID, err := types.NewAuthID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
registered := make(chan *types.Node, 1)
|
||||
nodeToRegister := types.RegisterNode{
|
||||
Node: types.Node{
|
||||
Hostname: "followup-timeout-node",
|
||||
},
|
||||
Registered: registered,
|
||||
}
|
||||
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
|
||||
// Don't send anything on channel - will timeout
|
||||
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
|
||||
Hostname: "followup-timeout-node",
|
||||
})
|
||||
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
||||
// Don't call FinishRegistration - will timeout
|
||||
|
||||
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
|
||||
},
|
||||
@ -1345,24 +1336,19 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
{
|
||||
name: "followup_registration_node_nil_response",
|
||||
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper
|
||||
regID, err := types.NewRegistrationID()
|
||||
regID, err := types.NewAuthID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
registered := make(chan *types.Node, 1)
|
||||
nodeToRegister := types.RegisterNode{
|
||||
Node: types.Node{
|
||||
Hostname: "nil-response-node",
|
||||
},
|
||||
Registered: registered,
|
||||
}
|
||||
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
|
||||
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
|
||||
Hostname: "nil-response-node",
|
||||
})
|
||||
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
||||
|
||||
// Simulate registration that returns nil (cache expired during auth)
|
||||
// The channel is buffered (size 1), so this can complete immediately
|
||||
// Simulate registration that returns empty NodeView (cache expired during auth)
|
||||
go func() {
|
||||
registered <- nil // Nil indicates cache expiry
|
||||
nodeToRegister.FinishRegistration(types.NodeView{}) // Empty view indicates cache expiry
|
||||
}()
|
||||
|
||||
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
|
||||
@ -1815,7 +1801,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper
|
||||
// Generate a registration ID that doesn't exist in cache
|
||||
// This simulates an expired/missing cache entry
|
||||
regID, err := types.NewRegistrationID()
|
||||
regID, err := types.NewAuthID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -1847,11 +1833,11 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
|
||||
// Extract and validate the new registration ID exists in cache
|
||||
newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/")
|
||||
newRegID, err := types.RegistrationIDFromString(newRegIDStr)
|
||||
newRegID, err := types.AuthIDFromString(newRegIDStr)
|
||||
assert.NoError(t, err, "should be able to parse new registration ID") //nolint:testifylint // inside closure
|
||||
|
||||
// Verify new registration entry exists in cache
|
||||
_, found := app.state.GetRegistrationCacheEntry(newRegID)
|
||||
_, found := app.state.GetAuthCacheEntry(newRegID)
|
||||
assert.True(t, found, "new registration should exist in cache")
|
||||
},
|
||||
},
|
||||
@ -2300,7 +2286,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify cache entry exists
|
||||
cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID)
|
||||
cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
|
||||
assert.True(t, found, "registration cache entry should exist initially")
|
||||
assert.NotNil(t, cacheEntry)
|
||||
|
||||
@ -2315,7 +2301,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
assert.Error(t, err, "should fail with invalid user ID") //nolint:testifylint // inside closure, uses assert pattern
|
||||
|
||||
// Cache entry should still exist after auth error (for retry scenarios)
|
||||
_, stillFound := app.state.GetRegistrationCacheEntry(registrationID)
|
||||
_, stillFound := app.state.GetAuthCacheEntry(registrationID)
|
||||
assert.True(t, stillFound, "registration cache entry should still exist after auth error for potential retry")
|
||||
},
|
||||
},
|
||||
@ -2375,8 +2361,8 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs")
|
||||
|
||||
// Both cache entries should exist simultaneously
|
||||
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
|
||||
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
|
||||
_, found1 := app.state.GetAuthCacheEntry(regID1)
|
||||
_, found2 := app.state.GetAuthCacheEntry(regID2)
|
||||
|
||||
assert.True(t, found1, "first registration cache entry should exist")
|
||||
assert.True(t, found2, "second registration cache entry should exist")
|
||||
@ -2427,8 +2413,8 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify both exist
|
||||
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
|
||||
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
|
||||
_, found1 := app.state.GetAuthCacheEntry(regID1)
|
||||
_, found2 := app.state.GetAuthCacheEntry(regID2)
|
||||
|
||||
assert.True(t, found1, "first cache entry should exist")
|
||||
assert.True(t, found2, "second cache entry should exist")
|
||||
@ -2490,7 +2476,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
}
|
||||
|
||||
// First registration should still be in cache (not completed)
|
||||
_, stillFound := app.state.GetRegistrationCacheEntry(regID1)
|
||||
_, stillFound := app.state.GetAuthCacheEntry(regID1)
|
||||
assert.True(t, stillFound, "first registration should still be pending")
|
||||
},
|
||||
},
|
||||
@ -2601,7 +2587,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
||||
var (
|
||||
initialResp *tailcfg.RegisterResponse
|
||||
authURL string
|
||||
registrationID types.RegistrationID
|
||||
registrationID types.AuthID
|
||||
finalResp *tailcfg.RegisterResponse
|
||||
err error
|
||||
)
|
||||
@ -2629,10 +2615,10 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
||||
|
||||
if step.expectCacheEntry {
|
||||
// Verify registration cache entry was created
|
||||
cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID)
|
||||
cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
|
||||
require.True(t, found, "registration cache entry should exist")
|
||||
require.NotNil(t, cacheEntry, "cache entry should not be nil")
|
||||
require.Equal(t, req.NodeKey, cacheEntry.Node.NodeKey, "cache entry should have correct node key")
|
||||
require.Equal(t, req.NodeKey, cacheEntry.Node().NodeKey(), "cache entry should have correct node key")
|
||||
}
|
||||
|
||||
case stepTypeAuthCompletion:
|
||||
@ -2692,7 +2678,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
||||
// Check cache cleanup expectation for this step
|
||||
if step.expectCacheEntry == false && registrationID != "" {
|
||||
// Verify cache entry was cleaned up
|
||||
_, found := app.state.GetRegistrationCacheEntry(registrationID)
|
||||
_, found := app.state.GetAuthCacheEntry(registrationID)
|
||||
require.False(t, found, "registration cache entry should be cleaned up after step: %s", step.stepType)
|
||||
}
|
||||
}
|
||||
@ -2714,7 +2700,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
||||
}
|
||||
|
||||
// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL.
|
||||
func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) {
|
||||
func extractRegistrationIDFromAuthURL(authURL string) (types.AuthID, error) {
|
||||
// AuthURL format: "http://localhost/register/abc123"
|
||||
const registerPrefix = "/register/"
|
||||
|
||||
@ -2725,7 +2711,7 @@ func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, err
|
||||
|
||||
idStr := authURL[idx+len(registerPrefix):]
|
||||
|
||||
return types.RegistrationIDFromString(idStr)
|
||||
return types.AuthIDFromString(idStr)
|
||||
}
|
||||
|
||||
// validateCompleteRegistrationResponse performs comprehensive validation of a registration response.
|
||||
@ -3583,8 +3569,8 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
|
||||
nodeKey := key.NewNode()
|
||||
|
||||
// Simulate a registration cache entry (as would be created during web auth)
|
||||
registrationID := types.MustRegistrationID()
|
||||
regEntry := types.NewRegisterNode(types.Node{
|
||||
registrationID := types.MustAuthID()
|
||||
regEntry := types.NewRegisterAuthRequest(types.Node{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "webauth-tags-node",
|
||||
@ -3593,7 +3579,7 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
|
||||
RequestTags: []string{"tag:unauthorized"}, // This tag is not in policy
|
||||
},
|
||||
})
|
||||
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
|
||||
app.state.SetAuthCacheEntry(registrationID, regEntry)
|
||||
|
||||
// Complete the web auth - should fail because tag is unauthorized
|
||||
_, _, err := app.state.HandleNodeFromAuthPath(
|
||||
@ -3646,8 +3632,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
|
||||
nodeKey1 := key.NewNode()
|
||||
|
||||
// Step 1: Initial registration with tags
|
||||
registrationID1 := types.MustRegistrationID()
|
||||
regEntry1 := types.NewRegisterNode(types.Node{
|
||||
registrationID1 := types.MustAuthID()
|
||||
regEntry1 := types.NewRegisterAuthRequest(types.Node{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey1.Public(),
|
||||
Hostname: "reauth-untag-node",
|
||||
@ -3656,7 +3642,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
|
||||
RequestTags: []string{"tag:valid-owned", "tag:second"},
|
||||
},
|
||||
})
|
||||
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1)
|
||||
app.state.SetAuthCacheEntry(registrationID1, regEntry1)
|
||||
|
||||
// Complete initial registration with tags
|
||||
node, _, err := app.state.HandleNodeFromAuthPath(
|
||||
@ -3673,8 +3659,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
|
||||
|
||||
// Step 2: Reauth with EMPTY tags to untag
|
||||
nodeKey2 := key.NewNode() // New node key for reauth
|
||||
registrationID2 := types.MustRegistrationID()
|
||||
regEntry2 := types.NewRegisterNode(types.Node{
|
||||
registrationID2 := types.MustAuthID()
|
||||
regEntry2 := types.NewRegisterAuthRequest(types.Node{
|
||||
MachineKey: machineKey.Public(), // Same machine key
|
||||
NodeKey: nodeKey2.Public(), // Different node key (rotation)
|
||||
Hostname: "reauth-untag-node",
|
||||
@ -3683,7 +3669,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
|
||||
RequestTags: []string{}, // EMPTY - should untag
|
||||
},
|
||||
})
|
||||
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2)
|
||||
app.state.SetAuthCacheEntry(registrationID2, regEntry2)
|
||||
|
||||
// Complete reauth with empty tags
|
||||
nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath(
|
||||
@ -3759,8 +3745,8 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) {
|
||||
|
||||
// Step 2: Reauth via web auth with EMPTY tags to transition to user-owned
|
||||
nodeKey2 := key.NewNode() // New node key for reauth
|
||||
registrationID := types.MustRegistrationID()
|
||||
regEntry := types.NewRegisterNode(types.Node{
|
||||
registrationID := types.MustAuthID()
|
||||
regEntry := types.NewRegisterAuthRequest(types.Node{
|
||||
MachineKey: machineKey.Public(), // Same machine key
|
||||
NodeKey: nodeKey2.Public(), // Different node key (rotation)
|
||||
Hostname: "authkey-tagged-node",
|
||||
@ -3769,7 +3755,7 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) {
|
||||
RequestTags: []string{}, // EMPTY - should untag
|
||||
},
|
||||
})
|
||||
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
|
||||
app.state.SetAuthCacheEntry(registrationID, regEntry)
|
||||
|
||||
// Complete reauth with empty tags
|
||||
nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath(
|
||||
@ -3958,8 +3944,8 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) {
|
||||
// Step 4: Re-register the node to alice via HandleNodeFromAuthPath
|
||||
// This is what happens when running: headscale nodes register --user alice --key ...
|
||||
nodeKey2 := key.NewNode()
|
||||
registrationID := types.MustRegistrationID()
|
||||
regEntry := types.NewRegisterNode(types.Node{
|
||||
registrationID := types.MustAuthID()
|
||||
regEntry := types.NewRegisterAuthRequest(types.Node{
|
||||
MachineKey: machineKey.Public(), // Same machine key as the tagged node
|
||||
NodeKey: nodeKey2.Public(),
|
||||
Hostname: "tagged-orphan-node",
|
||||
@ -3968,7 +3954,7 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) {
|
||||
RequestTags: []string{}, // Empty - transition to user-owned
|
||||
},
|
||||
})
|
||||
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
|
||||
app.state.SetAuthCacheEntry(registrationID, regEntry)
|
||||
|
||||
// This should NOT panic - before the fix, this would panic with:
|
||||
// panic: runtime error: invalid memory address or nil pointer dereference
|
||||
|
||||
@ -47,7 +47,7 @@ const (
|
||||
type HSDatabase struct {
|
||||
DB *gorm.DB
|
||||
cfg *types.Config
|
||||
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
|
||||
regCache *zcache.Cache[types.AuthID, types.AuthRequest]
|
||||
}
|
||||
|
||||
// NewHeadscaleDatabase creates a new database connection and runs migrations.
|
||||
@ -56,7 +56,7 @@ type HSDatabase struct {
|
||||
//nolint:gocyclo // complex database initialization with many migrations
|
||||
func NewHeadscaleDatabase(
|
||||
cfg *types.Config,
|
||||
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
|
||||
regCache *zcache.Cache[types.AuthID, types.AuthRequest],
|
||||
) (*HSDatabase, error) {
|
||||
dbConn, err := openDB(cfg.Database)
|
||||
if err != nil {
|
||||
|
||||
@ -162,8 +162,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
|
||||
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
|
||||
func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
|
||||
return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
|
||||
}
|
||||
|
||||
func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
||||
|
||||
@ -247,7 +247,7 @@ func (api headscaleV1APIServer) RegisterNode(
|
||||
Str(zf.RegistrationKey, registrationKey).
|
||||
Msg("registering node")
|
||||
|
||||
registrationId, err := types.RegistrationIDFromString(request.GetKey())
|
||||
registrationId, err := types.AuthIDFromString(request.GetKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -780,33 +780,32 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
||||
Hostname: request.GetName(),
|
||||
}
|
||||
|
||||
registrationId, err := types.RegistrationIDFromString(request.GetKey())
|
||||
registrationId, err := types.AuthIDFromString(request.GetKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newNode := types.NewRegisterNode(
|
||||
types.Node{
|
||||
NodeKey: key.NewNode().Public(),
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
Hostname: request.GetName(),
|
||||
User: user,
|
||||
newNode := types.Node{
|
||||
NodeKey: key.NewNode().Public(),
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
Hostname: request.GetName(),
|
||||
User: user,
|
||||
|
||||
Expiry: &time.Time{},
|
||||
LastSeen: &time.Time{},
|
||||
Expiry: &time.Time{},
|
||||
LastSeen: &time.Time{},
|
||||
|
||||
Hostinfo: &hostinfo,
|
||||
},
|
||||
)
|
||||
Hostinfo: &hostinfo,
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Str("registration_id", registrationId.String()).
|
||||
Msg("adding debug machine via CLI, appending to registration cache")
|
||||
|
||||
api.h.state.SetRegistrationCacheEntry(registrationId, newNode)
|
||||
authRegReq := types.NewRegisterAuthRequest(newNode)
|
||||
api.h.state.SetAuthCacheEntry(registrationId, authRegReq)
|
||||
|
||||
return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil
|
||||
return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) Health(
|
||||
|
||||
@ -11,7 +11,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/juanfont/headscale/hscontrol/assets"
|
||||
"github.com/juanfont/headscale/hscontrol/templates"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
@ -245,11 +244,41 @@ func NewAuthProviderWeb(serverURL string) *AuthProviderWeb {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string {
|
||||
func (a *AuthProviderWeb) RegisterURL(authID types.AuthID) string {
|
||||
return fmt.Sprintf(
|
||||
"%s/register/%s",
|
||||
strings.TrimSuffix(a.serverURL, "/"),
|
||||
registrationId.String())
|
||||
authID.String())
|
||||
}
|
||||
|
||||
func (a *AuthProviderWeb) AuthURL(authID types.AuthID) string {
|
||||
return fmt.Sprintf(
|
||||
"%s/auth/%s",
|
||||
strings.TrimSuffix(a.serverURL, "/"),
|
||||
authID.String())
|
||||
}
|
||||
|
||||
func (a *AuthProviderWeb) AuthHandler(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
}
|
||||
|
||||
func authIDFromRequest(req *http.Request) (types.AuthID, error) {
|
||||
registrationId, err := urlParam[types.AuthID](req, "auth_id")
|
||||
if err != nil {
|
||||
return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err))
|
||||
}
|
||||
|
||||
// We need to make sure we dont open for XSS style injections, if the parameter that
|
||||
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
|
||||
// the template and log an error.
|
||||
err = registrationId.Validate()
|
||||
if err != nil {
|
||||
return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err))
|
||||
}
|
||||
|
||||
return registrationId, nil
|
||||
}
|
||||
|
||||
// RegisterHandler shows a simple message in the browser to point to the CLI
|
||||
@ -261,15 +290,9 @@ func (a *AuthProviderWeb) RegisterHandler(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
vars := mux.Vars(req)
|
||||
registrationIdStr := vars["registration_id"]
|
||||
|
||||
// We need to make sure we dont open for XSS style injections, if the parameter that
|
||||
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
|
||||
// the template and log an error.
|
||||
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
|
||||
registrationId, err := authIDFromRequest(req)
|
||||
if err != nil {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -95,8 +95,8 @@ var allBatcherFunctions = []batcherTestCase{
|
||||
}
|
||||
|
||||
// emptyCache creates an empty registration cache for testing.
|
||||
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
|
||||
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
|
||||
func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
|
||||
return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
|
||||
}
|
||||
|
||||
// Test configuration constants.
|
||||
|
||||
@ -12,7 +12,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/templates"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
@ -26,8 +25,8 @@ import (
|
||||
const (
|
||||
randomByteSize = 16
|
||||
defaultOAuthOptionsCount = 3
|
||||
registerCacheExpiration = time.Minute * 15
|
||||
registerCacheCleanup = time.Minute * 20
|
||||
authCacheExpiration = time.Minute * 15
|
||||
authCacheCleanup = time.Minute * 20
|
||||
)
|
||||
|
||||
var (
|
||||
@ -44,17 +43,21 @@ var (
|
||||
errOIDCUnverifiedEmail = errors.New("authenticated principal has an unverified email")
|
||||
)
|
||||
|
||||
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
|
||||
type RegistrationInfo struct {
|
||||
RegistrationID types.RegistrationID
|
||||
Verifier *string
|
||||
// AuthInfo contains both auth ID and verifier information for OIDC validation.
|
||||
type AuthInfo struct {
|
||||
AuthID types.AuthID
|
||||
Verifier *string
|
||||
Registration bool
|
||||
}
|
||||
|
||||
type AuthProviderOIDC struct {
|
||||
h *Headscale
|
||||
serverURL string
|
||||
cfg *types.OIDCConfig
|
||||
registrationCache *zcache.Cache[string, RegistrationInfo]
|
||||
h *Headscale
|
||||
serverURL string
|
||||
cfg *types.OIDCConfig
|
||||
|
||||
// authCache holds auth information between
|
||||
// the auth and the callback steps.
|
||||
authCache *zcache.Cache[string, AuthInfo]
|
||||
|
||||
oidcProvider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
@ -81,45 +84,63 @@ func NewAuthProviderOIDC(
|
||||
Scopes: cfg.Scope,
|
||||
}
|
||||
|
||||
registrationCache := zcache.New[string, RegistrationInfo](
|
||||
registerCacheExpiration,
|
||||
registerCacheCleanup,
|
||||
authCache := zcache.New[string, AuthInfo](
|
||||
authCacheExpiration,
|
||||
authCacheCleanup,
|
||||
)
|
||||
|
||||
return &AuthProviderOIDC{
|
||||
h: h,
|
||||
serverURL: serverURL,
|
||||
cfg: cfg,
|
||||
registrationCache: registrationCache,
|
||||
h: h,
|
||||
serverURL: serverURL,
|
||||
cfg: cfg,
|
||||
authCache: authCache,
|
||||
|
||||
oidcProvider: oidcProvider,
|
||||
oauth2Config: oauth2Config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string {
|
||||
func (a *AuthProviderOIDC) AuthURL(authID types.AuthID) string {
|
||||
return fmt.Sprintf(
|
||||
"%s/auth/%s",
|
||||
strings.TrimSuffix(a.serverURL, "/"),
|
||||
authID.String())
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) AuthHandler(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
a.authHandler(writer, req, false)
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) RegisterURL(authID types.AuthID) string {
|
||||
return fmt.Sprintf(
|
||||
"%s/register/%s",
|
||||
strings.TrimSuffix(a.serverURL, "/"),
|
||||
registrationID.String())
|
||||
authID.String())
|
||||
}
|
||||
|
||||
// RegisterHandler registers the OIDC callback handler with the given router.
|
||||
// It puts NodeKey in cache so the callback can retrieve it using the oidc state param.
|
||||
// Listens in /register/:registration_id.
|
||||
// Listens in /register/:auth_id.
|
||||
func (a *AuthProviderOIDC) RegisterHandler(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
vars := mux.Vars(req)
|
||||
registrationIdStr := vars["registration_id"]
|
||||
a.authHandler(writer, req, true)
|
||||
}
|
||||
|
||||
// We need to make sure we dont open for XSS style injections, if the parameter that
|
||||
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
|
||||
// the template and log an error.
|
||||
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
|
||||
// authHandler takes an incoming request that needs to be authenticated and
|
||||
// validates and prepares it for the OIDC flow.
|
||||
func (a *AuthProviderOIDC) authHandler(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
registration bool,
|
||||
) {
|
||||
authID, err := authIDFromRequest(req)
|
||||
if err != nil {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -137,9 +158,9 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize registration info with machine key
|
||||
registrationInfo := RegistrationInfo{
|
||||
RegistrationID: registrationId,
|
||||
registrationInfo := AuthInfo{
|
||||
AuthID: authID,
|
||||
Registration: registration,
|
||||
}
|
||||
|
||||
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
|
||||
@ -167,7 +188,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
||||
extras = append(extras, oidc.Nonce(nonce))
|
||||
|
||||
// Cache the registration info
|
||||
a.registrationCache.Set(state, registrationInfo)
|
||||
a.authCache.Set(state, registrationInfo)
|
||||
|
||||
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
|
||||
log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL)
|
||||
@ -302,16 +323,22 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
// If the node exists, then the node should be reauthenticated,
|
||||
// if the node does not exist, and the machine key exists, then
|
||||
// this is a new node that should be registered.
|
||||
registrationId := a.getRegistrationIDFromState(state)
|
||||
authInfo := a.getAuthInfoFromState(state)
|
||||
if authInfo == nil {
|
||||
log.Debug().Caller().Str("state", state).Msg("state not found in cache, login session may have expired")
|
||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
|
||||
|
||||
// Register the node if it does not exist.
|
||||
if registrationId != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// If this is a registration flow, then we need to register the node.
|
||||
if authInfo.Registration {
|
||||
verb := "Reauthenticated"
|
||||
|
||||
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
|
||||
newNode, err := a.handleRegistration(user, authInfo.AuthID, nodeExpiry)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
|
||||
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed")
|
||||
log.Debug().Caller().Str("registration_id", authInfo.AuthID.String()).Msg("registration session expired before authorization completed")
|
||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
|
||||
|
||||
return
|
||||
@ -339,9 +366,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
return
|
||||
}
|
||||
|
||||
// Neither node nor machine key was found in the state cache meaning
|
||||
// that we could not reauth nor register the node.
|
||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
|
||||
// TODO(kradalby): handle login flow (without registration) if needed.
|
||||
// We need to send an update here to whatever might be waiting for this auth flow.
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time {
|
||||
@ -374,7 +400,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
|
||||
var exchangeOpts []oauth2.AuthCodeOption
|
||||
|
||||
if a.cfg.PKCE.Enabled {
|
||||
regInfo, ok := a.registrationCache.Get(state)
|
||||
regInfo, ok := a.authCache.Get(state)
|
||||
if !ok {
|
||||
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
|
||||
}
|
||||
@ -507,14 +533,14 @@ func doOIDCAuthorization(
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRegistrationIDFromState retrieves the registration ID from the state.
|
||||
func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID {
|
||||
regInfo, ok := a.registrationCache.Get(state)
|
||||
// getAuthInfoFromState retrieves the registration ID from the state.
|
||||
func (a *AuthProviderOIDC) getAuthInfoFromState(state string) *AuthInfo {
|
||||
authInfo, ok := a.authCache.Get(state)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ®Info.RegistrationID
|
||||
return &authInfo
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
@ -562,7 +588,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
|
||||
func (a *AuthProviderOIDC) handleRegistration(
|
||||
user *types.User,
|
||||
registrationID types.RegistrationID,
|
||||
registrationID types.AuthID,
|
||||
expiry time.Time,
|
||||
) (bool, error) {
|
||||
node, nodeChange, err := a.h.state.HandleNodeFromAuthPath(
|
||||
|
||||
@ -82,8 +82,10 @@ type State struct {
|
||||
derpMap atomic.Pointer[tailcfg.DERPMap]
|
||||
// polMan handles policy evaluation and management
|
||||
polMan policy.PolicyManager
|
||||
// registrationCache caches node registration data to reduce database load
|
||||
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
|
||||
|
||||
// authCache caches any pending authentication requests, from either auth type (Web and OIDC).
|
||||
authCache *zcache.Cache[types.AuthID, types.AuthRequest]
|
||||
|
||||
// primaryRoutes tracks primary route assignments for nodes
|
||||
primaryRoutes *routes.PrimaryRoutes
|
||||
}
|
||||
@ -101,20 +103,20 @@ func NewState(cfg *types.Config) (*State, error) {
|
||||
cacheCleanup = cfg.Tuning.RegisterCacheCleanup
|
||||
}
|
||||
|
||||
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
|
||||
authCache := zcache.New[types.AuthID, types.AuthRequest](
|
||||
cacheExpiration,
|
||||
cacheCleanup,
|
||||
)
|
||||
|
||||
registrationCache.OnEvicted(
|
||||
func(id types.RegistrationID, rn types.RegisterNode) {
|
||||
rn.SendAndClose(nil)
|
||||
authCache.OnEvicted(
|
||||
func(id types.AuthID, rn types.AuthRequest) {
|
||||
rn.FinishRegistration(types.NodeView{})
|
||||
},
|
||||
)
|
||||
|
||||
db, err := hsdb.NewHeadscaleDatabase(
|
||||
cfg,
|
||||
registrationCache,
|
||||
authCache,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing database: %w", err)
|
||||
@ -178,12 +180,12 @@ func NewState(cfg *types.Config) (*State, error) {
|
||||
return &State{
|
||||
cfg: cfg,
|
||||
|
||||
db: db,
|
||||
ipAlloc: ipAlloc,
|
||||
polMan: polMan,
|
||||
registrationCache: registrationCache,
|
||||
primaryRoutes: routes.New(),
|
||||
nodeStore: nodeStore,
|
||||
db: db,
|
||||
ipAlloc: ipAlloc,
|
||||
polMan: polMan,
|
||||
authCache: authCache,
|
||||
primaryRoutes: routes.New(),
|
||||
nodeStore: nodeStore,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -1042,9 +1044,9 @@ func (s *State) DeletePreAuthKey(id uint64) error {
|
||||
return s.db.DeletePreAuthKey(id)
|
||||
}
|
||||
|
||||
// GetRegistrationCacheEntry retrieves a node registration from cache.
|
||||
func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.RegisterNode, bool) {
|
||||
entry, found := s.registrationCache.Get(id)
|
||||
// GetAuthCacheEntry retrieves a node registration from cache.
|
||||
func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) {
|
||||
entry, found := s.authCache.Get(id)
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
@ -1052,26 +1054,24 @@ func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.Regis
|
||||
return &entry, true
|
||||
}
|
||||
|
||||
// SetRegistrationCacheEntry stores a node registration in cache.
|
||||
func (s *State) SetRegistrationCacheEntry(id types.RegistrationID, entry types.RegisterNode) {
|
||||
s.registrationCache.Set(id, entry)
|
||||
// SetAuthCacheEntry stores a node registration in cache.
|
||||
func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) {
|
||||
s.authCache.Set(id, entry)
|
||||
}
|
||||
|
||||
// logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname.
|
||||
func logHostinfoValidation(machineKey, nodeKey, username, hostname string, hostinfo *tailcfg.Hostinfo) {
|
||||
if hostinfo == nil {
|
||||
func logHostinfoValidation(nv types.NodeView, username, hostname string) {
|
||||
if !nv.Hostinfo().Valid() {
|
||||
log.Warn().
|
||||
Caller().
|
||||
Str(zf.MachineKey, machineKey).
|
||||
Str(zf.NodeKey, nodeKey).
|
||||
EmbedObject(nv).
|
||||
Str(zf.UserName, username).
|
||||
Str(zf.GeneratedHostname, hostname).
|
||||
Msg("Registration had nil hostinfo, generated default hostname")
|
||||
} else if hostinfo.Hostname == "" {
|
||||
} else if nv.Hostinfo().Hostname() == "" {
|
||||
log.Warn().
|
||||
Caller().
|
||||
Str(zf.MachineKey, machineKey).
|
||||
Str(zf.NodeKey, nodeKey).
|
||||
EmbedObject(nv).
|
||||
Str(zf.UserName, username).
|
||||
Str(zf.GeneratedHostname, hostname).
|
||||
Msg("Registration had empty hostname, generated default")
|
||||
@ -1113,7 +1113,7 @@ type authNodeUpdateParams struct {
|
||||
// Node to update; must be valid and in NodeStore.
|
||||
ExistingNode types.NodeView
|
||||
// Client data: keys, hostinfo, endpoints.
|
||||
RegEntry *types.RegisterNode
|
||||
RegEntry *types.AuthRequest
|
||||
// Pre-validated hostinfo; NetInfo preserved from ExistingNode.
|
||||
ValidHostinfo *tailcfg.Hostinfo
|
||||
// Hostname from hostinfo, or generated from keys if client omits it.
|
||||
@ -1132,6 +1132,7 @@ type authNodeUpdateParams struct {
|
||||
// an existing node. It updates the node in NodeStore, processes RequestTags, and
|
||||
// persists changes to the database.
|
||||
func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) {
|
||||
regNv := params.RegEntry.Node()
|
||||
// Log the operation type
|
||||
if params.IsConvertFromTag {
|
||||
log.Info().
|
||||
@ -1140,16 +1141,16 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
Msg("Converting tagged node to user-owned node")
|
||||
} else {
|
||||
log.Info().
|
||||
EmbedObject(params.ExistingNode).
|
||||
Interface("hostinfo", params.RegEntry.Node.Hostinfo).
|
||||
Object("existing", params.ExistingNode).
|
||||
Object("incoming", regNv).
|
||||
Msg("Updating existing node registration via reauth")
|
||||
}
|
||||
|
||||
// Process RequestTags during reauth (#2979)
|
||||
// Due to json:",omitempty", we treat empty/nil as "clear tags"
|
||||
var requestTags []string
|
||||
if params.RegEntry.Node.Hostinfo != nil {
|
||||
requestTags = params.RegEntry.Node.Hostinfo.RequestTags
|
||||
if regNv.Hostinfo().Valid() {
|
||||
requestTags = regNv.Hostinfo().RequestTags().AsSlice()
|
||||
}
|
||||
|
||||
oldTags := params.ExistingNode.Tags().AsSlice()
|
||||
@ -1167,8 +1168,8 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
|
||||
// Update existing node in NodeStore - validation passed, safe to mutate
|
||||
updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) {
|
||||
node.NodeKey = params.RegEntry.Node.NodeKey
|
||||
node.DiscoKey = params.RegEntry.Node.DiscoKey
|
||||
node.NodeKey = regNv.NodeKey()
|
||||
node.DiscoKey = regNv.DiscoKey()
|
||||
node.Hostname = params.Hostname
|
||||
|
||||
// Preserve NetInfo from existing node when re-registering
|
||||
@ -1179,7 +1180,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
params.ValidHostinfo,
|
||||
)
|
||||
|
||||
node.Endpoints = params.RegEntry.Node.Endpoints
|
||||
node.Endpoints = regNv.Endpoints().AsSlice()
|
||||
node.IsOnline = new(false)
|
||||
node.LastSeen = new(time.Now())
|
||||
|
||||
@ -1188,7 +1189,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
if params.IsConvertFromTag {
|
||||
node.RegisterMethod = params.RegisterMethod
|
||||
} else {
|
||||
node.RegisterMethod = params.RegEntry.Node.RegisterMethod
|
||||
node.RegisterMethod = regNv.RegisterMethod()
|
||||
}
|
||||
|
||||
// Track tagged status BEFORE processing tags
|
||||
@ -1208,7 +1209,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
if params.Expiry != nil {
|
||||
node.Expiry = params.Expiry
|
||||
} else {
|
||||
node.Expiry = params.RegEntry.Node.Expiry
|
||||
node.Expiry = regNv.Expiry().Clone()
|
||||
}
|
||||
case !wasTagged && isTagged:
|
||||
// Personal → Tagged: clear expiry (tagged nodes don't expire)
|
||||
@ -1218,14 +1219,14 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
if params.Expiry != nil {
|
||||
node.Expiry = params.Expiry
|
||||
} else {
|
||||
node.Expiry = params.RegEntry.Node.Expiry
|
||||
node.Expiry = regNv.Expiry().Clone()
|
||||
}
|
||||
case !isTagged:
|
||||
// Personal → Personal: update expiry from client
|
||||
if params.Expiry != nil {
|
||||
node.Expiry = params.Expiry
|
||||
} else {
|
||||
node.Expiry = params.RegEntry.Node.Expiry
|
||||
node.Expiry = regNv.Expiry().Clone()
|
||||
}
|
||||
}
|
||||
// Tagged → Tagged: keep existing expiry (nil) - no action needed
|
||||
@ -1511,13 +1512,13 @@ func (s *State) processReauthTags(
|
||||
|
||||
// HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC).
|
||||
func (s *State) HandleNodeFromAuthPath(
|
||||
registrationID types.RegistrationID,
|
||||
authID types.AuthID,
|
||||
userID types.UserID,
|
||||
expiry *time.Time,
|
||||
registrationMethod string,
|
||||
) (types.NodeView, change.Change, error) {
|
||||
// Get the registration entry from cache
|
||||
regEntry, ok := s.GetRegistrationCacheEntry(registrationID)
|
||||
regEntry, ok := s.GetAuthCacheEntry(authID)
|
||||
if !ok {
|
||||
return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache
|
||||
}
|
||||
@ -1530,25 +1531,27 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
|
||||
// Ensure we have a valid hostname from the registration cache entry
|
||||
hostname := util.EnsureHostname(
|
||||
regEntry.Node.Hostinfo,
|
||||
regEntry.Node.MachineKey.String(),
|
||||
regEntry.Node.NodeKey.String(),
|
||||
regEntry.Node().Hostinfo(),
|
||||
regEntry.Node().MachineKey().String(),
|
||||
regEntry.Node().NodeKey().String(),
|
||||
)
|
||||
|
||||
// Ensure we have valid hostinfo
|
||||
validHostinfo := cmp.Or(regEntry.Node.Hostinfo, &tailcfg.Hostinfo{})
|
||||
validHostinfo.Hostname = hostname
|
||||
hostinfo := &tailcfg.Hostinfo{}
|
||||
if regEntry.Node().Hostinfo().Valid() {
|
||||
hostinfo = regEntry.Node().Hostinfo().AsStruct()
|
||||
}
|
||||
|
||||
hostinfo.Hostname = hostname
|
||||
|
||||
logHostinfoValidation(
|
||||
regEntry.Node.MachineKey.ShortString(),
|
||||
regEntry.Node.NodeKey.String(),
|
||||
regEntry.Node(),
|
||||
user.Name,
|
||||
hostname,
|
||||
regEntry.Node.Hostinfo,
|
||||
)
|
||||
|
||||
// Lookup existing nodes
|
||||
machineKey := regEntry.Node.MachineKey
|
||||
machineKey := regEntry.Node().MachineKey()
|
||||
existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID))
|
||||
existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey)
|
||||
|
||||
@ -1562,7 +1565,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
|
||||
// Create logger with common fields for all auth operations
|
||||
logger := log.With().
|
||||
Str(zf.RegistrationID, registrationID.String()).
|
||||
Str(zf.RegistrationID, authID.String()).
|
||||
Str(zf.UserName, user.Name).
|
||||
Str(zf.MachineKey, machineKey.ShortString()).
|
||||
Str(zf.Method, registrationMethod).
|
||||
@ -1571,7 +1574,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
// Common params for update operations
|
||||
updateParams := authNodeUpdateParams{
|
||||
RegEntry: regEntry,
|
||||
ValidHostinfo: validHostinfo,
|
||||
ValidHostinfo: hostinfo,
|
||||
Hostname: hostname,
|
||||
User: user,
|
||||
Expiry: expiry,
|
||||
@ -1605,7 +1608,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
Msg("Creating new node for different user (same machine key exists for another user)")
|
||||
|
||||
finalNode, err = s.createNewNodeFromAuth(
|
||||
logger, user, regEntry, hostname, validHostinfo,
|
||||
logger, user, regEntry, hostname, hostinfo,
|
||||
expiry, registrationMethod, existingNodeAnyUser,
|
||||
)
|
||||
if err != nil {
|
||||
@ -1613,7 +1616,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
}
|
||||
} else {
|
||||
finalNode, err = s.createNewNodeFromAuth(
|
||||
logger, user, regEntry, hostname, validHostinfo,
|
||||
logger, user, regEntry, hostname, hostinfo,
|
||||
expiry, registrationMethod, types.NodeView{},
|
||||
)
|
||||
if err != nil {
|
||||
@ -1622,10 +1625,10 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
}
|
||||
|
||||
// Signal to waiting clients
|
||||
regEntry.SendAndClose(finalNode.AsStruct())
|
||||
regEntry.FinishRegistration(finalNode)
|
||||
|
||||
// Delete from registration cache
|
||||
s.registrationCache.Delete(registrationID)
|
||||
s.authCache.Delete(authID)
|
||||
|
||||
// Update policy managers
|
||||
usersChange, err := s.updatePolicyManagerUsers()
|
||||
@ -1654,7 +1657,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
func (s *State) createNewNodeFromAuth(
|
||||
logger zerolog.Logger,
|
||||
user *types.User,
|
||||
regEntry *types.RegisterNode,
|
||||
regEntry *types.AuthRequest,
|
||||
hostname string,
|
||||
validHostinfo *tailcfg.Hostinfo,
|
||||
expiry *time.Time,
|
||||
@ -1667,13 +1670,13 @@ func (s *State) createNewNodeFromAuth(
|
||||
|
||||
return s.createAndSaveNewNode(newNodeParams{
|
||||
User: *user,
|
||||
MachineKey: regEntry.Node.MachineKey,
|
||||
NodeKey: regEntry.Node.NodeKey,
|
||||
DiscoKey: regEntry.Node.DiscoKey,
|
||||
MachineKey: regEntry.Node().MachineKey(),
|
||||
NodeKey: regEntry.Node().NodeKey(),
|
||||
DiscoKey: regEntry.Node().DiscoKey(),
|
||||
Hostname: hostname,
|
||||
Hostinfo: validHostinfo,
|
||||
Endpoints: regEntry.Node.Endpoints,
|
||||
Expiry: cmp.Or(expiry, regEntry.Node.Expiry),
|
||||
Endpoints: regEntry.Node().Endpoints().AsSlice(),
|
||||
Expiry: cmp.Or(expiry, regEntry.Node().Expiry().Clone()),
|
||||
RegisterMethod: registrationMethod,
|
||||
ExistingNodeForNetinfo: existingNodeForNetinfo,
|
||||
})
|
||||
@ -1759,7 +1762,7 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
|
||||
// Ensure we have a valid hostname - handle nil/empty cases
|
||||
hostname := util.EnsureHostname(
|
||||
regReq.Hostinfo,
|
||||
regReq.Hostinfo.View(),
|
||||
machineKey.String(),
|
||||
regReq.NodeKey.String(),
|
||||
)
|
||||
@ -1768,14 +1771,6 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
validHostinfo := cmp.Or(regReq.Hostinfo, &tailcfg.Hostinfo{})
|
||||
validHostinfo.Hostname = hostname
|
||||
|
||||
logHostinfoValidation(
|
||||
machineKey.ShortString(),
|
||||
regReq.NodeKey.ShortString(),
|
||||
pakUsername(),
|
||||
hostname,
|
||||
regReq.Hostinfo,
|
||||
)
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Str(zf.NodeName, hostname).
|
||||
|
||||
@ -7,7 +7,7 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
)
|
||||
|
||||
func RegisterWeb(registrationID types.RegistrationID) *elem.Element {
|
||||
func RegisterWeb(registrationID types.AuthID) *elem.Element {
|
||||
return HtmlStructure(
|
||||
elem.Title(nil, elem.Text("Registration - Headscale")),
|
||||
mdTypesetBody(
|
||||
|
||||
@ -21,7 +21,7 @@ func TestTemplateHTMLConsistency(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Register Web",
|
||||
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
|
||||
html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(),
|
||||
},
|
||||
{
|
||||
name: "Windows Config",
|
||||
@ -77,7 +77,7 @@ func TestTemplateModernHTMLFeatures(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Register Web",
|
||||
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
|
||||
html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(),
|
||||
},
|
||||
{
|
||||
name: "Windows Config",
|
||||
@ -125,7 +125,7 @@ func TestTemplateExternalLinkSecurity(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Register Web",
|
||||
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
|
||||
html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(),
|
||||
externalURLs: []string{}, // No external links
|
||||
},
|
||||
{
|
||||
@ -190,7 +190,7 @@ func TestTemplateAccessibilityAttributes(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Register Web",
|
||||
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(),
|
||||
html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(),
|
||||
},
|
||||
{
|
||||
name: "Windows Config",
|
||||
|
||||
@ -22,8 +22,8 @@ const (
|
||||
|
||||
// Common errors.
|
||||
var (
|
||||
ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||
ErrInvalidRegistrationIDLength = errors.New("registration ID has invalid length")
|
||||
ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||
ErrInvalidAuthIDLength = errors.New("registration ID has invalid length")
|
||||
)
|
||||
|
||||
type StateUpdateType int
|
||||
@ -159,21 +159,21 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
|
||||
}
|
||||
}
|
||||
|
||||
const RegistrationIDLength = 24
|
||||
const AuthIDLength = 24
|
||||
|
||||
type RegistrationID string
|
||||
type AuthID string
|
||||
|
||||
func NewRegistrationID() (RegistrationID, error) {
|
||||
rid, err := util.GenerateRandomStringURLSafe(RegistrationIDLength)
|
||||
func NewAuthID() (AuthID, error) {
|
||||
rid, err := util.GenerateRandomStringURLSafe(AuthIDLength)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return RegistrationID(rid), nil
|
||||
return AuthID(rid), nil
|
||||
}
|
||||
|
||||
func MustRegistrationID() RegistrationID {
|
||||
rid, err := NewRegistrationID()
|
||||
func MustAuthID() AuthID {
|
||||
rid, err := NewAuthID()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -181,43 +181,87 @@ func MustRegistrationID() RegistrationID {
|
||||
return rid
|
||||
}
|
||||
|
||||
func RegistrationIDFromString(str string) (RegistrationID, error) {
|
||||
if len(str) != RegistrationIDLength {
|
||||
return "", fmt.Errorf("%w: expected %d, got %d", ErrInvalidRegistrationIDLength, RegistrationIDLength, len(str))
|
||||
func AuthIDFromString(str string) (AuthID, error) {
|
||||
r := AuthID(str)
|
||||
|
||||
err := r.Validate()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return RegistrationID(str), nil
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r RegistrationID) String() string {
|
||||
func (r AuthID) String() string {
|
||||
return string(r)
|
||||
}
|
||||
|
||||
type RegisterNode struct {
|
||||
Node Node
|
||||
Registered chan *Node
|
||||
closed *atomic.Bool
|
||||
func (r AuthID) Validate() error {
|
||||
if len(r) != AuthIDLength {
|
||||
return fmt.Errorf("%w: expected %d, got %d", ErrInvalidAuthIDLength, AuthIDLength, len(r))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewRegisterNode(node Node) RegisterNode {
|
||||
return RegisterNode{
|
||||
Node: node,
|
||||
Registered: make(chan *Node),
|
||||
closed: &atomic.Bool{},
|
||||
// AuthRequest represent a pending authentication request from a user or a node.
|
||||
// If it is a registration request, the node field will be populate with the node that is trying to register.
|
||||
// When the authentication process is finished, the node that has been authenticated will be sent through the Finished channel.
|
||||
// The closed field is used to ensure that the Finished channel is only closed once, and that no more nodes are sent through it after it has been closed.
|
||||
type AuthRequest struct {
|
||||
node *Node
|
||||
finished chan NodeView
|
||||
closed *atomic.Bool
|
||||
}
|
||||
|
||||
func NewRegisterAuthRequest(node Node) AuthRequest {
|
||||
return AuthRequest{
|
||||
node: &node,
|
||||
finished: make(chan NodeView),
|
||||
closed: &atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func (rn *RegisterNode) SendAndClose(node *Node) {
|
||||
// Node returns the node that is trying to register.
|
||||
// It will panic if the AuthRequest is not a registration request.
|
||||
// Can _only_ be used in the registration path.
|
||||
func (rn *AuthRequest) Node() NodeView {
|
||||
if rn.node == nil {
|
||||
panic("Node can only be used in registration requests")
|
||||
}
|
||||
|
||||
return rn.node.View()
|
||||
}
|
||||
|
||||
func (rn *AuthRequest) FinishAuth() {
|
||||
rn.FinishRegistration(NodeView{})
|
||||
}
|
||||
|
||||
func (rn *AuthRequest) FinishRegistration(node NodeView) {
|
||||
if rn.closed.Swap(true) {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case rn.Registered <- node:
|
||||
default:
|
||||
if node.Valid() {
|
||||
select {
|
||||
case rn.finished <- node:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
close(rn.Registered)
|
||||
close(rn.finished)
|
||||
}
|
||||
|
||||
// WaitForRegistration waits for the authentication process to finish
|
||||
// and returns the authenticated node.
|
||||
// Can _only_ be used in the registration path.
|
||||
func (rn *AuthRequest) WaitForRegistration() <-chan NodeView {
|
||||
return rn.finished
|
||||
}
|
||||
|
||||
// WaitForAuth waits until a authentication request has been finished.
|
||||
func (rn *AuthRequest) WaitForAuth() {
|
||||
<-rn.WaitForRegistration()
|
||||
}
|
||||
|
||||
// DefaultBatcherWorkers returns the default number of batcher workers.
|
||||
|
||||
@ -295,8 +295,8 @@ func IsCI() bool {
|
||||
// 3. If normalisation fails → generate invalid-<random> replacement
|
||||
//
|
||||
// Returns the guaranteed-valid hostname to use.
|
||||
func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) string {
|
||||
if hostinfo == nil || hostinfo.Hostname == "" {
|
||||
func EnsureHostname(hostinfo tailcfg.HostinfoView, machineKey, nodeKey string) string {
|
||||
if !hostinfo.Valid() || hostinfo.Hostname() == "" {
|
||||
key := cmp.Or(machineKey, nodeKey)
|
||||
if key == "" {
|
||||
return "unknown-node"
|
||||
@ -310,7 +310,7 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri
|
||||
return "node-" + keyPrefix
|
||||
}
|
||||
|
||||
lowercased := strings.ToLower(hostinfo.Hostname)
|
||||
lowercased := strings.ToLower(hostinfo.Hostname())
|
||||
|
||||
err := ValidateHostname(lowercased)
|
||||
if err == nil {
|
||||
|
||||
@ -1070,7 +1070,7 @@ func TestEnsureHostname(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
|
||||
got := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey)
|
||||
// For invalid hostnames, we just check the prefix since the random part varies
|
||||
if strings.HasPrefix(tt.want, "invalid-") {
|
||||
if !strings.HasPrefix(got, "invalid-") {
|
||||
@ -1255,7 +1255,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
|
||||
gotHostname := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey)
|
||||
// For invalid hostnames, we just check the prefix since the random part varies
|
||||
if strings.HasPrefix(tt.wantHostname, "invalid-") {
|
||||
if !strings.HasPrefix(gotHostname, "invalid-") {
|
||||
@ -1284,7 +1284,7 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) {
|
||||
|
||||
hostinfo := &tailcfg.Hostinfo{Hostname: hostname}
|
||||
|
||||
result := EnsureHostname(hostinfo, "mkey", "nkey")
|
||||
result := EnsureHostname(hostinfo.View(), "mkey", "nkey")
|
||||
if len(result) > 63 {
|
||||
t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result))
|
||||
}
|
||||
@ -1300,8 +1300,8 @@ func TestEnsureHostname_Idempotent(t *testing.T) {
|
||||
OS: "linux",
|
||||
}
|
||||
|
||||
hostname1 := EnsureHostname(originalHostinfo, "mkey", "nkey")
|
||||
hostname2 := EnsureHostname(originalHostinfo, "mkey", "nkey")
|
||||
hostname1 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey")
|
||||
hostname2 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey")
|
||||
|
||||
if hostname1 != hostname2 {
|
||||
t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2)
|
||||
|
||||
@ -1065,11 +1065,11 @@ func TestNodeCommand(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
regIDs := []string{
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustAuthID().String(),
|
||||
types.MustAuthID().String(),
|
||||
types.MustAuthID().String(),
|
||||
types.MustAuthID().String(),
|
||||
types.MustAuthID().String(),
|
||||
}
|
||||
nodes := make([]*v1.Node, len(regIDs))
|
||||
|
||||
@ -1153,8 +1153,8 @@ func TestNodeCommand(t *testing.T) {
|
||||
assert.Equal(t, "node-5", listAll[4].GetName())
|
||||
|
||||
otherUserRegIDs := []string{
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustAuthID().String(),
|
||||
types.MustAuthID().String(),
|
||||
}
|
||||
otherUserMachines := make([]*v1.Node, len(otherUserRegIDs))
|
||||
|
||||
@ -1326,11 +1326,11 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
regIDs := []string{
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustAuthID().String(),
|
||||
types.MustAuthID().String(),
|
||||
types.MustAuthID().String(),
|
||||
types.MustAuthID().String(),
|
||||
types.MustAuthID().String(),
|
||||
}
|
||||
nodes := make([]*v1.Node, len(regIDs))
|
||||
|
||||
@ -1461,11 +1461,11 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
regIDs := []string{
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustRegistrationID().String(),
|
||||
types.MustAuthID().String(),
|
||||
types.MustAuthID().String(),
|
||||
types.MustAuthID().String(),
|
||||
types.MustAuthID().String(),
|
||||
types.MustAuthID().String(),
|
||||
}
|
||||
nodes := make([]*v1.Node, len(regIDs))
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user