1
0
mirror of https://github.com/juanfont/headscale.git synced 2024-12-20 19:09:07 +01:00

Set tags as part of handleAuthKeyCommon

This commit is contained in:
Benjamin George Roberts 2022-08-25 20:43:15 +10:00
parent 6faa1d2e4a
commit ac18723dd4
5 changed files with 75 additions and 6 deletions

View File

@ -106,6 +106,18 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
expiration = request.GetExpiration().AsTime()
}
if len(request.AclTags) > 0 {
for _, tag := range request.AclTags {
err := validateTag(tag)
if err != nil {
return &v1.CreatePreAuthKeyResponse{
PreAuthKey: nil,
}, status.Error(codes.InvalidArgument, err.Error())
}
}
}
preAuthKey, err := api.h.CreatePreAuthKey(
request.GetNamespace(),
request.GetReusable(),

View File

@ -260,6 +260,8 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
"24h",
"--output",
"json",
"--tags",
"tag:test1,tag:test2",
},
[]string{},
)
@ -333,6 +335,11 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
)
// Test that tags are present
for i := 0; i < count; i++ {
assert.DeepEquals(listedPreAuthKeys[i].AclTags, []string{"tag:test1,", "tag:test2"})
}
// Expire three keys
for i := 0; i < 3; i++ {
_, err := ExecuteCommand(

View File

@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"strconv"
"strings"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -55,6 +56,12 @@ func (h *Headscale) CreatePreAuthKey(
return nil, err
}
for _, tag := range aclTags {
if !strings.HasPrefix(tag, "tag:") {
return nil, fmt.Errorf("aclTag '%s' did not begin with 'tag:'", tag)
}
}
now := time.Now().UTC()
kstr, err := h.generateKey()
if err != nil {
@ -77,12 +84,17 @@ func (h *Headscale) CreatePreAuthKey(
}
if len(aclTags) > 0 {
seenTags := map[string]bool{}
for _, tag := range aclTags {
if err := db.Save(&PreAuthKeyAclTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
return fmt.Errorf(
"failed to create key tag in the database: %w",
err,
)
if seenTags[tag] == false {
if err := db.Save(&PreAuthKeyAclTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
return fmt.Errorf(
"failed to ceate key tag in the database: %w",
err,
)
}
seenTags[tag] = true
}
}
}
@ -222,7 +234,7 @@ func (key *PreAuthKey) toProto() *v1.PreAuthKey {
if len(key.AclTags) > 0 {
for idx := range key.AclTags {
protoKey.AclTags[idx] = key.AclTags[0].Tag
protoKey.AclTags[idx] = key.AclTags[idx].Tag
}
}

View File

@ -190,3 +190,20 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
_, err = app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
}
func (*Suite) TestPreAuthKeyAclTags(c *check.C) {
namespace, err := app.CreateNamespace("test8")
c.Assert(err, check.IsNil)
_, err = app.CreatePreAuthKey(namespace.Name, false, false, nil, []string{"badtag"})
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
tags := []string{"tag:test1", "tag:test2"}
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
_, err = app.CreatePreAuthKey(namespace.Name, false, false, nil, tagsWithDuplicate)
c.Assert(err, check.IsNil)
listedPaks, err := app.ListPreAuthKeys("test8")
c.Assert(err, check.IsNil)
c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags)
}

View File

@ -345,6 +345,7 @@ func (h *Headscale) handleAuthKeyCommon(
machine.NodeKey = nodeKey
machine.AuthKeyID = uint(pak.ID)
err := h.RefreshMachine(machine, registerRequest.Expiry)
if err != nil {
log.Error().
Caller().
@ -355,6 +356,25 @@ func (h *Headscale) handleAuthKeyCommon(
return
}
aclTags := pak.toProto().AclTags
if len(aclTags) > 0 {
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
err = h.SetTags(machine, aclTags)
}
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Str("machine", machine.Hostname).
Strs("aclTags", aclTags).
Err(err).
Msg("Failed to set tags after refreshing machine")
return
}
} else {
now := time.Now().UTC()
@ -380,6 +400,7 @@ func (h *Headscale) handleAuthKeyCommon(
NodeKey: nodeKey,
LastSeen: &now,
AuthKeyID: uint(pak.ID),
ForcedTags: pak.toProto().AclTags,
}
machine, err = h.RegisterMachine(