1
0
mirror of https://github.com/juanfont/headscale.git synced 2024-12-30 00:09:42 +01:00

handle route updates correctly

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-09-11 11:45:46 -05:00 committed by Kristoffer Dalby
parent c957f893bd
commit 096ac31bb3
3 changed files with 29 additions and 8 deletions

View File

@ -278,6 +278,12 @@ func (hsdb *HSDatabase) saveMachineRoutes(machine *types.Machine) error {
advertisedRoutes[prefix] = false advertisedRoutes[prefix] = false
} }
log.Trace().
Str("machine", machine.Hostname).
Interface("advertisedRoutes", advertisedRoutes).
Interface("currentRoutes", currentRoutes).
Msg("updating routes")
for pos, route := range currentRoutes { for pos, route := range currentRoutes {
if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok { if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok {
if !route.Advertised { if !route.Advertised {

View File

@ -66,6 +66,9 @@ func (h *Headscale) handlePoll(
) { ) {
logInfo, logErr := logPollFunc(mapRequest, machine, isNoise) logInfo, logErr := logPollFunc(mapRequest, machine, isNoise)
// This is the mechanism where the node gives us inforamtion about its
// current configuration.
//
// If OmitPeers is true, Stream is false, and ReadOnly is false, // If OmitPeers is true, Stream is false, and ReadOnly is false,
// then te server will let clients update their endpoints without // then te server will let clients update their endpoints without
// breaking existing long-polling (Stream == true) connections. // breaking existing long-polling (Stream == true) connections.
@ -84,8 +87,11 @@ func (h *Headscale) handlePoll(
Msg("Received endpoint update") Msg("Received endpoint update")
now := time.Now().UTC() now := time.Now().UTC()
machine.Endpoints = mapRequest.Endpoints
machine.LastSeen = &now machine.LastSeen = &now
machine.Hostname = mapRequest.Hostinfo.Hostname
machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
machine.Endpoints = mapRequest.Endpoints
if err := h.db.MachineSave(machine); err != nil { if err := h.db.MachineSave(machine); err != nil {
logErr(err, "Failed to persist/update machine in the database") logErr(err, "Failed to persist/update machine in the database")
@ -94,6 +100,14 @@ func (h *Headscale) handlePoll(
return return
} }
err := h.db.SaveMachineRoutes(machine)
if err != nil {
logErr(err, "Error processing machine routes")
http.Error(writer, "", http.StatusInternalServerError)
return
}
h.nodeNotifier.NotifyWithIgnore( h.nodeNotifier.NotifyWithIgnore(
types.StateUpdate{ types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
@ -134,6 +148,8 @@ func (h *Headscale) handlePoll(
return return
} }
now := time.Now().UTC()
machine.LastSeen = &now
machine.Hostname = mapRequest.Hostinfo.Hostname machine.Hostname = mapRequest.Hostinfo.Hostname
machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo) machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)

View File

@ -413,14 +413,12 @@ func TestEnablingRoutes(t *testing.T) {
// advertise routes using the up command // advertise routes using the up command
for i, client := range allClients { for i, client := range allClients {
routeStr := fmt.Sprintf("10.0.%d.0/24", i) routeStr := fmt.Sprintf("10.0.%d.0/24", i)
hostname, _ := client.FQDN() command := []string{
_, _, err = client.Execute([]string{
"tailscale", "tailscale",
"up", "set",
fmt.Sprintf("--advertise-routes=%s", routeStr), "--advertise-routes=" + routeStr,
"-login-server", headscale.GetEndpoint(), }
"--hostname", hostname, _, _, err := client.Execute(command)
})
assertNoErrf(t, "failed to advertise route: %s", err) assertNoErrf(t, "failed to advertise route: %s", err)
} }
@ -474,6 +472,7 @@ func TestEnablingRoutes(t *testing.T) {
&enablingRoutes, &enablingRoutes,
) )
assertNoErr(t, err) assertNoErr(t, err)
assert.Len(t, enablingRoutes, 3)
for _, route := range enablingRoutes { for _, route := range enablingRoutes {
assert.Equal(t, route.Advertised, true) assert.Equal(t, route.Advertised, true)