1
0
mirror of https://github.com/juanfont/headscale.git synced 2024-12-20 19:09:07 +01:00

golangci-lint --fix

This commit is contained in:
Kristoffer Dalby 2021-11-13 08:39:04 +00:00
parent dae34ca8c5
commit 2634215f12
17 changed files with 62 additions and 77 deletions

View File

@ -9,7 +9,6 @@ import (
"strings" "strings"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/tailscale/hujson" "github.com/tailscale/hujson"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -25,7 +24,7 @@ const (
errorInvalidPortFormat = Error("invalid port format") errorInvalidPortFormat = Error("invalid port format")
) )
// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules // LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules.
func (h *Headscale) LoadACLPolicy(path string) error { func (h *Headscale) LoadACLPolicy(path string) error {
policyFile, err := os.Open(path) policyFile, err := os.Open(path)
if err != nil { if err != nil {

View File

@ -8,7 +8,7 @@ import (
"inet.af/netaddr" "inet.af/netaddr"
) )
// ACLPolicy represents a Tailscale ACL Policy // ACLPolicy represents a Tailscale ACL Policy.
type ACLPolicy struct { type ACLPolicy struct {
Groups Groups `json:"Groups"` Groups Groups `json:"Groups"`
Hosts Hosts `json:"Hosts"` Hosts Hosts `json:"Hosts"`
@ -17,30 +17,30 @@ type ACLPolicy struct {
Tests []ACLTest `json:"Tests"` Tests []ACLTest `json:"Tests"`
} }
// ACL is a basic rule for the ACL Policy // ACL is a basic rule for the ACL Policy.
type ACL struct { type ACL struct {
Action string `json:"Action"` Action string `json:"Action"`
Users []string `json:"Users"` Users []string `json:"Users"`
Ports []string `json:"Ports"` Ports []string `json:"Ports"`
} }
// Groups references a series of alias in the ACL rules // Groups references a series of alias in the ACL rules.
type Groups map[string][]string type Groups map[string][]string
// Hosts are alias for IP addresses or subnets // Hosts are alias for IP addresses or subnets.
type Hosts map[string]netaddr.IPPrefix type Hosts map[string]netaddr.IPPrefix
// TagOwners specify what users (namespaces?) are allow to use certain tags // TagOwners specify what users (namespaces?) are allow to use certain tags.
type TagOwners map[string][]string type TagOwners map[string][]string
// ACLTest is not implemented, but should be use to check if a certain rule is allowed // ACLTest is not implemented, but should be use to check if a certain rule is allowed.
type ACLTest struct { type ACLTest struct {
User string `json:"User"` User string `json:"User"`
Allow []string `json:"Allow"` Allow []string `json:"Allow"`
Deny []string `json:"Deny,omitempty"` Deny []string `json:"Deny,omitempty"`
} }
// UnmarshalJSON allows to parse the Hosts directly into netaddr objects // UnmarshalJSON allows to parse the Hosts directly into netaddr objects.
func (h *Hosts) UnmarshalJSON(data []byte) error { func (h *Hosts) UnmarshalJSON(data []byte) error {
hosts := Hosts{} hosts := Hosts{}
hs := make(map[string]string) hs := make(map[string]string)
@ -68,7 +68,7 @@ func (h *Hosts) UnmarshalJSON(data []byte) error {
return nil return nil
} }
// IsZero is perhaps a bit naive here // IsZero is perhaps a bit naive here.
func (p ACLPolicy) IsZero() bool { func (p ACLPolicy) IsZero() bool {
if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 { if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 {
return true return true

10
api.go
View File

@ -10,23 +10,22 @@ import (
"strings" "strings"
"time" "time"
"github.com/rs/zerolog/log"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/wgkey" "tailscale.com/types/wgkey"
) )
// KeyHandler provides the Headscale pub key // KeyHandler provides the Headscale pub key
// Listens in /key // Listens in /key.
func (h *Headscale) KeyHandler(c *gin.Context) { func (h *Headscale) KeyHandler(c *gin.Context) {
c.Data(200, "text/plain; charset=utf-8", []byte(h.publicKey.HexString())) c.Data(200, "text/plain; charset=utf-8", []byte(h.publicKey.HexString()))
} }
// RegisterWebAPI shows a simple message in the browser to point to the CLI // RegisterWebAPI shows a simple message in the browser to point to the CLI
// Listens in /register // Listens in /register.
func (h *Headscale) RegisterWebAPI(c *gin.Context) { func (h *Headscale) RegisterWebAPI(c *gin.Context) {
mKeyStr := c.Query("key") mKeyStr := c.Query("key")
if mKeyStr == "" { if mKeyStr == "" {
@ -55,7 +54,7 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
} }
// RegistrationHandler handles the actual registration process of a machine // RegistrationHandler handles the actual registration process of a machine
// Endpoint /machine/:id // Endpoint /machine/:id.
func (h *Headscale) RegistrationHandler(c *gin.Context) { func (h *Headscale) RegistrationHandler(c *gin.Context) {
body, _ := io.ReadAll(c.Request.Body) body, _ := io.ReadAll(c.Request.Body)
mKeyStr := c.Param("id") mKeyStr := c.Param("id")
@ -111,7 +110,6 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// We have the updated key! // We have the updated key!
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() { if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) { if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {

10
app.go
View File

@ -18,20 +18,19 @@ import (
"time" "time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/grpc-ecosystem/go-grpc-middleware" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/philip-bui/grpc-zerolog" "github.com/patrickmn/go-cache"
zerolog "github.com/philip-bui/grpc-zerolog"
zl "github.com/rs/zerolog" zl "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/soheilhy/cmux" "github.com/soheilhy/cmux"
ginprometheus "github.com/zsais/go-gin-prometheus" ginprometheus "github.com/zsais/go-gin-prometheus"
"golang.org/x/crypto/acme" "golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"golang.org/x/oauth2"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -280,7 +279,6 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
req interface{}, req interface{},
info *grpc.UnaryServerInfo, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (interface{}, error) { handler grpc.UnaryHandler) (interface{}, error) {
// Check if the request is coming from the on-server client. // Check if the request is coming from the on-server client.
// This is not secure, but it is to maintain maintainability // This is not secure, but it is to maintain maintainability
// with the "legacy" database-based client // with the "legacy" database-based client

View File

@ -5,14 +5,13 @@ import (
"net/http" "net/http"
"text/template" "text/template"
"github.com/rs/zerolog/log"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gofrs/uuid" "github.com/gofrs/uuid"
"github.com/rs/zerolog/log"
) )
// AppleMobileConfig shows a simple message in the browser to point to the CLI // AppleMobileConfig shows a simple message in the browser to point to the CLI
// Listens in /register // Listens in /register.
func (h *Headscale) AppleMobileConfig(c *gin.Context) { func (h *Headscale) AppleMobileConfig(c *gin.Context) {
t := template.Must(template.New("apple").Parse(` t := template.Must(template.New("apple").Parse(`
<html> <html>

View File

@ -171,7 +171,7 @@ omit the route you do not want to enable.
}, },
} }
// routesToPtables converts the list of routes to a nice table // routesToPtables converts the list of routes to a nice table.
func routesToPtables(routes *v1.Routes) pterm.TableData { func routesToPtables(routes *v1.Routes) pterm.TableData {
d := pterm.TableData{{"Route", "Enabled"}} d := pterm.TableData{{"Route", "Enabled"}}

View File

@ -355,7 +355,6 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.
// If the address is not set, we assume that we are on the server hosting headscale. // If the address is not set, we assume that we are on the server hosting headscale.
if address == "" { if address == "" {
log.Debug(). log.Debug().
Str("socket", cfg.UnixSocket). Str("socket", cfg.UnixSocket).
Msgf("HEADSCALE_CLI_ADDRESS environment is not set, connecting to unix socket.") Msgf("HEADSCALE_CLI_ADDRESS environment is not set, connecting to unix socket.")

4
db.go
View File

@ -84,7 +84,7 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
return db, nil return db, nil
} }
// getValue returns the value for the given key in KV // getValue returns the value for the given key in KV.
func (h *Headscale) getValue(key string) (string, error) { func (h *Headscale) getValue(key string) (string, error) {
var row KV var row KV
if result := h.db.First(&row, "key = ?", key); errors.Is( if result := h.db.First(&row, "key = ?", key); errors.Is(
@ -96,7 +96,7 @@ func (h *Headscale) getValue(key string) (string, error) {
return row.Value, nil return row.Value, nil
} }
// setValue sets value for the given key in KV // setValue sets value for the given key in KV.
func (h *Headscale) setValue(key string, value string) error { func (h *Headscale) setValue(key string, value string) error {
kv := KV{ kv := KV{
Key: key, Key: key,

View File

@ -10,9 +10,7 @@ import (
"time" "time"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@ -55,7 +53,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
// DERPMap, it will _only_ look at the Regions, an integer. // DERPMap, it will _only_ look at the Regions, an integer.
// If a region exists in two of the given DERPMaps, the region // If a region exists in two of the given DERPMaps, the region
// form the _last_ DERPMap will be preserved. // form the _last_ DERPMap will be preserved.
// An empty DERPMap list will result in a DERPMap with no regions // An empty DERPMap list will result in a DERPMap with no regions.
func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap { func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap {
result := tailcfg.DERPMap{ result := tailcfg.DERPMap{
OmitDefaultRegions: false, OmitDefaultRegions: false,

View File

@ -10,10 +10,9 @@ import (
"time" "time"
"github.com/fatih/set" "github.com/fatih/set"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"gorm.io/datatypes" "gorm.io/datatypes"
"gorm.io/gorm" "gorm.io/gorm"
"inet.af/netaddr" "inet.af/netaddr"
@ -21,7 +20,7 @@ import (
"tailscale.com/types/wgkey" "tailscale.com/types/wgkey"
) )
// Machine is a Headscale client // Machine is a Headscale client.
type Machine struct { type Machine struct {
ID uint64 `gorm:"primary_key"` ID uint64 `gorm:"primary_key"`
MachineKey string `gorm:"type:varchar(64);unique_index"` MachineKey string `gorm:"type:varchar(64);unique_index"`
@ -56,12 +55,12 @@ type (
MachinesP []*Machine MachinesP []*Machine
) )
// For the time being this method is rather naive // For the time being this method is rather naive.
func (m Machine) isAlreadyRegistered() bool { func (m Machine) isAlreadyRegistered() bool {
return m.Registered return m.Registered
} }
// isExpired returns whether the machine registration has expired // isExpired returns whether the machine registration has expired.
func (m Machine) isExpired() bool { func (m Machine) isExpired() bool {
return time.Now().UTC().After(*m.Expiry) return time.Now().UTC().After(*m.Expiry)
} }
@ -119,7 +118,7 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
return machines, nil return machines, nil
} }
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for // getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for.
func (h *Headscale) getShared(m *Machine) (Machines, error) { func (h *Headscale) getShared(m *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
@ -146,7 +145,7 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) {
return peers, nil return peers, nil
} }
// getSharedTo fetches the machines of the namespaces this machine is shared in // getSharedTo fetches the machines of the namespaces this machine is shared in.
func (h *Headscale) getSharedTo(m *Machine) (Machines, error) { func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
log.Trace(). log.Trace().
Caller(). Caller().
@ -228,7 +227,7 @@ func (h *Headscale) ListMachines() ([]Machine, error) {
return machines, nil return machines, nil
} }
// GetMachine finds a Machine by name and namespace and returns the Machine struct // GetMachine finds a Machine by name and namespace and returns the Machine struct.
func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) { func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) {
machines, err := h.ListMachinesInNamespace(namespace) machines, err := h.ListMachinesInNamespace(namespace)
if err != nil { if err != nil {
@ -243,7 +242,7 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error)
return nil, fmt.Errorf("machine not found") return nil, fmt.Errorf("machine not found")
} }
// GetMachineByID finds a Machine by ID and returns the Machine struct // GetMachineByID finds a Machine by ID and returns the Machine struct.
func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
m := Machine{} m := Machine{}
if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil { if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil {
@ -252,7 +251,7 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
return &m, nil return &m, nil
} }
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct // GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) { func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) {
m := Machine{} m := Machine{}
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil { if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil {
@ -270,7 +269,7 @@ func (h *Headscale) UpdateMachine(m *Machine) error {
return nil return nil
} }
// DeleteMachine softs deletes a Machine from the database // DeleteMachine softs deletes a Machine from the database.
func (h *Headscale) DeleteMachine(m *Machine) error { func (h *Headscale) DeleteMachine(m *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(m) err := h.RemoveSharedMachineFromAllNamespaces(m)
if err != nil && err != errorMachineNotShared { if err != nil && err != errorMachineNotShared {
@ -287,7 +286,7 @@ func (h *Headscale) DeleteMachine(m *Machine) error {
return h.RequestMapUpdates(namespaceID) return h.RequestMapUpdates(namespaceID)
} }
// HardDeleteMachine hard deletes a Machine from the database // HardDeleteMachine hard deletes a Machine from the database.
func (h *Headscale) HardDeleteMachine(m *Machine) error { func (h *Headscale) HardDeleteMachine(m *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(m) err := h.RemoveSharedMachineFromAllNamespaces(m)
if err != nil && err != errorMachineNotShared { if err != nil && err != errorMachineNotShared {
@ -302,7 +301,7 @@ func (h *Headscale) HardDeleteMachine(m *Machine) error {
return h.RequestMapUpdates(namespaceID) return h.RequestMapUpdates(namespaceID)
} }
// GetHostInfo returns a Hostinfo struct for the machine // GetHostInfo returns a Hostinfo struct for the machine.
func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) { func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
hostinfo := tailcfg.Hostinfo{} hostinfo := tailcfg.Hostinfo{}
if len(m.HostInfo) != 0 { if len(m.HostInfo) != 0 {
@ -397,7 +396,7 @@ func (ms Machines) toNodes(
} }
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes // toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
// as per the expected behaviour in the official SaaS // as per the expected behaviour in the official SaaS.
func (m Machine) toNode( func (m Machine) toNode(
baseDomain string, baseDomain string,
dnsConfig *tailcfg.DNSConfig, dnsConfig *tailcfg.DNSConfig,
@ -572,7 +571,7 @@ func (m *Machine) toProto() *v1.Machine {
return machine return machine
} }
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) { func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) {
ns, err := h.GetNamespace(namespace) ns, err := h.GetNamespace(namespace)
if err != nil { if err != nil {

View File

@ -30,7 +30,7 @@ type Namespace struct {
} }
// CreateNamespace creates a new Namespace. Returns error if could not be created // CreateNamespace creates a new Namespace. Returns error if could not be created
// or another namespace already exists // or another namespace already exists.
func (h *Headscale) CreateNamespace(name string) (*Namespace, error) { func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
n := Namespace{} n := Namespace{}
if err := h.db.Where("name = ?", name).First(&n).Error; err == nil { if err := h.db.Where("name = ?", name).First(&n).Error; err == nil {
@ -99,7 +99,7 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
return nil return nil
} }
// GetNamespace fetches a namespace by name // GetNamespace fetches a namespace by name.
func (h *Headscale) GetNamespace(name string) (*Namespace, error) { func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
n := Namespace{} n := Namespace{}
if result := h.db.First(&n, "name = ?", name); errors.Is( if result := h.db.First(&n, "name = ?", name); errors.Is(
@ -111,7 +111,7 @@ func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
return &n, nil return &n, nil
} }
// ListNamespaces gets all the existing namespaces // ListNamespaces gets all the existing namespaces.
func (h *Headscale) ListNamespaces() ([]Namespace, error) { func (h *Headscale) ListNamespaces() ([]Namespace, error) {
namespaces := []Namespace{} namespaces := []Namespace{}
if err := h.db.Find(&namespaces).Error; err != nil { if err := h.db.Find(&namespaces).Error; err != nil {
@ -120,7 +120,7 @@ func (h *Headscale) ListNamespaces() ([]Namespace, error) {
return namespaces, nil return namespaces, nil
} }
// ListMachinesInNamespace gets all the nodes in a given namespace // ListMachinesInNamespace gets all the nodes in a given namespace.
func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) { func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
n, err := h.GetNamespace(name) n, err := h.GetNamespace(name)
if err != nil { if err != nil {
@ -134,7 +134,7 @@ func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
return machines, nil return machines, nil
} }
// ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace // ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace.
func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error) { func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error) {
namespace, err := h.GetNamespace(name) namespace, err := h.GetNamespace(name)
if err != nil { if err != nil {
@ -158,7 +158,7 @@ func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error
return machines, nil return machines, nil
} }
// SetMachineNamespace assigns a Machine to a namespace // SetMachineNamespace assigns a Machine to a namespace.
func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error { func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error {
n, err := h.GetNamespace(namespaceName) n, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
@ -169,7 +169,7 @@ func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error
return nil return nil
} }
// RequestMapUpdates signals the KV worker to update the maps for this namespace // RequestMapUpdates signals the KV worker to update the maps for this namespace.
func (h *Headscale) RequestMapUpdates(namespaceID uint) error { func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
namespace := Namespace{} namespace := Namespace{}
if err := h.db.First(&namespace, namespaceID).Error; err != nil { if err := h.db.First(&namespace, namespaceID).Error; err != nil {

View File

@ -57,7 +57,7 @@ func (h *Headscale) initOIDC() error {
// RegisterOIDC redirects to the OIDC provider for authentication // RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param // Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey // Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC(c *gin.Context) { func (h *Headscale) RegisterOIDC(c *gin.Context) {
mKeyStr := c.Param("mkey") mKeyStr := c.Param("mkey")
if mKeyStr == "" { if mKeyStr == "" {
@ -88,7 +88,7 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
// Retrieves the mkey from the state cache and adds the machine to the users email namespace // Retrieves the mkey from the state cache and adds the machine to the users email namespace
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo // TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback // Listens in /oidc/callback.
func (h *Headscale) OIDCCallback(c *gin.Context) { func (h *Headscale) OIDCCallback(c *gin.Context) {
code := c.Query("code") code := c.Query("code")
state := c.Query("state") state := c.Query("state")
@ -170,7 +170,6 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok { if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok {
// register the machine if it's new // register the machine if it's new
if !m.Registered { if !m.Registered {
log.Debug().Msg("Registering new machine after successful callback") log.Debug().Msg("Registering new machine after successful callback")
ns, err := h.GetNamespace(nsName) ns, err := h.GetNamespace(nsName)
@ -218,7 +217,6 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
</html> </html>
`, claims.Email))) `, claims.Email)))
} }
log.Error(). log.Error().

View File

@ -7,10 +7,9 @@ import (
"strconv" "strconv"
"time" "time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm" "gorm.io/gorm"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
) )
const ( const (
@ -19,7 +18,7 @@ const (
errSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used") errSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used")
) )
// PreAuthKey describes a pre-authorization key usable in a particular namespace // PreAuthKey describes a pre-authorization key usable in a particular namespace.
type PreAuthKey struct { type PreAuthKey struct {
ID uint64 `gorm:"primary_key"` ID uint64 `gorm:"primary_key"`
Key string Key string
@ -33,7 +32,7 @@ type PreAuthKey struct {
Expiration *time.Time Expiration *time.Time
} }
// CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it // CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it.
func (h *Headscale) CreatePreAuthKey( func (h *Headscale) CreatePreAuthKey(
namespaceName string, namespaceName string,
reusable bool, reusable bool,
@ -65,7 +64,7 @@ func (h *Headscale) CreatePreAuthKey(
return &k, nil return &k, nil
} }
// ListPreAuthKeys returns the list of PreAuthKeys for a namespace // ListPreAuthKeys returns the list of PreAuthKeys for a namespace.
func (h *Headscale) ListPreAuthKeys(namespaceName string) ([]PreAuthKey, error) { func (h *Headscale) ListPreAuthKeys(namespaceName string) ([]PreAuthKey, error) {
n, err := h.GetNamespace(namespaceName) n, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
@ -79,7 +78,7 @@ func (h *Headscale) ListPreAuthKeys(namespaceName string) ([]PreAuthKey, error)
return keys, nil return keys, nil
} }
// GetPreAuthKey returns a PreAuthKey for a given key // GetPreAuthKey returns a PreAuthKey for a given key.
func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, error) { func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, error) {
pak, err := h.checkKeyValidity(key) pak, err := h.checkKeyValidity(key)
if err != nil { if err != nil {
@ -93,7 +92,7 @@ func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, er
return pak, nil return pak, nil
} }
// MarkExpirePreAuthKey marks a PreAuthKey as expired // MarkExpirePreAuthKey marks a PreAuthKey as expired.
func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error { func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error {
if err := h.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { if err := h.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err return err
@ -102,7 +101,7 @@ func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error {
} }
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node // checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
// If returns no error and a PreAuthKey, it can be used // If returns no error and a PreAuthKey, it can be used.
func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
pak := PreAuthKey{} pak := PreAuthKey{}
if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is( if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(

View File

@ -10,7 +10,7 @@ import (
// Deprecated: use machine function instead // Deprecated: use machine function instead
// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by // GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by
// namespace and node name) // namespace and node name).
func (h *Headscale) GetAdvertisedNodeRoutes( func (h *Headscale) GetAdvertisedNodeRoutes(
namespace string, namespace string,
nodeName string, nodeName string,
@ -29,7 +29,7 @@ func (h *Headscale) GetAdvertisedNodeRoutes(
// Deprecated: use machine function instead // Deprecated: use machine function instead
// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by // GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by
// namespace and node name) // namespace and node name).
func (h *Headscale) GetEnabledNodeRoutes( func (h *Headscale) GetEnabledNodeRoutes(
namespace string, namespace string,
nodeName string, nodeName string,
@ -63,7 +63,7 @@ func (h *Headscale) GetEnabledNodeRoutes(
} }
// Deprecated: use machine function instead // Deprecated: use machine function instead
// IsNodeRouteEnabled checks if a certain route has been enabled // IsNodeRouteEnabled checks if a certain route has been enabled.
func (h *Headscale) IsNodeRouteEnabled( func (h *Headscale) IsNodeRouteEnabled(
namespace string, namespace string,
nodeName string, nodeName string,
@ -89,7 +89,7 @@ func (h *Headscale) IsNodeRouteEnabled(
// Deprecated: use EnableRoute in machine.go // Deprecated: use EnableRoute in machine.go
// EnableNodeRoute enables a subnet route advertised by a node (identified by // EnableNodeRoute enables a subnet route advertised by a node (identified by
// namespace and node name) // namespace and node name).
func (h *Headscale) EnableNodeRoute( func (h *Headscale) EnableNodeRoute(
namespace string, namespace string,
nodeName string, nodeName string,

View File

@ -8,7 +8,7 @@ const (
errorMachineNotShared = Error("Machine not shared to this namespace") errorMachineNotShared = Error("Machine not shared to this namespace")
) )
// SharedMachine is a join table to support sharing nodes between namespaces // SharedMachine is a join table to support sharing nodes between namespaces.
type SharedMachine struct { type SharedMachine struct {
gorm.Model gorm.Model
MachineID uint64 MachineID uint64
@ -17,7 +17,7 @@ type SharedMachine struct {
Namespace Namespace Namespace Namespace
} }
// AddSharedMachineToNamespace adds a machine as a shared node to a namespace // AddSharedMachineToNamespace adds a machine as a shared node to a namespace.
func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error { func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error {
if m.NamespaceID == ns.ID { if m.NamespaceID == ns.ID {
return errorSameNamespace return errorSameNamespace
@ -42,7 +42,7 @@ func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error
return nil return nil
} }
// RemoveSharedMachineFromNamespace removes a shared machine from a namespace // RemoveSharedMachineFromNamespace removes a shared machine from a namespace.
func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) error { func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) error {
if m.NamespaceID == ns.ID { if m.NamespaceID == ns.ID {
// Can't unshare from primary namespace // Can't unshare from primary namespace
@ -69,7 +69,7 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
return nil return nil
} }
// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces // RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces.
func (h *Headscale) RemoveSharedMachineFromAllNamespaces(m *Machine) error { func (h *Headscale) RemoveSharedMachineFromAllNamespaces(m *Machine) error {
sharedMachine := SharedMachine{} sharedMachine := SharedMachine{}
if result := h.db.Where("machine_id = ?", m.ID).Unscoped().Delete(&sharedMachine); result.Error != nil { if result := h.db.Where("machine_id = ?", m.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {

View File

@ -6,9 +6,8 @@ import (
"net/http" "net/http"
"text/template" "text/template"
"github.com/rs/zerolog/log"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
) )
//go:embed gen/openapiv2/headscale/v1/headscale.swagger.json //go:embed gen/openapiv2/headscale/v1/headscale.swagger.json

View File

@ -113,7 +113,6 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
if ip.IsZero() && if ip.IsZero() &&
ip.IsLoopback() { ip.IsLoopback() {
ip = ip.Next() ip = ip.Next()
continue continue
} }