@ -10,11 +10,22 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/puzpuzpuz/xsync/v3"
"gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
var smap = func ( m map [ types . NodeID ] bool ) * xsync . MapOf [ types . NodeID , bool ] {
s := xsync . NewMapOf [ types . NodeID , bool ] ( )
for k , v := range m {
s . Store ( k , v )
}
return s
}
func ( s * Suite ) TestGetRoutes ( c * check . C ) {
user , err := db . CreateUser ( "test" )
c . Assert ( err , check . IsNil )
@ -331,7 +342,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
name string
nodes types . Nodes
routes types . Routes
isConnected [ ] types . NodeConnectedMap
isConnected [ ] map [ types . NodeID ] bool
want [ ] * types . StateUpdate
wantErr bool
} {
@ -346,7 +357,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
r ( 1 , 1 , ipp ( "10.0.0.0/24" ) , true , true ) ,
r ( 2 , 2 , ipp ( "10.0.0.0/24" ) , true , false ) ,
} ,
isConnected : [ ] types . NodeConnectedMap {
isConnected : [ ] map [ types . NodeID ] bool {
// n1 goes down
{
1 : false ,
@ -384,7 +395,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
r ( 1 , 1 , ipp ( "10.0.0.0/24" ) , true , true ) ,
r ( 2 , 2 , ipp ( "10.0.0.0/24" ) , true , false ) ,
} ,
isConnected : [ ] types . NodeConnectedMap {
isConnected : [ ] map [ types . NodeID ] bool {
// n1 up recon = noop
{
1 : true ,
@ -428,7 +439,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
r ( 2 , 2 , ipp ( "10.0.0.0/24" ) , true , false ) ,
r ( 3 , 3 , ipp ( "10.0.0.0/24" ) , true , false ) ,
} ,
isConnected : [ ] types . NodeConnectedMap {
isConnected : [ ] map [ types . NodeID ] bool {
// n1 goes down
{
1 : false ,
@ -486,7 +497,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
r ( 2 , 2 , ipp ( "10.0.0.0/24" ) , false , false ) ,
r ( 3 , 3 , ipp ( "10.0.0.0/24" ) , true , false ) ,
} ,
isConnected : [ ] types . NodeConnectedMap {
isConnected : [ ] map [ types . NodeID ] bool {
// n1 goes down
{
1 : false ,
@ -516,7 +527,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
r ( 2 , 2 , ipp ( "10.0.0.0/24" ) , true , false ) ,
r ( 3 , 3 , ipp ( "10.1.0.0/24" ) , true , false ) ,
} ,
isConnected : [ ] types . NodeConnectedMap {
isConnected : [ ] map [ types . NodeID ] bool {
// n1 goes down
{
1 : false ,
@ -539,7 +550,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
r ( 2 , 2 , ipp ( "10.0.0.0/24" ) , true , false ) ,
r ( 3 , 3 , ipp ( "10.1.0.0/24" ) , false , false ) ,
} ,
isConnected : [ ] types . NodeConnectedMap {
isConnected : [ ] map [ types . NodeID ] bool {
// n1 goes down
{
1 : false ,
@ -562,7 +573,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
r ( 2 , 2 , ipp ( "10.0.0.0/24" ) , true , false ) ,
r ( 3 , 3 , ipp ( "10.1.0.0/24" ) , true , false ) ,
} ,
isConnected : [ ] types . NodeConnectedMap {
isConnected : [ ] map [ types . NodeID ] bool {
// n1 goes down
{
1 : false ,
@ -585,7 +596,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
r ( 2 , 2 , ipp ( "10.0.0.0/24" ) , true , true ) ,
r ( 3 , 3 , ipp ( "10.1.0.0/24" ) , true , false ) ,
} ,
isConnected : [ ] types . NodeConnectedMap {
isConnected : [ ] map [ types . NodeID ] bool {
// n1 goes down
{
1 : true ,
@ -618,7 +629,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
want := tt . want [ step ]
got , err := Write ( db . DB , func ( tx * gorm . DB ) ( * types . StateUpdate , error ) {
return FailoverNodeRoutesIfNeccessary ( tx , isConnected , node )
return FailoverNodeRoutesIfNeccessary ( tx , smap ( isConnected ) , node )
} )
if ( err != nil ) != tt . wantErr {
@ -640,7 +651,7 @@ func TestFailoverRouteTx(t *testing.T) {
name string
failingRoute types . Route
routes types . Routes
isConnected types . NodeConnectedMap
isConnected map [ types . NodeID ] bool
want [ ] types . NodeID
wantErr bool
} {
@ -743,7 +754,7 @@ func TestFailoverRouteTx(t *testing.T) {
Enabled : true ,
} ,
} ,
isConnected : types . NodeConnectedMap {
isConnected : map [ types . NodeID ] bool {
1 : false ,
2 : true ,
} ,
@ -841,7 +852,7 @@ func TestFailoverRouteTx(t *testing.T) {
Enabled : true ,
} ,
} ,
isConnected : types . NodeConnectedMap {
isConnected : map [ types . NodeID ] bool {
1 : true ,
2 : true ,
3 : true ,
@ -889,7 +900,7 @@ func TestFailoverRouteTx(t *testing.T) {
Enabled : true ,
} ,
} ,
isConnected : types . NodeConnectedMap {
isConnected : map [ types . NodeID ] bool {
1 : true ,
4 : false ,
} ,
@ -945,7 +956,7 @@ func TestFailoverRouteTx(t *testing.T) {
Enabled : true ,
} ,
} ,
isConnected : types . NodeConnectedMap {
isConnected : map [ types . NodeID ] bool {
1 : false ,
2 : true ,
4 : false ,
@ -1010,7 +1021,7 @@ func TestFailoverRouteTx(t *testing.T) {
}
got , err := Write ( db . DB , func ( tx * gorm . DB ) ( [ ] types . NodeID , error ) {
return failoverRouteTx ( tx , tt . isConnected , & tt . failingRoute )
return failoverRouteTx ( tx , smap ( tt . isConnected ) , & tt . failingRoute )
} )
if ( err != nil ) != tt . wantErr {
@ -1048,7 +1059,7 @@ func TestFailoverRoute(t *testing.T) {
name string
failingRoute types . Route
routes types . Routes
isConnected types . NodeConnectedMap
isConnected map [ types . NodeID ] bool
want * failover
} {
{
@ -1085,7 +1096,7 @@ func TestFailoverRoute(t *testing.T) {
r ( 1 , 1 , ipp ( "10.0.0.0/24" ) , true , true ) ,
r ( 2 , 2 , ipp ( "10.0.0.0/24" ) , true , false ) ,
} ,
isConnected : types . NodeConnectedMap {
isConnected : map [ types . NodeID ] bool {
1 : false ,
2 : true ,
} ,
@ -1111,7 +1122,7 @@ func TestFailoverRoute(t *testing.T) {
r ( 2 , 2 , ipp ( "10.0.0.0/24" ) , true , true ) ,
r ( 3 , 3 , ipp ( "10.0.0.0/24" ) , true , false ) ,
} ,
isConnected : types . NodeConnectedMap {
isConnected : map [ types . NodeID ] bool {
1 : true ,
2 : true ,
3 : true ,
@ -1128,7 +1139,7 @@ func TestFailoverRoute(t *testing.T) {
r ( 1 , 1 , ipp ( "10.0.0.0/24" ) , true , true ) ,
r ( 2 , 4 , ipp ( "10.0.0.0/24" ) , true , false ) ,
} ,
isConnected : types . NodeConnectedMap {
isConnected : map [ types . NodeID ] bool {
1 : true ,
4 : false ,
} ,
@ -1142,7 +1153,7 @@ func TestFailoverRoute(t *testing.T) {
r ( 2 , 4 , ipp ( "10.0.0.0/24" ) , true , false ) ,
r ( 3 , 2 , ipp ( "10.0.0.0/24" ) , true , false ) ,
} ,
isConnected : types . NodeConnectedMap {
isConnected : map [ types . NodeID ] bool {
1 : false ,
2 : true ,
4 : false ,
@ -1172,7 +1183,7 @@ func TestFailoverRoute(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
gotf := failoverRoute ( tt . isConnected , & tt . failingRoute , tt . routes )
gotf := failoverRoute ( smap ( tt . isConnected ) , & tt . failingRoute , tt . routes )
if tt . want == nil && gotf != nil {
t . Fatalf ( "expected nil, got %+v" , gotf )