mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	Merge pull request #725 from juanfont/switch-to-db-d
Improve registration protocol implementation and switch to NodeKey as main identifier
This commit is contained in:
		
						commit
						09cd7ba304
					
				| @ -4,6 +4,7 @@ | ||||
| 
 | ||||
| - Updated dependencies (including the library that lacked armhf support) [#722](https://github.com/juanfont/headscale/pull/722) | ||||
| - Fix missing group expansion in function `excludeCorretlyTaggedNodes` [#563](https://github.com/juanfont/headscale/issues/563) | ||||
| - Improve registration protocol implementation and switch to NodeKey as main identifier [#725](https://github.com/juanfont/headscale/pull/725) | ||||
| 
 | ||||
| ## 0.16.0 (2022-07-25) | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										62
									
								
								api.go
									
									
									
									
									
								
							
							
						
						
									
										62
									
								
								api.go
									
									
									
									
									
								
							| @ -21,6 +21,8 @@ import ( | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	// TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed.
 | ||||
| 	registrationHoldoff                      = time.Second * 5 | ||||
| 	reservedResponseHeaderSize               = 4 | ||||
| 	RegisterMethodAuthKey                    = "authkey" | ||||
| 	RegisterMethodOIDC                       = "oidc" | ||||
| @ -107,13 +109,17 @@ var registerWebAPITemplate = template.Must( | ||||
| `)) | ||||
| 
 | ||||
| // RegisterWebAPI shows a simple message in the browser to point to the CLI
 | ||||
| // Listens in /register.
 | ||||
| // Listens in /register/:nkey.
 | ||||
| //
 | ||||
| // This is not part of the Tailscale control API, as we could send whatever URL
 | ||||
| // in the RegisterResponse.AuthURL field.
 | ||||
| func (h *Headscale) RegisterWebAPI( | ||||
| 	writer http.ResponseWriter, | ||||
| 	req *http.Request, | ||||
| ) { | ||||
| 	machineKeyStr := req.URL.Query().Get("key") | ||||
| 	if machineKeyStr == "" { | ||||
| 	vars := mux.Vars(req) | ||||
| 	nodeKeyStr, ok := vars["nkey"] | ||||
| 	if !ok || nodeKeyStr == "" { | ||||
| 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||
| 		writer.WriteHeader(http.StatusBadRequest) | ||||
| 		_, err := writer.Write([]byte("Wrong params")) | ||||
| @ -129,7 +135,7 @@ func (h *Headscale) RegisterWebAPI( | ||||
| 
 | ||||
| 	var content bytes.Buffer | ||||
| 	if err := registerWebAPITemplate.Execute(&content, registerWebAPITemplateConfig{ | ||||
| 		Key: machineKeyStr, | ||||
| 		Key: nodeKeyStr, | ||||
| 	}); err != nil { | ||||
| 		log.Error(). | ||||
| 			Str("func", "RegisterWebAPI"). | ||||
| @ -206,8 +212,6 @@ func (h *Headscale) RegistrationHandler( | ||||
| 	now := time.Now().UTC() | ||||
| 	machine, err := h.GetMachineByMachineKey(machineKey) | ||||
| 	if errors.Is(err, gorm.ErrRecordNotFound) { | ||||
| 		log.Info().Str("machine", registerRequest.Hostinfo.Hostname).Msg("New machine") | ||||
| 
 | ||||
| 		machineKeyStr := MachinePublicKeyStripPrefix(machineKey) | ||||
| 
 | ||||
| 		// If the machine has AuthKey set, handle registration via PreAuthKeys
 | ||||
| @ -217,6 +221,44 @@ func (h *Headscale) RegistrationHandler( | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		// Check if the node is waiting for interactive login.
 | ||||
| 		//
 | ||||
| 		// TODO(juan): We could use this field to improve our protocol implementation,
 | ||||
| 		// and hold the request until the client closes it, or the interactive
 | ||||
| 		// login is completed (i.e., the user registers the machine).
 | ||||
| 		// This is not implemented yet, as it is no strictly required. The only side-effect
 | ||||
| 		// is that the client will hammer headscale with requests until it gets a
 | ||||
| 		// successful RegisterResponse.
 | ||||
| 		if registerRequest.Followup != "" { | ||||
| 			if _, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok { | ||||
| 				log.Debug(). | ||||
| 					Caller(). | ||||
| 					Str("machine", registerRequest.Hostinfo.Hostname). | ||||
| 					Str("node_key", registerRequest.NodeKey.ShortString()). | ||||
| 					Str("node_key_old", registerRequest.OldNodeKey.ShortString()). | ||||
| 					Str("follow_up", registerRequest.Followup). | ||||
| 					Msg("Machine is waiting for interactive login") | ||||
| 
 | ||||
| 				ticker := time.NewTicker(registrationHoldoff) | ||||
| 				select { | ||||
| 				case <-req.Context().Done(): | ||||
| 					return | ||||
| 				case <-ticker.C: | ||||
| 					h.handleMachineRegistrationNew(writer, req, machineKey, registerRequest) | ||||
| 
 | ||||
| 					return | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		log.Info(). | ||||
| 			Caller(). | ||||
| 			Str("machine", registerRequest.Hostinfo.Hostname). | ||||
| 			Str("node_key", registerRequest.NodeKey.ShortString()). | ||||
| 			Str("node_key_old", registerRequest.OldNodeKey.ShortString()). | ||||
| 			Str("follow_up", registerRequest.Followup). | ||||
| 			Msg("New machine not yet in the database") | ||||
| 
 | ||||
| 		givenName, err := h.GenerateGivenName(registerRequest.Hostinfo.Hostname) | ||||
| 		if err != nil { | ||||
| 			log.Error(). | ||||
| @ -251,7 +293,7 @@ func (h *Headscale) RegistrationHandler( | ||||
| 		} | ||||
| 
 | ||||
| 		h.registrationCache.Set( | ||||
| 			machineKeyStr, | ||||
| 			newMachine.NodeKey, | ||||
| 			newMachine, | ||||
| 			registerCacheExpiration, | ||||
| 		) | ||||
| @ -652,7 +694,7 @@ func (h *Headscale) handleMachineRegistrationNew( | ||||
| 	// The machine registration is new, redirect the client to the registration URL
 | ||||
| 	log.Debug(). | ||||
| 		Str("machine", registerRequest.Hostinfo.Hostname). | ||||
| 		Msg("The node is sending us a new NodeKey, sending auth url") | ||||
| 		Msg("The node seems to be new, sending auth url") | ||||
| 	if h.cfg.OIDC.Issuer != "" { | ||||
| 		resp.AuthURL = fmt.Sprintf( | ||||
| 			"%s/oidc/register/%s", | ||||
| @ -660,8 +702,8 @@ func (h *Headscale) handleMachineRegistrationNew( | ||||
| 			machineKey.String(), | ||||
| 		) | ||||
| 	} else { | ||||
| 		resp.AuthURL = fmt.Sprintf("%s/register?key=%s", | ||||
| 			strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey)) | ||||
| 		resp.AuthURL = fmt.Sprintf("%s/register/%s", | ||||
| 			strings.TrimSuffix(h.cfg.ServerURL, "/"), NodePublicKeyStripPrefix(registerRequest.NodeKey)) | ||||
| 	} | ||||
| 
 | ||||
| 	respBody, err := encode(resp, &machineKey, h.privateKey) | ||||
|  | ||||
							
								
								
									
										16
									
								
								app.go
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								app.go
									
									
									
									
									
								
							| @ -417,21 +417,17 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router { | ||||
| 
 | ||||
| 	router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet) | ||||
| 	router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet) | ||||
| 	router.HandleFunc("/register", h.RegisterWebAPI).Methods(http.MethodGet) | ||||
| 	router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler). | ||||
| 		Methods(http.MethodPost) | ||||
| 	router.HandleFunc("/register/{nkey}", h.RegisterWebAPI).Methods(http.MethodGet) | ||||
| 	router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler).Methods(http.MethodPost) | ||||
| 	router.HandleFunc("/machine/{mkey}", h.RegistrationHandler).Methods(http.MethodPost) | ||||
| 	router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet) | ||||
| 	router.HandleFunc("/oidc/register/{nkey}", h.RegisterOIDC).Methods(http.MethodGet) | ||||
| 	router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet) | ||||
| 	router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet) | ||||
| 	router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig). | ||||
| 		Methods(http.MethodGet) | ||||
| 	router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).Methods(http.MethodGet) | ||||
| 	router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet) | ||||
| 	router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig). | ||||
| 		Methods(http.MethodGet) | ||||
| 	router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig).Methods(http.MethodGet) | ||||
| 	router.HandleFunc("/swagger", SwaggerUI).Methods(http.MethodGet) | ||||
| 	router.HandleFunc("/swagger/v1/openapiv2.json", SwaggerAPIv1). | ||||
| 		Methods(http.MethodGet) | ||||
| 	router.HandleFunc("/swagger/v1/openapiv2.json", SwaggerAPIv1).Methods(http.MethodGet) | ||||
| 
 | ||||
| 	if h.cfg.DERP.ServerEnabled { | ||||
| 		router.HandleFunc("/derp", h.DERPHandler) | ||||
|  | ||||
| @ -108,7 +108,7 @@ var registerNodeCmd = &cobra.Command{ | ||||
| 		if err != nil { | ||||
| 			ErrorOutput( | ||||
| 				err, | ||||
| 				fmt.Sprintf("Error getting machine key from flag: %s", err), | ||||
| 				fmt.Sprintf("Error getting node key from flag: %s", err), | ||||
| 				output, | ||||
| 			) | ||||
| 
 | ||||
|  | ||||
| @ -159,7 +159,7 @@ func (api headscaleV1APIServer) RegisterMachine( | ||||
| ) (*v1.RegisterMachineResponse, error) { | ||||
| 	log.Trace(). | ||||
| 		Str("namespace", request.GetNamespace()). | ||||
| 		Str("machine_key", request.GetKey()). | ||||
| 		Str("node_key", request.GetKey()). | ||||
| 		Msg("Registering machine") | ||||
| 
 | ||||
| 	machine, err := api.h.RegisterMachineFromAuthCallback( | ||||
|  | ||||
							
								
								
									
										19
									
								
								machine.go
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								machine.go
									
									
									
									
									
								
							| @ -350,7 +350,7 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { | ||||
| 	return &m, nil | ||||
| } | ||||
| 
 | ||||
| // GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
 | ||||
| // GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct.
 | ||||
| func (h *Headscale) GetMachineByMachineKey( | ||||
| 	machineKey key.MachinePublic, | ||||
| ) (*Machine, error) { | ||||
| @ -362,6 +362,19 @@ func (h *Headscale) GetMachineByMachineKey( | ||||
| 	return &m, nil | ||||
| } | ||||
| 
 | ||||
| // GetMachineByNodeKey finds a Machine by its current NodeKey.
 | ||||
| func (h *Headscale) GetMachineByNodeKey( | ||||
| 	nodeKey key.NodePublic, | ||||
| ) (*Machine, error) { | ||||
| 	machine := Machine{} | ||||
| 	if result := h.db.Preload("Namespace").First(&machine, "node_key = ?", | ||||
| 		NodePublicKeyStripPrefix(nodeKey)); result.Error != nil { | ||||
| 		return nil, result.Error | ||||
| 	} | ||||
| 
 | ||||
| 	return &machine, nil | ||||
| } | ||||
| 
 | ||||
| // UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database
 | ||||
| // and updates it with the latest data from the database.
 | ||||
| func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error { | ||||
| @ -762,11 +775,11 @@ func getTags( | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) RegisterMachineFromAuthCallback( | ||||
| 	machineKeyStr string, | ||||
| 	nodeKeyStr string, | ||||
| 	namespaceName string, | ||||
| 	registrationMethod string, | ||||
| ) (*Machine, error) { | ||||
| 	if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok { | ||||
| 	if machineInterface, ok := h.registrationCache.Get(nodeKeyStr); ok { | ||||
| 		if registrationMachine, ok := machineInterface.(Machine); ok { | ||||
| 			namespace, err := h.GetNamespace(namespaceName) | ||||
| 			if err != nil { | ||||
|  | ||||
							
								
								
									
										56
									
								
								oidc.go
									
									
									
									
									
								
							
							
						
						
									
										56
									
								
								oidc.go
									
									
									
									
									
								
							| @ -27,7 +27,7 @@ const ( | ||||
| 	errOIDCAllowedDomains      = Error("authenticated principal does not match any allowed domain") | ||||
| 	errOIDCAllowedUsers        = Error("authenticated principal does not match any allowed user") | ||||
| 	errOIDCInvalidMachineState = Error("requested machine state key expired before authorisation completed") | ||||
| 	errOIDCMachineKeyMissing   = Error("could not get machine key from cache") | ||||
| 	errOIDCNodeKeyMissing      = Error("could not get node key from cache") | ||||
| ) | ||||
| 
 | ||||
| type IDTokenClaims struct { | ||||
| @ -68,26 +68,26 @@ func (h *Headscale) initOIDC() error { | ||||
| } | ||||
| 
 | ||||
| // RegisterOIDC redirects to the OIDC provider for authentication
 | ||||
| // Puts machine key in cache so the callback can retrieve it using the oidc state param
 | ||||
| // Listens in /oidc/register/:mKey.
 | ||||
| // Puts NodeKey in cache so the callback can retrieve it using the oidc state param
 | ||||
| // Listens in /oidc/register/:nKey.
 | ||||
| func (h *Headscale) RegisterOIDC( | ||||
| 	writer http.ResponseWriter, | ||||
| 	req *http.Request, | ||||
| ) { | ||||
| 	vars := mux.Vars(req) | ||||
| 	machineKeyStr, ok := vars["mkey"] | ||||
| 	if !ok || machineKeyStr == "" { | ||||
| 	nodeKeyStr, ok := vars["nkey"] | ||||
| 	if !ok || nodeKeyStr == "" { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
| 			Msg("Missing machine key in URL") | ||||
| 		http.Error(writer, "Missing machine key in URL", http.StatusBadRequest) | ||||
| 			Msg("Missing node key in URL") | ||||
| 		http.Error(writer, "Missing node key in URL", http.StatusBadRequest) | ||||
| 
 | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	log.Trace(). | ||||
| 		Caller(). | ||||
| 		Str("machine_key", machineKeyStr). | ||||
| 		Str("node_key", nodeKeyStr). | ||||
| 		Msg("Received oidc register call") | ||||
| 
 | ||||
| 	randomBlob := make([]byte, randomByteSize) | ||||
| @ -102,8 +102,8 @@ func (h *Headscale) RegisterOIDC( | ||||
| 
 | ||||
| 	stateStr := hex.EncodeToString(randomBlob)[:32] | ||||
| 
 | ||||
| 	// place the machine key into the state cache, so it can be retrieved later
 | ||||
| 	h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration) | ||||
| 	// place the node key into the state cache, so it can be retrieved later
 | ||||
| 	h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration) | ||||
| 
 | ||||
| 	// Add any extra parameter provided in the configuration to the Authorize Endpoint request
 | ||||
| 	extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)) | ||||
| @ -135,7 +135,7 @@ var oidcCallbackTemplate = template.Must( | ||||
| ) | ||||
| 
 | ||||
| // OIDCCallback handles the callback from the OIDC endpoint
 | ||||
| // Retrieves the mkey from the state cache and adds the machine to the users email namespace
 | ||||
| // Retrieves the nkey from the state cache and adds the machine to the users email namespace
 | ||||
| // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
 | ||||
| // TODO: Add groups information from OIDC tokens into machine HostInfo
 | ||||
| // Listens in /oidc/callback.
 | ||||
| @ -178,7 +178,7 @@ func (h *Headscale) OIDCCallback( | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	machineKey, machineExists, err := h.validateMachineForOIDCCallback(writer, state, claims) | ||||
| 	nodeKey, machineExists, err := h.validateMachineForOIDCCallback(writer, state, claims) | ||||
| 	if err != nil || machineExists { | ||||
| 		return | ||||
| 	} | ||||
| @ -196,7 +196,7 @@ func (h *Headscale) OIDCCallback( | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if err := h.registerMachineForOIDCCallback(writer, namespace, machineKey); err != nil { | ||||
| 	if err := h.registerMachineForOIDCCallback(writer, namespace, nodeKey); err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| @ -401,7 +401,7 @@ func (h *Headscale) validateMachineForOIDCCallback( | ||||
| 	writer http.ResponseWriter, | ||||
| 	state string, | ||||
| 	claims *IDTokenClaims, | ||||
| ) (*key.MachinePublic, bool, error) { | ||||
| ) (*key.NodePublic, bool, error) { | ||||
| 	// retrieve machinekey from state cache
 | ||||
| 	machineKeyIf, machineKeyFound := h.registrationCache.Get(state) | ||||
| 	if !machineKeyFound { | ||||
| @ -420,14 +420,14 @@ func (h *Headscale) validateMachineForOIDCCallback( | ||||
| 		return nil, false, errOIDCInvalidMachineState | ||||
| 	} | ||||
| 
 | ||||
| 	var machineKey key.MachinePublic | ||||
| 	machineKeyFromCache, machineKeyOK := machineKeyIf.(string) | ||||
| 	err := machineKey.UnmarshalText( | ||||
| 		[]byte(MachinePublicKeyEnsurePrefix(machineKeyFromCache)), | ||||
| 	var nodeKey key.NodePublic | ||||
| 	nodeKeyFromCache, nodeKeyOK := machineKeyIf.(string) | ||||
| 	err := nodeKey.UnmarshalText( | ||||
| 		[]byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)), | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Msg("could not parse machine public key") | ||||
| 			Msg("could not parse node public key") | ||||
| 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||
| 		writer.WriteHeader(http.StatusBadRequest) | ||||
| 		_, werr := writer.Write([]byte("could not parse public key")) | ||||
| @ -441,11 +441,11 @@ func (h *Headscale) validateMachineForOIDCCallback( | ||||
| 		return nil, false, err | ||||
| 	} | ||||
| 
 | ||||
| 	if !machineKeyOK { | ||||
| 		log.Error().Msg("could not get machine key from cache") | ||||
| 	if !nodeKeyOK { | ||||
| 		log.Error().Msg("could not get node key from cache") | ||||
| 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||
| 		writer.WriteHeader(http.StatusInternalServerError) | ||||
| 		_, err := writer.Write([]byte("could not get machine key from cache")) | ||||
| 		_, err := writer.Write([]byte("could not get node key from cache")) | ||||
| 		if err != nil { | ||||
| 			log.Error(). | ||||
| 				Caller(). | ||||
| @ -453,14 +453,14 @@ func (h *Headscale) validateMachineForOIDCCallback( | ||||
| 				Msg("Failed to write response") | ||||
| 		} | ||||
| 
 | ||||
| 		return nil, false, errOIDCMachineKeyMissing | ||||
| 		return nil, false, errOIDCNodeKeyMissing | ||||
| 	} | ||||
| 
 | ||||
| 	// retrieve machine information if it exist
 | ||||
| 	// The error is not important, because if it does not
 | ||||
| 	// exist, then this is a new machine and we will move
 | ||||
| 	// on to registration.
 | ||||
| 	machine, _ := h.GetMachineByMachineKey(machineKey) | ||||
| 	machine, _ := h.GetMachineByNodeKey(nodeKey) | ||||
| 
 | ||||
| 	if machine != nil { | ||||
| 		log.Trace(). | ||||
| @ -520,7 +520,7 @@ func (h *Headscale) validateMachineForOIDCCallback( | ||||
| 		return nil, true, nil | ||||
| 	} | ||||
| 
 | ||||
| 	return &machineKey, false, nil | ||||
| 	return &nodeKey, false, nil | ||||
| } | ||||
| 
 | ||||
| func getNamespaceName( | ||||
| @ -600,12 +600,12 @@ func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback( | ||||
| func (h *Headscale) registerMachineForOIDCCallback( | ||||
| 	writer http.ResponseWriter, | ||||
| 	namespace *Namespace, | ||||
| 	machineKey *key.MachinePublic, | ||||
| 	nodeKey *key.NodePublic, | ||||
| ) error { | ||||
| 	machineKeyStr := MachinePublicKeyStripPrefix(*machineKey) | ||||
| 	nodeKeyStr := NodePublicKeyStripPrefix(*nodeKey) | ||||
| 
 | ||||
| 	if _, err := h.RegisterMachineFromAuthCallback( | ||||
| 		machineKeyStr, | ||||
| 		nodeKeyStr, | ||||
| 		namespace.Name, | ||||
| 		RegisterMethodOIDC, | ||||
| 	); err != nil { | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user