mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	Merge pull request #725 from juanfont/switch-to-db-d
Improve registration protocol implementation and switch to NodeKey as main identifier
This commit is contained in:
		
						commit
						09cd7ba304
					
				| @ -4,6 +4,7 @@ | |||||||
| 
 | 
 | ||||||
| - Updated dependencies (including the library that lacked armhf support) [#722](https://github.com/juanfont/headscale/pull/722) | - Updated dependencies (including the library that lacked armhf support) [#722](https://github.com/juanfont/headscale/pull/722) | ||||||
| - Fix missing group expansion in function `excludeCorretlyTaggedNodes` [#563](https://github.com/juanfont/headscale/issues/563) | - Fix missing group expansion in function `excludeCorretlyTaggedNodes` [#563](https://github.com/juanfont/headscale/issues/563) | ||||||
|  | - Improve registration protocol implementation and switch to NodeKey as main identifier [#725](https://github.com/juanfont/headscale/pull/725) | ||||||
| 
 | 
 | ||||||
| ## 0.16.0 (2022-07-25) | ## 0.16.0 (2022-07-25) | ||||||
| 
 | 
 | ||||||
|  | |||||||
							
								
								
									
										62
									
								
								api.go
									
									
									
									
									
								
							
							
						
						
									
										62
									
								
								api.go
									
									
									
									
									
								
							| @ -21,6 +21,8 @@ import ( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
|  | 	// TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed.
 | ||||||
|  | 	registrationHoldoff                      = time.Second * 5 | ||||||
| 	reservedResponseHeaderSize               = 4 | 	reservedResponseHeaderSize               = 4 | ||||||
| 	RegisterMethodAuthKey                    = "authkey" | 	RegisterMethodAuthKey                    = "authkey" | ||||||
| 	RegisterMethodOIDC                       = "oidc" | 	RegisterMethodOIDC                       = "oidc" | ||||||
| @ -107,13 +109,17 @@ var registerWebAPITemplate = template.Must( | |||||||
| `)) | `)) | ||||||
| 
 | 
 | ||||||
| // RegisterWebAPI shows a simple message in the browser to point to the CLI
 | // RegisterWebAPI shows a simple message in the browser to point to the CLI
 | ||||||
| // Listens in /register.
 | // Listens in /register/:nkey.
 | ||||||
|  | //
 | ||||||
|  | // This is not part of the Tailscale control API, as we could send whatever URL
 | ||||||
|  | // in the RegisterResponse.AuthURL field.
 | ||||||
| func (h *Headscale) RegisterWebAPI( | func (h *Headscale) RegisterWebAPI( | ||||||
| 	writer http.ResponseWriter, | 	writer http.ResponseWriter, | ||||||
| 	req *http.Request, | 	req *http.Request, | ||||||
| ) { | ) { | ||||||
| 	machineKeyStr := req.URL.Query().Get("key") | 	vars := mux.Vars(req) | ||||||
| 	if machineKeyStr == "" { | 	nodeKeyStr, ok := vars["nkey"] | ||||||
|  | 	if !ok || nodeKeyStr == "" { | ||||||
| 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||||
| 		writer.WriteHeader(http.StatusBadRequest) | 		writer.WriteHeader(http.StatusBadRequest) | ||||||
| 		_, err := writer.Write([]byte("Wrong params")) | 		_, err := writer.Write([]byte("Wrong params")) | ||||||
| @ -129,7 +135,7 @@ func (h *Headscale) RegisterWebAPI( | |||||||
| 
 | 
 | ||||||
| 	var content bytes.Buffer | 	var content bytes.Buffer | ||||||
| 	if err := registerWebAPITemplate.Execute(&content, registerWebAPITemplateConfig{ | 	if err := registerWebAPITemplate.Execute(&content, registerWebAPITemplateConfig{ | ||||||
| 		Key: machineKeyStr, | 		Key: nodeKeyStr, | ||||||
| 	}); err != nil { | 	}); err != nil { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| 			Str("func", "RegisterWebAPI"). | 			Str("func", "RegisterWebAPI"). | ||||||
| @ -206,8 +212,6 @@ func (h *Headscale) RegistrationHandler( | |||||||
| 	now := time.Now().UTC() | 	now := time.Now().UTC() | ||||||
| 	machine, err := h.GetMachineByMachineKey(machineKey) | 	machine, err := h.GetMachineByMachineKey(machineKey) | ||||||
| 	if errors.Is(err, gorm.ErrRecordNotFound) { | 	if errors.Is(err, gorm.ErrRecordNotFound) { | ||||||
| 		log.Info().Str("machine", registerRequest.Hostinfo.Hostname).Msg("New machine") |  | ||||||
| 
 |  | ||||||
| 		machineKeyStr := MachinePublicKeyStripPrefix(machineKey) | 		machineKeyStr := MachinePublicKeyStripPrefix(machineKey) | ||||||
| 
 | 
 | ||||||
| 		// If the machine has AuthKey set, handle registration via PreAuthKeys
 | 		// If the machine has AuthKey set, handle registration via PreAuthKeys
 | ||||||
| @ -217,6 +221,44 @@ func (h *Headscale) RegistrationHandler( | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		// Check if the node is waiting for interactive login.
 | ||||||
|  | 		//
 | ||||||
|  | 		// TODO(juan): We could use this field to improve our protocol implementation,
 | ||||||
|  | 		// and hold the request until the client closes it, or the interactive
 | ||||||
|  | 		// login is completed (i.e., the user registers the machine).
 | ||||||
|  | 		// This is not implemented yet, as it is no strictly required. The only side-effect
 | ||||||
|  | 		// is that the client will hammer headscale with requests until it gets a
 | ||||||
|  | 		// successful RegisterResponse.
 | ||||||
|  | 		if registerRequest.Followup != "" { | ||||||
|  | 			if _, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok { | ||||||
|  | 				log.Debug(). | ||||||
|  | 					Caller(). | ||||||
|  | 					Str("machine", registerRequest.Hostinfo.Hostname). | ||||||
|  | 					Str("node_key", registerRequest.NodeKey.ShortString()). | ||||||
|  | 					Str("node_key_old", registerRequest.OldNodeKey.ShortString()). | ||||||
|  | 					Str("follow_up", registerRequest.Followup). | ||||||
|  | 					Msg("Machine is waiting for interactive login") | ||||||
|  | 
 | ||||||
|  | 				ticker := time.NewTicker(registrationHoldoff) | ||||||
|  | 				select { | ||||||
|  | 				case <-req.Context().Done(): | ||||||
|  | 					return | ||||||
|  | 				case <-ticker.C: | ||||||
|  | 					h.handleMachineRegistrationNew(writer, req, machineKey, registerRequest) | ||||||
|  | 
 | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		log.Info(). | ||||||
|  | 			Caller(). | ||||||
|  | 			Str("machine", registerRequest.Hostinfo.Hostname). | ||||||
|  | 			Str("node_key", registerRequest.NodeKey.ShortString()). | ||||||
|  | 			Str("node_key_old", registerRequest.OldNodeKey.ShortString()). | ||||||
|  | 			Str("follow_up", registerRequest.Followup). | ||||||
|  | 			Msg("New machine not yet in the database") | ||||||
|  | 
 | ||||||
| 		givenName, err := h.GenerateGivenName(registerRequest.Hostinfo.Hostname) | 		givenName, err := h.GenerateGivenName(registerRequest.Hostinfo.Hostname) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error(). | 			log.Error(). | ||||||
| @ -251,7 +293,7 @@ func (h *Headscale) RegistrationHandler( | |||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		h.registrationCache.Set( | 		h.registrationCache.Set( | ||||||
| 			machineKeyStr, | 			newMachine.NodeKey, | ||||||
| 			newMachine, | 			newMachine, | ||||||
| 			registerCacheExpiration, | 			registerCacheExpiration, | ||||||
| 		) | 		) | ||||||
| @ -652,7 +694,7 @@ func (h *Headscale) handleMachineRegistrationNew( | |||||||
| 	// The machine registration is new, redirect the client to the registration URL
 | 	// The machine registration is new, redirect the client to the registration URL
 | ||||||
| 	log.Debug(). | 	log.Debug(). | ||||||
| 		Str("machine", registerRequest.Hostinfo.Hostname). | 		Str("machine", registerRequest.Hostinfo.Hostname). | ||||||
| 		Msg("The node is sending us a new NodeKey, sending auth url") | 		Msg("The node seems to be new, sending auth url") | ||||||
| 	if h.cfg.OIDC.Issuer != "" { | 	if h.cfg.OIDC.Issuer != "" { | ||||||
| 		resp.AuthURL = fmt.Sprintf( | 		resp.AuthURL = fmt.Sprintf( | ||||||
| 			"%s/oidc/register/%s", | 			"%s/oidc/register/%s", | ||||||
| @ -660,8 +702,8 @@ func (h *Headscale) handleMachineRegistrationNew( | |||||||
| 			machineKey.String(), | 			machineKey.String(), | ||||||
| 		) | 		) | ||||||
| 	} else { | 	} else { | ||||||
| 		resp.AuthURL = fmt.Sprintf("%s/register?key=%s", | 		resp.AuthURL = fmt.Sprintf("%s/register/%s", | ||||||
| 			strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey)) | 			strings.TrimSuffix(h.cfg.ServerURL, "/"), NodePublicKeyStripPrefix(registerRequest.NodeKey)) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	respBody, err := encode(resp, &machineKey, h.privateKey) | 	respBody, err := encode(resp, &machineKey, h.privateKey) | ||||||
|  | |||||||
							
								
								
									
										16
									
								
								app.go
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								app.go
									
									
									
									
									
								
							| @ -417,21 +417,17 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router { | |||||||
| 
 | 
 | ||||||
| 	router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet) | 	router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet) | ||||||
| 	router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet) | 	router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet) | ||||||
| 	router.HandleFunc("/register", h.RegisterWebAPI).Methods(http.MethodGet) | 	router.HandleFunc("/register/{nkey}", h.RegisterWebAPI).Methods(http.MethodGet) | ||||||
| 	router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler). | 	router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler).Methods(http.MethodPost) | ||||||
| 		Methods(http.MethodPost) |  | ||||||
| 	router.HandleFunc("/machine/{mkey}", h.RegistrationHandler).Methods(http.MethodPost) | 	router.HandleFunc("/machine/{mkey}", h.RegistrationHandler).Methods(http.MethodPost) | ||||||
| 	router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet) | 	router.HandleFunc("/oidc/register/{nkey}", h.RegisterOIDC).Methods(http.MethodGet) | ||||||
| 	router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet) | 	router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet) | ||||||
| 	router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet) | 	router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet) | ||||||
| 	router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig). | 	router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).Methods(http.MethodGet) | ||||||
| 		Methods(http.MethodGet) |  | ||||||
| 	router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet) | 	router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet) | ||||||
| 	router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig). | 	router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig).Methods(http.MethodGet) | ||||||
| 		Methods(http.MethodGet) |  | ||||||
| 	router.HandleFunc("/swagger", SwaggerUI).Methods(http.MethodGet) | 	router.HandleFunc("/swagger", SwaggerUI).Methods(http.MethodGet) | ||||||
| 	router.HandleFunc("/swagger/v1/openapiv2.json", SwaggerAPIv1). | 	router.HandleFunc("/swagger/v1/openapiv2.json", SwaggerAPIv1).Methods(http.MethodGet) | ||||||
| 		Methods(http.MethodGet) |  | ||||||
| 
 | 
 | ||||||
| 	if h.cfg.DERP.ServerEnabled { | 	if h.cfg.DERP.ServerEnabled { | ||||||
| 		router.HandleFunc("/derp", h.DERPHandler) | 		router.HandleFunc("/derp", h.DERPHandler) | ||||||
|  | |||||||
| @ -108,7 +108,7 @@ var registerNodeCmd = &cobra.Command{ | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			ErrorOutput( | 			ErrorOutput( | ||||||
| 				err, | 				err, | ||||||
| 				fmt.Sprintf("Error getting machine key from flag: %s", err), | 				fmt.Sprintf("Error getting node key from flag: %s", err), | ||||||
| 				output, | 				output, | ||||||
| 			) | 			) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -159,7 +159,7 @@ func (api headscaleV1APIServer) RegisterMachine( | |||||||
| ) (*v1.RegisterMachineResponse, error) { | ) (*v1.RegisterMachineResponse, error) { | ||||||
| 	log.Trace(). | 	log.Trace(). | ||||||
| 		Str("namespace", request.GetNamespace()). | 		Str("namespace", request.GetNamespace()). | ||||||
| 		Str("machine_key", request.GetKey()). | 		Str("node_key", request.GetKey()). | ||||||
| 		Msg("Registering machine") | 		Msg("Registering machine") | ||||||
| 
 | 
 | ||||||
| 	machine, err := api.h.RegisterMachineFromAuthCallback( | 	machine, err := api.h.RegisterMachineFromAuthCallback( | ||||||
|  | |||||||
							
								
								
									
										19
									
								
								machine.go
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								machine.go
									
									
									
									
									
								
							| @ -350,7 +350,7 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { | |||||||
| 	return &m, nil | 	return &m, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
 | // GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct.
 | ||||||
| func (h *Headscale) GetMachineByMachineKey( | func (h *Headscale) GetMachineByMachineKey( | ||||||
| 	machineKey key.MachinePublic, | 	machineKey key.MachinePublic, | ||||||
| ) (*Machine, error) { | ) (*Machine, error) { | ||||||
| @ -362,6 +362,19 @@ func (h *Headscale) GetMachineByMachineKey( | |||||||
| 	return &m, nil | 	return &m, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // GetMachineByNodeKey finds a Machine by its current NodeKey.
 | ||||||
|  | func (h *Headscale) GetMachineByNodeKey( | ||||||
|  | 	nodeKey key.NodePublic, | ||||||
|  | ) (*Machine, error) { | ||||||
|  | 	machine := Machine{} | ||||||
|  | 	if result := h.db.Preload("Namespace").First(&machine, "node_key = ?", | ||||||
|  | 		NodePublicKeyStripPrefix(nodeKey)); result.Error != nil { | ||||||
|  | 		return nil, result.Error | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return &machine, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database
 | // UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database
 | ||||||
| // and updates it with the latest data from the database.
 | // and updates it with the latest data from the database.
 | ||||||
| func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error { | func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error { | ||||||
| @ -762,11 +775,11 @@ func getTags( | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (h *Headscale) RegisterMachineFromAuthCallback( | func (h *Headscale) RegisterMachineFromAuthCallback( | ||||||
| 	machineKeyStr string, | 	nodeKeyStr string, | ||||||
| 	namespaceName string, | 	namespaceName string, | ||||||
| 	registrationMethod string, | 	registrationMethod string, | ||||||
| ) (*Machine, error) { | ) (*Machine, error) { | ||||||
| 	if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok { | 	if machineInterface, ok := h.registrationCache.Get(nodeKeyStr); ok { | ||||||
| 		if registrationMachine, ok := machineInterface.(Machine); ok { | 		if registrationMachine, ok := machineInterface.(Machine); ok { | ||||||
| 			namespace, err := h.GetNamespace(namespaceName) | 			namespace, err := h.GetNamespace(namespaceName) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
|  | |||||||
							
								
								
									
										56
									
								
								oidc.go
									
									
									
									
									
								
							
							
						
						
									
										56
									
								
								oidc.go
									
									
									
									
									
								
							| @ -27,7 +27,7 @@ const ( | |||||||
| 	errOIDCAllowedDomains      = Error("authenticated principal does not match any allowed domain") | 	errOIDCAllowedDomains      = Error("authenticated principal does not match any allowed domain") | ||||||
| 	errOIDCAllowedUsers        = Error("authenticated principal does not match any allowed user") | 	errOIDCAllowedUsers        = Error("authenticated principal does not match any allowed user") | ||||||
| 	errOIDCInvalidMachineState = Error("requested machine state key expired before authorisation completed") | 	errOIDCInvalidMachineState = Error("requested machine state key expired before authorisation completed") | ||||||
| 	errOIDCMachineKeyMissing   = Error("could not get machine key from cache") | 	errOIDCNodeKeyMissing      = Error("could not get node key from cache") | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type IDTokenClaims struct { | type IDTokenClaims struct { | ||||||
| @ -68,26 +68,26 @@ func (h *Headscale) initOIDC() error { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RegisterOIDC redirects to the OIDC provider for authentication
 | // RegisterOIDC redirects to the OIDC provider for authentication
 | ||||||
| // Puts machine key in cache so the callback can retrieve it using the oidc state param
 | // Puts NodeKey in cache so the callback can retrieve it using the oidc state param
 | ||||||
| // Listens in /oidc/register/:mKey.
 | // Listens in /oidc/register/:nKey.
 | ||||||
| func (h *Headscale) RegisterOIDC( | func (h *Headscale) RegisterOIDC( | ||||||
| 	writer http.ResponseWriter, | 	writer http.ResponseWriter, | ||||||
| 	req *http.Request, | 	req *http.Request, | ||||||
| ) { | ) { | ||||||
| 	vars := mux.Vars(req) | 	vars := mux.Vars(req) | ||||||
| 	machineKeyStr, ok := vars["mkey"] | 	nodeKeyStr, ok := vars["nkey"] | ||||||
| 	if !ok || machineKeyStr == "" { | 	if !ok || nodeKeyStr == "" { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| 			Caller(). | 			Caller(). | ||||||
| 			Msg("Missing machine key in URL") | 			Msg("Missing node key in URL") | ||||||
| 		http.Error(writer, "Missing machine key in URL", http.StatusBadRequest) | 		http.Error(writer, "Missing node key in URL", http.StatusBadRequest) | ||||||
| 
 | 
 | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	log.Trace(). | 	log.Trace(). | ||||||
| 		Caller(). | 		Caller(). | ||||||
| 		Str("machine_key", machineKeyStr). | 		Str("node_key", nodeKeyStr). | ||||||
| 		Msg("Received oidc register call") | 		Msg("Received oidc register call") | ||||||
| 
 | 
 | ||||||
| 	randomBlob := make([]byte, randomByteSize) | 	randomBlob := make([]byte, randomByteSize) | ||||||
| @ -102,8 +102,8 @@ func (h *Headscale) RegisterOIDC( | |||||||
| 
 | 
 | ||||||
| 	stateStr := hex.EncodeToString(randomBlob)[:32] | 	stateStr := hex.EncodeToString(randomBlob)[:32] | ||||||
| 
 | 
 | ||||||
| 	// place the machine key into the state cache, so it can be retrieved later
 | 	// place the node key into the state cache, so it can be retrieved later
 | ||||||
| 	h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration) | 	h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration) | ||||||
| 
 | 
 | ||||||
| 	// Add any extra parameter provided in the configuration to the Authorize Endpoint request
 | 	// Add any extra parameter provided in the configuration to the Authorize Endpoint request
 | ||||||
| 	extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)) | 	extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)) | ||||||
| @ -135,7 +135,7 @@ var oidcCallbackTemplate = template.Must( | |||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // OIDCCallback handles the callback from the OIDC endpoint
 | // OIDCCallback handles the callback from the OIDC endpoint
 | ||||||
| // Retrieves the mkey from the state cache and adds the machine to the users email namespace
 | // Retrieves the nkey from the state cache and adds the machine to the users email namespace
 | ||||||
| // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
 | // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
 | ||||||
| // TODO: Add groups information from OIDC tokens into machine HostInfo
 | // TODO: Add groups information from OIDC tokens into machine HostInfo
 | ||||||
| // Listens in /oidc/callback.
 | // Listens in /oidc/callback.
 | ||||||
| @ -178,7 +178,7 @@ func (h *Headscale) OIDCCallback( | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	machineKey, machineExists, err := h.validateMachineForOIDCCallback(writer, state, claims) | 	nodeKey, machineExists, err := h.validateMachineForOIDCCallback(writer, state, claims) | ||||||
| 	if err != nil || machineExists { | 	if err != nil || machineExists { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @ -196,7 +196,7 @@ func (h *Headscale) OIDCCallback( | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err := h.registerMachineForOIDCCallback(writer, namespace, machineKey); err != nil { | 	if err := h.registerMachineForOIDCCallback(writer, namespace, nodeKey); err != nil { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -401,7 +401,7 @@ func (h *Headscale) validateMachineForOIDCCallback( | |||||||
| 	writer http.ResponseWriter, | 	writer http.ResponseWriter, | ||||||
| 	state string, | 	state string, | ||||||
| 	claims *IDTokenClaims, | 	claims *IDTokenClaims, | ||||||
| ) (*key.MachinePublic, bool, error) { | ) (*key.NodePublic, bool, error) { | ||||||
| 	// retrieve machinekey from state cache
 | 	// retrieve machinekey from state cache
 | ||||||
| 	machineKeyIf, machineKeyFound := h.registrationCache.Get(state) | 	machineKeyIf, machineKeyFound := h.registrationCache.Get(state) | ||||||
| 	if !machineKeyFound { | 	if !machineKeyFound { | ||||||
| @ -420,14 +420,14 @@ func (h *Headscale) validateMachineForOIDCCallback( | |||||||
| 		return nil, false, errOIDCInvalidMachineState | 		return nil, false, errOIDCInvalidMachineState | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var machineKey key.MachinePublic | 	var nodeKey key.NodePublic | ||||||
| 	machineKeyFromCache, machineKeyOK := machineKeyIf.(string) | 	nodeKeyFromCache, nodeKeyOK := machineKeyIf.(string) | ||||||
| 	err := machineKey.UnmarshalText( | 	err := nodeKey.UnmarshalText( | ||||||
| 		[]byte(MachinePublicKeyEnsurePrefix(machineKeyFromCache)), | 		[]byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)), | ||||||
| 	) | 	) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error(). | 		log.Error(). | ||||||
| 			Msg("could not parse machine public key") | 			Msg("could not parse node public key") | ||||||
| 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||||
| 		writer.WriteHeader(http.StatusBadRequest) | 		writer.WriteHeader(http.StatusBadRequest) | ||||||
| 		_, werr := writer.Write([]byte("could not parse public key")) | 		_, werr := writer.Write([]byte("could not parse public key")) | ||||||
| @ -441,11 +441,11 @@ func (h *Headscale) validateMachineForOIDCCallback( | |||||||
| 		return nil, false, err | 		return nil, false, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if !machineKeyOK { | 	if !nodeKeyOK { | ||||||
| 		log.Error().Msg("could not get machine key from cache") | 		log.Error().Msg("could not get node key from cache") | ||||||
| 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||||
| 		writer.WriteHeader(http.StatusInternalServerError) | 		writer.WriteHeader(http.StatusInternalServerError) | ||||||
| 		_, err := writer.Write([]byte("could not get machine key from cache")) | 		_, err := writer.Write([]byte("could not get node key from cache")) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error(). | 			log.Error(). | ||||||
| 				Caller(). | 				Caller(). | ||||||
| @ -453,14 +453,14 @@ func (h *Headscale) validateMachineForOIDCCallback( | |||||||
| 				Msg("Failed to write response") | 				Msg("Failed to write response") | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		return nil, false, errOIDCMachineKeyMissing | 		return nil, false, errOIDCNodeKeyMissing | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// retrieve machine information if it exist
 | 	// retrieve machine information if it exist
 | ||||||
| 	// The error is not important, because if it does not
 | 	// The error is not important, because if it does not
 | ||||||
| 	// exist, then this is a new machine and we will move
 | 	// exist, then this is a new machine and we will move
 | ||||||
| 	// on to registration.
 | 	// on to registration.
 | ||||||
| 	machine, _ := h.GetMachineByMachineKey(machineKey) | 	machine, _ := h.GetMachineByNodeKey(nodeKey) | ||||||
| 
 | 
 | ||||||
| 	if machine != nil { | 	if machine != nil { | ||||||
| 		log.Trace(). | 		log.Trace(). | ||||||
| @ -520,7 +520,7 @@ func (h *Headscale) validateMachineForOIDCCallback( | |||||||
| 		return nil, true, nil | 		return nil, true, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &machineKey, false, nil | 	return &nodeKey, false, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getNamespaceName( | func getNamespaceName( | ||||||
| @ -600,12 +600,12 @@ func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback( | |||||||
| func (h *Headscale) registerMachineForOIDCCallback( | func (h *Headscale) registerMachineForOIDCCallback( | ||||||
| 	writer http.ResponseWriter, | 	writer http.ResponseWriter, | ||||||
| 	namespace *Namespace, | 	namespace *Namespace, | ||||||
| 	machineKey *key.MachinePublic, | 	nodeKey *key.NodePublic, | ||||||
| ) error { | ) error { | ||||||
| 	machineKeyStr := MachinePublicKeyStripPrefix(*machineKey) | 	nodeKeyStr := NodePublicKeyStripPrefix(*nodeKey) | ||||||
| 
 | 
 | ||||||
| 	if _, err := h.RegisterMachineFromAuthCallback( | 	if _, err := h.RegisterMachineFromAuthCallback( | ||||||
| 		machineKeyStr, | 		nodeKeyStr, | ||||||
| 		namespace.Name, | 		namespace.Name, | ||||||
| 		RegisterMethodOIDC, | 		RegisterMethodOIDC, | ||||||
| 	); err != nil { | 	); err != nil { | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user