mirror of
https://github.com/juanfont/headscale.git
synced 2024-12-20 19:09:07 +01:00
Merge branch 'main' into update-deps-20220904
This commit is contained in:
commit
3c73cbe92b
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@ -26,7 +26,7 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: golangci/golangci-lint-action@v2
|
||||
with:
|
||||
version: v1.46.1
|
||||
version: v1.49.0
|
||||
|
||||
# Only block PRs on new problems.
|
||||
# If this is not enabled, we will end up having PRs
|
||||
|
@ -825,7 +825,6 @@ func Test_listMachinesInNamespace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// nolint
|
||||
func Test_expandAlias(t *testing.T) {
|
||||
type args struct {
|
||||
machines []Machine
|
||||
|
2
api.go
2
api.go
@ -52,7 +52,7 @@ func (h *Headscale) HealthHandler(
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.pingDB(); err != nil {
|
||||
if err := h.pingDB(req.Context()); err != nil {
|
||||
respond(err)
|
||||
|
||||
return
|
||||
|
31
app.go
31
app.go
@ -18,7 +18,7 @@ import (
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/mux"
|
||||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/patrickmn/go-cache"
|
||||
@ -601,7 +601,7 @@ func (h *Headscale) Serve() error {
|
||||
|
||||
grpcOptions := []grpc.ServerOption{
|
||||
grpc.UnaryInterceptor(
|
||||
grpc_middleware.ChainUnaryServer(
|
||||
grpcMiddleware.ChainUnaryServer(
|
||||
h.grpcAuthenticationInterceptor,
|
||||
zerolog.NewUnaryServerInterceptor(),
|
||||
),
|
||||
@ -820,10 +820,19 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||
// Configuration via autocert with HTTP-01. This requires listening on
|
||||
// port 80 for the certificate validation in addition to the headscale
|
||||
// service, which can be configured to run on any other port.
|
||||
|
||||
server := &http.Server{
|
||||
Addr: h.cfg.TLS.LetsEncrypt.Listen,
|
||||
Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)),
|
||||
ReadTimeout: HTTPReadTimeout,
|
||||
}
|
||||
|
||||
err := server.ListenAndServe()
|
||||
|
||||
go func() {
|
||||
log.Fatal().
|
||||
Caller().
|
||||
Err(http.ListenAndServe(h.cfg.TLS.LetsEncrypt.Listen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
|
||||
Err(err).
|
||||
Msg("failed to set up a HTTP server")
|
||||
}()
|
||||
|
||||
@ -860,19 +869,17 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) setLastStateChangeToNow(namespaces ...string) {
|
||||
func (h *Headscale) setLastStateChangeToNow() {
|
||||
var err error
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
if len(namespaces) == 0 {
|
||||
namespaces, err = h.ListNamespacesStr()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("failed to fetch all namespaces, failing to update last changed state.")
|
||||
}
|
||||
namespaces, err := h.ListNamespacesStr()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("failed to fetch all namespaces, failing to update last changed state.")
|
||||
}
|
||||
|
||||
for _, namespace := range namespaces {
|
||||
|
@ -5,12 +5,11 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"net/netip"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
4
db.go
4
db.go
@ -221,8 +221,8 @@ func (h *Headscale) setValue(key string, value string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Headscale) pingDB() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
func (h *Headscale) pingDB(ctx context.Context) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
db, err := h.db.DB()
|
||||
if err != nil {
|
||||
|
2
derp.go
2
derp.go
@ -34,7 +34,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), HTTPReadTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", addr.String(), nil)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -154,7 +154,7 @@ func (h *Headscale) DERPHandler(
|
||||
|
||||
if !fastStart {
|
||||
pubKey := h.privateKey.Public()
|
||||
pubKeyStr := pubKey.UntypedHexString() // nolint
|
||||
pubKeyStr := pubKey.UntypedHexString() //nolint
|
||||
fmt.Fprintf(conn, "HTTP/1.1 101 Switching Protocols\r\n"+
|
||||
"Upgrade: DERP\r\n"+
|
||||
"Connection: Upgrade\r\n"+
|
||||
@ -174,7 +174,7 @@ func (h *Headscale) DERPProbeHandler(
|
||||
req *http.Request,
|
||||
) {
|
||||
switch req.Method {
|
||||
case "HEAD", "GET":
|
||||
case http.MethodHead, http.MethodGet:
|
||||
writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
@ -202,7 +202,7 @@ func (h *Headscale) DERPBootstrapDNSHandler(
|
||||
) {
|
||||
dnsEntries := make(map[string][]net.IP)
|
||||
|
||||
resolvCtx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute)
|
||||
defer cancel()
|
||||
var resolver net.Resolver
|
||||
for _, region := range h.DERPMap.Regions {
|
||||
|
@ -540,7 +540,6 @@ func Test_getTags(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// nolint
|
||||
func Test_getFilteredByACLPeers(t *testing.T) {
|
||||
type args struct {
|
||||
machines []Machine
|
||||
|
4
noise.go
4
noise.go
@ -31,7 +31,9 @@ func (h *Headscale) NoiseUpgradeHandler(
|
||||
return
|
||||
}
|
||||
|
||||
server := http.Server{}
|
||||
server := http.Server{
|
||||
ReadTimeout: HTTPReadTimeout,
|
||||
}
|
||||
server.Handler = h2c.NewHandler(h.noiseMux, &http2.Server{})
|
||||
err = server.Serve(netutil.NewOneConnListener(noiseConn, nil))
|
||||
if err != nil {
|
||||
|
10
oidc.go
10
oidc.go
@ -148,12 +148,12 @@ func (h *Headscale) OIDCCallback(
|
||||
return
|
||||
}
|
||||
|
||||
rawIDToken, err := h.getIDTokenForOIDCCallback(writer, code, state)
|
||||
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
idToken, err := h.verifyIDTokenForOIDCCallback(writer, rawIDToken)
|
||||
idToken, err := h.verifyIDTokenForOIDCCallback(req.Context(), writer, rawIDToken)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -240,10 +240,11 @@ func validateOIDCCallbackParams(
|
||||
}
|
||||
|
||||
func (h *Headscale) getIDTokenForOIDCCallback(
|
||||
ctx context.Context,
|
||||
writer http.ResponseWriter,
|
||||
code, state string,
|
||||
) (string, error) {
|
||||
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
|
||||
oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
@ -287,11 +288,12 @@ func (h *Headscale) getIDTokenForOIDCCallback(
|
||||
}
|
||||
|
||||
func (h *Headscale) verifyIDTokenForOIDCCallback(
|
||||
ctx context.Context,
|
||||
writer http.ResponseWriter,
|
||||
rawIDToken string,
|
||||
) (*oidc.IDToken, error) {
|
||||
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
|
||||
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
||||
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
|
@ -105,7 +105,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// If the machine has AuthKey set, handle registration via PreAuthKeys
|
||||
if registerRequest.Auth.AuthKey != "" {
|
||||
h.handleAuthKeyCommon(writer, req, registerRequest, machineKey)
|
||||
h.handleAuthKeyCommon(writer, registerRequest, machineKey)
|
||||
|
||||
return
|
||||
}
|
||||
@ -134,7 +134,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||
case <-req.Context().Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
h.handleNewMachineCommon(writer, req, registerRequest, machineKey)
|
||||
h.handleNewMachineCommon(writer, registerRequest, machineKey)
|
||||
|
||||
return
|
||||
}
|
||||
@ -190,7 +190,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||
registerCacheExpiration,
|
||||
)
|
||||
|
||||
h.handleNewMachineCommon(writer, req, registerRequest, machineKey)
|
||||
h.handleNewMachineCommon(writer, registerRequest, machineKey)
|
||||
|
||||
return
|
||||
}
|
||||
@ -207,7 +207,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
||||
if !registerRequest.Expiry.IsZero() &&
|
||||
registerRequest.Expiry.UTC().Before(now) {
|
||||
h.handleMachineLogOutCommon(writer, req, *machine, machineKey)
|
||||
h.handleMachineLogOutCommon(writer, *machine, machineKey)
|
||||
|
||||
return
|
||||
}
|
||||
@ -215,7 +215,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||
// If machine is not expired, and is register, we have a already accepted this machine,
|
||||
// let it proceed with a valid registration
|
||||
if !machine.isExpired() {
|
||||
h.handleMachineValidRegistrationCommon(writer, req, *machine, machineKey)
|
||||
h.handleMachineValidRegistrationCommon(writer, *machine, machineKey)
|
||||
|
||||
return
|
||||
}
|
||||
@ -226,7 +226,6 @@ func (h *Headscale) handleRegisterCommon(
|
||||
!machine.isExpired() {
|
||||
h.handleMachineRefreshKeyCommon(
|
||||
writer,
|
||||
req,
|
||||
registerRequest,
|
||||
*machine,
|
||||
machineKey,
|
||||
@ -236,7 +235,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||
}
|
||||
|
||||
// The machine has expired
|
||||
h.handleMachineExpiredCommon(writer, req, registerRequest, *machine, machineKey)
|
||||
h.handleMachineExpiredCommon(writer, registerRequest, *machine, machineKey)
|
||||
|
||||
machine.Expiry = &time.Time{}
|
||||
h.registrationCache.Set(
|
||||
@ -256,7 +255,6 @@ func (h *Headscale) handleRegisterCommon(
|
||||
// TODO: check if any locks are needed around IP allocation.
|
||||
func (h *Headscale) handleAuthKeyCommon(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
registerRequest tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) {
|
||||
@ -455,7 +453,6 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||
// for authorizing the machine. This url is then showed to the user by the local Tailscale client.
|
||||
func (h *Headscale) handleNewMachineCommon(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
registerRequest tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) {
|
||||
@ -511,7 +508,6 @@ func (h *Headscale) handleNewMachineCommon(
|
||||
|
||||
func (h *Headscale) handleMachineLogOutCommon(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
machine Machine,
|
||||
machineKey key.MachinePublic,
|
||||
) {
|
||||
@ -570,7 +566,6 @@ func (h *Headscale) handleMachineLogOutCommon(
|
||||
|
||||
func (h *Headscale) handleMachineValidRegistrationCommon(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
machine Machine,
|
||||
machineKey key.MachinePublic,
|
||||
) {
|
||||
@ -624,7 +619,6 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
|
||||
|
||||
func (h *Headscale) handleMachineRefreshKeyCommon(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
registerRequest tailcfg.RegisterRequest,
|
||||
machine Machine,
|
||||
machineKey key.MachinePublic,
|
||||
@ -684,7 +678,6 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
|
||||
|
||||
func (h *Headscale) handleMachineExpiredCommon(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
registerRequest tailcfg.RegisterRequest,
|
||||
machine Machine,
|
||||
machineKey key.MachinePublic,
|
||||
@ -699,7 +692,7 @@ func (h *Headscale) handleMachineExpiredCommon(
|
||||
Msg("Machine registration has expired. Sending a authurl to register")
|
||||
|
||||
if registerRequest.Auth.AuthKey != "" {
|
||||
h.handleAuthKeyCommon(writer, req, registerRequest, machineKey)
|
||||
h.handleAuthKeyCommon(writer, registerRequest, machineKey)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ const machineNameContextKey = contextKey("machineName")
|
||||
// managed the poll loop.
|
||||
func (h *Headscale) handlePollCommon(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
ctx context.Context,
|
||||
machine *Machine,
|
||||
mapRequest tailcfg.MapRequest,
|
||||
isNoise bool,
|
||||
@ -201,7 +201,7 @@ func (h *Headscale) handlePollCommon(
|
||||
|
||||
h.pollNetMapStream(
|
||||
writer,
|
||||
req,
|
||||
ctx,
|
||||
machine,
|
||||
mapRequest,
|
||||
pollDataChan,
|
||||
@ -221,7 +221,7 @@ func (h *Headscale) handlePollCommon(
|
||||
// ensuring we communicate updates and data to the connected clients.
|
||||
func (h *Headscale) pollNetMapStream(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
ctxReq context.Context,
|
||||
machine *Machine,
|
||||
mapRequest tailcfg.MapRequest,
|
||||
pollDataChan chan []byte,
|
||||
@ -232,7 +232,7 @@ func (h *Headscale) pollNetMapStream(
|
||||
h.pollNetMapStreamWG.Add(1)
|
||||
defer h.pollNetMapStreamWG.Done()
|
||||
|
||||
ctx := context.WithValue(req.Context(), machineNameContextKey, machine.Hostname)
|
||||
ctx := context.WithValue(ctxReq, machineNameContextKey, machine.Hostname)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
@ -75,6 +75,8 @@ func (h *Headscale) marshalResponse(
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot marshal response")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if machineKey.IsZero() { // if Noise
|
||||
|
@ -90,5 +90,5 @@ func (h *Headscale) PollNetMapHandler(
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("A machine is entering polling via the legacy protocol")
|
||||
|
||||
h.handlePollCommon(writer, req, machine, mapRequest, false)
|
||||
h.handlePollCommon(writer, req.Context(), machine, mapRequest, false)
|
||||
}
|
||||
|
@ -63,5 +63,5 @@ func (h *Headscale) NoisePollNetMapHandler(
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("A machine is entering polling via the Noise protocol")
|
||||
|
||||
h.handlePollCommon(writer, req, machine, mapRequest, true)
|
||||
h.handlePollCommon(writer, req.Context(), machine, mapRequest, true)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user