1
0
mirror of https://github.com/juanfont/headscale.git synced 2025-01-22 00:11:47 +01:00

Merge pull request #72 from kradalby/ip-pool

Make IP Prefix configurable and available ip deterministic
This commit is contained in:
Juan Font 2021-08-03 20:27:42 +02:00 committed by GitHub
commit 3879120967
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 227 additions and 42 deletions

View File

@ -113,9 +113,10 @@ Headscale's configuration file is named `config.json` or `config.yaml`. Headscal
``` ```
"server_url": "http://192.168.1.12:8080", "server_url": "http://192.168.1.12:8080",
"listen_addr": "0.0.0.0:8080", "listen_addr": "0.0.0.0:8080",
"ip_prefix": "100.64.0.0/10"
``` ```
`server_url` is the external URL via which Headscale is reachable. `listen_addr` is the IP address and port the Headscale program should listen on. `server_url` is the external URL via which Headscale is reachable. `listen_addr` is the IP address and port the Headscale program should listen on. `ip_prefix` is the IP prefix (range) in which IP addresses for nodes will be allocated.
``` ```
"private_key_path": "private.key", "private_key_path": "private.key",

1
api.go
View File

@ -445,6 +445,7 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key,
log.Println(err) log.Println(err)
return return
} }
log.Printf("Assigning %s to %s", ip, m.Name)
m.AuthKeyID = uint(pak.ID) m.AuthKeyID = uint(pak.ID)
m.IPAddress = ip.String() m.IPAddress = ip.String()

2
app.go
View File

@ -13,6 +13,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"gorm.io/gorm" "gorm.io/gorm"
"inet.af/netaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/wgkey" "tailscale.com/types/wgkey"
) )
@ -24,6 +25,7 @@ type Config struct {
PrivateKeyPath string PrivateKeyPath string
DerpMap *tailcfg.DERPMap DerpMap *tailcfg.DERPMap
EphemeralNodeInactivityTimeout time.Duration EphemeralNodeInactivityTimeout time.Duration
IPPrefix netaddr.IPPrefix
DBtype string DBtype string
DBpath string DBpath string

View File

@ -6,6 +6,7 @@ import (
"testing" "testing"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"inet.af/netaddr"
) )
func Test(t *testing.T) { func Test(t *testing.T) {
@ -36,7 +37,9 @@ func (s *Suite) ResetDB(c *check.C) {
if err != nil { if err != nil {
c.Fatal(err) c.Fatal(err)
} }
cfg := Config{} cfg := Config{
IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"),
}
h = Headscale{ h = Headscale{
cfg: cfg, cfg: cfg,

View File

@ -15,6 +15,7 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: n.ID, NamespaceID: n.ID,
IPAddress: "10.0.0.1",
} }
h.db.Save(&m) h.db.Save(&m)

View File

@ -14,6 +14,7 @@ import (
"github.com/juanfont/headscale" "github.com/juanfont/headscale"
"github.com/spf13/viper" "github.com/spf13/viper"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"inet.af/netaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@ -36,6 +37,8 @@ func LoadConfig(path string) error {
viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01") viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01")
viper.SetDefault("ip_prefix", "100.64.0.0/10")
err := viper.ReadInConfig() err := viper.ReadInConfig()
if err != nil { if err != nil {
return fmt.Errorf("Fatal error reading config file: %s \n", err) return fmt.Errorf("Fatal error reading config file: %s \n", err)
@ -97,6 +100,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
Addr: viper.GetString("listen_addr"), Addr: viper.GetString("listen_addr"),
PrivateKeyPath: absPath(viper.GetString("private_key_path")), PrivateKeyPath: absPath(viper.GetString("private_key_path")),
DerpMap: derpMap, DerpMap: derpMap,
IPPrefix: netaddr.MustParseIPPrefix(viper.GetString("ip_prefix")),
EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"), EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"),

View File

@ -7,18 +7,12 @@ package headscale
import ( import (
"crypto/rand" "crypto/rand"
"encoding/binary"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net"
"time"
mathrand "math/rand"
"golang.org/x/crypto/nacl/box" "golang.org/x/crypto/nacl/box"
"gorm.io/gorm" "inet.af/netaddr"
"tailscale.com/types/wgkey" "tailscale.com/types/wgkey"
) )
@ -77,47 +71,71 @@ func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, err
return msg, nil return msg, nil
} }
func (h *Headscale) getAvailableIP() (*net.IP, error) { func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
i := 0 ipPrefix := h.cfg.IPPrefix
usedIps, err := h.getUsedIPs()
if err != nil {
return nil, err
}
// Get the first IP in our prefix
ip := ipPrefix.IP()
for { for {
ip, err := getRandomIP() if !ipPrefix.Contains(ip) {
if err != nil { return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix)
return nil, err
} }
m := Machine{}
if result := h.db.First(&m, "ip_address = ?", ip.String()); errors.Is(result.Error, gorm.ErrRecordNotFound) { // Some OS (including Linux) does not like when IPs ends with 0 or 255, which
return ip, nil // is typically called network or broadcast. Lets avoid them and continue
// to look when we get one of those traditionally reserved IPs.
ipRaw := ip.As4()
if ipRaw[3] == 0 || ipRaw[3] == 255 {
ip = ip.Next()
continue
} }
i++
if i == 100 { // really random number if ip.IsZero() &&
break ip.IsLoopback() {
ip = ip.Next()
continue
} }
if !containsIPs(usedIps, ip) {
return &ip, nil
}
ip = ip.Next()
} }
return nil, errors.New("Could not find an available IP address in 100.64.0.0/10")
} }
func getRandomIP() (*net.IP, error) { func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
mathrand.Seed(time.Now().Unix()) var addresses []string
ipo, ipnet, err := net.ParseCIDR("100.64.0.0/10") h.db.Model(&Machine{}).Pluck("ip_address", &addresses)
if err == nil {
ip := ipo.To4() ips := make([]netaddr.IP, len(addresses))
// fmt.Println("In Randomize IPAddr: IP ", ip, " IPNET: ", ipnet) for index, addr := range addresses {
// fmt.Println("Final address is ", ip) if addr != "" {
// fmt.Println("Broadcast address is ", ipb) ip, err := netaddr.ParseIP(addr)
// fmt.Println("Network address is ", ipn) if err != nil {
r := mathrand.Uint32() return nil, fmt.Errorf("failed to parse ip from database, %w", err)
ipRaw := make([]byte, 4) }
binary.LittleEndian.PutUint32(ipRaw, r)
// ipRaw[3] = 254 ips[index] = ip
// fmt.Println("ipRaw is ", ipRaw)
for i, v := range ipRaw {
// fmt.Println("IP Before: ", ip[i], " v is ", v, " Mask is: ", ipnet.Mask[i])
ip[i] = ip[i] + (v &^ ipnet.Mask[i])
// fmt.Println("IP After: ", ip[i])
} }
// fmt.Println("FINAL IP: ", ip.String())
return &ip, nil
} }
return nil, err return ips, nil
}
func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool {
for _, v := range ips {
if v == ip {
return true
}
}
return false
} }

