mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-28 10:51:44 +01:00 
			
		
		
		
	Merge branch 'main' into topic/docker-release
This commit is contained in:
		
						commit
						1e93347a26
					
				
							
								
								
									
										41
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,41 @@
 | 
				
			|||||||
 | 
					name: CI
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					on: [push, pull_request]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					jobs:
 | 
				
			||||||
 | 
					  # The "build" workflow
 | 
				
			||||||
 | 
					  lint:
 | 
				
			||||||
 | 
					    # The type of runner that the job will run on
 | 
				
			||||||
 | 
					    runs-on: ubuntu-latest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Steps represent a sequence of tasks that will be executed as part of the job
 | 
				
			||||||
 | 
					    steps:
 | 
				
			||||||
 | 
					      # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
 | 
				
			||||||
 | 
					      - uses: actions/checkout@v2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      # Install and run golangci-lint as a separate step, it's much faster this
 | 
				
			||||||
 | 
					      # way because this action has caching. It'll get run again in `make lint`
 | 
				
			||||||
 | 
					      # below, but it's still much faster in the end than installing
 | 
				
			||||||
 | 
					      # golangci-lint manually in the `Run lint` step.
 | 
				
			||||||
 | 
					      - uses: golangci/golangci-lint-action@v2
 | 
				
			||||||
 | 
					        with:
 | 
				
			||||||
 | 
					          args: --timeout 4m
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      # Setup Go
 | 
				
			||||||
 | 
					      - name: Setup Go
 | 
				
			||||||
 | 
					        uses: actions/setup-go@v2
 | 
				
			||||||
 | 
					        with:
 | 
				
			||||||
 | 
					          go-version: "1.16.3" # The Go version to download (if necessary) and use.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      # Install all the dependencies
 | 
				
			||||||
 | 
					      - name: Install dependencies
 | 
				
			||||||
 | 
					        run: |
 | 
				
			||||||
 | 
					          go version
 | 
				
			||||||
 | 
					          go install golang.org/x/lint/golint@latest
 | 
				
			||||||
 | 
					          sudo apt update
 | 
				
			||||||
 | 
					          sudo apt install -y make
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      - name: Run lint
 | 
				
			||||||
 | 
					        with:
 | 
				
			||||||
 | 
					          args: --timeout 4m
 | 
				
			||||||
 | 
					        run: make lint
 | 
				
			||||||
							
								
								
									
										23
									
								
								.github/workflows/test-integration.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								.github/workflows/test-integration.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,23 @@
 | 
				
			|||||||
 | 
					name: CI
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					on: [pull_request]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					jobs:
 | 
				
			||||||
 | 
					  # The "build" workflow
 | 
				
			||||||
 | 
					  integration-test:
 | 
				
			||||||
 | 
					    # The type of runner that the job will run on
 | 
				
			||||||
 | 
					    runs-on: ubuntu-latest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Steps represent a sequence of tasks that will be executed as part of the job
 | 
				
			||||||
 | 
					    steps:
 | 
				
			||||||
 | 
					      # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
 | 
				
			||||||
 | 
					      - uses: actions/checkout@v2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      # Setup Go
 | 
				
			||||||
 | 
					      - name: Setup Go
 | 
				
			||||||
 | 
					        uses: actions/setup-go@v2
 | 
				
			||||||
 | 
					        with:
 | 
				
			||||||
 | 
					          go-version: "1.16.3"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      - name: Run Integration tests
 | 
				
			||||||
 | 
					        run: go test -tags integration -timeout 30m
 | 
				
			||||||
							
								
								
									
										46
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										46
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							@ -10,36 +10,24 @@ jobs:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Steps represent a sequence of tasks that will be executed as part of the job
 | 
					    # Steps represent a sequence of tasks that will be executed as part of the job
 | 
				
			||||||
    steps:
 | 
					    steps:
 | 
				
			||||||
    # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
 | 
					      # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
 | 
				
			||||||
    - uses: actions/checkout@v2
 | 
					      - uses: actions/checkout@v2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Install and run golangci-lint as a separate step, it's much faster this
 | 
					      # Setup Go
 | 
				
			||||||
    # way because this action has caching. It'll get run again in `make lint`
 | 
					      - name: Setup Go
 | 
				
			||||||
    # below, but it's still much faster in the end than installing
 | 
					        uses: actions/setup-go@v2
 | 
				
			||||||
    # golangci-lint manually in the `Run lint` step.
 | 
					        with:
 | 
				
			||||||
    - uses: golangci/golangci-lint-action@v2
 | 
					          go-version: "1.16.3" # The Go version to download (if necessary) and use.
 | 
				
			||||||
      with:
 | 
					 | 
				
			||||||
        args: --timeout 2m
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup Go
 | 
					      # Install all the dependencies
 | 
				
			||||||
    - name: Setup Go
 | 
					      - name: Install dependencies
 | 
				
			||||||
      uses: actions/setup-go@v2
 | 
					        run: |
 | 
				
			||||||
      with:
 | 
					          go version
 | 
				
			||||||
        go-version: '1.16.3' # The Go version to download (if necessary) and use.
 | 
					          sudo apt update
 | 
				
			||||||
 | 
					          sudo apt install -y make
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Install all the dependencies
 | 
					      - name: Run tests
 | 
				
			||||||
    - name: Install dependencies
 | 
					        run: make test
 | 
				
			||||||
      run: |
 | 
					 | 
				
			||||||
        go version
 | 
					 | 
				
			||||||
        go install golang.org/x/lint/golint@latest
 | 
					 | 
				
			||||||
        sudo apt update
 | 
					 | 
				
			||||||
        sudo apt install -y make
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    - name: Run tests
 | 
					      - name: Run build
 | 
				
			||||||
      run: make test
 | 
					        run: make
 | 
				
			||||||
 | 
					 | 
				
			||||||
    - name: Run lint
 | 
					 | 
				
			||||||
      run: make lint
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    - name: Run build
 | 
					 | 
				
			||||||
      run: make
 | 
					 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -22,3 +22,5 @@ config.json
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# Exclude Jetbrains Editors
 | 
					# Exclude Jetbrains Editors
 | 
				
			||||||
.idea
 | 
					.idea
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					test_output/ 
 | 
				
			||||||
 | 
				
			|||||||
@ -10,7 +10,7 @@ COPY . /go/src/headscale
 | 
				
			|||||||
RUN go install -a -ldflags="-extldflags=-static" -tags netgo,sqlite_omit_load_extension ./cmd/headscale
 | 
					RUN go install -a -ldflags="-extldflags=-static" -tags netgo,sqlite_omit_load_extension ./cmd/headscale
 | 
				
			||||||
RUN test -e /go/bin/headscale
 | 
					RUN test -e /go/bin/headscale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
FROM ubuntu:latest
 | 
					FROM ubuntu:20.04
 | 
				
			||||||
 | 
					
 | 
				
			||||||
COPY --from=build /go/bin/headscale /usr/local/bin/headscale
 | 
					COPY --from=build /go/bin/headscale /usr/local/bin/headscale
 | 
				
			||||||
ENV TZ UTC
 | 
					ENV TZ UTC
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								Makefile
									
									
									
									
									
								
							@ -2,7 +2,7 @@
 | 
				
			|||||||
version = $(shell ./scripts/version-at-commit.sh)
 | 
					version = $(shell ./scripts/version-at-commit.sh)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
build:
 | 
					build:
 | 
				
			||||||
	go build -ldflags "-s -w -X main.version=$(version)" cmd/headscale/headscale.go
 | 
						go build -ldflags "-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.version=$(version)" cmd/headscale/headscale.go
 | 
				
			||||||
 | 
					
 | 
				
			||||||
dev: lint test build
 | 
					dev: lint test build
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -25,14 +25,13 @@ Headscale implements this coordination server.
 | 
				
			|||||||
- [X] JSON-formatted output
 | 
					- [X] JSON-formatted output
 | 
				
			||||||
- [X] ACLs
 | 
					- [X] ACLs
 | 
				
			||||||
