mirror of
https://github.com/juanfont/headscale.git
synced 2025-02-20 00:18:41 +01:00
Block ability to use '*'. replace origins string to a list to config. enable cors for specific endpoints (#2)
This commit is contained in:
parent
68ce697562
commit
cbf6a43e0d
@ -40,12 +40,15 @@ grpc_listen_addr: 127.0.0.1:50443
|
||||
# are doing.
|
||||
grpc_allow_insecure: false
|
||||
|
||||
# The Access-Control-Allow-Origin header specifies which origins are allowed to access resources.
|
||||
# The allow_origins list will allow you to set the Access-Control-Allow-Origin header to the origin in the list.
|
||||
# This will allow you to enable cors and set headscale without a reverse proxy.
|
||||
# Multiple origins can be set in the allow_origins list.
|
||||
# Options:
|
||||
# - "*" to allow access from any origin (not recommended for sensitive data).
|
||||
# - "http://example.com" to only allow access from a specific origin.
|
||||
# - "" to disable Cross-Origin Resource Sharing (CORS).
|
||||
access_control_allow_origin: ""
|
||||
# - "*" is disabled (due to security risks).
|
||||
# - "https://example.com" to only allow access from a specific origin.
|
||||
# - "https://example.com:1234" to allow access from a specific origin with a port.
|
||||
cors:
|
||||
allow_origins: []
|
||||
|
||||
# The Noise section includes specific configuration for the
|
||||
# TS2021 Noise protocol
|
||||
|
@ -455,18 +455,63 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error {
|
||||
return os.Remove(h.cfg.UnixSocket)
|
||||
}
|
||||
|
||||
// corsHeaderMiddleware will add an "Access-Control-Allow-Origin" to enable CORS.
|
||||
func (h *Headscale) corsHeadersMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", h.cfg.AccessControlAllowOrigins)
|
||||
// skip disabled CORS endpoints
|
||||
if !h.enabledCorsRoutes(r.URL.Path) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
origin := r.Header.Get("Origin")
|
||||
// we compare origin from the allowed Origins list. Then add the header with origin
|
||||
for _, allowedOrigin := range h.cfg.AllowedOrigins.Origins {
|
||||
if allowedOrigin == origin {
|
||||
w.Header().Set("Vary", "Origin")
|
||||
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
|
||||
break
|
||||
}
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Headscale) enabledCorsRoutes(routerPath string) bool {
|
||||
// enable all api endpoints
|
||||
if strings.HasPrefix(routerPath, "/api/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// A list of enabled CORS endpoints
|
||||
enabledRoutes := []string{
|
||||
"/health",
|
||||
"/key",
|
||||
"/register/{registration_id}",
|
||||
"/oidc/callback",
|
||||
"/verify",
|
||||
"/derp",
|
||||
"/derp/probe",
|
||||
"/derp/latency-check",
|
||||
"/bootstrap-dns",
|
||||
"/machine/register",
|
||||
"/machine/map",
|
||||
}
|
||||
|
||||
for _, routes := range enabledRoutes {
|
||||
if routes == routerPath {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||
router := mux.NewRouter()
|
||||
router.Use(prometheusMiddleware)
|
||||
|
||||
if h.cfg.AccessControlAllowOrigins != "" {
|
||||
if len(h.cfg.AllowedOrigins.Origins) != 0 {
|
||||
router.Use(h.corsHeadersMiddleware)
|
||||
}
|
||||
|
||||
|
@ -66,7 +66,7 @@ type Config struct {
|
||||
Log LogConfig
|
||||
DisableUpdateCheck bool
|
||||
|
||||
AccessControlAllowOrigins string
|
||||
AllowedOrigins CorsConfig
|
||||
|
||||
Database DatabaseConfig
|
||||
|
||||
@ -210,6 +210,10 @@ type LogTailConfig struct {
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
type CorsConfig struct {
|
||||
Origins []string
|
||||
}
|
||||
|
||||
type CLIConfig struct {
|
||||
Address string
|
||||
APIKey string
|
||||
@ -534,6 +538,14 @@ func logtailConfig() LogTailConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func corsConfig() CorsConfig {
|
||||
allowedOrigins := viper.GetStringSlice("cors.allowed_origins")
|
||||
|
||||
return CorsConfig{
|
||||
Origins: allowedOrigins,
|
||||
}
|
||||
}
|
||||
|
||||
func policyConfig() PolicyConfig {
|
||||
policyPath := viper.GetString("policy.path")
|
||||
policyMode := viper.GetString("policy.mode")
|
||||
@ -907,7 +919,7 @@ func LoadServerConfig() (*Config, error) {
|
||||
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
|
||||
DisableUpdateCheck: false,
|
||||
|
||||
AccessControlAllowOrigins: viper.GetString("access_control_allow_origin"),
|
||||
AllowedOrigins: corsConfig(),
|
||||
|
||||
PrefixV4: prefix4,
|
||||
PrefixV6: prefix6,
|
||||
|
Loading…
Reference in New Issue
Block a user