155
utils_test.go Normal file
View File

@ -0,0 +1,155 @@
package headscale
import (
"gopkg.in/check.v1"
"inet.af/netaddr"
)
func (s *Suite) TestGetAvailableIp(c *check.C) {
ip, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.1")
c.Assert(ip.String(), check.Equals, expected.String())
}
func (s *Suite) TestGetUsedIps(c *check.C) {
ip, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
n, err := h.CreateNamespace("test_ip")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
m := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
IPAddress: ip.String(),
}
h.db.Save(&m)
ips, err := h.getUsedIPs()
c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.1")
c.Assert(ips[0], check.Equals, expected)
m1, err := h.GetMachineByID(0)
c.Assert(err, check.IsNil)
c.Assert(m1.IPAddress, check.Equals, expected.String())
}
func (s *Suite) TestGetMultiIp(c *check.C) {
n, err := h.CreateNamespace("test-ip-multi")
c.Assert(err, check.IsNil)
for i := 1; i <= 350; i++ {
ip, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
m := Machine{
ID: uint64(i),
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
IPAddress: ip.String(),
}
h.db.Save(&m)
}
ips, err := h.getUsedIPs()
c.Assert(err, check.IsNil)
c.Assert(len(ips), check.Equals, 350)
c.Assert(ips[0], check.Equals, netaddr.MustParseIP("10.27.0.1"))
c.Assert(ips[9], check.Equals, netaddr.MustParseIP("10.27.0.10"))
c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.47"))
// Check that we can read back the IPs
m1, err := h.GetMachineByID(1)
c.Assert(err, check.IsNil)
c.Assert(m1.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.1").String())
m50, err := h.GetMachineByID(50)
c.Assert(err, check.IsNil)
c.Assert(m50.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.50").String())
expectedNextIP := netaddr.MustParseIP("10.27.1.97")
nextIP, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
c.Assert(nextIP.String(), check.Equals, expectedNextIP.String())
// If we call get Available again, we should receive
// the same IP, as it has not been reserved.
nextIP2, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String())
}
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
ip, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.1")
c.Assert(ip.String(), check.Equals, expected.String())
n, err := h.CreateNamespace("test_ip")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
m := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
h.db.Save(&m)
ip2, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
c.Assert(ip2.String(), check.Equals, expected.String())
}