- [X] Support for alternative IP ranges in the tailnets (default Tailscale's 100.64.0.0/10)
 | 
					- [X] Support for alternative IP ranges in the tailnets (default Tailscale's 100.64.0.0/10)
 | 
				
			||||||
- [ ] Share nodes between ~~users~~ namespaces 
 | 
					- [X] DNS (passing DNS servers to nodes)
 | 
				
			||||||
- [ ] DNS
 | 
					- [X] Share nodes between ~~users~~ namespaces 
 | 
				
			||||||
 | 
					- [ ] MagicDNS / Smart DNS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Roadmap 🤷
 | 
					## Roadmap 🤷
 | 
				
			||||||
 | 
					
 | 
				
			||||||
We are now focusing on adding integration tests with the official clients. 
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Suggestions/PRs welcomed!
 | 
					Suggestions/PRs welcomed!
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										222
									
								
								api.go
									
									
									
									
									
								
							
							
						
						
									
										222
									
								
								api.go
									
									
									
									
									
								
							@ -13,9 +13,7 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"github.com/klauspost/compress/zstd"
 | 
						"github.com/klauspost/compress/zstd"
 | 
				
			||||||
	"gorm.io/datatypes"
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	"inet.af/netaddr"
 | 
					 | 
				
			||||||
	"tailscale.com/tailcfg"
 | 
						"tailscale.com/tailcfg"
 | 
				
			||||||
	"tailscale.com/types/wgkey"
 | 
						"tailscale.com/types/wgkey"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@ -35,8 +33,6 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// spew.Dump(c.Params)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
 | 
						c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
 | 
				
			||||||
	<html>
 | 
						<html>
 | 
				
			||||||
	<body>
 | 
						<body>
 | 
				
			||||||
@ -82,14 +78,16 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						now := time.Now().UTC()
 | 
				
			||||||
	var m Machine
 | 
						var m Machine
 | 
				
			||||||
	if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
 | 
						if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
 | 
				
			||||||
		log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
 | 
							log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
 | 
				
			||||||
		m = Machine{
 | 
							m = Machine{
 | 
				
			||||||
			Expiry:     &req.Expiry,
 | 
								Expiry:               &req.Expiry,
 | 
				
			||||||
			MachineKey: mKey.HexString(),
 | 
								MachineKey:           mKey.HexString(),
 | 
				
			||||||
			Name:       req.Hostinfo.Hostname,
 | 
								Name:                 req.Hostinfo.Hostname,
 | 
				
			||||||
			NodeKey:    wgkey.Key(req.NodeKey).HexString(),
 | 
								NodeKey:              wgkey.Key(req.NodeKey).HexString(),
 | 
				
			||||||
 | 
								LastSuccessfulUpdate: &now,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if err := h.db.Create(&m).Error; err != nil {
 | 
							if err := h.db.Create(&m).Error; err != nil {
 | 
				
			||||||
			log.Error().
 | 
								log.Error().
 | 
				
			||||||
@ -215,202 +213,12 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
 | 
				
			|||||||
	c.Data(200, "application/json; charset=utf-8", respBody)
 | 
						c.Data(200, "application/json; charset=utf-8", respBody)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// PollNetMapHandler takes care of /machine/:id/map
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
 | 
					 | 
				
			||||||
// the clients when something in the network changes.
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// The clients POST stuff like HostInfo and their Endpoints here, but
 | 
					 | 
				
			||||||
// only after their first request (marked with the ReadOnly field).
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
 | 
					 | 
				
			||||||
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
 | 
					 | 
				
			||||||
	log.Trace().
 | 
					 | 
				
			||||||
		Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
		Str("id", c.Param("id")).
 | 
					 | 
				
			||||||
		Msg("PollNetMapHandler called")
 | 
					 | 
				
			||||||
	body, _ := io.ReadAll(c.Request.Body)
 | 
					 | 
				
			||||||
	mKeyStr := c.Param("id")
 | 
					 | 
				
			||||||
	mKey, err := wgkey.ParseHex(mKeyStr)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		log.Error().
 | 
					 | 
				
			||||||
			Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
			Err(err).
 | 
					 | 
				
			||||||
			Msg("Cannot parse client key")
 | 
					 | 
				
			||||||
		c.String(http.StatusBadRequest, "")
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	req := tailcfg.MapRequest{}
 | 
					 | 
				
			||||||
	err = decode(body, &req, &mKey, h.privateKey)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		log.Error().
 | 
					 | 
				
			||||||
			Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
			Err(err).
 | 
					 | 
				
			||||||
			Msg("Cannot decode message")
 | 
					 | 
				
			||||||
		c.String(http.StatusBadRequest, "")
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var m Machine
 | 
					 | 
				
			||||||
	if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
 | 
					 | 
				
			||||||
		log.Warn().
 | 
					 | 
				
			||||||
			Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
			Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
 | 
					 | 
				
			||||||
		c.String(http.StatusUnauthorized, "")
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	log.Trace().
 | 
					 | 
				
			||||||
		Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
		Str("id", c.Param("id")).
 | 
					 | 
				
			||||||
		Str("machine", m.Name).
 | 
					 | 
				
			||||||
		Msg("Found machine in database")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	hostinfo, _ := json.Marshal(req.Hostinfo)
 | 
					 | 
				
			||||||
	m.Name = req.Hostinfo.Hostname
 | 
					 | 
				
			||||||
	m.HostInfo = datatypes.JSON(hostinfo)
 | 
					 | 
				
			||||||
	m.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
 | 
					 | 
				
			||||||
	now := time.Now().UTC()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// From Tailscale client:
 | 
					 | 
				
			||||||
	//
 | 
					 | 
				
			||||||
	// ReadOnly is whether the client just wants to fetch the MapResponse,
 | 
					 | 
				
			||||||
	// without updating their Endpoints. The Endpoints field will be ignored and
 | 
					 | 
				
			||||||
	// LastSeen will not be updated and peers will not be notified of changes.
 | 
					 | 
				
			||||||
	//
 | 
					 | 
				
			||||||
	// The intended use is for clients to discover the DERP map at start-up
 | 
					 | 
				
			||||||
	// before their first real endpoint update.
 | 
					 | 
				
			||||||
	if !req.ReadOnly {
 | 
					 | 
				
			||||||
		endpoints, _ := json.Marshal(req.Endpoints)
 | 
					 | 
				
			||||||
		m.Endpoints = datatypes.JSON(endpoints)
 | 
					 | 
				
			||||||
		m.LastSeen = &now
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	h.db.Save(&m)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	data, err := h.getMapResponse(mKey, req, m)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		log.Error().
 | 
					 | 
				
			||||||
			Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
			Str("id", c.Param("id")).
 | 
					 | 
				
			||||||
			Str("machine", m.Name).
 | 
					 | 
				
			||||||
			Err(err).
 | 
					 | 
				
			||||||
			Msg("Failed to get Map response")
 | 
					 | 
				
			||||||
		c.String(http.StatusInternalServerError, ":(")
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// We update our peers if the client is not sending ReadOnly in the MapRequest
 | 
					 | 
				
			||||||
	// so we don't distribute its initial request (it comes with
 | 
					 | 
				
			||||||
	// empty endpoints to peers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
 | 
					 | 
				
			||||||
	log.Debug().
 | 
					 | 
				
			||||||
		Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
		Str("id", c.Param("id")).
 | 
					 | 
				
			||||||
		Str("machine", m.Name).
 | 
					 | 
				
			||||||
		Bool("readOnly", req.ReadOnly).
 | 
					 | 
				
			||||||
		Bool("omitPeers", req.OmitPeers).
 | 
					 | 
				
			||||||
		Bool("stream", req.Stream).
 | 
					 | 
				
			||||||
		Msg("Client map request processed")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if req.ReadOnly {
 | 
					 | 
				
			||||||
		log.Info().
 | 
					 | 
				
			||||||
			Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
			Str("machine", m.Name).
 | 
					 | 
				
			||||||
			Msg("Client is starting up. Asking for DERP map")
 | 
					 | 
				
			||||||
		c.Data(200, "application/json; charset=utf-8", *data)
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if req.OmitPeers && !req.Stream {
 | 
					 | 
				
			||||||
		log.Info().
 | 
					 | 
				
			||||||
			Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
			Str("machine", m.Name).
 | 
					 | 
				
			||||||
			Msg("Client sent endpoint update and is ok with a response without peer list")
 | 
					 | 
				
			||||||
		c.Data(200, "application/json; charset=utf-8", *data)
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	} else if req.OmitPeers && req.Stream {
 | 
					 | 
				
			||||||
		log.Warn().
 | 
					 | 
				
			||||||
			Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
			Str("machine", m.Name).
 | 
					 | 
				
			||||||
			Msg("Ignoring request, don't know how to handle it")
 | 
					 | 
				
			||||||
		c.String(http.StatusBadRequest, "")
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Only create update channel if it has not been created
 | 
					 | 
				
			||||||
	var update chan []byte
 | 
					 | 
				
			||||||
	log.Trace().
 | 
					 | 
				
			||||||
		Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
		Str("id", c.Param("id")).
 | 
					 | 
				
			||||||
		Str("machine", m.Name).
 | 
					 | 
				
			||||||
		Msg("Creating or loading update channel")
 | 
					 | 
				
			||||||
	if result, ok := h.clientsPolling.LoadOrStore(m.ID, make(chan []byte, 1)); ok {
 | 
					 | 
				
			||||||
		update = result.(chan []byte)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	pollData := make(chan []byte, 1)
 | 
					 | 
				
			||||||
	defer close(pollData)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	cancelKeepAlive := make(chan []byte, 1)
 | 
					 | 
				
			||||||
	defer close(cancelKeepAlive)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	log.Info().
 | 
					 | 
				
			||||||
		Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
		Str("machine", m.Name).
 | 
					 | 
				
			||||||
		Msg("Client is ready to access the tailnet")
 | 
					 | 
				
			||||||
	log.Info().
 | 
					 | 
				
			||||||
		Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
		Str("machine", m.Name).
 | 
					 | 
				
			||||||
		Msg("Sending initial map")
 | 
					 | 
				
			||||||
	pollData <- *data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	log.Info().
 | 
					 | 
				
			||||||
		Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
		Str("machine", m.Name).
 | 
					 | 
				
			||||||
		Msg("Notifying peers")
 | 
					 | 
				
			||||||
		// TODO: Why does this block?
 | 
					 | 
				
			||||||
	go h.notifyChangesToPeers(&m)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	h.PollNetMapStream(c, m, req, mKey, pollData, update, cancelKeepAlive)
 | 
					 | 
				
			||||||
	log.Trace().
 | 
					 | 
				
			||||||
		Str("handler", "PollNetMap").
 | 
					 | 
				
			||||||
		Str("id", c.Param("id")).
 | 
					 | 
				
			||||||
		Str("machine", m.Name).
 | 
					 | 
				
			||||||
		Msg("Finished stream, closing PollNetMap session")
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgkey.Key, req tailcfg.MapRequest, m Machine) {
 | 
					 | 
				
			||||||
	for {
 | 
					 | 
				
			||||||
		select {
 | 
					 | 
				
			||||||
		case <-cancel:
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		default:
 | 
					 | 
				
			||||||
			data, err := h.getMapKeepAliveResponse(mKey, req, m)
 | 
					 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				log.Error().
 | 
					 | 
				
			||||||
					Str("func", "keepAlive").
 | 
					 | 
				
			||||||
					Err(err).
 | 
					 | 
				
			||||||
					Msg("Error generating the keep alive msg")
 | 
					 | 
				
			||||||
				return
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			log.Debug().
 | 
					 | 
				
			||||||
				Str("func", "keepAlive").
 | 
					 | 
				
			||||||
				Str("machine", m.Name).
 | 
					 | 
				
			||||||
				Msg("Sending keepalive")
 | 
					 | 
				
			||||||
			pollData <- *data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			time.Sleep(60 * time.Second)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) {
 | 
					func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) {
 | 
				
			||||||
	log.Trace().
 | 
						log.Trace().
 | 
				
			||||||
		Str("func", "getMapResponse").
 | 
							Str("func", "getMapResponse").
 | 
				
			||||||
		Str("machine", req.Hostinfo.Hostname).
 | 
							Str("machine", req.Hostinfo.Hostname).
 | 
				
			||||||
		Msg("Creating Map response")
 | 
							Msg("Creating Map response")
 | 
				
			||||||
	node, err := m.toNode()
 | 
						node, err := m.toNode(true)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Error().
 | 
							log.Error().
 | 
				
			||||||
			Str("func", "getMapResponse").
 | 
								Str("func", "getMapResponse").
 | 
				
			||||||
@ -434,10 +242,15 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Mac
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	resp := tailcfg.MapResponse{
 | 
						resp := tailcfg.MapResponse{
 | 
				
			||||||
		KeepAlive:    false,
 | 
							KeepAlive: false,
 | 
				
			||||||
		Node:         node,
 | 
							Node:      node,
 | 
				
			||||||
		Peers:        *peers,
 | 
							Peers:     *peers,
 | 
				
			||||||
		DNS:          []netaddr.IP{},
 | 
							//TODO(kradalby): As per tailscale docs, if DNSConfig is nil,
 | 
				
			||||||
 | 
							// it means its not updated, maybe we can have some logic
 | 
				
			||||||
 | 
							// to check and only pass updates when its updates.
 | 
				
			||||||
 | 
							// This is probably more relevant if we try to implement
 | 
				
			||||||
 | 
							// "MagicDNS"
 | 
				
			||||||
 | 
							DNSConfig:    h.cfg.DNSConfig,
 | 
				
			||||||
		SearchPaths:  []string{},
 | 
							SearchPaths:  []string{},
 | 
				
			||||||
		Domain:       "headscale.net",
 | 
							Domain:       "headscale.net",
 | 
				
			||||||
		PacketFilter: *h.aclRules,
 | 
							PacketFilter: *h.aclRules,
 | 
				
			||||||
@ -465,7 +278,6 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Mac
 | 
				
			|||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	// spew.Dump(resp)
 | 
					 | 
				
			||||||
	// declare the incoming size on the first 4 bytes
 | 
						// declare the incoming size on the first 4 bytes
 | 
				
			||||||
	data := make([]byte, 4)
 | 
						data := make([]byte, 4)
 | 
				
			||||||
	binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
 | 
						binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
 | 
				
			||||||
@ -542,7 +354,7 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key,
 | 
				
			|||||||
		Str("func", "handleAuthKey").
 | 
							Str("func", "handleAuthKey").
 | 
				
			||||||
		Str("machine", m.Name).
 | 
							Str("machine", m.Name).
 | 
				
			||||||
		Str("ip", ip.String()).
 | 
							Str("ip", ip.String()).
 | 
				
			||||||
		Msgf("Assining %s to %s", ip, m.Name)
 | 
							Msgf("Assigning %s to %s", ip, m.Name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	m.AuthKeyID = uint(pak.ID)
 | 
						m.AuthKeyID = uint(pak.ID)
 | 
				
			||||||
	m.IPAddress = ip.String()
 | 
						m.IPAddress = ip.String()
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										45
									
								
								app.go
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								app.go
									
									
									
									
									
								
							@ -43,6 +43,8 @@ type Config struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	TLSCertPath string
 | 
						TLSCertPath string
 | 
				
			||||||
	TLSKeyPath  string
 | 
						TLSKeyPath  string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						DNSConfig *tailcfg.DNSConfig
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Headscale represents the base app of the service
 | 
					// Headscale represents the base app of the service
 | 
				
			||||||
@ -58,7 +60,10 @@ type Headscale struct {
 | 
				
			|||||||
	aclPolicy *ACLPolicy
 | 
						aclPolicy *ACLPolicy
 | 
				
			||||||
	aclRules  *[]tailcfg.FilterRule
 | 
						aclRules  *[]tailcfg.FilterRule
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	clientsPolling sync.Map
 | 
						clientsUpdateChannels     sync.Map
 | 
				
			||||||
 | 
						clientsUpdateChannelMutex sync.Mutex
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						lastStateChange sync.Map
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewHeadscale returns the Headscale app
 | 
					// NewHeadscale returns the Headscale app
 | 
				
			||||||
@ -165,9 +170,18 @@ func (h *Headscale) Serve() error {
 | 
				
			|||||||
	r.POST("/machine/:id", h.RegistrationHandler)
 | 
						r.POST("/machine/:id", h.RegistrationHandler)
 | 
				
			||||||
	var err error
 | 
						var err error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						timeout := 30 * time.Second
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	go h.watchForKVUpdates(5000)
 | 
						go h.watchForKVUpdates(5000)
 | 
				
			||||||
	go h.expireEphemeralNodes(5000)
 | 
						go h.expireEphemeralNodes(5000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s := &http.Server{
 | 
				
			||||||
 | 
							Addr:         h.cfg.Addr,
 | 
				
			||||||
 | 
							Handler:      r,
 | 
				
			||||||
 | 
							ReadTimeout:  timeout,
 | 
				
			||||||
 | 
							WriteTimeout: timeout,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if h.cfg.TLSLetsEncryptHostname != "" {
 | 
						if h.cfg.TLSLetsEncryptHostname != "" {
 | 
				
			||||||
		if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
 | 
							if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
 | 
				
			||||||
			log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
 | 
								log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
 | 
				
			||||||
@ -179,9 +193,11 @@ func (h *Headscale) Serve() error {
 | 
				
			|||||||
			Cache:      autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
 | 
								Cache:      autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		s := &http.Server{
 | 
							s := &http.Server{
 | 
				
			||||||
			Addr:      h.cfg.Addr,
 | 
								Addr:         h.cfg.Addr,
 | 
				
			||||||
			TLSConfig: m.TLSConfig(),
 | 
								TLSConfig:    m.TLSConfig(),
 | 
				
			||||||
			Handler:   r,
 | 
								Handler:      r,
 | 
				
			||||||
 | 
								ReadTimeout:  timeout,
 | 
				
			||||||
 | 
								WriteTimeout: timeout,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" {
 | 
							if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" {
 | 
				
			||||||
			// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
 | 
								// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
 | 
				
			||||||
@ -206,12 +222,29 @@ func (h *Headscale) Serve() error {
 | 
				
			|||||||
		if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
 | 
							if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
 | 
				
			||||||
			log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
 | 
								log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		err = r.Run(h.cfg.Addr)
 | 
							err = s.ListenAndServe()
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
 | 
							if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
 | 
				
			||||||
			log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
 | 
								log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		err = r.RunTLS(h.cfg.Addr, h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
 | 
							err = s.ListenAndServeTLS(h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return err
 | 
						return err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *Headscale) setLastStateChangeToNow(namespace string) {
 | 
				
			||||||
 | 
						now := time.Now().UTC()
 | 
				
			||||||
 | 
						h.lastStateChange.Store(namespace, now)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *Headscale) getLastStateChange(namespace string) time.Time {
 | 
				
			||||||
 | 
						if wrapped, ok := h.lastStateChange.Load(namespace); ok {
 | 
				
			||||||
 | 
							lastChange, _ := wrapped.(time.Time)
 | 
				
			||||||
 | 
							return lastChange
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						now := time.Now().UTC()
 | 
				
			||||||
 | 
						h.lastStateChange.Store(namespace, now)
 | 
				
			||||||
 | 
						return now
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -25,6 +25,7 @@ func init() {
 | 
				
			|||||||
	nodeCmd.AddCommand(listNodesCmd)
 | 
						nodeCmd.AddCommand(listNodesCmd)
 | 
				
			||||||
	nodeCmd.AddCommand(registerNodeCmd)
 | 
						nodeCmd.AddCommand(registerNodeCmd)
 | 
				
			||||||
	nodeCmd.AddCommand(deleteNodeCmd)
 | 
						nodeCmd.AddCommand(deleteNodeCmd)
 | 
				
			||||||
 | 
						nodeCmd.AddCommand(shareMachineCmd)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var nodeCmd = &cobra.Command{
 | 
					var nodeCmd = &cobra.Command{
 | 
				
			||||||
@ -79,9 +80,26 @@ var listNodesCmd = &cobra.Command{
 | 
				
			|||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Fatalf("Error initializing: %s", err)
 | 
								log.Fatalf("Error initializing: %s", err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							namespace, err := h.GetNamespace(n)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Fatalf("Error fetching namespace: %s", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		machines, err := h.ListMachinesInNamespace(n)
 | 
							machines, err := h.ListMachinesInNamespace(n)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Fatalf("Error fetching machines: %s", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							sharedMachines, err := h.ListSharedMachinesInNamespace(n)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Fatalf("Error fetching shared machines: %s", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							allMachines := append(*machines, *sharedMachines...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if strings.HasPrefix(o, "json") {
 | 
							if strings.HasPrefix(o, "json") {
 | 
				
			||||||
			JsonOutput(machines, err, o)
 | 
								JsonOutput(allMachines, err, o)
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -89,7 +107,7 @@ var listNodesCmd = &cobra.Command{
 | 
				
			|||||||
			log.Fatalf("Error getting nodes: %s", err)
 | 
								log.Fatalf("Error getting nodes: %s", err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		d, err := nodesToPtables(*machines)
 | 
							d, err := nodesToPtables(*namespace, allMachines)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Fatalf("Error converting to table: %s", err)
 | 
								log.Fatalf("Error converting to table: %s", err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -145,31 +163,94 @@ var deleteNodeCmd = &cobra.Command{
 | 
				
			|||||||
	},
 | 
						},
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func nodesToPtables(m []headscale.Machine) (pterm.TableData, error) {
 | 
					var shareMachineCmd = &cobra.Command{
 | 
				
			||||||
	d := pterm.TableData{{"ID", "Name", "NodeKey", "IP address", "Ephemeral", "Last seen", "Online"}}
 | 
						Use:   "share ID namespace",
 | 
				
			||||||
 | 
						Short: "Shares a node from the current namespace to the specified one",
 | 
				
			||||||
 | 
						Args: func(cmd *cobra.Command, args []string) error {
 | 
				
			||||||
 | 
							if len(args) < 2 {
 | 
				
			||||||
 | 
								return fmt.Errorf("missing parameters")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						Run: func(cmd *cobra.Command, args []string) {
 | 
				
			||||||
 | 
							namespace, err := cmd.Flags().GetString("namespace")
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Fatalf("Error getting namespace: %s", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							output, _ := cmd.Flags().GetString("output")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, m := range m {
 | 
							h, err := getHeadscaleApp()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Fatalf("Error initializing: %s", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							_, err = h.GetNamespace(namespace)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Fatalf("Error fetching origin namespace: %s", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							destinationNamespace, err := h.GetNamespace(args[1])
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Fatalf("Error fetching destination namespace: %s", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							id, err := strconv.Atoi(args[0])
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Fatalf("Error converting ID to integer: %s", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							machine, err := h.GetMachineByID(uint64(id))
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Fatalf("Error getting node: %s", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							err = h.AddSharedMachineToNamespace(machine, destinationNamespace)
 | 
				
			||||||
 | 
							if strings.HasPrefix(output, "json") {
 | 
				
			||||||
 | 
								JsonOutput(map[string]string{"Result": "Node shared"}, err, output)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								fmt.Printf("Error sharing node: %s\n", err)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							fmt.Println("Node shared!")
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func nodesToPtables(currentNamespace headscale.Namespace, machines []headscale.Machine) (pterm.TableData, error) {
 | 
				
			||||||
 | 
						d := pterm.TableData{{"ID", "Name", "NodeKey", "Namespace", "IP address", "Ephemeral", "Last seen", "Online"}}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, machine := range machines {
 | 
				
			||||||
		var ephemeral bool
 | 
							var ephemeral bool
 | 
				
			||||||
		if m.AuthKey != nil && m.AuthKey.Ephemeral {
 | 
							if machine.AuthKey != nil && machine.AuthKey.Ephemeral {
 | 
				
			||||||
			ephemeral = true
 | 
								ephemeral = true
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		var lastSeen time.Time
 | 
							var lastSeen time.Time
 | 
				
			||||||
		if m.LastSeen != nil {
 | 
							var lastSeenTime string
 | 
				
			||||||
			lastSeen = *m.LastSeen
 | 
							if machine.LastSeen != nil {
 | 
				
			||||||
 | 
								lastSeen = *machine.LastSeen
 | 
				
			||||||
 | 
								lastSeenTime = lastSeen.Format("2006-01-02 15:04:05")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		nKey, err := wgkey.ParseHex(m.NodeKey)
 | 
							nKey, err := wgkey.ParseHex(machine.NodeKey)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		nodeKey := tailcfg.NodeKey(nKey)
 | 
							nodeKey := tailcfg.NodeKey(nKey)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		var online string
 | 
							var online string
 | 
				
			||||||
		if m.LastSeen.After(time.Now().Add(-5 * time.Minute)) { // TODO: Find a better way to reliably show if online
 | 
							if lastSeen.After(time.Now().Add(-5 * time.Minute)) { // TODO: Find a better way to reliably show if online
 | 
				
			||||||
			online = pterm.LightGreen("true")
 | 
								online = pterm.LightGreen("true")
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			online = pterm.LightRed("false")
 | 
								online = pterm.LightRed("false")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		d = append(d, []string{strconv.FormatUint(m.ID, 10), m.Name, nodeKey.ShortString(), m.IPAddress, strconv.FormatBool(ephemeral), lastSeen.Format("2006-01-02 15:04:05"), online})
 | 
					
 | 
				
			||||||
 | 
							var namespace string
 | 
				
			||||||
 | 
							if currentNamespace.ID == machine.NamespaceID {
 | 
				
			||||||
 | 
								namespace = pterm.LightMagenta(machine.Namespace.Name)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								namespace = pterm.LightYellow(machine.Namespace.Name)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							d = append(d, []string{strconv.FormatUint(machine.ID, 10), machine.Name, nodeKey.ShortString(), namespace, machine.IPAddress, strconv.FormatBool(ephemeral), lastSeenTime, online})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return d, nil
 | 
						return d, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -5,6 +5,7 @@ import (
 | 
				
			|||||||
	"log"
 | 
						"log"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/pterm/pterm"
 | 
				
			||||||
	"github.com/spf13/cobra"
 | 
						"github.com/spf13/cobra"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -15,6 +16,9 @@ func init() {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Fatalf(err.Error())
 | 
							log.Fatalf(err.Error())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						enableRouteCmd.Flags().BoolP("all", "a", false, "Enable all routes advertised by the node")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	routesCmd.AddCommand(listRoutesCmd)
 | 
						routesCmd.AddCommand(listRoutesCmd)
 | 
				
			||||||
	routesCmd.AddCommand(enableRouteCmd)
 | 
						routesCmd.AddCommand(enableRouteCmd)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -44,19 +48,25 @@ var listRoutesCmd = &cobra.Command{
 | 
				
			|||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Fatalf("Error initializing: %s", err)
 | 
								log.Fatalf("Error initializing: %s", err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		routes, err := h.GetNodeRoutes(n, args[0])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if strings.HasPrefix(o, "json") {
 | 
					 | 
				
			||||||
			JsonOutput(routes, err, o)
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							availableRoutes, err := h.GetAdvertisedNodeRoutes(n, args[0])
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			fmt.Println(err)
 | 
								fmt.Println(err)
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		fmt.Println(routes)
 | 
							if strings.HasPrefix(o, "json") {
 | 
				
			||||||
 | 
								// TODO: Add enable/disabled information to this interface
 | 
				
			||||||
 | 
								JsonOutput(availableRoutes, err, o)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							d := h.RoutesToPtables(n, args[0], *availableRoutes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Fatal(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	},
 | 
						},
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -64,32 +74,74 @@ var enableRouteCmd = &cobra.Command{
 | 
				
			|||||||
	Use:   "enable node-name route",
 | 
						Use:   "enable node-name route",
 | 
				
			||||||
	Short: "Allows exposing a route declared by this node to the rest of the nodes",
 | 
						Short: "Allows exposing a route declared by this node to the rest of the nodes",
 | 
				
			||||||
	Args: func(cmd *cobra.Command, args []string) error {
 | 
						Args: func(cmd *cobra.Command, args []string) error {
 | 
				
			||||||
		if len(args) < 2 {
 | 
							all, err := cmd.Flags().GetBool("all")
 | 
				
			||||||
			return fmt.Errorf("Missing parameters")
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Fatalf("Error getting namespace: %s", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if all {
 | 
				
			||||||
 | 
								if len(args) < 1 {
 | 
				
			||||||
 | 
									return fmt.Errorf("Missing parameters")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return nil
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								if len(args) < 2 {
 | 
				
			||||||
 | 
									return fmt.Errorf("Missing parameters")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	},
 | 
						},
 | 
				
			||||||
	Run: func(cmd *cobra.Command, args []string) {
 | 
						Run: func(cmd *cobra.Command, args []string) {
 | 
				
			||||||
		n, err := cmd.Flags().GetString("namespace")
 | 
							n, err := cmd.Flags().GetString("namespace")
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Fatalf("Error getting namespace: %s", err)
 | 
								log.Fatalf("Error getting namespace: %s", err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		o, _ := cmd.Flags().GetString("output")
 | 
							o, _ := cmd.Flags().GetString("output")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							all, err := cmd.Flags().GetBool("all")
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Fatalf("Error getting namespace: %s", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		h, err := getHeadscaleApp()
 | 
							h, err := getHeadscaleApp()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Fatalf("Error initializing: %s", err)
 | 
								log.Fatalf("Error initializing: %s", err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		route, err := h.EnableNodeRoute(n, args[0], args[1])
 | 
					 | 
				
			||||||
		if strings.HasPrefix(o, "json") {
 | 
					 | 
				
			||||||
			JsonOutput(route, err, o)
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if err != nil {
 | 
							if all {
 | 
				
			||||||
			fmt.Println(err)
 | 
								availableRoutes, err := h.GetAdvertisedNodeRoutes(n, args[0])
 | 
				
			||||||
			return
 | 
								if err != nil {
 | 
				
			||||||
 | 
									fmt.Println(err)
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								for _, availableRoute := range *availableRoutes {
 | 
				
			||||||
 | 
									err = h.EnableNodeRoute(n, args[0], availableRoute.String())
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										fmt.Println(err)
 | 
				
			||||||
 | 
										return
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									if strings.HasPrefix(o, "json") {
 | 
				
			||||||
 | 
										JsonOutput(availableRoute, err, o)
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										fmt.Printf("Enabled route %s\n", availableRoute)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								err = h.EnableNodeRoute(n, args[0], args[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if strings.HasPrefix(o, "json") {
 | 
				
			||||||
 | 
									JsonOutput(args[1], err, o)
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									fmt.Println(err)
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								fmt.Printf("Enabled route %s\n", args[1])
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		fmt.Printf("Enabled route %s\n", route)
 | 
					 | 
				
			||||||
	},
 | 
						},
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -39,7 +39,9 @@ func LoadConfig(path string) error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	viper.SetDefault("ip_prefix", "100.64.0.0/10")
 | 
						viper.SetDefault("ip_prefix", "100.64.0.0/10")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	viper.SetDefault("log_level", "debug")
 | 
						viper.SetDefault("log_level", "info")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						viper.SetDefault("dns_config", nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := viper.ReadInConfig()
 | 
						err := viper.ReadInConfig()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@ -70,6 +72,45 @@ func LoadConfig(path string) error {
 | 
				
			|||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func GetDNSConfig() *tailcfg.DNSConfig {
 | 
				
			||||||
 | 
						if viper.IsSet("dns_config") {
 | 
				
			||||||
 | 
							dnsConfig := &tailcfg.DNSConfig{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if viper.IsSet("dns_config.nameservers") {
 | 
				
			||||||
 | 
								nameserversStr := viper.GetStringSlice("dns_config.nameservers")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								nameservers := make([]netaddr.IP, len(nameserversStr))
 | 
				
			||||||
 | 
								resolvers := make([]tailcfg.DNSResolver, len(nameserversStr))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								for index, nameserverStr := range nameserversStr {
 | 
				
			||||||
 | 
									nameserver, err := netaddr.ParseIP(nameserverStr)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										log.Error().
 | 
				
			||||||
 | 
											Str("func", "getDNSConfig").
 | 
				
			||||||
 | 
											Err(err).
 | 
				
			||||||
 | 
											Msgf("Could not parse nameserver IP: %s", nameserverStr)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									nameservers[index] = nameserver
 | 
				
			||||||
 | 
									resolvers[index] = tailcfg.DNSResolver{
 | 
				
			||||||
 | 
										Addr: nameserver.String(),
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								dnsConfig.Nameservers = nameservers
 | 
				
			||||||
 | 
								dnsConfig.Resolvers = resolvers
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if viper.IsSet("dns_config.domains") {
 | 
				
			||||||
 | 
								dnsConfig.Domains = viper.GetStringSlice("dns_config.domains")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return dnsConfig
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func absPath(path string) string {
 | 
					func absPath(path string) string {
 | 
				
			||||||
@ -126,6 +167,8 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		TLSCertPath: absPath(viper.GetString("tls_cert_path")),
 | 
							TLSCertPath: absPath(viper.GetString("tls_cert_path")),
 | 
				
			||||||
		TLSKeyPath:  absPath(viper.GetString("tls_key_path")),
 | 
							TLSKeyPath:  absPath(viper.GetString("tls_key_path")),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							DNSConfig: GetDNSConfig(),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	h, err := headscale.NewHeadscale(cfg)
 | 
						h, err := headscale.NewHeadscale(cfg)
 | 
				
			||||||
 | 
				
			|||||||
@ -58,7 +58,7 @@ func (*Suite) TestPostgresConfigLoading(c *check.C) {
 | 
				
			|||||||
	c.Assert(viper.GetString("db_port"), check.Equals, "5432")
 | 
						c.Assert(viper.GetString("db_port"), check.Equals, "5432")
 | 
				
			||||||
	c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
 | 
						c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
 | 
				
			||||||
	c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
 | 
						c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
 | 
				
			||||||
	c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
 | 
						c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (*Suite) TestSqliteConfigLoading(c *check.C) {
 | 
					func (*Suite) TestSqliteConfigLoading(c *check.C) {
 | 
				
			||||||
@ -92,6 +92,37 @@ func (*Suite) TestSqliteConfigLoading(c *check.C) {
 | 
				
			|||||||
	c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
 | 
						c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
 | 
				
			||||||
	c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
 | 
						c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
 | 
				
			||||||
	c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
 | 
						c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
 | 
				
			||||||
 | 
						c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (*Suite) TestDNSConfigLoading(c *check.C) {
 | 
				
			||||||
 | 
						tmpDir, err := ioutil.TempDir("", "headscale")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							c.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer os.RemoveAll(tmpDir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						path, err := os.Getwd()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							c.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Symlink the example config file
 | 
				
			||||||
 | 
						err = os.Symlink(filepath.Clean(path+"/../../config.json.sqlite.example"), filepath.Join(tmpDir, "config.json"))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							c.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Load example config, it should load without validation errors
 | 
				
			||||||
 | 
						err = cli.LoadConfig(tmpDir)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						dnsConfig := cli.GetDNSConfig()
 | 
				
			||||||
 | 
						fmt.Println(dnsConfig)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						c.Assert(dnsConfig.Nameservers[0].String(), check.Equals, "1.1.1.1")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						c.Assert(dnsConfig.Resolvers[0].Addr, check.Equals, "1.1.1.1")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
 | 
					func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
 | 
				
			||||||
 | 
				
			|||||||
@ -16,5 +16,10 @@
 | 
				
			|||||||
    "tls_letsencrypt_challenge_type": "HTTP-01",
 | 
					    "tls_letsencrypt_challenge_type": "HTTP-01",
 | 
				
			||||||
    "tls_cert_path": "",
 | 
					    "tls_cert_path": "",
 | 
				
			||||||
    "tls_key_path": "",
 | 
					    "tls_key_path": "",
 | 
				
			||||||
    "acl_policy_path": ""
 | 
					    "acl_policy_path": "",
 | 
				
			||||||
 | 
					    "dns_config": {
 | 
				
			||||||
 | 
					        "nameservers": [
 | 
				
			||||||
 | 
					            "1.1.1.1"
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -12,5 +12,10 @@
 | 
				
			|||||||
    "tls_letsencrypt_challenge_type": "HTTP-01",
 | 
					    "tls_letsencrypt_challenge_type": "HTTP-01",
 | 
				
			||||||
    "tls_cert_path": "",
 | 
					    "tls_cert_path": "",
 | 
				
			||||||
    "tls_key_path": "",
 | 
					    "tls_key_path": "",
 | 
				
			||||||
    "acl_policy_path": ""
 | 
					    "acl_policy_path": "",
 | 
				
			||||||
 | 
					    "dns_config": {
 | 
				
			||||||
 | 
					        "nameservers": [
 | 
				
			||||||
 | 
					            "1.1.1.1"
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										5
									
								
								db.go
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								db.go
									
									
									
									
									
								
							@ -44,6 +44,11 @@ func (h *Headscale) initDB() error {
 | 
				
			|||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = db.AutoMigrate(&SharedMachine{})
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = h.setValue("db_version", dbVersion)
 | 
						err = h.setValue("db_version", dbVersion)
 | 
				
			||||||
	return err
 | 
						return err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,7 +1,7 @@
 | 
				
			|||||||
# This file contains some of the official Tailscale DERP servers, 
 | 
					# This file contains some of the official Tailscale DERP servers, 
 | 
				
			||||||
# shamelessly taken from https://github.com/tailscale/tailscale/blob/main/net/dnsfallback/dns-fallback-servers.json
 | 
					# shamelessly taken from https://github.com/tailscale/tailscale/blob/main/net/dnsfallback/dns-fallback-servers.json
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# If you plan to somehow use headscale, please deploy your own DERP infra
 | 
					# If you plan to somehow use headscale, please deploy your own DERP infra: https://tailscale.com/kb/1118/custom-derp-servers/
 | 
				
			||||||
regions: 
 | 
					regions: 
 | 
				
			||||||
  1:
 | 
					  1:
 | 
				
			||||||
    regionid: 1
 | 
					    regionid: 1
 | 
				
			||||||
 | 
				
			|||||||
@ -4,10 +4,13 @@ package headscale
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
	"log"
 | 
						"log"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
 | 
						"path"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
@ -20,23 +23,48 @@ import (
 | 
				
			|||||||
	"inet.af/netaddr"
 | 
						"inet.af/netaddr"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type IntegrationTestSuite struct {
 | 
					 | 
				
			||||||
	suite.Suite
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestIntegrationTestSuite(t *testing.T) {
 | 
					 | 
				
			||||||
	suite.Run(t, new(IntegrationTestSuite))
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
var integrationTmpDir string
 | 
					var integrationTmpDir string
 | 
				
			||||||
var ih Headscale
 | 
					var ih Headscale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var pool dockertest.Pool
 | 
					var pool dockertest.Pool
 | 
				
			||||||
var network dockertest.Network
 | 
					var network dockertest.Network
 | 
				
			||||||
var headscale dockertest.Resource
 | 
					var headscale dockertest.Resource
 | 
				
			||||||
var tailscaleCount int = 5
 | 
					var tailscaleCount int = 25
 | 
				
			||||||
var tailscales map[string]dockertest.Resource
 | 
					var tailscales map[string]dockertest.Resource
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type IntegrationTestSuite struct {
 | 
				
			||||||
 | 
						suite.Suite
 | 
				
			||||||
 | 
						stats *suite.SuiteInformation
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestIntegrationTestSuite(t *testing.T) {
 | 
				
			||||||
 | 
						s := new(IntegrationTestSuite)
 | 
				
			||||||
 | 
						suite.Run(t, s)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// HandleStats, which allows us to check if we passed and save logs
 | 
				
			||||||
 | 
						// is called after TearDown, so we cannot tear down containers before
 | 
				
			||||||
 | 
						// we have potentially saved the logs.
 | 
				
			||||||
 | 
						for _, tailscale := range tailscales {
 | 
				
			||||||
 | 
							if err := pool.Purge(&tailscale); err != nil {
 | 
				
			||||||
 | 
								log.Printf("Could not purge resource: %s\n", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if !s.stats.Passed() {
 | 
				
			||||||
 | 
							err := saveLog(&headscale, "test_output")
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Printf("Could not save log: %s\n", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := pool.Purge(&headscale); err != nil {
 | 
				
			||||||
 | 
							log.Printf("Could not purge resource: %s\n", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := network.Close(); err != nil {
 | 
				
			||||||
 | 
							log.Printf("Could not close network: %s\n", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) {
 | 
					func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) {
 | 
				
			||||||
	var stdout bytes.Buffer
 | 
						var stdout bytes.Buffer
 | 
				
			||||||
	var stderr bytes.Buffer
 | 
						var stderr bytes.Buffer
 | 
				
			||||||
@ -62,6 +90,48 @@ func executeCommand(resource *dockertest.Resource, cmd []string) (string, error)
 | 
				
			|||||||
	return stdout.String(), nil
 | 
						return stdout.String(), nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func saveLog(resource *dockertest.Resource, basePath string) error {
 | 
				
			||||||
 | 
						err := os.MkdirAll(basePath, os.ModePerm)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var stdout bytes.Buffer
 | 
				
			||||||
 | 
						var stderr bytes.Buffer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = pool.Client.Logs(
 | 
				
			||||||
 | 
							docker.LogsOptions{
 | 
				
			||||||
 | 
								Context:      context.TODO(),
 | 
				
			||||||
 | 
								Container:    resource.Container.ID,
 | 
				
			||||||
 | 
								OutputStream: &stdout,
 | 
				
			||||||
 | 
								ErrorStream:  &stderr,
 | 
				
			||||||
 | 
								Tail:         "all",
 | 
				
			||||||
 | 
								RawTerminal:  false,
 | 
				
			||||||
 | 
								Stdout:       true,
 | 
				
			||||||
 | 
								Stderr:       true,
 | 
				
			||||||
 | 
								Follow:       false,
 | 
				
			||||||
 | 
								Timestamps:   false,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stdout.log"), []byte(stdout.String()), 0644)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stderr.log"), []byte(stdout.String()), 0644)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func dockerRestartPolicy(config *docker.HostConfig) {
 | 
					func dockerRestartPolicy(config *docker.HostConfig) {
 | 
				
			||||||
	// set AutoRemove to true so that stopped container goes away by itself
 | 
						// set AutoRemove to true so that stopped container goes away by itself
 | 
				
			||||||
	config.AutoRemove = true
 | 
						config.AutoRemove = true
 | 
				
			||||||
@ -115,7 +185,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
 | 
				
			|||||||
		PortBindings: map[docker.Port][]docker.PortBinding{
 | 
							PortBindings: map[docker.Port][]docker.PortBinding{
 | 
				
			||||||
			"8080/tcp": []docker.PortBinding{{HostPort: "8080"}},
 | 
								"8080/tcp": []docker.PortBinding{{HostPort: "8080"}},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		Env: []string{},
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	fmt.Println("Creating headscale container")
 | 
						fmt.Println("Creating headscale container")
 | 
				
			||||||
@ -134,7 +203,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
 | 
				
			|||||||
			Name:     hostname,
 | 
								Name:     hostname,
 | 
				
			||||||
			Networks: []*dockertest.Network{&network},
 | 
								Networks: []*dockertest.Network{&network},
 | 
				
			||||||
			Cmd:      []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"},
 | 
								Cmd:      []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"},
 | 
				
			||||||
			Env:      []string{},
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if pts, err := pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, dockerRestartPolicy); err == nil {
 | 
							if pts, err := pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, dockerRestartPolicy); err == nil {
 | 
				
			||||||
@ -145,7 +213,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
 | 
				
			|||||||
		fmt.Printf("Created %s container\n", hostname)
 | 
							fmt.Printf("Created %s container\n", hostname)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// TODO: Replace this logic with something that can be detected on Github Actions
 | 
					 | 
				
			||||||
	fmt.Println("Waiting for headscale to be ready")
 | 
						fmt.Println("Waiting for headscale to be ready")
 | 
				
			||||||
	hostEndpoint := fmt.Sprintf("localhost:%s", headscale.GetPort("8080/tcp"))
 | 
						hostEndpoint := fmt.Sprintf("localhost:%s", headscale.GetPort("8080/tcp"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -197,23 +264,14 @@ func (s *IntegrationTestSuite) SetupSuite() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// The nodes need a bit of time to get their updated maps from headscale
 | 
						// The nodes need a bit of time to get their updated maps from headscale
 | 
				
			||||||
	// TODO: See if we can have a more deterministic wait here.
 | 
						// TODO: See if we can have a more deterministic wait here.
 | 
				
			||||||
	time.Sleep(20 * time.Second)
 | 
						time.Sleep(60 * time.Second)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *IntegrationTestSuite) TearDownSuite() {
 | 
					func (s *IntegrationTestSuite) TearDownSuite() {
 | 
				
			||||||
	if err := pool.Purge(&headscale); err != nil {
 | 
					}
 | 
				
			||||||
		log.Printf("Could not purge resource: %s\n", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, tailscale := range tailscales {
 | 
					func (s *IntegrationTestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) {
 | 
				
			||||||
		if err := pool.Purge(&tailscale); err != nil {
 | 
						s.stats = stats
 | 
				
			||||||
			log.Printf("Could not purge resource: %s\n", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := network.Close(); err != nil {
 | 
					 | 
				
			||||||
		log.Printf("Could not close network: %s\n", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *IntegrationTestSuite) TestListNodes() {
 | 
					func (s *IntegrationTestSuite) TestListNodes() {
 | 
				
			||||||
@ -295,7 +353,15 @@ func (s *IntegrationTestSuite) TestPingAllPeers() {
 | 
				
			|||||||
			s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
 | 
								s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
 | 
				
			||||||
				// We currently cant ping ourselves, so skip that.
 | 
									// We currently cant ping ourselves, so skip that.
 | 
				
			||||||
				if peername != hostname {
 | 
									if peername != hostname {
 | 
				
			||||||
					command := []string{"tailscale", "ping", "--timeout=1s", "--c=1", ip.String()}
 | 
										// We are only interested in "direct ping" which means what we
 | 
				
			||||||
 | 
										// might need a couple of more attempts before reaching the node.
 | 
				
			||||||
 | 
										command := []string{
 | 
				
			||||||
 | 
											"tailscale", "ping",
 | 
				
			||||||
 | 
											"--timeout=1s",
 | 
				
			||||||
 | 
											"--c=20",
 | 
				
			||||||
 | 
											"--until-direct=true",
 | 
				
			||||||
 | 
											ip.String(),
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
 | 
										fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
 | 
				
			||||||
					result, err := executeCommand(
 | 
										result, err := executeCommand(
 | 
				
			||||||
 | 
				
			|||||||
@ -7,5 +7,5 @@
 | 
				
			|||||||
  "db_type": "sqlite3",
 | 
					  "db_type": "sqlite3",
 | 
				
			||||||
  "db_path": "/tmp/integration_test_db.sqlite3",
 | 
					  "db_path": "/tmp/integration_test_db.sqlite3",
 | 
				
			||||||
  "acl_policy_path": "",
 | 
					  "acl_policy_path": "",
 | 
				
			||||||
  "log_level": "trace"
 | 
					  "log_level": "debug"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										201
									
								
								machine.go
									
									
									
									
									
								
							
							
						
						
									
										201
									
								
								machine.go
									
									
									
									
									
								
							@ -2,6 +2,7 @@ package headscale
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"sort"
 | 
						"sort"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
@ -31,8 +32,9 @@ type Machine struct {
 | 
				
			|||||||
	AuthKeyID      uint
 | 
						AuthKeyID      uint
 | 
				
			||||||
	AuthKey        *PreAuthKey
 | 
						AuthKey        *PreAuthKey
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	LastSeen *time.Time
 | 
						LastSeen             *time.Time
 | 
				
			||||||
	Expiry   *time.Time
 | 
						LastSuccessfulUpdate *time.Time
 | 
				
			||||||
 | 
						Expiry               *time.Time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	HostInfo      datatypes.JSON
 | 
						HostInfo      datatypes.JSON
 | 
				
			||||||
	Endpoints     datatypes.JSON
 | 
						Endpoints     datatypes.JSON
 | 
				
			||||||
@ -48,7 +50,9 @@ func (m Machine) isAlreadyRegistered() bool {
 | 
				
			|||||||
	return m.Registered
 | 
						return m.Registered
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m Machine) toNode() (*tailcfg.Node, error) {
 | 
					// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
 | 
				
			||||||
 | 
					// as per the expected behaviour in the official SaaS
 | 
				
			||||||
 | 
					func (m Machine) toNode(includeRoutes bool) (*tailcfg.Node, error) {
 | 
				
			||||||
	nKey, err := wgkey.ParseHex(m.NodeKey)
 | 
						nKey, err := wgkey.ParseHex(m.NodeKey)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
@ -83,24 +87,26 @@ func (m Machine) toNode() (*tailcfg.Node, error) {
 | 
				
			|||||||
	allowedIPs := []netaddr.IPPrefix{}
 | 
						allowedIPs := []netaddr.IPPrefix{}
 | 
				
			||||||
	allowedIPs = append(allowedIPs, ip) // we append the node own IP, as it is required by the clients
 | 
						allowedIPs = append(allowedIPs, ip) // we append the node own IP, as it is required by the clients
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	routesStr := []string{}
 | 
						if includeRoutes {
 | 
				
			||||||
	if len(m.EnabledRoutes) != 0 {
 | 
							routesStr := []string{}
 | 
				
			||||||
		allwIps, err := m.EnabledRoutes.MarshalJSON()
 | 
							if len(m.EnabledRoutes) != 0 {
 | 
				
			||||||
		if err != nil {
 | 
								allwIps, err := m.EnabledRoutes.MarshalJSON()
 | 
				
			||||||
			return nil, err
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return nil, err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								err = json.Unmarshal(allwIps, &routesStr)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return nil, err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		err = json.Unmarshal(allwIps, &routesStr)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return nil, err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, aip := range routesStr {
 | 
							for _, routeStr := range routesStr {
 | 
				
			||||||
		ip, err := netaddr.ParseIPPrefix(aip)
 | 
								ip, err := netaddr.ParseIPPrefix(routeStr)
 | 
				
			||||||
		if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
			return nil, err
 | 
									return nil, err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								allowedIPs = append(allowedIPs, ip)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		allowedIPs = append(allowedIPs, ip)
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	endpoints := []string{}
 | 
						endpoints := []string{}
 | 
				
			||||||
@ -134,13 +140,20 @@ func (m Machine) toNode() (*tailcfg.Node, error) {
 | 
				
			|||||||
		derp = "127.3.3.40:0" // Zero means disconnected or unknown.
 | 
							derp = "127.3.3.40:0" // Zero means disconnected or unknown.
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var keyExpiry time.Time
 | 
				
			||||||
 | 
						if m.Expiry != nil {
 | 
				
			||||||
 | 
							keyExpiry = *m.Expiry
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							keyExpiry = time.Time{}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	n := tailcfg.Node{
 | 
						n := tailcfg.Node{
 | 
				
			||||||
		ID:         tailcfg.NodeID(m.ID),                               // this is the actual ID
 | 
							ID:         tailcfg.NodeID(m.ID),                               // this is the actual ID
 | 
				
			||||||
		StableID:   tailcfg.StableNodeID(strconv.FormatUint(m.ID, 10)), // in headscale, unlike tailcontrol server, IDs are permanent
 | 
							StableID:   tailcfg.StableNodeID(strconv.FormatUint(m.ID, 10)), // in headscale, unlike tailcontrol server, IDs are permanent
 | 
				
			||||||
		Name:       hostinfo.Hostname,
 | 
							Name:       hostinfo.Hostname,
 | 
				
			||||||
		User:       tailcfg.UserID(m.NamespaceID),
 | 
							User:       tailcfg.UserID(m.NamespaceID),
 | 
				
			||||||
		Key:        tailcfg.NodeKey(nKey),
 | 
							Key:        tailcfg.NodeKey(nKey),
 | 
				
			||||||
		KeyExpiry:  *m.Expiry,
 | 
							KeyExpiry:  keyExpiry,
 | 
				
			||||||
		Machine:    tailcfg.MachineKey(mKey),
 | 
							Machine:    tailcfg.MachineKey(mKey),
 | 
				
			||||||
		DiscoKey:   discoKey,
 | 
							DiscoKey:   discoKey,
 | 
				
			||||||
		Addresses:  addrs,
 | 
							Addresses:  addrs,
 | 
				
			||||||
@ -163,6 +176,7 @@ func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) {
 | 
				
			|||||||
		Str("func", "getPeers").
 | 
							Str("func", "getPeers").
 | 
				
			||||||
		Str("machine", m.Name).
 | 
							Str("machine", m.Name).
 | 
				
			||||||
		Msg("Finding peers")
 | 
							Msg("Finding peers")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	machines := []Machine{}
 | 
						machines := []Machine{}
 | 
				
			||||||
	if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered",
 | 
						if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered",
 | 
				
			||||||
		m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
 | 
							m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
 | 
				
			||||||
@ -170,9 +184,23 @@ func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) {
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// We fetch here machines that are shared to the `Namespace` of the machine we are getting peers for
 | 
				
			||||||
 | 
						sharedMachines := []SharedMachine{}
 | 
				
			||||||
 | 
						if err := h.db.Preload("Namespace").Preload("Machine").Where("namespace_id = ?",
 | 
				
			||||||
 | 
							m.NamespaceID).Find(&sharedMachines).Error; err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	peers := []*tailcfg.Node{}
 | 
						peers := []*tailcfg.Node{}
 | 
				
			||||||
	for _, mn := range machines {
 | 
						for _, mn := range machines {
 | 
				
			||||||
		peer, err := mn.toNode()
 | 
							peer, err := mn.toNode(true)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							peers = append(peers, peer)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, sharedMachine := range sharedMachines {
 | 
				
			||||||
 | 
							peer, err := sharedMachine.Machine.toNode(false) // shared nodes do not expose their routes
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@ -199,18 +227,27 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error)
 | 
				
			|||||||
			return &m, nil
 | 
								return &m, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil, fmt.Errorf("not found")
 | 
						return nil, fmt.Errorf("machine not found")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// GetMachineByID finds a Machine by ID and returns the Machine struct
 | 
					// GetMachineByID finds a Machine by ID and returns the Machine struct
 | 
				
			||||||
func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
 | 
					func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
 | 
				
			||||||
	m := Machine{}
 | 
						m := Machine{}
 | 
				
			||||||
	if result := h.db.Find(&Machine{ID: id}).First(&m); result.Error != nil {
 | 
						if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil {
 | 
				
			||||||
		return nil, result.Error
 | 
							return nil, result.Error
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &m, nil
 | 
						return &m, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// UpdateMachine takes a Machine struct pointer (typically already loaded from database
 | 
				
			||||||
 | 
					// and updates it with the latest data from the database.
 | 
				
			||||||
 | 
					func (h *Headscale) UpdateMachine(m *Machine) error {
 | 
				
			||||||
 | 
						if result := h.db.Find(m).First(&m); result.Error != nil {
 | 
				
			||||||
 | 
							return result.Error
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// DeleteMachine softs deletes a Machine from the database
 | 
					// DeleteMachine softs deletes a Machine from the database
 | 
				
			||||||
func (h *Headscale) DeleteMachine(m *Machine) error {
 | 
					func (h *Headscale) DeleteMachine(m *Machine) error {
 | 
				
			||||||
	m.Registered = false
 | 
						m.Registered = false
 | 
				
			||||||
@ -249,23 +286,119 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *Headscale) notifyChangesToPeers(m *Machine) {
 | 
					func (h *Headscale) notifyChangesToPeers(m *Machine) {
 | 
				
			||||||
	peers, _ := h.getPeers(*m)
 | 
						peers, err := h.getPeers(*m)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Error().
 | 
				
			||||||
 | 
								Str("func", "notifyChangesToPeers").
 | 
				
			||||||
 | 
								Str("machine", m.Name).
 | 
				
			||||||
 | 
								Msgf("Error getting peers: %s", err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	for _, p := range *peers {
 | 
						for _, p := range *peers {
 | 
				
			||||||
		pUp, ok := h.clientsPolling.Load(uint64(p.ID))
 | 
							log.Info().
 | 
				
			||||||
		if ok {
 | 
								Str("func", "notifyChangesToPeers").
 | 
				
			||||||
 | 
								Str("machine", m.Name).
 | 
				
			||||||
 | 
								Str("peer", p.Name).
 | 
				
			||||||
 | 
								Str("address", p.Addresses[0].String()).
 | 
				
			||||||
 | 
								Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
 | 
				
			||||||
 | 
							err := h.sendRequestOnUpdateChannel(p)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
			log.Info().
 | 
								log.Info().
 | 
				
			||||||
				Str("func", "notifyChangesToPeers").
 | 
									Str("func", "notifyChangesToPeers").
 | 
				
			||||||
				Str("machine", m.Name).
 | 
									Str("machine", m.Name).
 | 
				
			||||||
				Str("peer", m.Name).
 | 
									Str("peer", p.Name).
 | 
				
			||||||
				Str("address", p.Addresses[0].String()).
 | 
					 | 
				
			||||||
				Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
 | 
					 | 
				
			||||||
			pUp.(chan []byte) <- []byte{}
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			log.Info().
 | 
					 | 
				
			||||||
				Str("func", "notifyChangesToPeers").
 | 
					 | 
				
			||||||
				Str("machine", m.Name).
 | 
					 | 
				
			||||||
				Str("peer", m.Name).
 | 
					 | 
				
			||||||
				Msgf("Peer %s does not appear to be polling", p.Name)
 | 
									Msgf("Peer %s does not appear to be polling", p.Name)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							log.Trace().
 | 
				
			||||||
 | 
								Str("func", "notifyChangesToPeers").
 | 
				
			||||||
 | 
								Str("machine", m.Name).
 | 
				
			||||||
 | 
								Str("peer", p.Name).
 | 
				
			||||||
 | 
								Str("address", p.Addresses[0].String()).
 | 
				
			||||||
 | 
								Msgf("Notified peer %s (%s)", p.Name, p.Addresses[0])
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *Headscale) getOrOpenUpdateChannel(m *Machine) <-chan struct{} {
 | 
				
			||||||
 | 
						var updateChan chan struct{}
 | 
				
			||||||
 | 
						if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
 | 
				
			||||||
 | 
							if unwrapped, ok := storedChan.(chan struct{}); ok {
 | 
				
			||||||
 | 
								updateChan = unwrapped
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								log.Error().
 | 
				
			||||||
 | 
									Str("handler", "openUpdateChannel").
 | 
				
			||||||
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Msg("Failed to convert update channel to struct{}")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							log.Debug().
 | 
				
			||||||
 | 
								Str("handler", "openUpdateChannel").
 | 
				
			||||||
 | 
								Str("machine", m.Name).
 | 
				
			||||||
 | 
								Msg("Update channel not found, creating")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							updateChan = make(chan struct{})
 | 
				
			||||||
 | 
							h.clientsUpdateChannels.Store(m.ID, updateChan)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return updateChan
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *Headscale) closeUpdateChannel(m *Machine) {
 | 
				
			||||||
 | 
						h.clientsUpdateChannelMutex.Lock()
 | 
				
			||||||
 | 
						defer h.clientsUpdateChannelMutex.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
 | 
				
			||||||
 | 
							if unwrapped, ok := storedChan.(chan struct{}); ok {
 | 
				
			||||||
 | 
								close(unwrapped)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.clientsUpdateChannels.Delete(m.ID)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *Headscale) sendRequestOnUpdateChannel(m *tailcfg.Node) error {
 | 
				
			||||||
 | 
						h.clientsUpdateChannelMutex.Lock()
 | 
				
			||||||
 | 
						defer h.clientsUpdateChannelMutex.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID))
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							log.Info().
 | 
				
			||||||
 | 
								Str("func", "requestUpdate").
 | 
				
			||||||
 | 
								Str("machine", m.Name).
 | 
				
			||||||
 | 
								Msgf("Notifying peer %s", m.Name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if update, ok := pUp.(chan struct{}); ok {
 | 
				
			||||||
 | 
								log.Trace().
 | 
				
			||||||
 | 
									Str("func", "requestUpdate").
 | 
				
			||||||
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Msgf("Update channel is %#v", update)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								update <- struct{}{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								log.Trace().
 | 
				
			||||||
 | 
									Str("func", "requestUpdate").
 | 
				
			||||||
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Msgf("Notified machine %s", m.Name)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							log.Info().
 | 
				
			||||||
 | 
								Str("func", "requestUpdate").
 | 
				
			||||||
 | 
								Str("machine", m.Name).
 | 
				
			||||||
 | 
								Msgf("Machine %s does not appear to be polling", m.Name)
 | 
				
			||||||
 | 
							return errors.New("machine does not seem to be polling")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *Headscale) isOutdated(m *Machine) bool {
 | 
				
			||||||
 | 
						err := h.UpdateMachine(m)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						lastChange := h.getLastStateChange(m.Namespace.Name)
 | 
				
			||||||
 | 
						log.Trace().
 | 
				
			||||||
 | 
							Str("func", "keepAlive").
 | 
				
			||||||
 | 
							Str("machine", m.Name).
 | 
				
			||||||
 | 
							Time("last_successful_update", *m.LastSuccessfulUpdate).
 | 
				
			||||||
 | 
							Time("last_state_change", lastChange).
 | 
				
			||||||
 | 
							Msgf("Checking if %s is missing updates", m.Name)
 | 
				
			||||||
 | 
						return m.LastSuccessfulUpdate.Before(lastChange)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -91,12 +91,34 @@ func (h *Headscale) ListMachinesInNamespace(name string) (*[]Machine, error) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	machines := []Machine{}
 | 
						machines := []Machine{}
 | 
				
			||||||
	if err := h.db.Preload("AuthKey").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
 | 
						if err := h.db.Preload("AuthKey").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &machines, nil
 | 
						return &machines, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace
 | 
				
			||||||
 | 
					func (h *Headscale) ListSharedMachinesInNamespace(name string) (*[]Machine, error) {
 | 
				
			||||||
 | 
						namespace, err := h.GetNamespace(name)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						sharedMachines := []SharedMachine{}
 | 
				
			||||||
 | 
						if err := h.db.Preload("Namespace").Where(&SharedMachine{NamespaceID: namespace.ID}).Find(&sharedMachines).Error; err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						machines := []Machine{}
 | 
				
			||||||
 | 
						for _, sharedMachine := range sharedMachines {
 | 
				
			||||||
 | 
							machine, err := h.GetMachineByID(sharedMachine.MachineID) // otherwise not everything comes filled
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							machines = append(machines, *machine)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &machines, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// SetMachineNamespace assigns a Machine to a namespace
 | 
					// SetMachineNamespace assigns a Machine to a namespace
 | 
				
			||||||
func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error {
 | 
					func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error {
 | 
				
			||||||
	n, err := h.GetNamespace(namespaceName)
 | 
						n, err := h.GetNamespace(namespaceName)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										404
									
								
								poll.go
									
									
									
									
									
								
							
							
						
						
									
										404
									
								
								poll.go
									
									
									
									
									
								
							@ -1,38 +1,225 @@
 | 
				
			|||||||
package headscale
 | 
					package headscale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
 | 
						"gorm.io/datatypes"
 | 
				
			||||||
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	"tailscale.com/tailcfg"
 | 
						"tailscale.com/tailcfg"
 | 
				
			||||||
	"tailscale.com/types/wgkey"
 | 
						"tailscale.com/types/wgkey"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// PollNetMapHandler takes care of /machine/:id/map
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// This is the busiest endpoint, as it keeps the HTTP long poll that updates
 | 
				
			||||||
 | 
					// the clients when something in the network changes.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// The clients POST stuff like HostInfo and their Endpoints here, but
 | 
				
			||||||
 | 
					// only after their first request (marked with the ReadOnly field).
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// At this moment the updates are sent in a quite horrendous way, but they kinda work.
 | 
				
			||||||
 | 
					func (h *Headscale) PollNetMapHandler(c *gin.Context) {
 | 
				
			||||||
 | 
						log.Trace().
 | 
				
			||||||
 | 
							Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
							Str("id", c.Param("id")).
 | 
				
			||||||
 | 
							Msg("PollNetMapHandler called")
 | 
				
			||||||
 | 
						body, _ := io.ReadAll(c.Request.Body)
 | 
				
			||||||
 | 
						mKeyStr := c.Param("id")
 | 
				
			||||||
 | 
						mKey, err := wgkey.ParseHex(mKeyStr)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Error().
 | 
				
			||||||
 | 
								Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
								Err(err).
 | 
				
			||||||
 | 
								Msg("Cannot parse client key")
 | 
				
			||||||
 | 
							c.String(http.StatusBadRequest, "")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						req := tailcfg.MapRequest{}
 | 
				
			||||||
 | 
						err = decode(body, &req, &mKey, h.privateKey)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Error().
 | 
				
			||||||
 | 
								Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
								Err(err).
 | 
				
			||||||
 | 
								Msg("Cannot decode message")
 | 
				
			||||||
 | 
							c.String(http.StatusBadRequest, "")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var m Machine
 | 
				
			||||||
 | 
						if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
 | 
				
			||||||
 | 
							log.Warn().
 | 
				
			||||||
 | 
								Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
								Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
 | 
				
			||||||
 | 
							c.String(http.StatusUnauthorized, "")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						log.Trace().
 | 
				
			||||||
 | 
							Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
							Str("id", c.Param("id")).
 | 
				
			||||||
 | 
							Str("machine", m.Name).
 | 
				
			||||||
 | 
							Msg("Found machine in database")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						hostinfo, _ := json.Marshal(req.Hostinfo)
 | 
				
			||||||
 | 
						m.Name = req.Hostinfo.Hostname
 | 
				
			||||||
 | 
						m.HostInfo = datatypes.JSON(hostinfo)
 | 
				
			||||||
 | 
						m.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
 | 
				
			||||||
 | 
						now := time.Now().UTC()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// From Tailscale client:
 | 
				
			||||||
 | 
						//
 | 
				
			||||||
 | 
						// ReadOnly is whether the client just wants to fetch the MapResponse,
 | 
				
			||||||
 | 
						// without updating their Endpoints. The Endpoints field will be ignored and
 | 
				
			||||||
 | 
						// LastSeen will not be updated and peers will not be notified of changes.
 | 
				
			||||||
 | 
						//
 | 
				
			||||||
 | 
						// The intended use is for clients to discover the DERP map at start-up
 | 
				
			||||||
 | 
						// before their first real endpoint update.
 | 
				
			||||||
 | 
						if !req.ReadOnly {
 | 
				
			||||||
 | 
							endpoints, _ := json.Marshal(req.Endpoints)
 | 
				
			||||||
 | 
							m.Endpoints = datatypes.JSON(endpoints)
 | 
				
			||||||
 | 
							m.LastSeen = &now
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.db.Save(&m)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						data, err := h.getMapResponse(mKey, req, m)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Error().
 | 
				
			||||||
 | 
								Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
								Str("id", c.Param("id")).
 | 
				
			||||||
 | 
								Str("machine", m.Name).
 | 
				
			||||||
 | 
								Err(err).
 | 
				
			||||||
 | 
								Msg("Failed to get Map response")
 | 
				
			||||||
 | 
							c.String(http.StatusInternalServerError, ":(")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// We update our peers if the client is not sending ReadOnly in the MapRequest
 | 
				
			||||||
 | 
						// so we don't distribute its initial request (it comes with
 | 
				
			||||||
 | 
						// empty endpoints to peers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
 | 
				
			||||||
 | 
						log.Debug().
 | 
				
			||||||
 | 
							Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
							Str("id", c.Param("id")).
 | 
				
			||||||
 | 
							Str("machine", m.Name).
 | 
				
			||||||
 | 
							Bool("readOnly", req.ReadOnly).
 | 
				
			||||||
 | 
							Bool("omitPeers", req.OmitPeers).
 | 
				
			||||||
 | 
							Bool("stream", req.Stream).
 | 
				
			||||||
 | 
							Msg("Client map request processed")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if req.ReadOnly {
 | 
				
			||||||
 | 
							log.Info().
 | 
				
			||||||
 | 
								Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
								Str("machine", m.Name).
 | 
				
			||||||
 | 
								Msg("Client is starting up. Probably interested in a DERP map")
 | 
				
			||||||
 | 
							c.Data(200, "application/json; charset=utf-8", *data)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// There has been an update to _any_ of the nodes that the other nodes would
 | 
				
			||||||
 | 
						// need to know about
 | 
				
			||||||
 | 
						h.setLastStateChangeToNow(m.Namespace.Name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// The request is not ReadOnly, so we need to set up channels for updating
 | 
				
			||||||
 | 
						// peers via longpoll
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Only create update channel if it has not been created
 | 
				
			||||||
 | 
						log.Trace().
 | 
				
			||||||
 | 
							Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
							Str("id", c.Param("id")).
 | 
				
			||||||
 | 
							Str("machine", m.Name).
 | 
				
			||||||
 | 
							Msg("Loading or creating update channel")
 | 
				
			||||||
 | 
						updateChan := h.getOrOpenUpdateChannel(&m)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pollDataChan := make(chan []byte)
 | 
				
			||||||
 | 
						// defer close(pollData)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						keepAliveChan := make(chan []byte)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cancelKeepAlive := make(chan struct{})
 | 
				
			||||||
 | 
						defer close(cancelKeepAlive)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if req.OmitPeers && !req.Stream {
 | 
				
			||||||
 | 
							log.Info().
 | 
				
			||||||
 | 
								Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
								Str("machine", m.Name).
 | 
				
			||||||
 | 
								Msg("Client sent endpoint update and is ok with a response without peer list")
 | 
				
			||||||
 | 
							c.Data(200, "application/json; charset=utf-8", *data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// It sounds like we should update the nodes when we have received a endpoint update
 | 
				
			||||||
 | 
							// even tho the comments in the tailscale code dont explicitly say so.
 | 
				
			||||||
 | 
							go h.notifyChangesToPeers(&m)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						} else if req.OmitPeers && req.Stream {
 | 
				
			||||||
 | 
							log.Warn().
 | 
				
			||||||
 | 
								Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
								Str("machine", m.Name).
 | 
				
			||||||
 | 
								Msg("Ignoring request, don't know how to handle it")
 | 
				
			||||||
 | 
							c.String(http.StatusBadRequest, "")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Info().
 | 
				
			||||||
 | 
							Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
							Str("machine", m.Name).
 | 
				
			||||||
 | 
							Msg("Client is ready to access the tailnet")
 | 
				
			||||||
 | 
						log.Info().
 | 
				
			||||||
 | 
							Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
							Str("machine", m.Name).
 | 
				
			||||||
 | 
							Msg("Sending initial map")
 | 
				
			||||||
 | 
						go func() { pollDataChan <- *data }()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Info().
 | 
				
			||||||
 | 
							Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
							Str("machine", m.Name).
 | 
				
			||||||
 | 
							Msg("Notifying peers")
 | 
				
			||||||
 | 
						go h.notifyChangesToPeers(&m)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						h.PollNetMapStream(c, m, req, mKey, pollDataChan, keepAliveChan, updateChan, cancelKeepAlive)
 | 
				
			||||||
 | 
						log.Trace().
 | 
				
			||||||
 | 
							Str("handler", "PollNetMap").
 | 
				
			||||||
 | 
							Str("id", c.Param("id")).
 | 
				
			||||||
 | 
							Str("machine", m.Name).
 | 
				
			||||||
 | 
							Msg("Finished stream, closing PollNetMap session")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// PollNetMapStream takes care of /machine/:id/map
 | 
				
			||||||
 | 
					// stream logic, ensuring we communicate updates and data
 | 
				
			||||||
 | 
					// to the connected clients.
 | 
				
			||||||
func (h *Headscale) PollNetMapStream(
 | 
					func (h *Headscale) PollNetMapStream(
 | 
				
			||||||
	c *gin.Context,
 | 
						c *gin.Context,
 | 
				
			||||||
	m Machine,
 | 
						m Machine,
 | 
				
			||||||
	req tailcfg.MapRequest,
 | 
						req tailcfg.MapRequest,
 | 
				
			||||||
	mKey wgkey.Key,
 | 
						mKey wgkey.Key,
 | 
				
			||||||
	pollData chan []byte,
 | 
						pollDataChan chan []byte,
 | 
				
			||||||
	update chan []byte,
 | 
						keepAliveChan chan []byte,
 | 
				
			||||||
	cancelKeepAlive chan []byte,
 | 
						updateChan <-chan struct{},
 | 
				
			||||||
 | 
						cancelKeepAlive chan struct{},
 | 
				
			||||||
) {
 | 
					) {
 | 
				
			||||||
 | 
						go h.scheduledPollWorker(cancelKeepAlive, keepAliveChan, mKey, req, m)
 | 
				
			||||||
	go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	c.Stream(func(w io.Writer) bool {
 | 
						c.Stream(func(w io.Writer) bool {
 | 
				
			||||||
		log.Trace().
 | 
							log.Trace().
 | 
				
			||||||
			Str("handler", "PollNetMapStream").
 | 
								Str("handler", "PollNetMapStream").
 | 
				
			||||||
			Str("machine", m.Name).
 | 
								Str("machine", m.Name).
 | 
				
			||||||
			Msg("Waiting for data to stream...")
 | 
								Msg("Waiting for data to stream...")
 | 
				
			||||||
		select {
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		case data := <-pollData:
 | 
							log.Trace().
 | 
				
			||||||
 | 
								Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
								Str("machine", m.Name).
 | 
				
			||||||
 | 
								Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case data := <-pollDataChan:
 | 
				
			||||||
			log.Trace().
 | 
								log.Trace().
 | 
				
			||||||
				Str("handler", "PollNetMapStream").
 | 
									Str("handler", "PollNetMapStream").
 | 
				
			||||||
				Str("machine", m.Name).
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Str("channel", "pollData").
 | 
				
			||||||
				Int("bytes", len(data)).
 | 
									Int("bytes", len(data)).
 | 
				
			||||||
				Msg("Sending data received via pollData channel")
 | 
									Msg("Sending data received via pollData channel")
 | 
				
			||||||
			_, err := w.Write(data)
 | 
								_, err := w.Write(data)
 | 
				
			||||||
@ -40,44 +227,148 @@ func (h *Headscale) PollNetMapStream(
 | 
				
			|||||||
				log.Error().
 | 
									log.Error().
 | 
				
			||||||
					Str("handler", "PollNetMapStream").
 | 
										Str("handler", "PollNetMapStream").
 | 
				
			||||||
					Str("machine", m.Name).
 | 
										Str("machine", m.Name).
 | 
				
			||||||
 | 
										Str("channel", "pollData").
 | 
				
			||||||
					Err(err).
 | 
										Err(err).
 | 
				
			||||||
					Msg("Cannot write data")
 | 
										Msg("Cannot write data")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			log.Trace().
 | 
								log.Trace().
 | 
				
			||||||
				Str("handler", "PollNetMapStream").
 | 
									Str("handler", "PollNetMapStream").
 | 
				
			||||||
				Str("machine", m.Name).
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Str("channel", "pollData").
 | 
				
			||||||
				Int("bytes", len(data)).
 | 
									Int("bytes", len(data)).
 | 
				
			||||||
				Msg("Data from pollData channel written successfully")
 | 
									Msg("Data from pollData channel written successfully")
 | 
				
			||||||
 | 
									// TODO: Abstract away all the database calls, this can cause race conditions
 | 
				
			||||||
 | 
									// when an outdated machine object is kept alive, e.g. db is update from
 | 
				
			||||||
 | 
									// command line, but then overwritten.
 | 
				
			||||||
 | 
								err = h.UpdateMachine(&m)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									log.Error().
 | 
				
			||||||
 | 
										Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
										Str("machine", m.Name).
 | 
				
			||||||
 | 
										Str("channel", "pollData").
 | 
				
			||||||
 | 
										Err(err).
 | 
				
			||||||
 | 
										Msg("Cannot update machine from database")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								now := time.Now().UTC()
 | 
				
			||||||
 | 
								m.LastSeen = &now
 | 
				
			||||||
 | 
								m.LastSuccessfulUpdate = &now
 | 
				
			||||||
 | 
								h.db.Save(&m)
 | 
				
			||||||
 | 
								log.Trace().
 | 
				
			||||||
 | 
									Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Str("channel", "pollData").
 | 
				
			||||||
 | 
									Int("bytes", len(data)).
 | 
				
			||||||
 | 
									Msg("Machine updated successfully after sending pollData")
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							case data := <-keepAliveChan:
 | 
				
			||||||
 | 
								log.Trace().
 | 
				
			||||||
 | 
									Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Str("channel", "keepAlive").
 | 
				
			||||||
 | 
									Int("bytes", len(data)).
 | 
				
			||||||
 | 
									Msg("Sending keep alive message")
 | 
				
			||||||
 | 
								_, err := w.Write(data)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									log.Error().
 | 
				
			||||||
 | 
										Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
										Str("machine", m.Name).
 | 
				
			||||||
 | 
										Str("channel", "keepAlive").
 | 
				
			||||||
 | 
										Err(err).
 | 
				
			||||||
 | 
										Msg("Cannot write keep alive message")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								log.Trace().
 | 
				
			||||||
 | 
									Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Str("channel", "keepAlive").
 | 
				
			||||||
 | 
									Int("bytes", len(data)).
 | 
				
			||||||
 | 
									Msg("Keep alive sent successfully")
 | 
				
			||||||
 | 
									// TODO: Abstract away all the database calls, this can cause race conditions
 | 
				
			||||||
 | 
									// when an outdated machine object is kept alive, e.g. db is update from
 | 
				
			||||||
 | 
									// command line, but then overwritten.
 | 
				
			||||||
 | 
								err = h.UpdateMachine(&m)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									log.Error().
 | 
				
			||||||
 | 
										Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
										Str("machine", m.Name).
 | 
				
			||||||
 | 
										Str("channel", "keepAlive").
 | 
				
			||||||
 | 
										Err(err).
 | 
				
			||||||
 | 
										Msg("Cannot update machine from database")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
			now := time.Now().UTC()
 | 
								now := time.Now().UTC()
 | 
				
			||||||
			m.LastSeen = &now
 | 
								m.LastSeen = &now
 | 
				
			||||||
			h.db.Save(&m)
 | 
								h.db.Save(&m)
 | 
				
			||||||
			log.Trace().
 | 
								log.Trace().
 | 
				
			||||||
				Str("handler", "PollNetMapStream").
 | 
									Str("handler", "PollNetMapStream").
 | 
				
			||||||
				Str("machine", m.Name).
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Str("channel", "keepAlive").
 | 
				
			||||||
				Int("bytes", len(data)).
 | 
									Int("bytes", len(data)).
 | 
				
			||||||
				Msg("Machine updated successfully after sending pollData")
 | 
									Msg("Machine updated successfully after sending keep alive")
 | 
				
			||||||
			return true
 | 
								return true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		case <-update:
 | 
							case <-updateChan:
 | 
				
			||||||
			log.Debug().
 | 
								log.Trace().
 | 
				
			||||||
				Str("handler", "PollNetMapStream").
 | 
									Str("handler", "PollNetMapStream").
 | 
				
			||||||
				Str("machine", m.Name).
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Str("channel", "update").
 | 
				
			||||||
				Msg("Received a request for update")
 | 
									Msg("Received a request for update")
 | 
				
			||||||
			data, err := h.getMapResponse(mKey, req, m)
 | 
								if h.isOutdated(&m) {
 | 
				
			||||||
			if err != nil {
 | 
									log.Debug().
 | 
				
			||||||
				log.Error().
 | 
					 | 
				
			||||||
					Str("handler", "PollNetMapStream").
 | 
										Str("handler", "PollNetMapStream").
 | 
				
			||||||
					Str("machine", m.Name).
 | 
										Str("machine", m.Name).
 | 
				
			||||||
					Err(err).
 | 
										Time("last_successful_update", *m.LastSuccessfulUpdate).
 | 
				
			||||||
					Msg("Could not get the map update")
 | 
										Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
 | 
				
			||||||
			}
 | 
										Msgf("There has been updates since the last successful update to %s", m.Name)
 | 
				
			||||||
			_, err = w.Write(*data)
 | 
									data, err := h.getMapResponse(mKey, req, m)
 | 
				
			||||||
			if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
				log.Error().
 | 
										log.Error().
 | 
				
			||||||
 | 
											Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
											Str("machine", m.Name).
 | 
				
			||||||
 | 
											Str("channel", "update").
 | 
				
			||||||
 | 
											Err(err).
 | 
				
			||||||
 | 
											Msg("Could not get the map update")
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									_, err = w.Write(*data)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										log.Error().
 | 
				
			||||||
 | 
											Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
											Str("machine", m.Name).
 | 
				
			||||||
 | 
											Str("channel", "update").
 | 
				
			||||||
 | 
											Err(err).
 | 
				
			||||||
 | 
											Msg("Could not write the map response")
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									log.Trace().
 | 
				
			||||||
					Str("handler", "PollNetMapStream").
 | 
										Str("handler", "PollNetMapStream").
 | 
				
			||||||
					Str("machine", m.Name).
 | 
										Str("machine", m.Name).
 | 
				
			||||||
					Err(err).
 | 
										Str("channel", "update").
 | 
				
			||||||
					Msg("Could not write the map response")
 | 
										Msg("Updated Map has been sent")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										// Keep track of the last successful update,
 | 
				
			||||||
 | 
										// we sometimes end in a state were the update
 | 
				
			||||||
 | 
										// is not picked up by a client and we use this
 | 
				
			||||||
 | 
										// to determine if we should "force" an update.
 | 
				
			||||||
 | 
										// TODO: Abstract away all the database calls, this can cause race conditions
 | 
				
			||||||
 | 
										// when an outdated machine object is kept alive, e.g. db is update from
 | 
				
			||||||
 | 
										// command line, but then overwritten.
 | 
				
			||||||
 | 
									err = h.UpdateMachine(&m)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										log.Error().
 | 
				
			||||||
 | 
											Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
											Str("machine", m.Name).
 | 
				
			||||||
 | 
											Str("channel", "update").
 | 
				
			||||||
 | 
											Err(err).
 | 
				
			||||||
 | 
											Msg("Cannot update machine from database")
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									now := time.Now().UTC()
 | 
				
			||||||
 | 
									m.LastSuccessfulUpdate = &now
 | 
				
			||||||
 | 
									h.db.Save(&m)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									log.Trace().
 | 
				
			||||||
 | 
										Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
										Str("machine", m.Name).
 | 
				
			||||||
 | 
										Time("last_successful_update", *m.LastSuccessfulUpdate).
 | 
				
			||||||
 | 
										Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
 | 
				
			||||||
 | 
										Msgf("%s is up to date", m.Name)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			return true
 | 
								return true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -86,13 +377,78 @@ func (h *Headscale) PollNetMapStream(
 | 
				
			|||||||
				Str("handler", "PollNetMapStream").
 | 
									Str("handler", "PollNetMapStream").
 | 
				
			||||||
				Str("machine", m.Name).
 | 
									Str("machine", m.Name).
 | 
				
			||||||
				Msg("The client has closed the connection")
 | 
									Msg("The client has closed the connection")
 | 
				
			||||||
 | 
									// TODO: Abstract away all the database calls, this can cause race conditions
 | 
				
			||||||
 | 
									// when an outdated machine object is kept alive, e.g. db is update from
 | 
				
			||||||
 | 
									// command line, but then overwritten.
 | 
				
			||||||
 | 
								err := h.UpdateMachine(&m)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									log.Error().
 | 
				
			||||||
 | 
										Str("handler", "PollNetMapStream").
 | 
				
			||||||
 | 
										Str("machine", m.Name).
 | 
				
			||||||
 | 
										Str("channel", "Done").
 | 
				
			||||||
 | 
										Err(err).
 | 
				
			||||||
 | 
										Msg("Cannot update machine from database")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
			now := time.Now().UTC()
 | 
								now := time.Now().UTC()
 | 
				
			||||||
			m.LastSeen = &now
 | 
								m.LastSeen = &now
 | 
				
			||||||
			h.db.Save(&m)
 | 
								h.db.Save(&m)
 | 
				
			||||||
			cancelKeepAlive <- []byte{}
 | 
					
 | 
				
			||||||
			h.clientsPolling.Delete(m.ID)
 | 
								cancelKeepAlive <- struct{}{}
 | 
				
			||||||
			close(update)
 | 
					
 | 
				
			||||||
 | 
								h.closeUpdateChannel(&m)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								close(pollDataChan)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								close(keepAliveChan)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			return false
 | 
								return false
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *Headscale) scheduledPollWorker(
 | 
				
			||||||
 | 
						cancelChan <-chan struct{},
 | 
				
			||||||
 | 
						keepAliveChan chan<- []byte,
 | 
				
			||||||
 | 
						mKey wgkey.Key,
 | 
				
			||||||
 | 
						req tailcfg.MapRequest,
 | 
				
			||||||
 | 
						m Machine,
 | 
				
			||||||
 | 
					) {
 | 
				
			||||||
 | 
						keepAliveTicker := time.NewTicker(60 * time.Second)
 | 
				
			||||||
 | 
						updateCheckerTicker := time.NewTicker(30 * time.Second)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case <-cancelChan:
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							case <-keepAliveTicker.C:
 | 
				
			||||||
 | 
								data, err := h.getMapKeepAliveResponse(mKey, req, m)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									log.Error().
 | 
				
			||||||
 | 
										Str("func", "keepAlive").
 | 
				
			||||||
 | 
										Err(err).
 | 
				
			||||||
 | 
										Msg("Error generating the keep alive msg")
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								log.Debug().
 | 
				
			||||||
 | 
									Str("func", "keepAlive").
 | 
				
			||||||
 | 
									Str("machine", m.Name).
 | 
				
			||||||
 | 
									Msg("Sending keepalive")
 | 
				
			||||||
 | 
								keepAliveChan <- *data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							case <-updateCheckerTicker.C:
 | 
				
			||||||
 | 
								// Send an update request regardless of outdated or not, if data is sent
 | 
				
			||||||
 | 
								// to the node is determined in the updateChan consumer block
 | 
				
			||||||
 | 
								n, _ := m.toNode(true)
 | 
				
			||||||
 | 
								err := h.sendRequestOnUpdateChannel(n)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									log.Error().
 | 
				
			||||||
 | 
										Str("func", "keepAlive").
 | 
				
			||||||
 | 
										Str("machine", m.Name).
 | 
				
			||||||
 | 
										Err(err).
 | 
				
			||||||
 | 
										Msgf("Failed to send update request to %s", m.Name)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										131
									
								
								routes.go
									
									
									
									
									
								
							
							
						
						
									
										131
									
								
								routes.go
									
									
									
									
									
								
							@ -2,55 +2,142 @@ package headscale
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"errors"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/pterm/pterm"
 | 
				
			||||||
	"gorm.io/datatypes"
 | 
						"gorm.io/datatypes"
 | 
				
			||||||
	"inet.af/netaddr"
 | 
						"inet.af/netaddr"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// GetNodeRoutes returns the subnet routes advertised by a node (identified by
 | 
					// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by
 | 
				
			||||||
// namespace and node name)
 | 
					// namespace and node name)
 | 
				
			||||||
func (h *Headscale) GetNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
 | 
					func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
 | 
				
			||||||
	m, err := h.GetMachine(namespace, nodeName)
 | 
						m, err := h.GetMachine(namespace, nodeName)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	hi, err := m.GetHostInfo()
 | 
						hostInfo, err := m.GetHostInfo()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &hi.RoutableIPs, nil
 | 
						return &hostInfo.RoutableIPs, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by
 | 
				
			||||||
 | 
					// namespace and node name)
 | 
				
			||||||
 | 
					func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) {
 | 
				
			||||||
 | 
						m, err := h.GetMachine(namespace, nodeName)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						data, err := m.EnabledRoutes.MarshalJSON()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						routesStr := []string{}
 | 
				
			||||||
 | 
						err = json.Unmarshal(data, &routesStr)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						routes := make([]netaddr.IPPrefix, len(routesStr))
 | 
				
			||||||
 | 
						for index, routeStr := range routesStr {
 | 
				
			||||||
 | 
							route, err := netaddr.ParseIPPrefix(routeStr)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							routes[index] = route
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return routes, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// IsNodeRouteEnabled checks if a certain route has been enabled
 | 
				
			||||||
 | 
					func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool {
 | 
				
			||||||
 | 
						route, err := netaddr.ParseIPPrefix(routeStr)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, enabledRoute := range enabledRoutes {
 | 
				
			||||||
 | 
							if route == enabledRoute {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// EnableNodeRoute enables a subnet route advertised by a node (identified by
 | 
					// EnableNodeRoute enables a subnet route advertised by a node (identified by
 | 
				
			||||||
// namespace and node name)
 | 
					// namespace and node name)
 | 
				
			||||||
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) (*netaddr.IPPrefix, error) {
 | 
					func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error {
 | 
				
			||||||
	m, err := h.GetMachine(namespace, nodeName)
 | 
						m, err := h.GetMachine(namespace, nodeName)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return err
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	hi, err := m.GetHostInfo()
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	route, err := netaddr.ParseIPPrefix(routeStr)
 | 
						route, err := netaddr.ParseIPPrefix(routeStr)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, rIP := range hi.RoutableIPs {
 | 
						availableRoutes, err := h.GetAdvertisedNodeRoutes(namespace, nodeName)
 | 
				
			||||||
		if rIP == route {
 | 
						if err != nil {
 | 
				
			||||||
			routes, _ := json.Marshal([]string{routeStr}) // TODO: only one for the time being, so overwriting the rest
 | 
							return err
 | 
				
			||||||
			m.EnabledRoutes = datatypes.JSON(routes)
 | 
						}
 | 
				
			||||||
			h.db.Save(&m)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
			err = h.RequestMapUpdates(m.NamespaceID)
 | 
						enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName)
 | 
				
			||||||
			if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
				return nil, err
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						available := false
 | 
				
			||||||
 | 
						for _, availableRoute := range *availableRoutes {
 | 
				
			||||||
 | 
							// If the route is available, and not yet enabled, add it to the new routing table
 | 
				
			||||||
 | 
							if route == availableRoute {
 | 
				
			||||||
 | 
								available = true
 | 
				
			||||||
 | 
								if !h.IsNodeRouteEnabled(namespace, nodeName, routeStr) {
 | 
				
			||||||
 | 
									enabledRoutes = append(enabledRoutes, route)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			return &rIP, nil
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil, errors.New("could not find routable range")
 | 
					
 | 
				
			||||||
 | 
						if !available {
 | 
				
			||||||
 | 
							return fmt.Errorf("route (%s) is not available on node %s", nodeName, routeStr)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						routes, err := json.Marshal(enabledRoutes)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m.EnabledRoutes = datatypes.JSON(routes)
 | 
				
			||||||
 | 
						h.db.Save(&m)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = h.RequestMapUpdates(m.NamespaceID)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RoutesToPtables converts the list of routes to a nice table
 | 
				
			||||||
 | 
					func (h *Headscale) RoutesToPtables(namespace string, nodeName string, availableRoutes []netaddr.IPPrefix) pterm.TableData {
 | 
				
			||||||
 | 
						d := pterm.TableData{{"Route", "Enabled"}}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, route := range availableRoutes {
 | 
				
			||||||
 | 
							enabled := h.IsNodeRouteEnabled(namespace, nodeName, route.String())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							d = append(d, []string{route.String(), strconv.FormatBool(enabled)})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return d
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -16,7 +16,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
 | 
				
			|||||||
	pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
 | 
						pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_, err = h.GetMachine("test", "testmachine")
 | 
						_, err = h.GetMachine("test", "test_get_route_machine")
 | 
				
			||||||
	c.Assert(err, check.NotNil)
 | 
						c.Assert(err, check.NotNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	route, err := netaddr.ParseIPPrefix("10.0.0.0/24")
 | 
						route, err := netaddr.ParseIPPrefix("10.0.0.0/24")
 | 
				
			||||||
@ -33,7 +33,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
 | 
				
			|||||||
		MachineKey:     "foo",
 | 
							MachineKey:     "foo",
 | 
				
			||||||
		NodeKey:        "bar",
 | 
							NodeKey:        "bar",
 | 
				
			||||||
		DiscoKey:       "faa",
 | 
							DiscoKey:       "faa",
 | 
				
			||||||
		Name:           "testmachine",
 | 
							Name:           "test_get_route_machine",
 | 
				
			||||||
		NamespaceID:    n.ID,
 | 
							NamespaceID:    n.ID,
 | 
				
			||||||
		Registered:     true,
 | 
							Registered:     true,
 | 
				
			||||||
		RegisterMethod: "authKey",
 | 
							RegisterMethod: "authKey",
 | 
				
			||||||
@ -42,14 +42,87 @@ func (s *Suite) TestGetRoutes(c *check.C) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	h.db.Save(&m)
 | 
						h.db.Save(&m)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	r, err := h.GetNodeRoutes("test", "testmachine")
 | 
						r, err := h.GetAdvertisedNodeRoutes("test", "test_get_route_machine")
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
	c.Assert(len(*r), check.Equals, 1)
 | 
						c.Assert(len(*r), check.Equals, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_, err = h.EnableNodeRoute("test", "testmachine", "192.168.0.0/24")
 | 
						err = h.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24")
 | 
				
			||||||
	c.Assert(err, check.NotNil)
 | 
						c.Assert(err, check.NotNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_, err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24")
 | 
						err = h.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Suite) TestGetEnableRoutes(c *check.C) {
 | 
				
			||||||
 | 
						n, err := h.CreateNamespace("test")
 | 
				
			||||||
	c.Assert(err, check.IsNil)
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine("test", "test_enable_route_machine")
 | 
				
			||||||
 | 
						c.Assert(err, check.NotNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						route, err := netaddr.ParseIPPrefix(
 | 
				
			||||||
 | 
							"10.0.0.0/24",
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						route2, err := netaddr.ParseIPPrefix(
 | 
				
			||||||
 | 
							"150.0.10.0/25",
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						hi := tailcfg.Hostinfo{
 | 
				
			||||||
 | 
							RoutableIPs: []netaddr.IPPrefix{route, route2},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						hostinfo, err := json.Marshal(hi)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m := Machine{
 | 
				
			||||||
 | 
							ID:             0,
 | 
				
			||||||
 | 
							MachineKey:     "foo",
 | 
				
			||||||
 | 
							NodeKey:        "bar",
 | 
				
			||||||
 | 
							DiscoKey:       "faa",
 | 
				
			||||||
 | 
							Name:           "test_enable_route_machine",
 | 
				
			||||||
 | 
							NamespaceID:    n.ID,
 | 
				
			||||||
 | 
							Registered:     true,
 | 
				
			||||||
 | 
							RegisterMethod: "authKey",
 | 
				
			||||||
 | 
							AuthKeyID:      uint(pak.ID),
 | 
				
			||||||
 | 
							HostInfo:       datatypes.JSON(hostinfo),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.db.Save(&m)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						availableRoutes, err := h.GetAdvertisedNodeRoutes("test", "test_enable_route_machine")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						c.Assert(len(*availableRoutes), check.Equals, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						enabledRoutes, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						c.Assert(len(enabledRoutes), check.Equals, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = h.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24")
 | 
				
			||||||
 | 
						c.Assert(err, check.NotNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						enabledRoutes1, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						c.Assert(len(enabledRoutes1), check.Equals, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Adding it twice will just let it pass through
 | 
				
			||||||
 | 
						err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						enabledRoutes2, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						c.Assert(len(enabledRoutes2), check.Equals, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = h.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						enabledRoutes3, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						c.Assert(len(enabledRoutes3), check.Equals, 2)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,4 +1,4 @@
 | 
				
			|||||||
#!/bin/bash
 | 
					#!/usr/bin/env bash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
set -e -o pipefail
 | 
					set -e -o pipefail
 | 
				
			||||||
commit="$1"
 | 
					commit="$1"
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										37
									
								
								sharing.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								sharing.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,37 @@
 | 
				
			|||||||
 | 
					package headscale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "gorm.io/gorm"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const errorSameNamespace = Error("Destination namespace same as origin")
 | 
				
			||||||
 | 
					const errorMachineAlreadyShared = Error("Node already shared to this namespace")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SharedMachine is a join table to support sharing nodes between namespaces
 | 
				
			||||||
 | 
					type SharedMachine struct {
 | 
				
			||||||
 | 
						gorm.Model
 | 
				
			||||||
 | 
						MachineID   uint64
 | 
				
			||||||
 | 
						Machine     Machine
 | 
				
			||||||
 | 
						NamespaceID uint
 | 
				
			||||||
 | 
						Namespace   Namespace
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// AddSharedMachineToNamespace adds a machine as a shared node to a namespace
 | 
				
			||||||
 | 
					func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error {
 | 
				
			||||||
 | 
						if m.NamespaceID == ns.ID {
 | 
				
			||||||
 | 
							return errorSameNamespace
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sharedMachine := SharedMachine{}
 | 
				
			||||||
 | 
						if err := h.db.Where("machine_id = ? AND namespace_id", m.ID, ns.ID).First(&sharedMachine).Error; err == nil {
 | 
				
			||||||
 | 
							return errorMachineAlreadyShared
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sharedMachine = SharedMachine{
 | 
				
			||||||
 | 
							MachineID:   m.ID,
 | 
				
			||||||
 | 
							Machine:     *m,
 | 
				
			||||||
 | 
							NamespaceID: ns.ID,
 | 
				
			||||||
 | 
							Namespace:   *ns,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.db.Save(&sharedMachine)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										359
									
								
								sharing_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										359
									
								
								sharing_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,359 @@
 | 
				
			|||||||
 | 
					package headscale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"gopkg.in/check.v1"
 | 
				
			||||||
 | 
						"tailscale.com/tailcfg"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) {
 | 
				
			||||||
 | 
						n1, err := h.CreateNamespace("shared1")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						n2, err := h.CreateNamespace("shared2")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pak2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
 | 
				
			||||||
 | 
						c.Assert(err, check.NotNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m1 := Machine{
 | 
				
			||||||
 | 
							ID:             0,
 | 
				
			||||||
 | 
							MachineKey:     "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
 | 
				
			||||||
 | 
							NodeKey:        "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
 | 
				
			||||||
 | 
							DiscoKey:       "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
 | 
				
			||||||
 | 
							Name:           "test_get_shared_nodes_1",
 | 
				
			||||||
 | 
							NamespaceID:    n1.ID,
 | 
				
			||||||
 | 
							Registered:     true,
 | 
				
			||||||
 | 
							RegisterMethod: "authKey",
 | 
				
			||||||
 | 
							IPAddress:      "100.64.0.1",
 | 
				
			||||||
 | 
							AuthKeyID:      uint(pak1.ID),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.db.Save(&m1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n1.Name, m1.Name)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m2 := Machine{
 | 
				
			||||||
 | 
							ID:             1,
 | 
				
			||||||
 | 
							MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							Name:           "test_get_shared_nodes_2",
 | 
				
			||||||
 | 
							NamespaceID:    n2.ID,
 | 
				
			||||||
 | 
							Registered:     true,
 | 
				
			||||||
 | 
							RegisterMethod: "authKey",
 | 
				
			||||||
 | 
							IPAddress:      "100.64.0.2",
 | 
				
			||||||
 | 
							AuthKeyID:      uint(pak2.ID),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.db.Save(&m2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n2.Name, m2.Name)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						p1s, err := h.getPeers(m1)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						c.Assert(len(*p1s), check.Equals, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = h.AddSharedMachineToNamespace(&m2, n1)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						p1sAfter, err := h.getPeers(m1)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						c.Assert(len(*p1sAfter), check.Equals, 1)
 | 
				
			||||||
 | 
						c.Assert((*p1sAfter)[0].ID, check.Equals, tailcfg.NodeID(m2.ID))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Suite) TestSameNamespace(c *check.C) {
 | 
				
			||||||
 | 
						n1, err := h.CreateNamespace("shared1")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						n2, err := h.CreateNamespace("shared2")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pak2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
 | 
				
			||||||
 | 
						c.Assert(err, check.NotNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m1 := Machine{
 | 
				
			||||||
 | 
							ID:             0,
 | 
				
			||||||
 | 
							MachineKey:     "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
 | 
				
			||||||
 | 
							NodeKey:        "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
 | 
				
			||||||
 | 
							DiscoKey:       "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
 | 
				
			||||||
 | 
							Name:           "test_get_shared_nodes_1",
 | 
				
			||||||
 | 
							NamespaceID:    n1.ID,
 | 
				
			||||||
 | 
							Registered:     true,
 | 
				
			||||||
 | 
							RegisterMethod: "authKey",
 | 
				
			||||||
 | 
							IPAddress:      "100.64.0.1",
 | 
				
			||||||
 | 
							AuthKeyID:      uint(pak1.ID),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.db.Save(&m1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n1.Name, m1.Name)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m2 := Machine{
 | 
				
			||||||
 | 
							ID:             1,
 | 
				
			||||||
 | 
							MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							Name:           "test_get_shared_nodes_2",
 | 
				
			||||||
 | 
							NamespaceID:    n2.ID,
 | 
				
			||||||
 | 
							Registered:     true,
 | 
				
			||||||
 | 
							RegisterMethod: "authKey",
 | 
				
			||||||
 | 
							IPAddress:      "100.64.0.2",
 | 
				
			||||||
 | 
							AuthKeyID:      uint(pak2.ID),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.db.Save(&m2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n2.Name, m2.Name)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						p1s, err := h.getPeers(m1)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						c.Assert(len(*p1s), check.Equals, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = h.AddSharedMachineToNamespace(&m1, n1)
 | 
				
			||||||
 | 
						c.Assert(err, check.Equals, errorSameNamespace)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Suite) TestAlreadyShared(c *check.C) {
 | 
				
			||||||
 | 
						n1, err := h.CreateNamespace("shared1")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						n2, err := h.CreateNamespace("shared2")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pak2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
 | 
				
			||||||
 | 
						c.Assert(err, check.NotNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m1 := Machine{
 | 
				
			||||||
 | 
							ID:             0,
 | 
				
			||||||
 | 
							MachineKey:     "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
 | 
				
			||||||
 | 
							NodeKey:        "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
 | 
				
			||||||
 | 
							DiscoKey:       "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
 | 
				
			||||||
 | 
							Name:           "test_get_shared_nodes_1",
 | 
				
			||||||
 | 
							NamespaceID:    n1.ID,
 | 
				
			||||||
 | 
							Registered:     true,
 | 
				
			||||||
 | 
							RegisterMethod: "authKey",
 | 
				
			||||||
 | 
							IPAddress:      "100.64.0.1",
 | 
				
			||||||
 | 
							AuthKeyID:      uint(pak1.ID),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.db.Save(&m1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n1.Name, m1.Name)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m2 := Machine{
 | 
				
			||||||
 | 
							ID:             1,
 | 
				
			||||||
 | 
							MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							Name:           "test_get_shared_nodes_2",
 | 
				
			||||||
 | 
							NamespaceID:    n2.ID,
 | 
				
			||||||
 | 
							Registered:     true,
 | 
				
			||||||
 | 
							RegisterMethod: "authKey",
 | 
				
			||||||
 | 
							IPAddress:      "100.64.0.2",
 | 
				
			||||||
 | 
							AuthKeyID:      uint(pak2.ID),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.db.Save(&m2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n2.Name, m2.Name)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						p1s, err := h.getPeers(m1)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						c.Assert(len(*p1s), check.Equals, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = h.AddSharedMachineToNamespace(&m2, n1)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						err = h.AddSharedMachineToNamespace(&m2, n1)
 | 
				
			||||||
 | 
						c.Assert(err, check.Equals, errorMachineAlreadyShared)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) {
 | 
				
			||||||
 | 
						n1, err := h.CreateNamespace("shared1")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						n2, err := h.CreateNamespace("shared2")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pak2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
 | 
				
			||||||
 | 
						c.Assert(err, check.NotNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m1 := Machine{
 | 
				
			||||||
 | 
							ID:             0,
 | 
				
			||||||
 | 
							MachineKey:     "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
 | 
				
			||||||
 | 
							NodeKey:        "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
 | 
				
			||||||
 | 
							DiscoKey:       "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
 | 
				
			||||||
 | 
							Name:           "test_get_shared_nodes_1",
 | 
				
			||||||
 | 
							NamespaceID:    n1.ID,
 | 
				
			||||||
 | 
							Registered:     true,
 | 
				
			||||||
 | 
							RegisterMethod: "authKey",
 | 
				
			||||||
 | 
							IPAddress:      "100.64.0.1",
 | 
				
			||||||
 | 
							AuthKeyID:      uint(pak1.ID),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.db.Save(&m1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n1.Name, m1.Name)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m2 := Machine{
 | 
				
			||||||
 | 
							ID:             1,
 | 
				
			||||||
 | 
							MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							Name:           "test_get_shared_nodes_2",
 | 
				
			||||||
 | 
							NamespaceID:    n2.ID,
 | 
				
			||||||
 | 
							Registered:     true,
 | 
				
			||||||
 | 
							RegisterMethod: "authKey",
 | 
				
			||||||
 | 
							IPAddress:      "100.64.0.2",
 | 
				
			||||||
 | 
							AuthKeyID:      uint(pak2.ID),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.db.Save(&m2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n2.Name, m2.Name)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						p1s, err := h.getPeers(m1)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						c.Assert(len(*p1s), check.Equals, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = h.AddSharedMachineToNamespace(&m2, n1)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						p1sAfter, err := h.getPeers(m1)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						c.Assert(len(*p1sAfter), check.Equals, 1)
 | 
				
			||||||
 | 
						c.Assert(len((*p1sAfter)[0].AllowedIPs), check.Equals, 1)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
 | 
				
			||||||
 | 
						n1, err := h.CreateNamespace("shared1")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						n2, err := h.CreateNamespace("shared2")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						n3, err := h.CreateNamespace("shared3")
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pak2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pak3, err := h.CreatePreAuthKey(n3.Name, false, false, nil)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pak4, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
 | 
				
			||||||
 | 
						c.Assert(err, check.NotNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m1 := Machine{
 | 
				
			||||||
 | 
							ID:             0,
 | 
				
			||||||
 | 
							MachineKey:     "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
 | 
				
			||||||
 | 
							NodeKey:        "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
 | 
				
			||||||
 | 
							DiscoKey:       "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
 | 
				
			||||||
 | 
							Name:           "test_get_shared_nodes_1",
 | 
				
			||||||
 | 
							NamespaceID:    n1.ID,
 | 
				
			||||||
 | 
							Registered:     true,
 | 
				
			||||||
 | 
							RegisterMethod: "authKey",
 | 
				
			||||||
 | 
							IPAddress:      "100.64.0.1",
 | 
				
			||||||
 | 
							AuthKeyID:      uint(pak1.ID),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.db.Save(&m1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n1.Name, m1.Name)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m2 := Machine{
 | 
				
			||||||
 | 
							ID:             1,
 | 
				
			||||||
 | 
							MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							Name:           "test_get_shared_nodes_2",
 | 
				
			||||||
 | 
							NamespaceID:    n2.ID,
 | 
				
			||||||
 | 
							Registered:     true,
 | 
				
			||||||
 | 
							RegisterMethod: "authKey",
 | 
				
			||||||
 | 
							IPAddress:      "100.64.0.2",
 | 
				
			||||||
 | 
							AuthKeyID:      uint(pak2.ID),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.db.Save(&m2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n2.Name, m2.Name)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m3 := Machine{
 | 
				
			||||||
 | 
							ID:             2,
 | 
				
			||||||
 | 
							MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							Name:           "test_get_shared_nodes_3",
 | 
				
			||||||
 | 
							NamespaceID:    n3.ID,
 | 
				
			||||||
 | 
							Registered:     true,
 | 
				
			||||||
 | 
							RegisterMethod: "authKey",
 | 
				
			||||||
 | 
							IPAddress:      "100.64.0.3",
 | 
				
			||||||
 | 
							AuthKeyID:      uint(pak3.ID),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.db.Save(&m3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n3.Name, m3.Name)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m4 := Machine{
 | 
				
			||||||
 | 
							ID:             3,
 | 
				
			||||||
 | 
							MachineKey:     "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							NodeKey:        "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							DiscoKey:       "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
 | 
				
			||||||
 | 
							Name:           "test_get_shared_nodes_4",
 | 
				
			||||||
 | 
							NamespaceID:    n1.ID,
 | 
				
			||||||
 | 
							Registered:     true,
 | 
				
			||||||
 | 
							RegisterMethod: "authKey",
 | 
				
			||||||
 | 
							IPAddress:      "100.64.0.4",
 | 
				
			||||||
 | 
							AuthKeyID:      uint(pak4.ID),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						h.db.Save(&m4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = h.GetMachine(n1.Name, m4.Name)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						p1s, err := h.getPeers(m1)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						c.Assert(len(*p1s), check.Equals, 1) // nodes 1 and 4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = h.AddSharedMachineToNamespace(&m2, n1)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						p1sAfter, err := h.getPeers(m1)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						c.Assert(len(*p1sAfter), check.Equals, 2) // nodes 1, 2, 4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pAlone, err := h.getPeers(m3)
 | 
				
			||||||
 | 
						c.Assert(err, check.IsNil)
 | 
				
			||||||
 | 
						c.Assert(len(*pAlone), check.Equals, 0) // node 3 is alone
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user