From 5b0be44391bfac5952f50a8078789e0b18f59c32 Mon Sep 17 00:00:00 2001 From: Andrew Harding Date: Wed, 4 Oct 2023 15:24:46 -0600 Subject: [PATCH] New Mutable Authorized Entry Cache (#4451) Signed-off-by: Andrew Harding Signed-off-by: Faisal Memon --- go.mod | 1 + go.sum | 2 + pkg/server/authorizedentries/agent.go | 26 + pkg/server/authorizedentries/agent_test.go | 64 +++ pkg/server/authorizedentries/aliases.go | 49 ++ pkg/server/authorizedentries/aliases_test.go | 70 +++ pkg/server/authorizedentries/cache.go | 293 +++++++++++ pkg/server/authorizedentries/cache_test.go | 456 ++++++++++++++++++ pkg/server/authorizedentries/entries.go | 45 ++ pkg/server/authorizedentries/entries_test.go | 59 +++ pkg/server/authorizedentries/recordpool.go | 22 + pkg/server/authorizedentries/selectorset.go | 28 ++ pkg/server/authorizedentries/stringset.go | 28 ++ pkg/server/cache/entrycache/fullcache_test.go | 75 ++- 14 files changed, 1179 insertions(+), 39 deletions(-) create mode 100644 pkg/server/authorizedentries/agent.go create mode 100644 pkg/server/authorizedentries/agent_test.go create mode 100644 pkg/server/authorizedentries/aliases.go create mode 100644 pkg/server/authorizedentries/aliases_test.go create mode 100644 pkg/server/authorizedentries/cache.go create mode 100644 pkg/server/authorizedentries/cache_test.go create mode 100644 pkg/server/authorizedentries/entries.go create mode 100644 pkg/server/authorizedentries/entries_test.go create mode 100644 pkg/server/authorizedentries/recordpool.go create mode 100644 pkg/server/authorizedentries/selectorset.go create mode 100644 pkg/server/authorizedentries/stringset.go diff --git a/go.mod b/go.mod index 1ab7cf3d83..d1cb3b6f62 100644 --- a/go.mod +++ b/go.mod @@ -40,6 +40,7 @@ require ( github.com/gofrs/uuid v4.4.0+incompatible github.com/gofrs/uuid/v5 v5.0.0 github.com/golang/protobuf v1.5.3 + github.com/google/btree v1.1.2 github.com/google/go-cmp v0.5.9 github.com/google/go-containerregistry v0.16.1 github.com/google/go-tpm v0.9.0 diff --git a/go.sum b/go.sum index 41b92a65dd..829df276c8 100644 --- a/go.sum +++ b/go.sum @@ -996,6 +996,8 @@ github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= +github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/certificate-transparency-go v1.1.6 h1:SW5K3sr7ptST/pIvNkSVWMiJqemRmkjJPPT0jzXdOOY= github.com/google/certificate-transparency-go v1.1.6/go.mod h1:0OJjOsOk+wj6aYQgP7FU0ioQ0AJUmnWPFMqTjQeazPQ= github.com/google/flatbuffers v23.5.26+incompatible h1:M9dgRyhJemaM4Sw8+66GHBu8ioaQmyPLg1b8VwK5WJg= diff --git a/pkg/server/authorizedentries/agent.go b/pkg/server/authorizedentries/agent.go new file mode 100644 index 0000000000..2a58326350 --- /dev/null +++ b/pkg/server/authorizedentries/agent.go @@ -0,0 +1,26 @@ +package authorizedentries + +type agentRecord struct { + ID string + + // ExpiresAt is seconds since unix epoch. Using intead of time.Time for + // reduced memory usage and better cache locality. + ExpiresAt int64 + + Selectors selectorSet +} + +func agentRecordByID(a, b agentRecord) bool { + return a.ID < b.ID +} + +func agentRecordByExpiresAt(a, b agentRecord) bool { + switch { + case a.ExpiresAt < b.ExpiresAt: + return true + case a.ExpiresAt > b.ExpiresAt: + return false + default: + return a.ID < b.ID + } +} diff --git a/pkg/server/authorizedentries/agent_test.go b/pkg/server/authorizedentries/agent_test.go new file mode 100644 index 0000000000..884ee2f37d --- /dev/null +++ b/pkg/server/authorizedentries/agent_test.go @@ -0,0 +1,64 @@ +package authorizedentries + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAgentRecordSize(t *testing.T) { + // The motivation for this test is to bring awareness and visibility into + // how much size the record occupies. We want to minimize the size to + // increase cache locality in the btree. + require.Equal(t, uintptr(32), unsafe.Sizeof(agentRecord{})) +} + +func TestAgentRecordByID(t *testing.T) { + assertLess := func(lesser, greater agentRecord) { + t.Helper() + assert.Truef(t, agentRecordByID(lesser, greater), "expected A%sE%sA%sE%s", greater.ID, greater.ExpiresAt, lesser.ID, lesser.ExpiresAt) + } + + // ExpiresAt is irrelevant. + records := []agentRecord{ + agentRecord{ID: "1", ExpiresAt: 9999}, + agentRecord{ID: "2", ExpiresAt: 8888}, + } + + lesser := agentRecord{} + for _, greater := range records { + assertLess(lesser, greater) + lesser = greater + } + + // Since there should only be one agent record by ID, the ExpiresAt field + // is ignored for purposes of placement in the btree. + assert.False(t, agentRecordByID(agentRecord{ID: "FOO", ExpiresAt: 1}, agentRecord{ID: "FOO", ExpiresAt: 2})) + assert.False(t, agentRecordByID(agentRecord{ID: "FOO", ExpiresAt: 2}, agentRecord{ID: "FOO", ExpiresAt: 1})) +} + +func TestAgentRecordByExpiresAt(t *testing.T) { + assertLess := func(lesser, greater agentRecord) { + t.Helper() + assert.Truef(t, agentRecordByExpiresAt(lesser, greater), "expected A%sE%dA%sE%d", greater.ID, greater.ExpiresAt, lesser.ID, lesser.ExpiresAt) + } + + records := []agentRecord{ + agentRecord{ID: "1"}, + agentRecord{ID: "2"}, + agentRecord{ID: "1", ExpiresAt: 1}, + agentRecord{ID: "2", ExpiresAt: 1}, + agentRecord{ID: "1", ExpiresAt: 2}, + agentRecord{ID: "2", ExpiresAt: 2}, + } + + lesser := agentRecord{} + for _, greater := range records { + assertLess(lesser, greater) + lesser = greater + } +} diff --git a/pkg/server/authorizedentries/aliases.go b/pkg/server/authorizedentries/aliases.go new file mode 100644 index 0000000000..42afc0551e --- /dev/null +++ b/pkg/server/authorizedentries/aliases.go @@ -0,0 +1,49 @@ +package authorizedentries + +type aliasRecord struct { + // EntryID is the ID of the registration entry that defines this node + // alias. + EntryID string + + // AliasID is the SPIFFE ID of nodes that match this alias. + AliasID string + + // Selector is the specific selector we use to fan out to this record + // during the crawl. + Selector Selector + + // AllSelectors is here out of convenience to verify that the agent + // possesses a superset of the alias's selectors and is therefore + // authorized for the alias. + AllSelectors selectorSet +} + +func aliasRecordByEntryID(a, b aliasRecord) bool { + switch { + case a.EntryID < b.EntryID: + return true + case a.EntryID > b.EntryID: + return false + case a.Selector.Type < b.Selector.Type: + return true + case a.Selector.Type > b.Selector.Type: + return false + default: + return a.Selector.Value < b.Selector.Value + } +} + +func aliasRecordBySelector(a, b aliasRecord) bool { + switch { + case a.Selector.Type < b.Selector.Type: + return true + case a.Selector.Type > b.Selector.Type: + return false + case a.Selector.Value < b.Selector.Value: + return true + case a.Selector.Value > b.Selector.Value: + return false + default: + return a.EntryID < b.EntryID + } +} diff --git a/pkg/server/authorizedentries/aliases_test.go b/pkg/server/authorizedentries/aliases_test.go new file mode 100644 index 0000000000..106d7302cf --- /dev/null +++ b/pkg/server/authorizedentries/aliases_test.go @@ -0,0 +1,70 @@ +package authorizedentries + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAliasRecordSize(t *testing.T) { + // The motivation for this test is to bring awareness and visibility into + // how much size the record occupies. We want to minimize the size to + // increase cache locality in the btree. + require.Equal(t, uintptr(72), unsafe.Sizeof(aliasRecord{})) +} + +func TestAliasRecordByEntryID(t *testing.T) { + assertLess := func(lesser, greater aliasRecord) { + t.Helper() + assert.Truef(t, aliasRecordByEntryID(lesser, greater), "expected E%sP%sE%sP%s", greater.EntryID, greater.Selector, lesser.EntryID, lesser.Selector) + } + + records := []aliasRecord{ + aliasRecord{EntryID: "1"}, + aliasRecord{EntryID: "1", Selector: Selector{Type: "1", Value: "1"}}, + aliasRecord{EntryID: "1", Selector: Selector{Type: "1", Value: "2"}}, + aliasRecord{EntryID: "1", Selector: Selector{Type: "2", Value: "1"}}, + aliasRecord{EntryID: "1", Selector: Selector{Type: "2", Value: "2"}}, + aliasRecord{EntryID: "2"}, + aliasRecord{EntryID: "2", Selector: Selector{Type: "1", Value: "1"}}, + aliasRecord{EntryID: "2", Selector: Selector{Type: "1", Value: "2"}}, + aliasRecord{EntryID: "2", Selector: Selector{Type: "2", Value: "1"}}, + aliasRecord{EntryID: "2", Selector: Selector{Type: "2", Value: "2"}}, + } + + lesser := aliasRecord{} + for _, greater := range records { + assertLess(lesser, greater) + lesser = greater + } +} + +func TestAliasRecordBySelector(t *testing.T) { + assertLess := func(lesser, greater aliasRecord) { + t.Helper() + assert.True(t, aliasRecordBySelector(lesser, greater), "expected P%sE%sP%sE%s", greater.Selector, greater.EntryID, lesser.Selector, lesser.EntryID) + } + + records := []aliasRecord{ + aliasRecord{Selector: Selector{Type: "1", Value: "1"}}, + aliasRecord{Selector: Selector{Type: "1", Value: "1"}, EntryID: "1"}, + aliasRecord{Selector: Selector{Type: "1", Value: "1"}, EntryID: "2"}, + aliasRecord{Selector: Selector{Type: "1", Value: "2"}, EntryID: "1"}, + aliasRecord{Selector: Selector{Type: "1", Value: "2"}, EntryID: "2"}, + aliasRecord{Selector: Selector{Type: "2", Value: "1"}}, + aliasRecord{Selector: Selector{Type: "2", Value: "1"}, EntryID: "1"}, + aliasRecord{Selector: Selector{Type: "2", Value: "1"}, EntryID: "2"}, + aliasRecord{Selector: Selector{Type: "2", Value: "2"}}, + aliasRecord{Selector: Selector{Type: "2", Value: "2"}, EntryID: "1"}, + aliasRecord{Selector: Selector{Type: "2", Value: "2"}, EntryID: "2"}, + } + lesser := aliasRecord{} + for _, greater := range records { + assertLess(lesser, greater) + lesser = greater + } +} diff --git a/pkg/server/authorizedentries/cache.go b/pkg/server/authorizedentries/cache.go new file mode 100644 index 0000000000..293789577e --- /dev/null +++ b/pkg/server/authorizedentries/cache.go @@ -0,0 +1,293 @@ +package authorizedentries + +import ( + "fmt" + "sync" + "time" + + "github.com/google/btree" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/spiffe/spire-api-sdk/proto/spire/api/types" + "github.com/spiffe/spire/pkg/common/idutil" +) + +const ( + // We can tweak these degrees to try and get optimal L1 cache use but + // it's probably not worth it unless we have benchmarks showing that it + // is a problem at scale in production. Initial benchmarking by myself + // at similar scale to some of our bigger, existing deployments didn't + // seem to yield much difference. As such, these values are probably an + // ok jumping off point. + agentRecordDegree = 32 + aliasRecordDegree = 32 + entryDegree = 32 +) + +type Selector struct { + Type string + Value string +} + +func (s Selector) String() string { + return s.Type + ":" + s.Value +} + +type Cache struct { + mu sync.RWMutex + + agentsByID *btree.BTreeG[agentRecord] + agentsByExpiresAt *btree.BTreeG[agentRecord] + + aliasesByEntryID *btree.BTreeG[aliasRecord] + aliasesBySelector *btree.BTreeG[aliasRecord] + + entriesByEntryID *btree.BTreeG[entryRecord] + entriesByParentID *btree.BTreeG[entryRecord] +} + +func NewCache() *Cache { + return &Cache{ + agentsByID: btree.NewG(agentRecordDegree, agentRecordByID), + agentsByExpiresAt: btree.NewG(agentRecordDegree, agentRecordByExpiresAt), + aliasesByEntryID: btree.NewG(aliasRecordDegree, aliasRecordByEntryID), + aliasesBySelector: btree.NewG(aliasRecordDegree, aliasRecordBySelector), + entriesByEntryID: btree.NewG(entryDegree, entryRecordByEntryID), + entriesByParentID: btree.NewG(entryDegree, entryRecordByParentID), + } +} + +func (c *Cache) GetAuthorizedEntries(agentID spiffeid.ID) []*types.Entry { + c.mu.RLock() + defer c.mu.RUnlock() + + // Load up the agent selectors. If the agent info does not exist, it is + // likely that the cache is still catching up to a recent attestation. + // Since the calling agent has already been authorized and authenticated, + // it is safe to continue with the authorized entry crawl to obtain entries + // that are directly parented against the agent. Any entries that would be + // obtained via node aliasing will not be returned until the cache is + // updated with the node selectors for the agent. + agent, _ := c.agentsByID.Get(agentRecord{ID: agentID.String()}) + + parentSeen := allocStringSet() + defer freeStringSet(parentSeen) + + records := allocRecordSlice() + defer freeRecordSlice(records) + + records = c.appendDescendents(records, agentID.String(), parentSeen) + + agentAliases := c.getAgentAliases(agent.Selectors) + for _, alias := range agentAliases { + records = c.appendDescendents(records, alias.AliasID, parentSeen) + } + + return cloneEntriesFromRecords(records) +} + +func (c *Cache) UpdateEntry(entry *types.Entry) { + c.mu.Lock() + defer c.mu.Unlock() + + c.removeEntry(entry.Id) + c.updateEntry(entry) +} + +func (c *Cache) RemoveEntry(entryID string) { + c.mu.Lock() + defer c.mu.Unlock() + + c.removeEntry(entryID) +} + +func (c *Cache) UpdateAgent(agentID string, expiresAt time.Time, selectors []*types.Selector) { + c.mu.Lock() + defer c.mu.Unlock() + + agent := agentRecord{ + ID: agentID, + ExpiresAt: expiresAt.Unix(), + Selectors: selectorSetFromProto(selectors), + } + + // Need to delete existing record from the ExpiresAt index first. Use + // the ID index to locate the existing record. + if existing, exists := c.agentsByID.Get(agent); exists { + c.agentsByExpiresAt.Delete(existing) + } + + c.agentsByID.ReplaceOrInsert(agent) + c.agentsByExpiresAt.ReplaceOrInsert(agent) +} + +func (c *Cache) RemoveAgent(agentID string) { + c.mu.Lock() + defer c.mu.Unlock() + if agent, exists := c.agentsByID.Get(agentRecord{ID: agentID}); exists { + c.agentsByID.Delete(agent) + c.agentsByExpiresAt.Delete(agent) + } +} + +func (c *Cache) PruneExpiredAgents() int { + now := time.Now().Unix() + pruned := 0 + + c.mu.Lock() + defer c.mu.Unlock() + for { + record, ok := c.agentsByExpiresAt.Min() + if !ok || record.ExpiresAt > now { + return pruned + } + c.agentsByID.Delete(record) + c.agentsByExpiresAt.Delete(record) + pruned++ + } +} + +func (c *Cache) appendDescendents(records []entryRecord, parentID string, parentSeen stringSet) []entryRecord { + if _, ok := parentSeen[parentID]; ok { + return records + } + parentSeen[parentID] = struct{}{} + + lenBefore := len(records) + records = c.appendEntryRecordsForParentID(records, parentID) + // Crawl the children that were appended to get their descendents + for _, entry := range records[lenBefore:] { + records = c.appendDescendents(records, entry.SPIFFEID, parentSeen) + } + return records +} + +func (c *Cache) appendEntryRecordsForParentID(records []entryRecord, parentID string) []entryRecord { + pivot := entryRecord{ParentID: parentID} + c.entriesByParentID.AscendGreaterOrEqual(pivot, func(record entryRecord) bool { + if record.ParentID != parentID { + return false + } + records = append(records, record) + return true + }) + return records +} + +func (c *Cache) getAgentAliases(agentSelectors selectorSet) []aliasRecord { + // Keep track of which aliases have already been evaluated. + aliasesSeen := allocStringSet() + defer freeStringSet(aliasesSeen) + + // Figure out which aliases the agent belongs to. + var aliasIDs []aliasRecord + for agentSelector := range agentSelectors { + pivot := aliasRecord{Selector: agentSelector} + c.aliasesBySelector.AscendGreaterOrEqual(pivot, func(record aliasRecord) bool { + if record.Selector != agentSelector { + return false + } + if _, ok := aliasesSeen[record.EntryID]; ok { + return true + } + aliasesSeen[record.EntryID] = struct{}{} + if isSubset(record.AllSelectors, agentSelectors) { + aliasIDs = append(aliasIDs, record) + } + return true + }) + } + return aliasIDs +} + +func (c *Cache) updateEntry(entry *types.Entry) { + if isNodeAlias(entry) { + ar := aliasRecord{ + EntryID: entry.Id, + AliasID: spiffeIDFromProto(entry.SpiffeId), + AllSelectors: selectorSetFromProto(entry.Selectors), + } + for selector := range ar.AllSelectors { + ar.Selector = selector + c.aliasesByEntryID.ReplaceOrInsert(ar) + c.aliasesBySelector.ReplaceOrInsert(ar) + } + return + } + + er := entryRecord{ + EntryID: entry.Id, + SPIFFEID: spiffeIDFromProto(entry.SpiffeId), + ParentID: spiffeIDFromProto(entry.ParentId), + // For quick cloning at the end of the crawl so we don't have to have + // a separate data structure for looking up entries by id. + EntryCloneOnly: entry, + } + c.entriesByParentID.ReplaceOrInsert(er) + c.entriesByEntryID.ReplaceOrInsert(er) +} + +func (c *Cache) removeEntry(entryID string) { + entryPivot := entryRecord{EntryID: entryID} + + var entryRecordsToDelete []entryRecord + c.entriesByEntryID.AscendGreaterOrEqual(entryPivot, func(record entryRecord) bool { + if record.EntryID != entryID { + return false + } + entryRecordsToDelete = append(entryRecordsToDelete, record) + return true + }) + + for _, record := range entryRecordsToDelete { + c.entriesByEntryID.Delete(record) + c.entriesByParentID.Delete(record) + } + + if len(entryRecordsToDelete) > 0 { + // entry was a normal workload registration. No need to search the aliases. + return + } + + var aliasRecordsToDelete []aliasRecord + aliasPivot := aliasRecord{EntryID: entryID} + c.aliasesByEntryID.AscendGreaterOrEqual(aliasPivot, func(record aliasRecord) bool { + if record.EntryID != entryID { + return false + } + aliasRecordsToDelete = append(aliasRecordsToDelete, record) + return true + }) + + for _, record := range aliasRecordsToDelete { + c.aliasesByEntryID.Delete(record) + c.aliasesBySelector.Delete(record) + } +} + +func (c *Cache) stats() cacheStats { + return cacheStats{ + AgentsByID: c.agentsByID.Len(), + AgentsByExpiresAt: c.agentsByExpiresAt.Len(), + AliasesByEntryID: c.aliasesByEntryID.Len(), + AliasesBySelector: c.aliasesBySelector.Len(), + EntriesByEntryID: c.entriesByEntryID.Len(), + EntriesByParentID: c.entriesByParentID.Len(), + } +} + +func spiffeIDFromProto(id *types.SPIFFEID) string { + return fmt.Sprintf("spiffe://%s%s", id.TrustDomain, id.Path) +} + +func isNodeAlias(e *types.Entry) bool { + return e.ParentId.Path == idutil.ServerIDPath +} + +type cacheStats struct { + AgentsByID int + AgentsByExpiresAt int + AliasesByEntryID int + AliasesBySelector int + EntriesByEntryID int + EntriesByParentID int +} diff --git a/pkg/server/authorizedentries/cache_test.go b/pkg/server/authorizedentries/cache_test.go new file mode 100644 index 0000000000..a80abf2051 --- /dev/null +++ b/pkg/server/authorizedentries/cache_test.go @@ -0,0 +1,456 @@ +package authorizedentries + +import ( + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/spiffe/spire-api-sdk/proto/spire/api/types" + "github.com/spiffe/spire/pkg/common/idutil" + "github.com/spiffe/spire/pkg/server/api" + "github.com/spiffe/spire/test/spiretest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + td = spiffeid.RequireTrustDomainFromString("domain.test") + server = spiffeid.RequireFromPath(td, idutil.ServerIDPath) + agent1 = spiffeid.RequireFromPath(td, "/spire/agent/1") + agent2 = spiffeid.RequireFromPath(td, "/spire/agent/2") + agent3 = spiffeid.RequireFromPath(td, "/spire/agent/3") + agent4 = spiffeid.RequireFromPath(td, "/spire/agent/4") + delegatee = spiffeid.RequireFromPath(td, "/delegatee") + alias1 = spiffeid.RequireFromPath(td, "/alias/1") + alias2 = spiffeid.RequireFromPath(td, "/alias/2") + sel1 = &types.Selector{Type: "S", Value: "1"} + sel2 = &types.Selector{Type: "S", Value: "2"} + sel3 = &types.Selector{Type: "S", Value: "3"} + now = time.Now().Truncate(time.Second) +) + +func TestGetAuthorizedEntries(t *testing.T) { + t.Run("empty cache", func(t *testing.T) { + testCache().assertAuthorizedEntries(t, agent1) + }) + + t.Run("agent not attested still returns direct children", func(t *testing.T) { + var ( + directChild = makeWorkload(agent1) + ) + testCache(). + withEntries(directChild). + assertAuthorizedEntries(t, agent1, directChild) + }) + + t.Run("directly via agent", func(t *testing.T) { + workload1 := makeWorkload(agent1) + workload2 := makeWorkload(agent2) + testCache(). + withAgent(agent1, sel1). + withEntries(workload1, workload2). + assertAuthorizedEntries(t, agent1, workload1) + }) + + t.Run("entry removed", func(t *testing.T) { + workload := makeWorkload(agent1) + cache := testCache(). + withAgent(agent1, sel1). + withEntries(workload).hydrate() + cache.RemoveEntry(workload.Id) + assertAuthorizedEntries(t, cache, agent1) + }) + + t.Run("indirectly via delegated workload", func(t *testing.T) { + var ( + delegateeEntry = makeDelegatee(agent1, delegatee) + workloadEntry = makeWorkload(delegatee) + someOtherEntry = makeWorkload(agent2) + ) + + testCache(). + withAgent(agent1, sel1). + withEntries(delegateeEntry, workloadEntry, someOtherEntry). + assertAuthorizedEntries(t, agent1, delegateeEntry, workloadEntry) + }) + + t.Run("indirectly via alias", func(t *testing.T) { + var ( + aliasEntry = makeAlias(alias1, sel1, sel2) + workloadEntry = makeWorkload(alias1) + ) + + test := testCache(). + withEntries(workloadEntry, aliasEntry). + withAgent(agent1, sel1). + withAgent(agent2, sel1, sel2). + withAgent(agent3, sel1, sel2, sel3) + + t.Run("agent has strict selector subset", func(t *testing.T) { + // Workload entry not available through alias since the agent + // does not have a superset of the alias selectors. + test.assertAuthorizedEntries(t, agent1) + }) + + t.Run("agent has selector match", func(t *testing.T) { + // Workload entry is available through alias since the agent + // has a non-strict superset of the alias selectors. + test.assertAuthorizedEntries(t, agent2, workloadEntry) + }) + + t.Run("agent has strict selector superset", func(t *testing.T) { + // Workload entry is available through alias since the agent + // has a strict superset of the alias selectors. + test.assertAuthorizedEntries(t, agent3, workloadEntry) + }) + }) + + t.Run("alias removed", func(t *testing.T) { + var ( + aliasEntry = makeAlias(alias1, sel1, sel2) + workloadEntry = makeWorkload(alias1) + ) + + cache := testCache(). + withEntries(workloadEntry, aliasEntry). + withAgent(agent1, sel1, sel2). + hydrate() + + cache.RemoveEntry(aliasEntry.Id) + assertAuthorizedEntries(t, cache, agent1) + }) + + t.Run("agent removed", func(t *testing.T) { + var ( + aliasEntry = makeAlias(alias1, sel1, sel2) + workloadEntry = makeWorkload(alias1) + ) + + cache := testCache(). + withEntries(workloadEntry, aliasEntry). + withAgent(agent1, sel1, sel2). + hydrate() + + cache.RemoveAgent(agent1.String()) + assertAuthorizedEntries(t, cache, agent1) + }) + + t.Run("agent pruned after expiry", func(t *testing.T) { + var ( + aliasEntry = makeAlias(alias1, sel1, sel2) + workloadEntry = makeWorkload(alias1) + ) + + cache := testCache(). + withEntries(workloadEntry, aliasEntry). + withExpiredAgent(agent1, time.Hour, sel1, sel2). + withExpiredAgent(agent2, time.Hour, sel1, sel2). + withExpiredAgent(agent3, time.Hour*2, sel1, sel2). + withAgent(agent4, sel1, sel2). + hydrate() + assertAuthorizedEntries(t, cache, agent1, workloadEntry) + assertAuthorizedEntries(t, cache, agent2, workloadEntry) + assertAuthorizedEntries(t, cache, agent3, workloadEntry) + assertAuthorizedEntries(t, cache, agent4, workloadEntry) + + assert.Equal(t, 3, cache.PruneExpiredAgents()) + + assertAuthorizedEntries(t, cache, agent1) + assertAuthorizedEntries(t, cache, agent2) + assertAuthorizedEntries(t, cache, agent3) + assertAuthorizedEntries(t, cache, agent4, workloadEntry) + }) +} + +func TestCacheInternalStats(t *testing.T) { + // This test asserts that the internal indexes are properly maintained + // across various operations. The motivation is to ensure that as the cache + // is updated that we are appropriately inserting and removing records from + // the indexees. + t.Run("pristine", func(t *testing.T) { + cache := NewCache() + require.Zero(t, cache.stats()) + }) + + t.Run("entries and aliases", func(t *testing.T) { + entry1 := makeWorkload(agent1) + entry2a := makeWorkload(agent2) + + // Version b will change to an alias instead + entry2b := makeAlias(alias1, sel1, sel2) + entry2b.Id = entry2a.Id + + cache := NewCache() + cache.UpdateEntry(entry1) + require.Equal(t, cacheStats{ + EntriesByEntryID: 1, + EntriesByParentID: 1, + }, cache.stats()) + + cache.UpdateEntry(entry2a) + require.Equal(t, cacheStats{ + EntriesByEntryID: 2, + EntriesByParentID: 2, + }, cache.stats()) + + cache.UpdateEntry(entry2b) + require.Equal(t, cacheStats{ + EntriesByEntryID: 1, + EntriesByParentID: 1, + AliasesByEntryID: 2, // one for each selector + AliasesBySelector: 2, // one for each selector + }, cache.stats()) + + cache.RemoveEntry(entry1.Id) + require.Equal(t, cacheStats{ + AliasesByEntryID: 2, // one for each selector + AliasesBySelector: 2, // one for each selector + }, cache.stats()) + + cache.RemoveEntry(entry2b.Id) + require.Zero(t, cache.stats()) + + // Remove again and make sure nothing happens. + cache.RemoveEntry(entry2b.Id) + require.Zero(t, cache.stats()) + }) + + t.Run("agents", func(t *testing.T) { + cache := NewCache() + cache.UpdateAgent(agent1.String(), now.Add(time.Hour), []*types.Selector{sel1}) + require.Equal(t, cacheStats{ + AgentsByID: 1, + AgentsByExpiresAt: 1, + }, cache.stats()) + + cache.UpdateAgent(agent2.String(), now.Add(time.Hour*2), []*types.Selector{sel2}) + require.Equal(t, cacheStats{ + AgentsByID: 2, + AgentsByExpiresAt: 2, + }, cache.stats()) + + cache.UpdateAgent(agent2.String(), now.Add(time.Hour*3), []*types.Selector{sel2}) + require.Equal(t, cacheStats{ + AgentsByID: 2, + AgentsByExpiresAt: 2, + }, cache.stats()) + + cache.RemoveAgent(agent1.String()) + require.Equal(t, cacheStats{ + AgentsByID: 1, + AgentsByExpiresAt: 1, + }, cache.stats()) + + cache.RemoveAgent(agent2.String()) + require.Zero(t, cache.stats()) + }) +} + +func testCache() *cacheTest { + return &cacheTest{ + entries: make(map[string]*types.Entry), + agents: make(map[spiffeid.ID]agentInfo), + } +} + +type cacheTest struct { + entries map[string]*types.Entry + agents map[spiffeid.ID]agentInfo +} + +type agentInfo struct { + ExpiresAt time.Time + Selectors []*types.Selector +} + +func (a *cacheTest) pickAgent() spiffeid.ID { + for agent := range a.agents { + return agent + } + return spiffeid.ID{} +} + +func (a *cacheTest) withEntries(entries ...*types.Entry) *cacheTest { + for _, entry := range entries { + a.entries[entry.Id] = entry + } + return a +} + +func (a *cacheTest) withAgent(node spiffeid.ID, selectors ...*types.Selector) *cacheTest { + expiresAt := now.Add(time.Hour * time.Duration(1+len(a.agents))) + a.agents[node] = agentInfo{ + ExpiresAt: expiresAt, + Selectors: append([]*types.Selector(nil), selectors...), + } + return a +} + +func (a *cacheTest) withExpiredAgent(node spiffeid.ID, expiredBy time.Duration, selectors ...*types.Selector) *cacheTest { + expiresAt := now.Add(-expiredBy) + a.agents[node] = agentInfo{ + ExpiresAt: expiresAt, + Selectors: append([]*types.Selector(nil), selectors...), + } + return a +} + +func (a *cacheTest) hydrate() *Cache { + cache := NewCache() + for _, entry := range a.entries { + cache.UpdateEntry(entry) + } + for agent, info := range a.agents { + cache.UpdateAgent(agent.String(), info.ExpiresAt, info.Selectors) + } + return cache +} + +func (a *cacheTest) assertAuthorizedEntries(t *testing.T, agent spiffeid.ID, expectEntries ...*types.Entry) { + t.Helper() + assertAuthorizedEntries(t, a.hydrate(), agent, expectEntries...) +} + +func makeAlias(alias spiffeid.ID, selectors ...*types.Selector) *types.Entry { + return &types.Entry{ + Id: fmt.Sprintf("alias-%d(spiffeid=%s)", makeEntryIDPrefix(), alias), + ParentId: api.ProtoFromID(server), + SpiffeId: api.ProtoFromID(alias), + Selectors: selectors, + } +} + +func makeDelegatee(parent, delegatee spiffeid.ID) *types.Entry { + return &types.Entry{ + Id: fmt.Sprintf("delegatee-%d(parent=%s,spiffeid=%s)", makeEntryIDPrefix(), parent, delegatee), + ParentId: api.ProtoFromID(parent), + SpiffeId: api.ProtoFromID(delegatee), + Selectors: []*types.Selector{{Type: "not", Value: "relevant"}}, + } +} + +func makeWorkload(parent spiffeid.ID) *types.Entry { + return &types.Entry{ + Id: fmt.Sprintf("workload-%d(parent=%s)", makeEntryIDPrefix(), parent), + ParentId: api.ProtoFromID(parent), + SpiffeId: &types.SPIFFEID{TrustDomain: "domain.test", Path: "/workload"}, + Selectors: []*types.Selector{{Type: "not", Value: "relevant"}}, + } +} + +var nextEntryIDPrefix int32 + +func makeEntryIDPrefix() int32 { + return atomic.AddInt32(&nextEntryIDPrefix, 1) +} + +// BenchmarkGetAuthorizedEntriesInMemory was ported from the old full entry +// cache and some of the bugs fixed. +func BenchmarkGetAuthorizedEntriesInMemory(b *testing.B) { + test := testCache() + + staticSelector1 := &types.Selector{Type: "static", Value: "static-1"} + staticSelector2 := &types.Selector{Type: "static", Value: "static-2"} + + const numAgents = 50000 + for i := 0; i < numAgents; i++ { + test.withAgent(spiffeid.RequireFromPathf(td, "/agent-%d", i), staticSelector1) + } + + aliasID1 := api.ProtoFromID(alias1) + aliasID2 := api.ProtoFromID(alias2) + + test.withEntries( + // Alias + &types.Entry{ + Id: "alias1", + SpiffeId: aliasID1, + ParentId: &types.SPIFFEID{TrustDomain: "domain.test", Path: idutil.ServerIDPath}, + Selectors: []*types.Selector{staticSelector1}, + }, + // False alias + &types.Entry{ + Id: "alias2", + SpiffeId: aliasID2, + ParentId: &types.SPIFFEID{TrustDomain: "domain.test", Path: idutil.ServerIDPath}, + Selectors: []*types.Selector{staticSelector2}, + }, + ) + + for i := 0; i < 300; i++ { + test.withEntries(&types.Entry{ + Id: fmt.Sprintf("alias1-workload-%d", i), + SpiffeId: &types.SPIFFEID{ + TrustDomain: "domain.test", + Path: fmt.Sprintf("/workload%d", i), + }, + ParentId: aliasID1, + Selectors: []*types.Selector{ + {Type: "unix", Value: fmt.Sprintf("uid:%d", i)}, + }, + }) + } + + for i := 0; i < 300; i++ { + test.withEntries(&types.Entry{ + Id: fmt.Sprintf("alias2-workload-%d", i), + SpiffeId: &types.SPIFFEID{ + TrustDomain: "domain.test", + Path: fmt.Sprintf("/workload%d", i), + }, + ParentId: aliasID2, + Selectors: []*types.Selector{ + {Type: "unix", Value: fmt.Sprintf("uid:%d", i)}, + }, + }) + } + + cache := test.hydrate() + b.ResetTimer() + for i := 0; i < b.N; i++ { + cache.GetAuthorizedEntries(test.pickAgent()) + } +} + +func assertAuthorizedEntries(tb testing.TB, cache *Cache, agentID spiffeid.ID, wantEntries ...*types.Entry) { + tb.Helper() + + entriesMap := func(entries []*types.Entry) map[string]*types.Entry { + m := make(map[string]*types.Entry) + for _, entry := range entries { + m[entry.Id] = entry + } + return m + } + + wantMap := entriesMap(wantEntries) + gotMap := entriesMap(cache.GetAuthorizedEntries(agentID)) + + for id, want := range wantMap { + got, ok := gotMap[id] + if !ok { + assert.Fail(tb, "expected entry not returned", "expected entry %q", id) + continue + } + + // Make sure the contents are equivalent. + spiretest.AssertProtoEqual(tb, want, got) + + // The pointer should not be equivalent. The cache should be cloning + // the entries before returning. + if want == got { + assert.Fail(tb, "entry proto was not cloned before return") + continue + } + } + + // Assert there were not unexpected entries returned. + for id := range gotMap { + if _, ok := wantMap[id]; !ok { + assert.Fail(tb, "unexpected entry returned", "unexpected entry %q", id) + continue + } + } +} diff --git a/pkg/server/authorizedentries/entries.go b/pkg/server/authorizedentries/entries.go new file mode 100644 index 0000000000..47bb84e676 --- /dev/null +++ b/pkg/server/authorizedentries/entries.go @@ -0,0 +1,45 @@ +package authorizedentries + +import ( + "github.com/spiffe/spire-api-sdk/proto/spire/api/types" + "google.golang.org/protobuf/proto" +) + +type entryRecord struct { + EntryID string + ParentID string + SPIFFEID string + + // Pointer to the entry. For cloning only after the end of the crawl. + EntryCloneOnly *types.Entry +} + +func entryRecordByEntryID(a, b entryRecord) bool { + return a.EntryID < b.EntryID +} + +func entryRecordByParentID(a, b entryRecord) bool { + switch { + case a.ParentID < b.ParentID: + return true + case a.ParentID > b.ParentID: + return false + default: + return a.EntryID < b.EntryID + } +} + +func cloneEntriesFromRecords(entryRecords []entryRecord) []*types.Entry { + if len(entryRecords) == 0 { + return nil + } + cloned := make([]*types.Entry, 0, len(entryRecords)) + for _, entryRecord := range entryRecords { + cloned = append(cloned, cloneEntry(entryRecord.EntryCloneOnly)) + } + return cloned +} + +func cloneEntry(entry *types.Entry) *types.Entry { + return proto.Clone(entry).(*types.Entry) +} diff --git a/pkg/server/authorizedentries/entries_test.go b/pkg/server/authorizedentries/entries_test.go new file mode 100644 index 0000000000..653b3f3c16 --- /dev/null +++ b/pkg/server/authorizedentries/entries_test.go @@ -0,0 +1,59 @@ +package authorizedentries + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEntryRecordSize(t *testing.T) { + // The motivation for this test is to bring awareness and visibility into + // how much size the record occupies. We want to minimize the size to + // increase cache locality in the btree. + require.Equal(t, uintptr(56), unsafe.Sizeof(entryRecord{})) +} + +func TestEntryRecordByEntryID(t *testing.T) { + assertLess := func(lesser, greater entryRecord) { + t.Helper() + assert.Truef(t, entryRecordByEntryID(lesser, greater), "expected E%sP%sE%sP%s", greater.EntryID, greater.ParentID, lesser.EntryID, lesser.ParentID) + } + + // ParentID is irrelevant. + records := []entryRecord{ + entryRecord{EntryID: "1", ParentID: "2"}, + entryRecord{EntryID: "2", ParentID: "1"}, + } + + lesser := entryRecord{} + for _, greater := range records { + assertLess(lesser, greater) + lesser = greater + } +} + +func TestEntryRecordByParentID(t *testing.T) { + assertLess := func(lesser, greater entryRecord) { + t.Helper() + assert.True(t, entryRecordByParentID(lesser, greater), "expected P%sE%sP%sE%s", greater.ParentID, greater.EntryID, lesser.ParentID, lesser.EntryID) + } + + records := []entryRecord{ + entryRecord{ParentID: "1"}, + entryRecord{ParentID: "1", EntryID: "1"}, + entryRecord{ParentID: "1", EntryID: "2"}, + entryRecord{ParentID: "2"}, + entryRecord{ParentID: "2", EntryID: "1"}, + entryRecord{ParentID: "2", EntryID: "2"}, + } + + lesser := entryRecord{} + for _, greater := range records { + assertLess(lesser, greater) + lesser = greater + } +} diff --git a/pkg/server/authorizedentries/recordpool.go b/pkg/server/authorizedentries/recordpool.go new file mode 100644 index 0000000000..ac9ad82e48 --- /dev/null +++ b/pkg/server/authorizedentries/recordpool.go @@ -0,0 +1,22 @@ +package authorizedentries + +import "sync" + +var ( + // Stores pointers to record slices. See https://staticcheck.io/docs/checks#SA6002. + recordPool = sync.Pool{ + New: func() interface{} { + p := []entryRecord(nil) + return &p + }, + } +) + +func allocRecordSlice() []entryRecord { + return *recordPool.Get().(*[]entryRecord) +} + +func freeRecordSlice(records []entryRecord) { + records = records[:0] + recordPool.Put(&records) +} diff --git a/pkg/server/authorizedentries/selectorset.go b/pkg/server/authorizedentries/selectorset.go new file mode 100644 index 0000000000..3f6898ce44 --- /dev/null +++ b/pkg/server/authorizedentries/selectorset.go @@ -0,0 +1,28 @@ +package authorizedentries + +import ( + "github.com/spiffe/spire-api-sdk/proto/spire/api/types" +) + +type selectorSet map[Selector]struct{} + +func selectorSetFromProto(selectors []*types.Selector) selectorSet { + set := make(selectorSet, len(selectors)) + for _, selector := range selectors { + set[Selector{Type: selector.Type, Value: selector.Value}] = struct{}{} + } + return set +} + +// Returns true if sub is a subset of whole +func isSubset(sub, whole selectorSet) bool { + if len(sub) > len(whole) { + return false + } + for s := range sub { + if _, ok := whole[s]; !ok { + return false + } + } + return true +} diff --git a/pkg/server/authorizedentries/stringset.go b/pkg/server/authorizedentries/stringset.go new file mode 100644 index 0000000000..eac488b60d --- /dev/null +++ b/pkg/server/authorizedentries/stringset.go @@ -0,0 +1,28 @@ +package authorizedentries + +import "sync" + +var ( + stringSetPool = sync.Pool{ + New: func() interface{} { + return make(stringSet) + }, + } +) + +type stringSet map[string]struct{} + +func allocStringSet() stringSet { + return stringSetPool.Get().(stringSet) +} + +func freeStringSet(set stringSet) { + clearStringSet(set) + stringSetPool.Put(set) +} + +func clearStringSet(set stringSet) { + for k := range set { + delete(set, k) + } +} diff --git a/pkg/server/cache/entrycache/fullcache_test.go b/pkg/server/cache/entrycache/fullcache_test.go index aa6bcdba7f..4b5911c1b0 100644 --- a/pkg/server/cache/entrycache/fullcache_test.go +++ b/pkg/server/cache/entrycache/fullcache_test.go @@ -48,7 +48,8 @@ func TestCache(t *testing.T) { ds := fakedatastore.New(t) ctx := context.Background() - const rootID = "spiffe://example.org/root" + rootID := spiffeid.RequireFromString("spiffe://example.org/root") + const serverID = "spiffe://example.org/spire/server" const numEntries = 5 entryIDs := make([]string, numEntries) @@ -56,7 +57,7 @@ func TestCache(t *testing.T) { entryIDURI := url.URL{ Scheme: spiffeScheme, Host: trustDomain, - Path: strconv.Itoa(i), + Path: "/" + strconv.Itoa(i), } entryIDs[i] = entryIDURI.String() @@ -80,12 +81,12 @@ func TestCache(t *testing.T) { entriesToCreate := []*common.RegistrationEntry{ { - ParentId: rootID, + ParentId: rootID.String(), SpiffeId: entryIDs[0], Selectors: irrelevantSelectors, }, { - ParentId: rootID, + ParentId: rootID.String(), SpiffeId: entryIDs[1], Selectors: irrelevantSelectors, }, @@ -122,18 +123,12 @@ func TestCache(t *testing.T) { createAttestedNode(t, ds, node) setNodeSelectors(ctx, t, ds, entryIDs[1], a1, b2) - entriesPb, err := api.RegistrationEntriesToProto(entries) - require.NoError(t, err) - cache, err := BuildFromDataStore(context.Background(), ds) assert.NoError(t, err) - actual := cache.GetAuthorizedEntries(spiffeid.RequireFromString(rootID)) - - // The node alias (entry 3) is not expected - expected := entriesPb[:3] - expected = append(expected, entriesPb[4]) - spiretest.AssertProtoListEqual(t, expected, actual) + expected := entries[:3] + expected = append(expected, entries[4]) + assertAuthorizedEntries(t, cache, rootID, expected...) } func TestCacheReturnsClonedEntries(t *testing.T) { @@ -237,28 +232,9 @@ func TestFullCacheNodeAliasing(t *testing.T) { cache, err := BuildFromDataStore(context.Background(), ds) assert.NoError(t, err) - sortEntries := func(es []*types.Entry) { - sort.Slice(es, func(a, b int) bool { - return es[a].Id < es[b].Id - }) - } - - assertAuthorizedEntries := func(agentID spiffeid.ID, entries ...*common.RegistrationEntry) { - t.Helper() - expected, err := api.RegistrationEntriesToProto(entries) - require.NoError(t, err) - - authorizedEntries := cache.GetAuthorizedEntries(agentID) - - sortEntries(expected) - sortEntries(authorizedEntries) - - spiretest.AssertProtoListEqual(t, expected, authorizedEntries) - } - - assertAuthorizedEntries(agentIDs[0], workloadEntries[:2]...) - assertAuthorizedEntries(agentIDs[1], workloadEntries[1]) - assertAuthorizedEntries(agentIDs[2], workloadEntries[2]) + assertAuthorizedEntries(t, cache, agentIDs[0], workloadEntries[:2]...) + assertAuthorizedEntries(t, cache, agentIDs[1], workloadEntries[1]) + assertAuthorizedEntries(t, cache, agentIDs[2], workloadEntries[2]) } func TestFullCacheExcludesNodeSelectorMappedEntriesForExpiredAgents(t *testing.T) { @@ -521,7 +497,8 @@ func BenchmarkBuildSQL(b *testing.B) { ds := newSQLPlugin(ctx, b) for _, entry := range allEntries { - e, _ := api.ProtoToRegistrationEntry(context.Background(), td, entry) + e, err := api.ProtoToRegistrationEntry(context.Background(), td, entry) + require.NoError(b, err) createRegistrationEntry(ctx, b, ds, e) } @@ -535,7 +512,8 @@ func BenchmarkBuildSQL(b *testing.B) { } createAttestedNode(b, ds, node) - ss, _ := api.SelectorsFromProto(agent.Selectors) + ss, err := api.SelectorsFromProto(agent.Selectors) + require.NoError(b, err) setNodeSelectors(ctx, b, ds, agentIDStr, ss...) } @@ -744,7 +722,7 @@ func buildBenchmarkData() ([]*types.Entry, []Agent) { Id: fmt.Sprintf("workload%d", i), SpiffeId: &types.SPIFFEID{ TrustDomain: "domain.test", - Path: fmt.Sprintf("workload%d", i), + Path: fmt.Sprintf("/workload%d", i), }, ParentId: aliasID1, Selectors: []*types.Selector{ @@ -759,7 +737,7 @@ func buildBenchmarkData() ([]*types.Entry, []Agent) { Id: fmt.Sprintf("workload%d", i), SpiffeId: &types.SPIFFEID{ TrustDomain: "domain.test", - Path: fmt.Sprintf("workload%d", i), + Path: fmt.Sprintf("/workload%d", i), }, ParentId: aliasID2, Selectors: []*types.Selector{ @@ -816,3 +794,22 @@ func newSQLPlugin(ctx context.Context, tb testing.TB) datastore.DataStore { return p } + +func assertAuthorizedEntries(tb testing.TB, cache Cache, agentID spiffeid.ID, entries ...*common.RegistrationEntry) { + tb.Helper() + expected, err := api.RegistrationEntriesToProto(entries) + require.NoError(tb, err) + + authorizedEntries := cache.GetAuthorizedEntries(agentID) + + sortEntries(expected) + sortEntries(authorizedEntries) + + spiretest.AssertProtoListEqual(tb, expected, authorizedEntries) +} + +func sortEntries(es []*types.Entry) { + sort.Slice(es, func(a, b int) bool { + return es[a].Id < es[b].Id + }) +}