1
0
mirror of https://github.com/juanfont/headscale.git synced 2026-02-23 13:50:36 +01:00
juanfont.headscale/cmd/headscale/cli/policy.go
2026-02-07 16:11:05 +02:00

570 lines
16 KiB
Go

package cli
import (
"encoding/json"
"fmt"
"io"
"os"
"strings"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/policy"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"tailscale.com/types/views"
)
const (
bypassFlag = "bypass-grpc-and-access-database-directly" //nolint:gosec // not a credential
separatorWidth = 50
outputFormatJSON = "json"
outputFormatJSONLine = "json-line"
)
func init() {
rootCmd.AddCommand(policyCmd)
getPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running")
policyCmd.AddCommand(getPolicy)
setPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
err := setPolicy.MarkFlagRequired("file")
if err != nil {
log.Fatal().Err(err).Msg("")
}
setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running")
policyCmd.AddCommand(setPolicy)
checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
err = checkPolicy.MarkFlagRequired("file")
if err != nil {
log.Fatal().Err(err).Msg("")
}
policyCmd.AddCommand(checkPolicy)
// Test command flags
testPolicy.Flags().StringP("src", "s", "", "Source alias to test from (user, group, tag, host, or IP)")
testPolicy.Flags().StringSliceP("accept", "a", nil, "Destinations that should be allowed (repeatable, format: host:port)")
testPolicy.Flags().StringSliceP("deny", "d", nil, "Destinations that should be denied (repeatable, format: host:port)")
testPolicy.Flags().StringP("proto", "p", "", "Protocol to test (tcp, udp, icmp)")
testPolicy.Flags().StringP("file", "f", "", "Path to a JSON file with test definitions")
testPolicy.Flags().StringP("policy-file", "", "", "Test against a proposed policy file instead of current policy")
testPolicy.Flags().BoolP("embedded", "e", false, "Run tests embedded in the current policy")
testPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running")
policyCmd.AddCommand(testPolicy)
}
var policyCmd = &cobra.Command{
Use: "policy",
Short: "Manage the Headscale ACL Policy",
}
var getPolicy = &cobra.Command{
Use: "get",
Short: "Print the current ACL Policy",
Aliases: []string{"show", "view", "fetch"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
var policy string
if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass {
confirm := false
force, _ := cmd.Flags().GetBool("force")
if !force {
confirm = util.YesNo("DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?")
}
if !confirm && !force {
ErrorOutput(nil, "Aborting command", output)
return
}
cfg, err := types.LoadServerConfig()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed loading config: %s", err), output)
}
d, err := db.NewHeadscaleDatabase(
cfg,
nil,
)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to open database: %s", err), output)
}
pol, err := d.GetPolicy()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed loading Policy from database: %s", err), output)
}
policy = pol.Data
} else {
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
request := &v1.GetPolicyRequest{}
response, err := client.GetPolicy(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed loading ACL Policy: %s", err), output)
}
policy = response.GetPolicy()
}
// TODO(pallabpain): Maybe print this better?
// This does not pass output as we dont support yaml, json or json-line
// output for this command. It is HuJSON already.
SuccessOutput("", policy, "")
},
}
var setPolicy = &cobra.Command{
Use: "set",
Short: "Updates the ACL Policy",
Long: `
Updates the existing ACL Policy with the provided policy. The policy must be a valid HuJSON object.
This command only works when the acl.policy_mode is set to "db", and the policy will be stored in the database.`,
Aliases: []string{"put", "update"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
policyPath, _ := cmd.Flags().GetString("file")
f, err := os.Open(policyPath)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error opening the policy file: %s", err), output)
}
defer f.Close()
policyBytes, err := io.ReadAll(f)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output)
}
if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass {
confirm := false
force, _ := cmd.Flags().GetBool("force")
if !force {
confirm = util.YesNo("DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?")
}
if !confirm && !force {
ErrorOutput(nil, "Aborting command", output)
return
}
cfg, err := types.LoadServerConfig()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed loading config: %s", err), output)
}
d, err := db.NewHeadscaleDatabase(
cfg,
nil,
)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to open database: %s", err), output)
}
users, err := d.ListUsers()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to load users for policy validation: %s", err), output)
}
_, err = policy.NewPolicyManager(policyBytes, users, views.Slice[types.NodeView]{})
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error parsing the policy file: %s", err), output)
return
}
_, err = d.SetPolicy(string(policyBytes))
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
}
} else {
request := &v1.SetPolicyRequest{Policy: string(policyBytes)}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
if _, err := client.SetPolicy(ctx, request); err != nil { //nolint:noinlineerr
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
}
}
SuccessOutput(nil, "Policy updated.", "")
},
}
var checkPolicy = &cobra.Command{
Use: "check",
Short: "Check the Policy file for errors",
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
policyPath, _ := cmd.Flags().GetString("file")
f, err := os.Open(policyPath)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error opening the policy file: %s", err), output)
}
defer f.Close()
policyBytes, err := io.ReadAll(f)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output)
}
_, err = policy.NewPolicyManager(policyBytes, nil, views.Slice[types.NodeView]{})
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error parsing the policy file: %s", err), output)
}
SuccessOutput(nil, "Policy is valid", "")
},
}
var testPolicy = &cobra.Command{
Use: "test",
Short: "Test ACL rules",
Long: `Test ACL rules to verify access between sources and destinations.
Examples:
# Test if user can access server
headscale policy test --src "alice@example.com" --accept "tag:server:22"
# Test with deny rules
headscale policy test --src "alice@" --accept "10.0.0.1:80" --deny "10.0.0.2:443"
# Run tests from a JSON file
headscale policy test --file tests.json
# Run embedded tests from current policy
headscale policy test --embedded
# Test against a proposed policy file
headscale policy test --src "alice@" --accept "10.0.0.1:22" --policy-file new-policy.json`,
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
// Collect tests from various sources
var tests []policyv2.ACLTest
// Get flags
src, _ := cmd.Flags().GetString("src")
accept, _ := cmd.Flags().GetStringSlice("accept")
deny, _ := cmd.Flags().GetStringSlice("deny")
proto, _ := cmd.Flags().GetString("proto")
testFile, _ := cmd.Flags().GetString("file")
policyFile, _ := cmd.Flags().GetString("policy-file")
embedded, _ := cmd.Flags().GetBool("embedded")
bypass, _ := cmd.Flags().GetBool(bypassFlag)
// Build test from command line flags if src is provided
if src != "" {
tests = append(tests, policyv2.ACLTest{
Src: src,
Proto: policyv2.Protocol(proto),
Accept: accept,
Deny: deny,
})
}
// Load tests from file if provided
if testFile != "" {
fileTests, err := loadTestsFromFile(testFile)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error loading tests from file: %s", err), output)
return
}
tests = append(tests, fileTests...)
}
// Read policy file if provided (for testing against proposed policy)
var policyBytes []byte
if policyFile != "" {
f, err := os.Open(policyFile)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error opening policy file: %s", err), output)
return
}
defer f.Close()
policyBytes, err = io.ReadAll(f)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error reading policy file: %s", err), output)
return
}
}
var results policyv2.ACLTestResults
if bypass {
results = runTestsBypass(cmd, output, tests, policyBytes, embedded)
} else {
results = runTestsGRPC(cmd, output, tests, policyBytes, embedded)
}
// Output results
if output == outputFormatJSON || output == outputFormatJSONLine {
SuccessOutput(results, "", output)
} else {
printHumanReadableResults(results)
}
},
}
func loadTestsFromFile(path string) ([]policyv2.ACLTest, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var tests []policyv2.ACLTest
decoder := json.NewDecoder(f)
err = decoder.Decode(&tests)
if err != nil {
return nil, err
}
return tests, nil
}
func runTestsBypass(cmd *cobra.Command, output string, tests []policyv2.ACLTest, policyBytes []byte, embedded bool) policyv2.ACLTestResults {
confirm := false
force, _ := cmd.Flags().GetBool("force")
if !force {
confirm = util.YesNo("DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?")
}
if !confirm && !force {
ErrorOutput(nil, "Aborting command", output)
return policyv2.ACLTestResults{}
}
cfg, err := types.LoadServerConfig()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed loading config: %s", err), output)
return policyv2.ACLTestResults{}
}
d, err := db.NewHeadscaleDatabase(
cfg,
nil,
)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to open database: %s", err), output)
return policyv2.ACLTestResults{}
}
users, err := d.ListUsers()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to load users: %s", err), output)
return policyv2.ACLTestResults{}
}
nodes, err := d.ListNodes()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to load nodes: %s", err), output)
return policyv2.ACLTestResults{}
}
// Convert nodes to NodeView slice
nodeViews := make([]types.NodeView, len(nodes))
for i, n := range nodes {
nodeViews[i] = n.View()
}
// Determine which policy to test against
var polBytes []byte
if len(policyBytes) > 0 {
polBytes = policyBytes
} else {
pol, err := d.GetPolicy()
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to load policy: %s", err), output)
return policyv2.ACLTestResults{}
}
polBytes = []byte(pol.Data)
}
pm, err := policyv2.NewPolicyManager(polBytes, users, views.SliceOf(nodeViews))
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to parse policy: %s", err), output)
return policyv2.ACLTestResults{}
}
// If embedded flag is set, get tests from the policy
if embedded {
pol := pm.Policy()
if pol != nil && len(pol.Tests) > 0 {
tests = append(tests, pol.Tests...)
}
}
if len(tests) == 0 {
ErrorOutput(nil, "No tests to run. Use --src, --file, or --embedded to specify tests.", output)
return policyv2.ACLTestResults{}
}
return pm.RunTests(tests)
}
func runTestsGRPC(_ *cobra.Command, output string, tests []policyv2.ACLTest, policyBytes []byte, embedded bool) policyv2.ACLTestResults {
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
// If embedded, get tests from current policy first
if embedded {
policyResp, err := client.GetPolicy(ctx, &v1.GetPolicyRequest{})
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to get current policy: %s", err), output)
return policyv2.ACLTestResults{}
}
// Parse policy to extract embedded tests
pm, err := policyv2.NewPolicyManager([]byte(policyResp.GetPolicy()), nil, views.Slice[types.NodeView]{})
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to parse policy: %s", err), output)
return policyv2.ACLTestResults{}
}
pol := pm.Policy()
if pol != nil && len(pol.Tests) > 0 {
tests = append(tests, pol.Tests...)
}
}
if len(tests) == 0 {
ErrorOutput(nil, "No tests to run. Use --src, --file, or --embedded to specify tests.", output)
return policyv2.ACLTestResults{}
}
// Convert tests to proto format
protoTests := make([]*v1.ACLTest, len(tests))
for i, t := range tests {
protoTests[i] = &v1.ACLTest{
Src: t.Src,
Proto: string(t.Proto),
Accept: t.Accept,
Deny: t.Deny,
}
}
request := &v1.TestACLRequest{
Tests: protoTests,
Policy: string(policyBytes),
}
response, err := client.TestACL(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to run ACL tests: %s", err), output)
return policyv2.ACLTestResults{}
}
// Convert proto response to internal format
results := policyv2.ACLTestResults{
AllPassed: response.GetAllPassed(),
Results: make([]policyv2.ACLTestResult, len(response.GetResults())),
}
for i, r := range response.GetResults() {
results.Results[i] = policyv2.ACLTestResult{
Src: r.GetSrc(),
Passed: r.GetPassed(),
Errors: r.GetErrors(),
AcceptOK: r.GetAcceptOk(),
AcceptFail: r.GetAcceptFail(),
DenyOK: r.GetDenyOk(),
DenyFail: r.GetDenyFail(),
}
}
return results
}
func printHumanReadableResults(results policyv2.ACLTestResults) {
fmt.Println("ACL Test Results")
fmt.Println(strings.Repeat("=", separatorWidth))
fmt.Println()
passedCount := 0
totalCount := len(results.Results)
for _, result := range results.Results {
fmt.Printf("Source: %s\n", result.Src)
fmt.Println()
if len(result.Errors) > 0 {
fmt.Println(" Errors:")
for _, e := range result.Errors {
fmt.Printf(" ! %s\n", e)
}
fmt.Println()
}
if len(result.AcceptOK) > 0 || len(result.AcceptFail) > 0 {
fmt.Println(" Accept Tests:")
for _, dest := range result.AcceptOK {
fmt.Printf(" [PASS] %s - ALLOWED (expected)\n", dest)
}
for _, dest := range result.AcceptFail {
fmt.Printf(" [FAIL] %s - DENIED (expected ALLOWED)\n", dest)
}
fmt.Println()
}
if len(result.DenyOK) > 0 || len(result.DenyFail) > 0 {
fmt.Println(" Deny Tests:")
for _, dest := range result.DenyOK {
fmt.Printf(" [PASS] %s - DENIED (expected)\n", dest)
}
for _, dest := range result.DenyFail {
fmt.Printf(" [FAIL] %s - ALLOWED (expected DENIED)\n", dest)
}
fmt.Println()
}
if result.Passed {
passedCount++
fmt.Println(" Result: PASSED")
} else {
fmt.Println(" Result: FAILED")
}
fmt.Println()
fmt.Println(strings.Repeat("-", separatorWidth))
fmt.Println()
}
// Summary
if results.AllPassed {
fmt.Printf("Overall: PASSED (%d/%d tests passed)\n", passedCount, totalCount)
} else {
fmt.Printf("Overall: FAILED (%d/%d tests passed)\n", passedCount, totalCount)
}
}