From 27b600117c8db30690c4905e40df2360967e0d4f Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Wed, 15 Nov 2023 17:44:42 -0300 Subject: [PATCH 01/17] add filtering options to count command Signed-off-by: FedeNQ --- cmd/spire-server/cli/entry/count.go | 91 ++++++++++++++++++- cmd/spire-server/cli/entry/util_posix_test.go | 14 +++ .../telemetry/server/datastore/wrapper.go | 4 +- .../server/datastore/wrapper_test.go | 2 +- pkg/server/api/debug/v1/service.go | 2 +- pkg/server/api/entry/v1/service.go | 80 +++++++++++++++- pkg/server/datastore/datastore.go | 11 ++- pkg/server/datastore/sqlstore/sqlstore.go | 32 ++++++- .../datastore/sqlstore/sqlstore_test.go | 4 +- test/fakes/fakedatastore/fakedatastore.go | 4 +- 10 files changed, 229 insertions(+), 15 deletions(-) diff --git a/cmd/spire-server/cli/entry/count.go b/cmd/spire-server/cli/entry/count.go index c095e76899..cf01ec8107 100644 --- a/cmd/spire-server/cli/entry/count.go +++ b/cmd/spire-server/cli/entry/count.go @@ -7,12 +7,36 @@ import ( "github.com/mitchellh/cli" entryv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/entry/v1" + "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/spiffe/spire/cmd/spire-server/util" commoncli "github.com/spiffe/spire/pkg/common/cli" "github.com/spiffe/spire/pkg/common/cliprinter" + "google.golang.org/protobuf/types/known/wrapperspb" ) type countCommand struct { + // Type and value are delimited by a colon (:) + // ex. "unix:uid:1000" or "spiffe_id:spiffe://example.org/foo" + selectors StringsFlag + + // Workload parent spiffeID + parentID string + + // Workload spiffeID + spiffeID string + + // Entry hint + hint string + + // List of SPIFFE IDs of trust domains the registration entry is federated with + federatesWith StringsFlag + + // Match used when filtering by federates with + matchFederatesWithOn string + + // Match used when filtering by selectors + matchSelectorsOn string + printer cliprinter.Printer env *commoncli.Env } @@ -39,7 +63,64 @@ func (*countCommand) Synopsis() string { // Run counts attested entries func (c *countCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient util.ServerClient) error { entryClient := serverClient.NewEntryClient() - countResponse, err := entryClient.CountEntries(ctx, &entryv1.CountEntriesRequest{}) + + filter := &entryv1.CountEntriesRequest_Filter{} + if c.parentID != "" { + id, err := idStringToProto(c.parentID) + if err != nil { + return fmt.Errorf("error parsing parent ID %q: %w", c.parentID, err) + } + filter.ByParentId = id + } + + if c.spiffeID != "" { + id, err := idStringToProto(c.spiffeID) + if err != nil { + return fmt.Errorf("error parsing SPIFFE ID %q: %w", c.spiffeID, err) + } + filter.BySpiffeId = id + } + + if len(c.selectors) != 0 { + matchSelectorBehavior, err := parseToSelectorMatch(c.matchSelectorsOn) + if err != nil { + return err + } + + selectors := make([]*types.Selector, len(c.selectors)) + for i, sel := range c.selectors { + selector, err := util.ParseSelector(sel) + if err != nil { + return fmt.Errorf("error parsing selectors: %w", err) + } + selectors[i] = selector + } + filter.BySelectors = &types.SelectorMatch{ + Selectors: selectors, + Match: matchSelectorBehavior, + } + } + + if len(c.federatesWith) > 0 { + matchFederatesWithBehavior, err := parseToFederatesWithMatch(c.matchFederatesWithOn) + if err != nil { + return err + } + + filter.ByFederatesWith = &types.FederatesWithMatch{ + TrustDomains: c.federatesWith, + Match: matchFederatesWithBehavior, + } + } + + if c.hint != "" { + filter.ByHint = wrapperspb.String(c.hint) + } + + countResponse, err := entryClient.CountEntries(ctx, &entryv1.CountEntriesRequest{ + Filter: filter, + }) + if err != nil { return err } @@ -48,6 +129,14 @@ func (c *countCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient u } func (c *countCommand) AppendFlags(fs *flag.FlagSet) { + fs.StringVar(&c.parentID, "parentID", "", "The Parent ID of the records to count") + fs.StringVar(&c.spiffeID, "spiffeID", "", "The SPIFFE ID of the records to count") + fs.Var(&c.selectors, "selector", "A colon-delimited type:value selector. Can be used more than once") + fs.Var(&c.federatesWith, "federatesWith", "SPIFFE ID of a trust domain an entry is federate with. Can be used more than once") + fs.StringVar(&c.matchFederatesWithOn, "matchFederatesWithOn", "superset", "The match mode used when filtering by federates with. Options: exact, any, superset and subset") + fs.StringVar(&c.matchSelectorsOn, "matchSelectorsOn", "superset", "The match mode used when filtering by selectors. Options: exact, any, superset and subset") + fs.StringVar(&c.hint, "hint", "", "The Hint of the records to show (optional)") + cliprinter.AppendFlagWithCustomPretty(&c.printer, fs, c.env, c.prettyPrintCount) } diff --git a/cmd/spire-server/cli/entry/util_posix_test.go b/cmd/spire-server/cli/entry/util_posix_test.go index d3825c1298..8e9f7be66a 100644 --- a/cmd/spire-server/cli/entry/util_posix_test.go +++ b/cmd/spire-server/cli/entry/util_posix_test.go @@ -112,9 +112,23 @@ const ( Path to the SPIRE Server API socket (default "/tmp/spire-server/private/api.sock") ` countUsage = `Usage of entry count: + -federatesWith value + SPIFFE ID of a trust domain an entry is federate with. Can be used more than once + -hint string + The Hint of the records to show (optional) + -matchFederatesWithOn string + The match mode used when filtering by federates with. Options: exact, any, superset and subset (default "superset") + -matchSelectorsOn string + The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") -output value Desired output format (pretty, json); default: pretty. + -parentID string + The Parent ID of the records to count + -selector value + A colon-delimited type:value selector. Can be used more than once -socketPath string Path to the SPIRE Server API socket (default "/tmp/spire-server/private/api.sock") + -spiffeID string + The SPIFFE ID of the records to count ` ) diff --git a/pkg/common/telemetry/server/datastore/wrapper.go b/pkg/common/telemetry/server/datastore/wrapper.go index 645b8501a2..3b01392ee4 100644 --- a/pkg/common/telemetry/server/datastore/wrapper.go +++ b/pkg/common/telemetry/server/datastore/wrapper.go @@ -198,10 +198,10 @@ func (w metricsWrapper) CountBundles(ctx context.Context) (_ int32, err error) { return w.ds.CountBundles(ctx) } -func (w metricsWrapper) CountRegistrationEntries(ctx context.Context) (_ int32, err error) { +func (w metricsWrapper) CountRegistrationEntries(ctx context.Context, req *datastore.CountRegistrationEntriesRequest) (_ int32, err error) { callCounter := StartCountRegistrationCall(w.m) defer callCounter.Done(&err) - return w.ds.CountRegistrationEntries(ctx) + return w.ds.CountRegistrationEntries(ctx, req) } func (w metricsWrapper) PruneAttestedNodesEvents(ctx context.Context, olderThan time.Duration) (err error) { diff --git a/pkg/common/telemetry/server/datastore/wrapper_test.go b/pkg/common/telemetry/server/datastore/wrapper_test.go index d085ee34d4..eadfd8c6e4 100644 --- a/pkg/common/telemetry/server/datastore/wrapper_test.go +++ b/pkg/common/telemetry/server/datastore/wrapper_test.go @@ -326,7 +326,7 @@ func (ds *fakeDataStore) CountBundles(context.Context) (int32, error) { return 0, ds.err } -func (ds *fakeDataStore) CountRegistrationEntries(context.Context) (int32, error) { +func (ds *fakeDataStore) CountRegistrationEntries(context.Context, *datastore.CountRegistrationEntriesRequest) (int32, error) { return 0, ds.err } diff --git a/pkg/server/api/debug/v1/service.go b/pkg/server/api/debug/v1/service.go index c1ab073b19..85d4d64e29 100644 --- a/pkg/server/api/debug/v1/service.go +++ b/pkg/server/api/debug/v1/service.go @@ -83,7 +83,7 @@ func (s *Service) GetInfo(ctx context.Context, _ *debugv1.GetInfoRequest) (*debu return nil, api.MakeErr(log, codes.Internal, "failed to count agents", err) } - entries, err := s.ds.CountRegistrationEntries(ctx) + entries, err := s.ds.CountRegistrationEntries(ctx, &datastore.CountRegistrationEntriesRequest{}) if err != nil { return nil, api.MakeErr(log, codes.Internal, "failed to count entries", err) } diff --git a/pkg/server/api/entry/v1/service.go b/pkg/server/api/entry/v1/service.go index 2ff683af8c..363b0ae8b9 100644 --- a/pkg/server/api/entry/v1/service.go +++ b/pkg/server/api/entry/v1/service.go @@ -61,8 +61,66 @@ func RegisterService(s grpc.ServiceRegistrar, service *Service) { } // CountEntries returns the total number of entries. -func (s *Service) CountEntries(ctx context.Context, _ *entryv1.CountEntriesRequest) (*entryv1.CountEntriesResponse, error) { - count, err := s.ds.CountRegistrationEntries(ctx) +func (s *Service) CountEntries(ctx context.Context, req *entryv1.CountEntriesRequest) (*entryv1.CountEntriesResponse, error) { + log := rpccontext.Logger(ctx) + countReq := &datastore.CountRegistrationEntriesRequest{} + + if req.Filter != nil { + rpccontext.AddRPCAuditFields(ctx, fieldsFromCountEntryFilter(ctx, s.td, req.Filter)) + if req.Filter.ByHint != nil { + countReq.ByHint = req.Filter.ByHint.GetValue() + } + + if req.Filter.ByParentId != nil { + parentID, err := api.TrustDomainMemberIDFromProto(ctx, s.td, req.Filter.ByParentId) + if err != nil { + return nil, api.MakeErr(log, codes.InvalidArgument, "malformed parent ID filter", err) + } + countReq.ByParentID = parentID.String() + } + + if req.Filter.BySpiffeId != nil { + spiffeID, err := api.TrustDomainWorkloadIDFromProto(ctx, s.td, req.Filter.BySpiffeId) + if err != nil { + return nil, api.MakeErr(log, codes.InvalidArgument, "malformed SPIFFE ID filter", err) + } + countReq.BySpiffeID = spiffeID.String() + } + + if req.Filter.BySelectors != nil { + dsSelectors, err := api.SelectorsFromProto(req.Filter.BySelectors.Selectors) + if err != nil { + return nil, api.MakeErr(log, codes.InvalidArgument, "malformed selectors filter", err) + } + if len(dsSelectors) == 0 { + return nil, api.MakeErr(log, codes.InvalidArgument, "malformed selectors filter", errors.New("empty selector set")) + } + countReq.BySelectors = &datastore.BySelectors{ + Match: datastore.MatchBehavior(req.Filter.BySelectors.Match), + Selectors: dsSelectors, + } + } + + if req.Filter.ByFederatesWith != nil { + trustDomains := make([]string, 0, len(req.Filter.ByFederatesWith.TrustDomains)) + for _, tdStr := range req.Filter.ByFederatesWith.TrustDomains { + td, err := spiffeid.TrustDomainFromString(tdStr) + if err != nil { + return nil, api.MakeErr(log, codes.InvalidArgument, "malformed federates with filter", err) + } + trustDomains = append(trustDomains, td.IDString()) + } + if len(trustDomains) == 0 { + return nil, api.MakeErr(log, codes.InvalidArgument, "malformed federates with filter", errors.New("empty trust domain set")) + } + countReq.ByFederatesWith = &datastore.ByFederatesWith{ + Match: datastore.MatchBehavior(req.Filter.ByFederatesWith.Match), + TrustDomains: trustDomains, + } + } + } + + count, err := s.ds.CountRegistrationEntries(ctx, countReq) if err != nil { log := rpccontext.Logger(ctx) return nil, api.MakeErr(log, codes.Internal, "failed to count entries", err) @@ -728,6 +786,24 @@ func fieldsFromListEntryFilter(ctx context.Context, td spiffeid.TrustDomain, fil return fields } +func fieldsFromCountEntryFilter(ctx context.Context, td spiffeid.TrustDomain, filter *entryv1.CountEntriesRequest_Filter) logrus.Fields { + fields := logrus.Fields{} + + if filter.ByParentId != nil { + if parentID, err := api.TrustDomainMemberIDFromProto(ctx, td, filter.ByParentId); err == nil { + fields[telemetry.ParentID] = parentID.String() + } + } + + if filter.BySpiffeId != nil { + if id, err := api.TrustDomainWorkloadIDFromProto(ctx, td, filter.BySpiffeId); err == nil { + fields[telemetry.SPIFFEID] = id.String() + } + } + + return fields +} + func sortEntriesByID(entries []*types.Entry) { sort.Slice(entries, func(a, b int) bool { return entries[a].Id < entries[b].Id diff --git a/pkg/server/datastore/datastore.go b/pkg/server/datastore/datastore.go index db96a1976a..9a014ce1d8 100644 --- a/pkg/server/datastore/datastore.go +++ b/pkg/server/datastore/datastore.go @@ -31,7 +31,7 @@ type DataStore interface { RevokeJWTKey(ctx context.Context, trustDomainID string, authorityID string) (*common.PublicKey, error) // Entries - CountRegistrationEntries(context.Context) (int32, error) + CountRegistrationEntries(context.Context, *CountRegistrationEntriesRequest) (int32, error) CreateRegistrationEntry(context.Context, *common.RegistrationEntry) (*common.RegistrationEntry, error) CreateOrReturnRegistrationEntry(context.Context, *common.RegistrationEntry) (*common.RegistrationEntry, bool, error) DeleteRegistrationEntry(ctx context.Context, entryID string) (*common.RegistrationEntry, error) @@ -242,6 +242,15 @@ type ListFederationRelationshipsResponse struct { Pagination *Pagination } +type CountRegistrationEntriesRequest struct { + DataConsistency DataConsistency + ByParentID string + BySelectors *BySelectors + BySpiffeID string + ByFederatesWith *ByFederatesWith + ByHint string +} + type BundleEndpointType string const ( diff --git a/pkg/server/datastore/sqlstore/sqlstore.go b/pkg/server/datastore/sqlstore/sqlstore.go index 465d73e745..b2c033781b 100644 --- a/pkg/server/datastore/sqlstore/sqlstore.go +++ b/pkg/server/datastore/sqlstore/sqlstore.go @@ -446,9 +446,25 @@ func (ds *Plugin) FetchRegistrationEntry(ctx context.Context, } // CountRegistrationEntries counts all registrations (pagination available) -func (ds *Plugin) CountRegistrationEntries(ctx context.Context) (count int32, err error) { +func (ds *Plugin) CountRegistrationEntries(ctx context.Context, req *datastore.CountRegistrationEntriesRequest) (count int32, err error) { + if hasFilters(req) { + var actDb = ds.db + if req.DataConsistency == datastore.TolerateStale && ds.roDb != nil { + actDb = ds.roDb + } + resp, err := listRegistrationEntries(ctx, actDb, ds.log, &datastore.ListRegistrationEntriesRequest{ + DataConsistency: req.DataConsistency, + ByParentID: req.ByParentID, + BySelectors: req.BySelectors, + BySpiffeID: req.BySpiffeID, + ByFederatesWith: req.ByFederatesWith, + ByHint: req.ByHint, + }) + return int32(len(resp.Entries)), err + } + if err = ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { - count, err = countRegistrationEntries(tx) + count, err = countRegistrationEntries(tx, req) return err }); err != nil { return 0, err @@ -457,6 +473,16 @@ func (ds *Plugin) CountRegistrationEntries(ctx context.Context) (count int32, er return count, nil } +func hasFilters(req *datastore.CountRegistrationEntriesRequest) bool { + if req.ByParentID != "" || req.ByHint != "" || req.BySpiffeID != "" { + return true + } + if req.ByFederatesWith != nil || req.BySelectors != nil { + return true + } + return false +} + // ListRegistrationEntries lists all registrations (pagination available) func (ds *Plugin) ListRegistrationEntries(ctx context.Context, req *datastore.ListRegistrationEntriesRequest, @@ -2619,7 +2645,7 @@ ORDER BY selector_id, dns_name_id return query, []any{entryID}, nil } -func countRegistrationEntries(tx *gorm.DB) (int32, error) { +func countRegistrationEntries(tx *gorm.DB, req *datastore.CountRegistrationEntriesRequest) (int32, error) { var count int if err := tx.Model(&RegisteredEntry{}).Count(&count).Error; err != nil { return 0, sqlError.Wrap(err) diff --git a/pkg/server/datastore/sqlstore/sqlstore_test.go b/pkg/server/datastore/sqlstore/sqlstore_test.go index e7f04203eb..53708aaaa9 100644 --- a/pkg/server/datastore/sqlstore/sqlstore_test.go +++ b/pkg/server/datastore/sqlstore/sqlstore_test.go @@ -509,7 +509,7 @@ func (s *PluginSuite) TestCountAttestedNodes() { func (s *PluginSuite) TestCountRegistrationEntries() { // Count empty registration entries - count, err := s.ds.CountRegistrationEntries(ctx) + count, err := s.ds.CountRegistrationEntries(ctx, &datastore.CountRegistrationEntriesRequest{}) s.Require().NoError(err) s.Require().Equal(int32(0), count) @@ -531,7 +531,7 @@ func (s *PluginSuite) TestCountRegistrationEntries() { s.Require().NoError(err) // Count all - count, err = s.ds.CountRegistrationEntries(ctx) + count, err = s.ds.CountRegistrationEntries(ctx, &datastore.CountRegistrationEntriesRequest{}) s.Require().NoError(err) s.Require().Equal(int32(2), count) } diff --git a/test/fakes/fakedatastore/fakedatastore.go b/test/fakes/fakedatastore/fakedatastore.go index b0f8d89440..ef026bc960 100644 --- a/test/fakes/fakedatastore/fakedatastore.go +++ b/test/fakes/fakedatastore/fakedatastore.go @@ -238,11 +238,11 @@ func (s *DataStore) GetNodeSelectors(ctx context.Context, spiffeID string, dataC return selectors, err } -func (s *DataStore) CountRegistrationEntries(ctx context.Context) (int32, error) { +func (s *DataStore) CountRegistrationEntries(ctx context.Context, req *datastore.CountRegistrationEntriesRequest) (int32, error) { if err := s.getNextError(); err != nil { return 0, err } - return s.ds.CountRegistrationEntries(ctx) + return s.ds.CountRegistrationEntries(ctx, req) } func (s *DataStore) CreateRegistrationEntry(ctx context.Context, entry *common.RegistrationEntry) (*common.RegistrationEntry, error) { From 2291905fd9de9f891dda75384626a64a36b9e0bd Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Fri, 17 Nov 2023 10:29:57 -0300 Subject: [PATCH 02/17] add more fields Signed-off-by: FedeNQ --- pkg/server/api/entry/v1/service.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pkg/server/api/entry/v1/service.go b/pkg/server/api/entry/v1/service.go index 363b0ae8b9..a57cf8d5d5 100644 --- a/pkg/server/api/entry/v1/service.go +++ b/pkg/server/api/entry/v1/service.go @@ -789,6 +789,10 @@ func fieldsFromListEntryFilter(ctx context.Context, td spiffeid.TrustDomain, fil func fieldsFromCountEntryFilter(ctx context.Context, td spiffeid.TrustDomain, filter *entryv1.CountEntriesRequest_Filter) logrus.Fields { fields := logrus.Fields{} + if filter.ByHint != nil { + fields[telemetry.Hint] = filter.ByHint.Value + } + if filter.ByParentId != nil { if parentID, err := api.TrustDomainMemberIDFromProto(ctx, td, filter.ByParentId); err == nil { fields[telemetry.ParentID] = parentID.String() @@ -801,6 +805,16 @@ func fieldsFromCountEntryFilter(ctx context.Context, td spiffeid.TrustDomain, fi } } + if filter.BySelectors != nil { + fields[telemetry.BySelectorMatch] = filter.BySelectors.Match.String() + fields[telemetry.BySelectors] = api.SelectorFieldFromProto(filter.BySelectors.Selectors) + } + + if filter.ByFederatesWith != nil { + fields[telemetry.FederatesWithMatch] = filter.ByFederatesWith.Match.String() + fields[telemetry.FederatesWith] = strings.Join(filter.ByFederatesWith.TrustDomains, ",") + } + return fields } From c834729b6c880a2da7728a8b4a9c7631406fb3e6 Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Wed, 6 Dec 2023 13:47:34 -0300 Subject: [PATCH 03/17] Add filtering to entry & agent count/show/list commands Signed-off-by: FedeNQ --- .../cli/agent/agent_posix_test.go | 20 ++++ cmd/spire-server/cli/agent/agent_test.go | 75 ++++++++++++ cmd/spire-server/cli/agent/count.go | 84 +++++++++++++- cmd/spire-server/cli/agent/list.go | 56 ++++++++- cmd/spire-server/cli/agent/util.go | 20 ++++ cmd/spire-server/cli/entry/count.go | 6 + cmd/spire-server/cli/entry/show.go | 2 + cmd/spire-server/cli/entry/show_test.go | 24 +++- cmd/spire-server/cli/entry/util_posix_test.go | 2 + doc/spire_server.md | 26 ++++- pkg/common/cli/flags.go | 22 ++++ .../telemetry/server/datastore/wrapper.go | 4 +- .../server/datastore/wrapper_test.go | 2 +- pkg/server/api/agent/v1/service.go | 75 ++++++++++-- pkg/server/api/debug/v1/service.go | 3 +- pkg/server/api/entry/v1/service.go | 16 +++ pkg/server/datastore/datastore.go | 13 ++- pkg/server/datastore/sqlstore/sqlstore.go | 109 +++++++++++++++--- .../datastore/sqlstore/sqlstore_test.go | 4 +- test/fakes/fakedatastore/fakedatastore.go | 4 +- 20 files changed, 513 insertions(+), 54 deletions(-) create mode 100644 cmd/spire-server/cli/agent/util.go diff --git a/cmd/spire-server/cli/agent/agent_posix_test.go b/cmd/spire-server/cli/agent/agent_posix_test.go index a1e6568bc6..8697fa6607 100644 --- a/cmd/spire-server/cli/agent/agent_posix_test.go +++ b/cmd/spire-server/cli/agent/agent_posix_test.go @@ -14,6 +14,14 @@ var ( Path to the SPIRE Server API socket (default "/tmp/spire-server/private/api.sock") ` listUsage = `Usage of agent list: + -attestationType string + The SPIFFE ID of the nodes to list + -banned value + Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all. + -canReattest value + Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all. + -expiresBefore string + A date that indicates the time it should expired before, (format: YYYY-MM-DD) -matchSelectorsOn string The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") -output value @@ -40,8 +48,20 @@ var ( The SPIFFE ID of the agent to evict (agent identity) ` countUsage = `Usage of agent count: + -attestationType string + The SPIFFE ID of the nodes to count + -banned value + Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all. + -canReattest value + Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all. + -expiresBefore string + A date that indicates the time it should expired before, (format: YYYY-MM-DD) + -matchSelectorsOn string + The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") -output value Desired output format (pretty, json); default: pretty. + -selector value + A colon-delimited type:value selector. Can be used more than once -socketPath string Path to the SPIRE Server API socket (default "/tmp/spire-server/private/api.sock") ` diff --git a/cmd/spire-server/cli/agent/agent_test.go b/cmd/spire-server/cli/agent/agent_test.go index bf63ba0fd3..2eab7544d3 100644 --- a/cmd/spire-server/cli/agent/agent_test.go +++ b/cmd/spire-server/cli/agent/agent_test.go @@ -225,6 +225,24 @@ func TestCount(t *testing.T) { expectedReturnCode: 1, expectedStderr: common.AddrError, }, + { + name: "Count by expiresBefore: date is too long", + args: []string{"-expiresBefore", "2001-01-011"}, + expectedReturnCode: 1, + expectedStderr: "Error: date is not valid: date is too long\n", + }, + { + name: "Count by expiresBefore: date is too long", + args: []string{"-expiresBefore", "2001-01-0"}, + expectedReturnCode: 1, + expectedStderr: "Error: date is not valid: date is too short\n", + }, + { + name: "Count by expiresBefore: month out of range", + args: []string{"-expiresBefore", "2001-13-05"}, + expectedReturnCode: 1, + expectedStderr: "Error: date is not valid: parsing time \"2001-13-05\": month out of range\n", + }, } { for _, format := range availableFormats { t.Run(fmt.Sprintf("%s using %s format", tt.name, format), func(t *testing.T) { @@ -389,6 +407,45 @@ func TestList(t *testing.T) { expectedStdoutPretty: "Found 1 attested agent:\n\nSPIFFE ID : spiffe://example.org/spire/agent/agent1", expectedStdoutJSON: `{"agents":[{"id":{"trust_domain":"example.org","path":"/spire/agent/agent1"},"attestation_type":"","x509svid_serial_number":"","x509svid_expires_at":"0","selectors":[],"banned":false,"can_reattest":true}],"next_page_token":""}`, }, + { + name: "by expiresBefore", + args: []string{"-expiresBefore", "2000-01-01"}, + expectReq: &agentv1.ListAgentsRequest{ + Filter: &agentv1.ListAgentsRequest_Filter{ + ByExpiresBefore: "2000-01-01", + }, + PageSize: 1000, + }, + existentAgents: testAgents, + expectedStdoutPretty: "Found 1 attested agent:\n\nSPIFFE ID : spiffe://example.org/spire/agent/agent1", + expectedStdoutJSON: `{"agents":[{"id":{"trust_domain":"example.org","path":"/spire/agent/agent1"},"attestation_type":"","x509svid_serial_number":"","x509svid_expires_at":"0","selectors":[],"banned":false,"can_reattest":true}],"next_page_token":""}`, + }, + { + name: "by banned", + args: []string{"-banned", "true"}, + expectReq: &agentv1.ListAgentsRequest{ + Filter: &agentv1.ListAgentsRequest_Filter{ + ByBanned: wrapperspb.Bool(true), + }, + PageSize: 1000, + }, + existentAgents: testAgentsWithBanned, + expectedStdoutPretty: "Found 1 attested agent:\n\nSPIFFE ID : spiffe://example.org/spire/agent/banned", + expectedStdoutJSON: `{"agents":[{"id":{"trust_domain":"example.org","path":"/spire/agent/banned"},"attestation_type":"","x509svid_serial_number":"","x509svid_expires_at":"0","selectors":[],"banned":true,"can_reattest":false}],"next_page_token":""}`, + }, + { + name: "by canReattest", + args: []string{"-canReattest", "true"}, + expectReq: &agentv1.ListAgentsRequest{ + Filter: &agentv1.ListAgentsRequest_Filter{ + ByCanReattest: wrapperspb.Bool(true), + }, + PageSize: 1000, + }, + existentAgents: testAgents, + expectedStdoutPretty: "Found 1 attested agent:\n\nSPIFFE ID : spiffe://example.org/spire/agent/agent1", + expectedStdoutJSON: `{"agents":[{"id":{"trust_domain":"example.org","path":"/spire/agent/agent1"},"attestation_type":"","x509svid_serial_number":"","x509svid_expires_at":"0","selectors":[],"banned":false,"can_reattest":true}],"next_page_token":""}`, + }, { name: "List by selectors: Invalid matcher", args: []string{"-selector", "foo:bar", "-selector", "bar:baz", "-matchSelectorsOn", "NO-MATCHER"}, @@ -407,6 +464,24 @@ func TestList(t *testing.T) { expectedReturnCode: 1, expectedStderr: common.AddrError, }, + { + name: "List by expiresBefore: date is too long", + args: []string{"-expiresBefore", "2001-01-011"}, + expectedReturnCode: 1, + expectedStderr: "Error: date is not valid: date is too long\n", + }, + { + name: "List by expiresBefore: date is too long", + args: []string{"-expiresBefore", "2001-01-0"}, + expectedReturnCode: 1, + expectedStderr: "Error: date is not valid: date is too short\n", + }, + { + name: "List by expiresBefore: month out of range", + args: []string{"-expiresBefore", "2001-13-05"}, + expectedReturnCode: 1, + expectedStderr: "Error: date is not valid: parsing time \"2001-13-05\": month out of range\n", + }, } { for _, format := range availableFormats { t.Run(fmt.Sprintf("%s using %s format", tt.name, format), func(t *testing.T) { diff --git a/cmd/spire-server/cli/agent/count.go b/cmd/spire-server/cli/agent/count.go index 4b46f77f0d..a3e3649c45 100644 --- a/cmd/spire-server/cli/agent/count.go +++ b/cmd/spire-server/cli/agent/count.go @@ -8,13 +8,35 @@ import ( "github.com/mitchellh/cli" agentv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/agent/v1" + "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/spiffe/spire/cmd/spire-server/util" commoncli "github.com/spiffe/spire/pkg/common/cli" "github.com/spiffe/spire/pkg/common/cliprinter" + "google.golang.org/protobuf/types/known/wrapperspb" ) type countCommand struct { - env *commoncli.Env + // Type and value are delimited by a colon (:) + // ex. "unix:uid:1000" or "spiffe_id:spiffe://example.org/foo" + selectors commoncli.StringsFlag + + // Match used when filtering by selectors + matchSelectorsOn string + + // Filters agents to those that are banned. + banned commoncli.BoolFlag + + // Filters agents by those expires before. + expiresBefore string + + // Filters agents to those matching the attestation type. + attestationType string + + // Filters agents that can re-attest. + canReattest commoncli.BoolFlag + + env *commoncli.Env + printer cliprinter.Printer } @@ -39,8 +61,60 @@ func (*countCommand) Synopsis() string { // Run counts attested agents func (c *countCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient util.ServerClient) error { + filter := &agentv1.CountAgentsRequest_Filter{} + if len(c.selectors) > 0 { + matchBehavior, err := parseToSelectorMatch(c.matchSelectorsOn) + if err != nil { + return err + } + + selectors := make([]*types.Selector, len(c.selectors)) + for i, sel := range c.selectors { + selector, err := util.ParseSelector(sel) + if err != nil { + return fmt.Errorf("error parsing selector %q: %w", sel, err) + } + selectors[i] = selector + } + filter.BySelectorMatch = &types.SelectorMatch{ + Selectors: selectors, + Match: matchBehavior, + } + } + + if c.expiresBefore != "" { + err := validate(c.expiresBefore) + if err != nil { + return fmt.Errorf("date is not valid: %w", err) + } + filter.ByExpiresBefore = c.expiresBefore + } + + if c.attestationType != "" { + filter.ByAttestationType = c.attestationType + } + + // 0: all, 1: can't reattest, 2: can reattest + if c.canReattest == 1 { + filter.ByCanReattest = wrapperspb.Bool(false) + } + if c.canReattest == 2 { + filter.ByCanReattest = wrapperspb.Bool(true) + } + + // 0: all, 1: no-banned, 2: banned + if c.banned == 1 { + filter.ByBanned = wrapperspb.Bool(false) + } + if c.banned == 2 { + filter.ByBanned = wrapperspb.Bool(true) + } + agentClient := serverClient.NewAgentClient() - countResponse, err := agentClient.CountAgents(ctx, &agentv1.CountAgentsRequest{}) + + countResponse, err := agentClient.CountAgents(ctx, &agentv1.CountAgentsRequest{ + Filter: filter, + }) if err != nil { return err } @@ -49,6 +123,12 @@ func (c *countCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient u } func (c *countCommand) AppendFlags(fs *flag.FlagSet) { + fs.Var(&c.selectors, "selector", "A colon-delimited type:value selector. Can be used more than once") + fs.StringVar(&c.attestationType, "attestationType", "", "The SPIFFE ID of the nodes to count") + fs.Var(&c.canReattest, "canReattest", "Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all.") + fs.Var(&c.banned, "banned", "Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all.") + fs.StringVar(&c.expiresBefore, "expiresBefore", "", "A date that indicates the time it should expired before, (format: YYYY-MM-DD)") + fs.StringVar(&c.matchSelectorsOn, "matchSelectorsOn", "superset", "The match mode used when filtering by selectors. Options: exact, any, superset and subset") cliprinter.AppendFlagWithCustomPretty(&c.printer, fs, c.env, prettyPrintCount) } diff --git a/cmd/spire-server/cli/agent/list.go b/cmd/spire-server/cli/agent/list.go index 80b66989ec..d891a69077 100644 --- a/cmd/spire-server/cli/agent/list.go +++ b/cmd/spire-server/cli/agent/list.go @@ -14,16 +14,32 @@ import ( commoncli "github.com/spiffe/spire/pkg/common/cli" "github.com/spiffe/spire/pkg/common/cliprinter" "github.com/spiffe/spire/pkg/common/idutil" + "google.golang.org/protobuf/types/known/wrapperspb" ) type listCommand struct { - env *commoncli.Env // Type and value are delimited by a colon (:) // ex. "unix:uid:1000" or "spiffe_id:spiffe://example.org/foo" selectors commoncli.StringsFlag - // Match used when filtering agents by selectors + + // Match used when filtering by selectors matchSelectorsOn string - printer cliprinter.Printer + + // Filters agents to those that are banned. + banned commoncli.BoolFlag + + // Filters agents by those expires before. + expiresBefore string + + // Filters agents to those matching the attestation type. + attestationType string + + // Filters agents that can re-attest. + canReattest commoncli.BoolFlag + + env *commoncli.Env + + printer cliprinter.Printer } // NewListCommand creates a new "list" subcommand for "agent" command. @@ -68,6 +84,34 @@ func (c *listCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient ut } } + if c.expiresBefore != "" { + err := validate(c.expiresBefore) + if err != nil { + return fmt.Errorf("date is not valid: %w", err) + } + filter.ByExpiresBefore = c.expiresBefore + } + + if c.attestationType != "" { + filter.ByAttestationType = c.attestationType + } + + // 0: all, 1: can't reattest, 2: can reattest + if c.canReattest == 1 { + filter.ByCanReattest = wrapperspb.Bool(false) + } + if c.canReattest == 2 { + filter.ByCanReattest = wrapperspb.Bool(true) + } + + // 0: all, 1: no-banned, 2: banned + if c.banned == 1 { + filter.ByBanned = wrapperspb.Bool(false) + } + if c.banned == 2 { + filter.ByBanned = wrapperspb.Bool(true) + } + agentClient := serverClient.NewAgentClient() pageToken := "" @@ -91,8 +135,12 @@ func (c *listCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient ut } func (c *listCommand) AppendFlags(fs *flag.FlagSet) { - fs.StringVar(&c.matchSelectorsOn, "matchSelectorsOn", "superset", "The match mode used when filtering by selectors. Options: exact, any, superset and subset") fs.Var(&c.selectors, "selector", "A colon-delimited type:value selector. Can be used more than once") + fs.StringVar(&c.attestationType, "attestationType", "", "The SPIFFE ID of the nodes to list") + fs.Var(&c.canReattest, "canReattest", "Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all.") + fs.Var(&c.banned, "banned", "Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all.") + fs.StringVar(&c.expiresBefore, "expiresBefore", "", "A date that indicates the time it should expired before, (format: YYYY-MM-DD)") + fs.StringVar(&c.matchSelectorsOn, "matchSelectorsOn", "superset", "The match mode used when filtering by selectors. Options: exact, any, superset and subset") cliprinter.AppendFlagWithCustomPretty(&c.printer, fs, c.env, prettyPrintAgents) } diff --git a/cmd/spire-server/cli/agent/util.go b/cmd/spire-server/cli/agent/util.go new file mode 100644 index 0000000000..e26ceb8b1c --- /dev/null +++ b/cmd/spire-server/cli/agent/util.go @@ -0,0 +1,20 @@ +package agent + +import ( + "fmt" + "time" +) + +func validate(s string) error { + if len(s) < 10 { + return fmt.Errorf("date is too short") + } + if len(s) > 10 { + return fmt.Errorf("date is too long") + } + _, err := time.Parse("2006-01-02", s) + if err != nil { + return err + } + return nil +} diff --git a/cmd/spire-server/cli/entry/count.go b/cmd/spire-server/cli/entry/count.go index cf01ec8107..06a5d748f5 100644 --- a/cmd/spire-server/cli/entry/count.go +++ b/cmd/spire-server/cli/entry/count.go @@ -31,6 +31,9 @@ type countCommand struct { // List of SPIFFE IDs of trust domains the registration entry is federated with federatesWith StringsFlag + // Whether or not the entry is for a downstream SPIRE server + downstream bool + // Match used when filtering by federates with matchFederatesWithOn string @@ -101,6 +104,8 @@ func (c *countCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient u } } + filter.ByDownstream = wrapperspb.Bool(c.downstream) + if len(c.federatesWith) > 0 { matchFederatesWithBehavior, err := parseToFederatesWithMatch(c.matchFederatesWithOn) if err != nil { @@ -131,6 +136,7 @@ func (c *countCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient u func (c *countCommand) AppendFlags(fs *flag.FlagSet) { fs.StringVar(&c.parentID, "parentID", "", "The Parent ID of the records to count") fs.StringVar(&c.spiffeID, "spiffeID", "", "The SPIFFE ID of the records to count") + fs.BoolVar(&c.downstream, "downstream", false, "A boolean value that, when set, indicates that the entry describes a downstream SPIRE server") fs.Var(&c.selectors, "selector", "A colon-delimited type:value selector. Can be used more than once") fs.Var(&c.federatesWith, "federatesWith", "SPIFFE ID of a trust domain an entry is federate with. Can be used more than once") fs.StringVar(&c.matchFederatesWithOn, "matchFederatesWithOn", "superset", "The match mode used when filtering by federates with. Options: exact, any, superset and subset") diff --git a/cmd/spire-server/cli/entry/show.go b/cmd/spire-server/cli/entry/show.go index 5cbfdf10e8..94f5503854 100644 --- a/cmd/spire-server/cli/entry/show.go +++ b/cmd/spire-server/cli/entry/show.go @@ -175,6 +175,8 @@ func (c *showCommand) fetchEntries(ctx context.Context, client entryv1.EntryClie filter.ByHint = wrapperspb.String(c.hint) } + filter.ByDownstream = wrapperspb.Bool(c.downstream) + pageToken := "" for { diff --git a/cmd/spire-server/cli/entry/show_test.go b/cmd/spire-server/cli/entry/show_test.go index 418889cb3f..d2fee9b325 100644 --- a/cmd/spire-server/cli/entry/show_test.go +++ b/cmd/spire-server/cli/entry/show_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/wrapperspb" ) func TestShowHelp(t *testing.T) { @@ -61,7 +62,9 @@ func TestShow(t *testing.T) { name: "List all entries (empty filter)", expListReq: &entryv1.ListEntriesRequest{ PageSize: listEntriesRequestPageSize, - Filter: &entryv1.ListEntriesRequest_Filter{}, + Filter: &entryv1.ListEntriesRequest_Filter{ + ByDownstream: wrapperspb.Bool(false), + }, }, fakeListResp: fakeRespAll, expOutPretty: fmt.Sprintf("Found 4 entries\n%s%s%s%s", @@ -103,7 +106,8 @@ func TestShow(t *testing.T) { expListReq: &entryv1.ListEntriesRequest{ PageSize: listEntriesRequestPageSize, Filter: &entryv1.ListEntriesRequest_Filter{ - ByParentId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/father"}, + ByParentId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/father"}, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespFather, @@ -124,7 +128,8 @@ func TestShow(t *testing.T) { expListReq: &entryv1.ListEntriesRequest{ PageSize: listEntriesRequestPageSize, Filter: &entryv1.ListEntriesRequest_Filter{ - BySpiffeId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/daughter"}, + BySpiffeId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/daughter"}, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespDaughter, @@ -152,6 +157,7 @@ func TestShow(t *testing.T) { }, Match: types.SelectorMatch_MATCH_SUPERSET, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespFatherDaughter, @@ -173,6 +179,7 @@ func TestShow(t *testing.T) { }, Match: types.SelectorMatch_MATCH_EXACT, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespFatherDaughter, @@ -194,6 +201,7 @@ func TestShow(t *testing.T) { }, Match: types.SelectorMatch_MATCH_SUPERSET, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespFatherDaughter, @@ -215,6 +223,7 @@ func TestShow(t *testing.T) { }, Match: types.SelectorMatch_MATCH_SUBSET, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespFatherDaughter, @@ -236,6 +245,7 @@ func TestShow(t *testing.T) { }, Match: types.SelectorMatch_MATCH_ANY, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespFatherDaughter, @@ -260,7 +270,8 @@ func TestShow(t *testing.T) { expListReq: &entryv1.ListEntriesRequest{ PageSize: listEntriesRequestPageSize, Filter: &entryv1.ListEntriesRequest_Filter{ - BySpiffeId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/daughter"}, + BySpiffeId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/daughter"}, + ByDownstream: wrapperspb.Bool(false), }, }, serverErr: status.Error(codes.Internal, "internal server error"), @@ -276,6 +287,7 @@ func TestShow(t *testing.T) { TrustDomains: []string{"spiffe://domain.test"}, Match: types.FederatesWithMatch_MATCH_SUPERSET, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespMotherDaughter, @@ -294,6 +306,7 @@ func TestShow(t *testing.T) { TrustDomains: []string{"spiffe://domain.test"}, Match: types.FederatesWithMatch_MATCH_EXACT, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespMotherDaughter, @@ -312,6 +325,7 @@ func TestShow(t *testing.T) { TrustDomains: []string{"spiffe://domain.test"}, Match: types.FederatesWithMatch_MATCH_ANY, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespMotherDaughter, @@ -330,6 +344,7 @@ func TestShow(t *testing.T) { TrustDomains: []string{"spiffe://domain.test"}, Match: types.FederatesWithMatch_MATCH_SUPERSET, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespMotherDaughter, @@ -348,6 +363,7 @@ func TestShow(t *testing.T) { TrustDomains: []string{"spiffe://domain.test"}, Match: types.FederatesWithMatch_MATCH_SUBSET, }, + ByDownstream: wrapperspb.Bool(false), }, }, fakeListResp: fakeRespMotherDaughter, diff --git a/cmd/spire-server/cli/entry/util_posix_test.go b/cmd/spire-server/cli/entry/util_posix_test.go index 8e9f7be66a..d55647482c 100644 --- a/cmd/spire-server/cli/entry/util_posix_test.go +++ b/cmd/spire-server/cli/entry/util_posix_test.go @@ -112,6 +112,8 @@ const ( Path to the SPIRE Server API socket (default "/tmp/spire-server/private/api.sock") ` countUsage = `Usage of entry count: + -downstream + A boolean value that, when set, indicates that the entry describes a downstream SPIRE server -federatesWith value SPIFFE ID of a trust domain an entry is federate with. Can be used more than once -hint string diff --git a/doc/spire_server.md b/doc/spire_server.md index 180d75ffb9..3190d3eb86 100644 --- a/doc/spire_server.md +++ b/doc/spire_server.md @@ -352,9 +352,15 @@ Updates registration entries. Displays the total number of registration entries. -| Command | Action | Default | -|:--------------|:------------------------------------|:-----------------------------------| -| `-socketPath` | Path to the SPIRE Server API socket | /tmp/spire-server/private/api.sock | +| Command | Action | Default | +|:-----------------|:-------------------------------------------------------------------------------------------------|:-----------------------------------| +| `-downstream` | A boolean value that, when set, indicates that the entry describes a downstream SPIRE server | | +| `-federatesWith` | SPIFFE ID of a trust domain an entry is federate with. Can be used more than once | | +| `-parentID` | The Parent ID of the records to count. | | +| `-selector` | A colon-delimited type:value selector. Can be used more than once to specify multiple selectors. | | +| `-socketPath` | Path to the SPIRE Server API socket | /tmp/spire-server/private/api.sock | +| `-spiffeID` | The SPIFFE ID of the records to count. | | + ### `spire-server entry delete` @@ -508,7 +514,11 @@ Displays the total number of attested nodes. | Command | Action | Default | |:--------------|:------------------------------------|:-----------------------------------| -| `-socketPath` | Path to the SPIRE Server API socket | /tmp/spire-server/private/api.sock | +| `-selector` | A colon-delimited type:value selector. Can be used more than once to specify multiple selectors. | | +| `-canReattest` | Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all | | +| `-banned` | Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all | | +| `-expiresBefore` | A date that indicates the time it should expired before. (format: YYYY-MM-DD) | | +| `-spiffeID` | The SPIFFE ID of the records to count. ### `spire-server agent evict` @@ -525,7 +535,13 @@ Displays attested nodes. | Command | Action | Default | |:--------------|:------------------------------------|:-----------------------------------| -| `-socketPath` | Path to the SPIRE Server API socket | /tmp/spire-server/private/api.sock | +| Command | Action | Default | +|:--------------|:------------------------------------|:-----------------------------------| +| `-selector` | A colon-delimited type:value selector. Can be used more than once to specify multiple selectors. | | +| `-canReattest` | Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all | | +| `-banned` | Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all | | +| `-expiresBefore` | A date that indicates the time it should expired before. (format: YYYY-MM-DD)| | +| `-attestationType` | Filters agents to those matching the attestation type. | | ### `spire-server agent show` diff --git a/pkg/common/cli/flags.go b/pkg/common/cli/flags.go index 0b336e7f46..13a1132791 100644 --- a/pkg/common/cli/flags.go +++ b/pkg/common/cli/flags.go @@ -45,3 +45,25 @@ func (s *StringsFlag) Set(val string) error { *s = append(*s, val) return nil } + +// BoolFlag is used to define 3 possible states: true, false, or all +// take care that false=1, and true=2 +type BoolFlag int + +func (b *BoolFlag) String() string { + return "" +} + +func (b *BoolFlag) Set(val string) error { + if val == "false" { + *b = 1 + return nil + } + if val == "true" { + *b = 2 + return nil + } + // if the value received isn't true or false, it will set the default value + *b = 0 + return nil +} diff --git a/pkg/common/telemetry/server/datastore/wrapper.go b/pkg/common/telemetry/server/datastore/wrapper.go index 3b01392ee4..d97ce5bf88 100644 --- a/pkg/common/telemetry/server/datastore/wrapper.go +++ b/pkg/common/telemetry/server/datastore/wrapper.go @@ -186,10 +186,10 @@ func (w metricsWrapper) ListRegistrationEntriesEvents(ctx context.Context, req * return w.ds.ListRegistrationEntriesEvents(ctx, req) } -func (w metricsWrapper) CountAttestedNodes(ctx context.Context) (_ int32, err error) { +func (w metricsWrapper) CountAttestedNodes(ctx context.Context, req *datastore.CountAttestedNodesRequest) (_ int32, err error) { callCounter := StartCountNodeCall(w.m) defer callCounter.Done(&err) - return w.ds.CountAttestedNodes(ctx) + return w.ds.CountAttestedNodes(ctx, req) } func (w metricsWrapper) CountBundles(ctx context.Context) (_ int32, err error) { diff --git a/pkg/common/telemetry/server/datastore/wrapper_test.go b/pkg/common/telemetry/server/datastore/wrapper_test.go index eadfd8c6e4..8c7d7fcc24 100644 --- a/pkg/common/telemetry/server/datastore/wrapper_test.go +++ b/pkg/common/telemetry/server/datastore/wrapper_test.go @@ -318,7 +318,7 @@ func (ds *fakeDataStore) AppendBundle(context.Context, *common.Bundle) (*common. return &common.Bundle{}, ds.err } -func (ds *fakeDataStore) CountAttestedNodes(context.Context) (int32, error) { +func (ds *fakeDataStore) CountAttestedNodes(context.Context, *datastore.CountAttestedNodesRequest) (int32, error) { return 0, ds.err } diff --git a/pkg/server/api/agent/v1/service.go b/pkg/server/api/agent/v1/service.go index 509ab874cd..572dce74f5 100644 --- a/pkg/server/api/agent/v1/service.go +++ b/pkg/server/api/agent/v1/service.go @@ -71,8 +71,42 @@ func RegisterService(s grpc.ServiceRegistrar, service *Service) { } // CountAgents returns the total number of agents. -func (s *Service) CountAgents(ctx context.Context, _ *agentv1.CountAgentsRequest) (*agentv1.CountAgentsResponse, error) { - count, err := s.ds.CountAttestedNodes(ctx) +func (s *Service) CountAgents(ctx context.Context, req *agentv1.CountAgentsRequest) (*agentv1.CountAgentsResponse, error) { + + log := rpccontext.Logger(ctx) + + countReq := &datastore.CountAttestedNodesRequest{} + + // Parse proto filter into datastore request + if req.Filter != nil { + filter := req.Filter + rpccontext.AddRPCAuditFields(ctx, fieldsFromCountAgentsRequest(filter)) + + if filter.ByBanned != nil { + countReq.ByBanned = &req.Filter.ByBanned.Value + } + if filter.ByCanReattest != nil { + countReq.ByCanReattest = &req.Filter.ByCanReattest.Value + } + + countReq.ByAttestationType = filter.ByAttestationType + + // err is verified previously + countReq.ByExpiresBefore, _ = time.Parse("2006-01-02", filter.ByExpiresBefore) + + if filter.BySelectorMatch != nil { + selectors, err := api.SelectorsFromProto(filter.BySelectorMatch.Selectors) + if err != nil { + return nil, api.MakeErr(log, codes.InvalidArgument, "failed to parse selectors", err) + } + countReq.BySelectorMatch = &datastore.BySelectors{ + Match: datastore.MatchBehavior(filter.BySelectorMatch.Match), + Selectors: selectors, + } + } + } + + count, err := s.ds.CountAttestedNodes(ctx, countReq) if err != nil { log := rpccontext.Logger(ctx) return nil, api.MakeErr(log, codes.Internal, "failed to count agents", err) @@ -94,20 +128,18 @@ func (s *Service) ListAgents(ctx context.Context, req *agentv1.ListAgentsRequest // Parse proto filter into datastore request if req.Filter != nil { filter := req.Filter - rpccontext.AddRPCAuditFields(ctx, fieldsFromFilterRequest(filter)) + rpccontext.AddRPCAuditFields(ctx, fieldsFromListAgentsRequest(filter)) - var byBanned *bool if filter.ByBanned != nil { - byBanned = &filter.ByBanned.Value + listReq.ByBanned = &req.Filter.ByBanned.Value } - var byCanReattest *bool if filter.ByCanReattest != nil { - byCanReattest = &filter.ByCanReattest.Value + listReq.ByCanReattest = &req.Filter.ByCanReattest.Value } listReq.ByAttestationType = filter.ByAttestationType - listReq.ByBanned = byBanned - listReq.ByCanReattest = byCanReattest + // err is verified previously + listReq.ByExpiresBefore, _ = time.Parse("2006-01-02", filter.ByExpiresBefore) if filter.BySelectorMatch != nil { selectors, err := api.SelectorsFromProto(filter.BySelectorMatch.Selectors) @@ -687,7 +719,30 @@ func getAttestAgentResponse(spiffeID spiffeid.ID, certificates []*x509.Certifica } } -func fieldsFromFilterRequest(filter *agentv1.ListAgentsRequest_Filter) logrus.Fields { +func fieldsFromListAgentsRequest(filter *agentv1.ListAgentsRequest_Filter) logrus.Fields { + fields := logrus.Fields{} + + if filter.ByAttestationType != "" { + fields[telemetry.NodeAttestorType] = filter.ByAttestationType + } + + if filter.ByBanned != nil { + fields[telemetry.ByBanned] = filter.ByBanned.Value + } + + if filter.ByCanReattest != nil { + fields[telemetry.ByCanReattest] = filter.ByCanReattest.Value + } + + if filter.BySelectorMatch != nil { + fields[telemetry.BySelectorMatch] = filter.BySelectorMatch.Match.String() + fields[telemetry.BySelectors] = api.SelectorFieldFromProto(filter.BySelectorMatch.Selectors) + } + + return fields +} + +func fieldsFromCountAgentsRequest(filter *agentv1.CountAgentsRequest_Filter) logrus.Fields { fields := logrus.Fields{} if filter.ByAttestationType != "" { diff --git a/pkg/server/api/debug/v1/service.go b/pkg/server/api/debug/v1/service.go index 85d4d64e29..163be708d3 100644 --- a/pkg/server/api/debug/v1/service.go +++ b/pkg/server/api/debug/v1/service.go @@ -78,11 +78,10 @@ func (s *Service) GetInfo(ctx context.Context, _ *debugv1.GetInfoRequest) (*debu // Update cache when expired or does not exists if s.getInfoResp.ts.IsZero() || s.clock.Now().Sub(s.getInfoResp.ts) >= cacheExpiry { - nodes, err := s.ds.CountAttestedNodes(ctx) + nodes, err := s.ds.CountAttestedNodes(ctx, &datastore.CountAttestedNodesRequest{}) if err != nil { return nil, api.MakeErr(log, codes.Internal, "failed to count agents", err) } - entries, err := s.ds.CountRegistrationEntries(ctx, &datastore.CountRegistrationEntriesRequest{}) if err != nil { return nil, api.MakeErr(log, codes.Internal, "failed to count entries", err) diff --git a/pkg/server/api/entry/v1/service.go b/pkg/server/api/entry/v1/service.go index a57cf8d5d5..180d66c241 100644 --- a/pkg/server/api/entry/v1/service.go +++ b/pkg/server/api/entry/v1/service.go @@ -118,6 +118,10 @@ func (s *Service) CountEntries(ctx context.Context, req *entryv1.CountEntriesReq TrustDomains: trustDomains, } } + + if req.Filter.ByDownstream != nil { + countReq.ByDownstream = &req.Filter.ByDownstream.Value + } } count, err := s.ds.CountRegistrationEntries(ctx, countReq) @@ -197,6 +201,10 @@ func (s *Service) ListEntries(ctx context.Context, req *entryv1.ListEntriesReque TrustDomains: trustDomains, } } + + if req.Filter.ByDownstream != nil { + listReq.ByDownstream = &req.Filter.ByDownstream.Value + } } dsResp, err := s.ds.ListRegistrationEntries(ctx, listReq) @@ -783,6 +791,10 @@ func fieldsFromListEntryFilter(ctx context.Context, td spiffeid.TrustDomain, fil fields[telemetry.FederatesWith] = strings.Join(filter.ByFederatesWith.TrustDomains, ",") } + if filter.ByDownstream != nil { + fields[telemetry.Downstream] = &filter.ByDownstream.Value + } + return fields } @@ -815,6 +827,10 @@ func fieldsFromCountEntryFilter(ctx context.Context, td spiffeid.TrustDomain, fi fields[telemetry.FederatesWith] = strings.Join(filter.ByFederatesWith.TrustDomains, ",") } + if filter.ByDownstream != nil { + fields[telemetry.Downstream] = &filter.ByDownstream.Value + } + return fields } diff --git a/pkg/server/datastore/datastore.go b/pkg/server/datastore/datastore.go index 9a014ce1d8..4b81e85cf5 100644 --- a/pkg/server/datastore/datastore.go +++ b/pkg/server/datastore/datastore.go @@ -46,7 +46,7 @@ type DataStore interface { GetLatestRegistrationEntryEventID(ctx context.Context) (uint, error) // Nodes - CountAttestedNodes(context.Context) (int32, error) + CountAttestedNodes(context.Context, *CountAttestedNodesRequest) (int32, error) CreateAttestedNode(context.Context, *common.AttestedNode) (*common.AttestedNode, error) DeleteAttestedNode(ctx context.Context, spiffeID string) (*common.AttestedNode, error) FetchAttestedNode(ctx context.Context, spiffeID string) (*common.AttestedNode, error) @@ -206,6 +206,7 @@ type ListRegistrationEntriesRequest struct { Pagination *Pagination ByFederatesWith *ByFederatesWith ByHint string + ByDownstream *bool } type CAJournal struct { @@ -242,6 +243,15 @@ type ListFederationRelationshipsResponse struct { Pagination *Pagination } +type CountAttestedNodesRequest struct { + ByAttestationType string + ByBanned *bool + ByExpiresBefore time.Time + BySelectorMatch *BySelectors + FetchSelectors bool + ByCanReattest *bool +} + type CountRegistrationEntriesRequest struct { DataConsistency DataConsistency ByParentID string @@ -249,6 +259,7 @@ type CountRegistrationEntriesRequest struct { BySpiffeID string ByFederatesWith *ByFederatesWith ByHint string + ByDownstream *bool } type BundleEndpointType string diff --git a/pkg/server/datastore/sqlstore/sqlstore.go b/pkg/server/datastore/sqlstore/sqlstore.go index b2c033781b..ea9d17f347 100644 --- a/pkg/server/datastore/sqlstore/sqlstore.go +++ b/pkg/server/datastore/sqlstore/sqlstore.go @@ -282,7 +282,18 @@ func (ds *Plugin) FetchAttestedNode(ctx context.Context, spiffeID string) (attes } // CountAttestedNodes counts all attested nodes -func (ds *Plugin) CountAttestedNodes(ctx context.Context) (count int32, err error) { +func (ds *Plugin) CountAttestedNodes(ctx context.Context, req *datastore.CountAttestedNodesRequest) (count int32, err error) { + if countAttestedNodesHasFilters(req) { + resp, err := listAttestedNodes(ctx, ds.db, ds.log, &datastore.ListAttestedNodesRequest{ + ByAttestationType: req.ByAttestationType, + ByBanned: req.ByBanned, + ByExpiresBefore: req.ByExpiresBefore, + BySelectorMatch: req.BySelectorMatch, + FetchSelectors: req.FetchSelectors, + ByCanReattest: req.ByCanReattest, + }) + return int32(len(resp.Nodes)), err + } if err = ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { count, err = countAttestedNodes(tx) return err @@ -447,11 +458,12 @@ func (ds *Plugin) FetchRegistrationEntry(ctx context.Context, // CountRegistrationEntries counts all registrations (pagination available) func (ds *Plugin) CountRegistrationEntries(ctx context.Context, req *datastore.CountRegistrationEntriesRequest) (count int32, err error) { - if hasFilters(req) { + if countRegistrationEntriesHasFilters(req) { var actDb = ds.db if req.DataConsistency == datastore.TolerateStale && ds.roDb != nil { actDb = ds.roDb } + resp, err := listRegistrationEntries(ctx, actDb, ds.log, &datastore.ListRegistrationEntriesRequest{ DataConsistency: req.DataConsistency, ByParentID: req.ByParentID, @@ -459,6 +471,7 @@ func (ds *Plugin) CountRegistrationEntries(ctx context.Context, req *datastore.C BySpiffeID: req.BySpiffeID, ByFederatesWith: req.ByFederatesWith, ByHint: req.ByHint, + ByDownstream: req.ByDownstream, }) return int32(len(resp.Entries)), err } @@ -473,16 +486,6 @@ func (ds *Plugin) CountRegistrationEntries(ctx context.Context, req *datastore.C return count, nil } -func hasFilters(req *datastore.CountRegistrationEntriesRequest) bool { - if req.ByParentID != "" || req.ByHint != "" || req.BySpiffeID != "" { - return true - } - if req.ByFederatesWith != nil || req.BySelectors != nil { - return true - } - return false -} - // ListRegistrationEntries lists all registrations (pagination available) func (ds *Plugin) ListRegistrationEntries(ctx context.Context, req *datastore.ListRegistrationEntriesRequest, @@ -1541,6 +1544,16 @@ func countAttestedNodes(tx *gorm.DB) (int32, error) { return int32(count), nil } +func countAttestedNodesHasFilters(req *datastore.CountAttestedNodesRequest) bool { + if req.ByAttestationType != "" || req.ByBanned != nil || !req.ByExpiresBefore.IsZero() { + return true + } + if req.BySelectorMatch != nil || !req.FetchSelectors || req.ByCanReattest != nil { + return true + } + return false +} + func listAttestedNodes(ctx context.Context, db *sqlDB, log logrus.FieldLogger, req *datastore.ListAttestedNodesRequest) (*datastore.ListAttestedNodesResponse, error) { if req.Pagination != nil && req.Pagination.PageSize == 0 { return nil, status.Error(codes.InvalidArgument, "cannot paginate with pagesize = 0") @@ -1731,7 +1744,6 @@ func listAttestedNodesOnce(ctx context.Context, db *sqlDB, req *datastore.ListAt resp.Pagination.Token = strconv.FormatUint(lastEID, 10) } } - return resp, nil } @@ -1789,7 +1801,6 @@ func buildListAttestedNodesQueryCTE(req *datastore.ListAttestedNodesRequest, dbT builder.WriteString("\t\tAND data_type = ?\n") args = append(args, req.ByAttestationType) } - // Filter by banned, an Attestation Node is banned when serial number is empty. // This filter allows 3 outputs: // - nil: returns all @@ -1802,8 +1813,11 @@ func buildListAttestedNodesQueryCTE(req *datastore.ListAttestedNodesRequest, dbT builder.WriteString("\t\tAND serial_number <> ''\n") } } - - // Filter by CanReattest. This is similar to ByBanned + // Filter by canReattest, + // This filter allows 3 outputs: + // - nil: returns all + // - true: returns nodes with canReattest=true + // - false: returns nodes with canReattest=false if req.ByCanReattest != nil { if *req.ByCanReattest { builder.WriteString("\t\tAND can_reattest = true\n") @@ -1951,7 +1965,6 @@ SELECT } builder.WriteString("\n) ORDER BY id ASC\n") - return builder.String(), args, nil } @@ -2654,6 +2667,16 @@ func countRegistrationEntries(tx *gorm.DB, req *datastore.CountRegistrationEntri return int32(count), nil } +func countRegistrationEntriesHasFilters(req *datastore.CountRegistrationEntriesRequest) bool { + if req.ByParentID != "" || req.ByHint != "" || req.BySpiffeID != "" { + return true + } + if req.ByFederatesWith != nil || req.BySelectors != nil || req.ByDownstream != nil { + return true + } + return false +} + func listRegistrationEntries(ctx context.Context, db *sqlDB, log logrus.FieldLogger, req *datastore.ListRegistrationEntriesRequest) (*datastore.ListRegistrationEntriesResponse, error) { if req.Pagination != nil && req.Pagination.PageSize == 0 { return nil, status.Error(codes.InvalidArgument, "cannot paginate with pagesize = 0") @@ -2810,6 +2833,7 @@ func listRegistrationEntriesOnce(ctx context.Context, db queryContext, databaseT } func buildListRegistrationEntriesQuery(dbType string, supportsCTE bool, req *datastore.ListRegistrationEntriesRequest) (string, []any, error) { + //TODO: check how to add downstream to all querys switch dbType { case SQLite: // The SQLite3 queries unconditionally leverage CTE since the @@ -2831,8 +2855,12 @@ func buildListRegistrationEntriesQuery(dbType string, supportsCTE bool, req *dat func buildListRegistrationEntriesQuerySQLite3(req *datastore.ListRegistrationEntriesRequest) (string, []any, error) { builder := new(strings.Builder) - filtered, args, err := appendListRegistrationEntriesFilterQuery("\nWITH listing AS (\n", builder, SQLite, req) + var downstream = false + if req.ByDownstream != nil { + downstream = *req.ByDownstream + } + if err != nil { return "", nil, err } @@ -2864,9 +2892,17 @@ SELECT FROM registered_entries `) + if filtered { builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") } + if downstream { + if !filtered { + builder.WriteString("\t\tWHERE downstream = true\n") + } else { + builder.WriteString("\t\tAND downstream = true\n") + } + } builder.WriteString(` UNION @@ -2915,6 +2951,11 @@ func buildListRegistrationEntriesQueryPostgreSQL(req *datastore.ListRegistration builder := new(strings.Builder) filtered, args, err := appendListRegistrationEntriesFilterQuery("\nWITH listing AS (\n", builder, PostgreSQL, req) + var downstream = false + if req.ByDownstream != nil { + downstream = *req.ByDownstream + } + if err != nil { return "", nil, err } @@ -2949,6 +2990,13 @@ FROM if filtered { builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") } + if downstream { + if !filtered { + builder.WriteString("\t\tWHERE downstream = true\n") + } else { + builder.WriteString("\t\tAND downstream = true\n") + } + } builder.WriteString(` UNION ALL @@ -3042,6 +3090,11 @@ LEFT JOIN `) filtered, args, err := appendListRegistrationEntriesFilterQuery("WHERE E.id IN (\n", builder, MySQL, req) + var downstream = false + if req.ByDownstream != nil { + downstream = *req.ByDownstream + } + if err != nil { return "", nil, err } @@ -3049,7 +3102,13 @@ LEFT JOIN if filtered { builder.WriteString(")") } - + if downstream { + if !filtered { + builder.WriteString("\t\tWHERE downstream = true\n") + } else { + builder.WriteString("\t\tAND downstream = true\n") + } + } builder.WriteString("\nORDER BY e_id, selector_id, dns_name_id\n;") return builder.String(), args, nil @@ -3059,6 +3118,11 @@ func buildListRegistrationEntriesQueryMySQLCTE(req *datastore.ListRegistrationEn builder := new(strings.Builder) filtered, args, err := appendListRegistrationEntriesFilterQuery("\nWITH listing AS (\n", builder, MySQL, req) + var downstream = false + if req.ByDownstream != nil { + downstream = *req.ByDownstream + } + if err != nil { return "", nil, err } @@ -3093,6 +3157,13 @@ FROM if filtered { builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") } + if downstream { + if !filtered { + builder.WriteString("\t\tWHERE downstream = true\n") + } else { + builder.WriteString("\t\tAND downstream = true\n") + } + } builder.WriteString(` UNION diff --git a/pkg/server/datastore/sqlstore/sqlstore_test.go b/pkg/server/datastore/sqlstore/sqlstore_test.go index 53708aaaa9..288850801e 100644 --- a/pkg/server/datastore/sqlstore/sqlstore_test.go +++ b/pkg/server/datastore/sqlstore/sqlstore_test.go @@ -478,7 +478,7 @@ func (s *PluginSuite) TestCountBundles() { func (s *PluginSuite) TestCountAttestedNodes() { // Count empty attested nodes - count, err := s.ds.CountAttestedNodes(ctx) + count, err := s.ds.CountAttestedNodes(ctx, &datastore.CountAttestedNodesRequest{}) s.Require().NoError(err) s.Require().Equal(int32(0), count) @@ -502,7 +502,7 @@ func (s *PluginSuite) TestCountAttestedNodes() { s.Require().NoError(err) // Count all - count, err = s.ds.CountAttestedNodes(ctx) + count, err = s.ds.CountAttestedNodes(ctx, &datastore.CountAttestedNodesRequest{}) s.Require().NoError(err) s.Require().Equal(int32(2), count) } diff --git a/test/fakes/fakedatastore/fakedatastore.go b/test/fakes/fakedatastore/fakedatastore.go index ef026bc960..404958983e 100644 --- a/test/fakes/fakedatastore/fakedatastore.go +++ b/test/fakes/fakedatastore/fakedatastore.go @@ -121,11 +121,11 @@ func (s *DataStore) PruneBundle(ctx context.Context, trustDomainID string, expir return s.ds.PruneBundle(ctx, trustDomainID, expiresBefore) } -func (s *DataStore) CountAttestedNodes(ctx context.Context) (int32, error) { +func (s *DataStore) CountAttestedNodes(ctx context.Context, req *datastore.CountAttestedNodesRequest) (int32, error) { if err := s.getNextError(); err != nil { return 0, err } - return s.ds.CountAttestedNodes(ctx) + return s.ds.CountAttestedNodes(ctx, req) } func (s *DataStore) CreateAttestedNode(ctx context.Context, node *common.AttestedNode) (*common.AttestedNode, error) { From 745716d5d711fdccb1267e76321085aaa3dc7ff7 Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Wed, 6 Dec 2023 14:35:02 -0300 Subject: [PATCH 04/17] fix lint Signed-off-by: FedeNQ --- cmd/spire-server/cli/entry/count.go | 2 +- cmd/spire-server/cli/entry/util_posix_test.go | 2 +- pkg/server/api/agent/v1/service.go | 1 - pkg/server/datastore/sqlstore/sqlstore.go | 2 +- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/cmd/spire-server/cli/entry/count.go b/cmd/spire-server/cli/entry/count.go index 06a5d748f5..540453983d 100644 --- a/cmd/spire-server/cli/entry/count.go +++ b/cmd/spire-server/cli/entry/count.go @@ -141,7 +141,7 @@ func (c *countCommand) AppendFlags(fs *flag.FlagSet) { fs.Var(&c.federatesWith, "federatesWith", "SPIFFE ID of a trust domain an entry is federate with. Can be used more than once") fs.StringVar(&c.matchFederatesWithOn, "matchFederatesWithOn", "superset", "The match mode used when filtering by federates with. Options: exact, any, superset and subset") fs.StringVar(&c.matchSelectorsOn, "matchSelectorsOn", "superset", "The match mode used when filtering by selectors. Options: exact, any, superset and subset") - fs.StringVar(&c.hint, "hint", "", "The Hint of the records to show (optional)") + fs.StringVar(&c.hint, "hint", "", "The Hint of the records to count (optional)") cliprinter.AppendFlagWithCustomPretty(&c.printer, fs, c.env, c.prettyPrintCount) } diff --git a/cmd/spire-server/cli/entry/util_posix_test.go b/cmd/spire-server/cli/entry/util_posix_test.go index d55647482c..7b04cb3f96 100644 --- a/cmd/spire-server/cli/entry/util_posix_test.go +++ b/cmd/spire-server/cli/entry/util_posix_test.go @@ -117,7 +117,7 @@ const ( -federatesWith value SPIFFE ID of a trust domain an entry is federate with. Can be used more than once -hint string - The Hint of the records to show (optional) + The Hint of the records to count (optional) -matchFederatesWithOn string The match mode used when filtering by federates with. Options: exact, any, superset and subset (default "superset") -matchSelectorsOn string diff --git a/pkg/server/api/agent/v1/service.go b/pkg/server/api/agent/v1/service.go index 572dce74f5..ce14ca841a 100644 --- a/pkg/server/api/agent/v1/service.go +++ b/pkg/server/api/agent/v1/service.go @@ -72,7 +72,6 @@ func RegisterService(s grpc.ServiceRegistrar, service *Service) { // CountAgents returns the total number of agents. func (s *Service) CountAgents(ctx context.Context, req *agentv1.CountAgentsRequest) (*agentv1.CountAgentsResponse, error) { - log := rpccontext.Logger(ctx) countReq := &datastore.CountAttestedNodesRequest{} diff --git a/pkg/server/datastore/sqlstore/sqlstore.go b/pkg/server/datastore/sqlstore/sqlstore.go index ea9d17f347..02485c1958 100644 --- a/pkg/server/datastore/sqlstore/sqlstore.go +++ b/pkg/server/datastore/sqlstore/sqlstore.go @@ -2658,7 +2658,7 @@ ORDER BY selector_id, dns_name_id return query, []any{entryID}, nil } -func countRegistrationEntries(tx *gorm.DB, req *datastore.CountRegistrationEntriesRequest) (int32, error) { +func countRegistrationEntries(tx *gorm.DB, _ *datastore.CountRegistrationEntriesRequest) (int32, error) { var count int if err := tx.Model(&RegisteredEntry{}).Count(&count).Error; err != nil { return 0, sqlError.Wrap(err) From b58a0d09f27a2cbcffcc2c408250d1d367800f8f Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Thu, 7 Dec 2023 11:43:21 -0300 Subject: [PATCH 05/17] add more unit test Signed-off-by: FedeNQ --- cmd/spire-server/cli/entry/count.go | 2 +- cmd/spire-server/cli/entry/count_test.go | 259 ++++++++++++++++++++++- doc/spire_server.md | 3 +- 3 files changed, 259 insertions(+), 5 deletions(-) diff --git a/cmd/spire-server/cli/entry/count.go b/cmd/spire-server/cli/entry/count.go index 540453983d..ce3f8153f6 100644 --- a/cmd/spire-server/cli/entry/count.go +++ b/cmd/spire-server/cli/entry/count.go @@ -141,7 +141,7 @@ func (c *countCommand) AppendFlags(fs *flag.FlagSet) { fs.Var(&c.federatesWith, "federatesWith", "SPIFFE ID of a trust domain an entry is federate with. Can be used more than once") fs.StringVar(&c.matchFederatesWithOn, "matchFederatesWithOn", "superset", "The match mode used when filtering by federates with. Options: exact, any, superset and subset") fs.StringVar(&c.matchSelectorsOn, "matchSelectorsOn", "superset", "The match mode used when filtering by selectors. Options: exact, any, superset and subset") - fs.StringVar(&c.hint, "hint", "", "The Hint of the records to count (optional)") + fs.StringVar(&c.hint, "hint", "", "The Hint of the records to count (optional)") cliprinter.AppendFlagWithCustomPretty(&c.printer, fs, c.env, c.prettyPrintCount) } diff --git a/cmd/spire-server/cli/entry/count_test.go b/cmd/spire-server/cli/entry/count_test.go index cfff9ca6f7..e162bd4a96 100644 --- a/cmd/spire-server/cli/entry/count_test.go +++ b/cmd/spire-server/cli/entry/count_test.go @@ -5,9 +5,11 @@ import ( "testing" entryv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/entry/v1" + "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/wrapperspb" ) func TestCountHelp(t *testing.T) { @@ -31,12 +33,262 @@ func TestCount(t *testing.T) { for _, tt := range []struct { name string args []string + expCountReq *entryv1.CountEntriesRequest fakeCountResp *entryv1.CountEntriesResponse serverErr error expOutPretty string expOutJSON string expErr string }{ + { + name: "Count all entries (empty filter)", + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp4, + expOutPretty: "4 registration entries", + expOutJSON: `{"count":4}`, + }, + { + name: "Count by parentID", + args: []string{"-parentID", "spiffe://example.org/father"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + ByParentId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/father"}, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp2, + expOutPretty: "2 registration entries", + expOutJSON: `{"count":2}`, + }, + { + name: "Count by parent ID using invalid ID", + args: []string{"-parentID", "invalid-id"}, + expErr: "Error: error parsing parent ID \"invalid-id\": scheme is missing or invalid\n", + }, + { + name: "Count by SPIFFE ID", + args: []string{"-spiffeID", "spiffe://example.org/daughter"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + BySpiffeId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/daughter"}, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp2, + expOutPretty: "2 registration entries", + expOutJSON: `{"count":2}`, + }, + { + name: "Count by SPIFFE ID using invalid ID", + args: []string{"-spiffeID", "invalid-id"}, + expErr: "Error: error parsing SPIFFE ID \"invalid-id\": scheme is missing or invalid\n", + }, + { + name: "Count by selectors: default matcher", + args: []string{"-selector", "foo:bar", "-selector", "bar:baz"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + BySelectors: &types.SelectorMatch{ + Selectors: []*types.Selector{ + {Type: "foo", Value: "bar"}, + {Type: "bar", Value: "baz"}, + }, + Match: types.SelectorMatch_MATCH_SUPERSET, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by selectors: exact matcher", + args: []string{"-selector", "foo:bar", "-selector", "bar:baz", "-matchSelectorsOn", "exact"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + BySelectors: &types.SelectorMatch{ + Selectors: []*types.Selector{ + {Type: "foo", Value: "bar"}, + {Type: "bar", Value: "baz"}, + }, + Match: types.SelectorMatch_MATCH_EXACT, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by selectors: superset matcher", + args: []string{"-selector", "foo:bar", "-selector", "bar:baz", "-matchSelectorsOn", "superset"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + BySelectors: &types.SelectorMatch{ + Selectors: []*types.Selector{ + {Type: "foo", Value: "bar"}, + {Type: "bar", Value: "baz"}, + }, + Match: types.SelectorMatch_MATCH_SUPERSET, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by selectors: subset matcher", + args: []string{"-selector", "foo:bar", "-selector", "bar:baz", "-matchSelectorsOn", "subset"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + BySelectors: &types.SelectorMatch{ + Selectors: []*types.Selector{ + {Type: "foo", Value: "bar"}, + {Type: "bar", Value: "baz"}, + }, + Match: types.SelectorMatch_MATCH_SUBSET, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by selectors: Any matcher", + args: []string{"-selector", "foo:bar", "-selector", "bar:baz", "-matchSelectorsOn", "any"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + BySelectors: &types.SelectorMatch{ + Selectors: []*types.Selector{ + {Type: "foo", Value: "bar"}, + {Type: "bar", Value: "baz"}, + }, + Match: types.SelectorMatch_MATCH_ANY, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by selectors: Invalid matcher", + args: []string{"-selector", "foo:bar", "-selector", "bar:baz", "-matchSelectorsOn", "NO-MATCHER"}, + expErr: "Error: match behavior \"NO-MATCHER\" unknown\n", + }, + { + name: "Count by selector using invalid selector", + args: []string{"-selector", "invalid-selector"}, + expErr: "Error: error parsing selectors: selector \"invalid-selector\" must be formatted as type:value\n", + }, + { + name: "Server error", + args: []string{"-spiffeID", "spiffe://example.org/daughter"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + BySpiffeId: &types.SPIFFEID{TrustDomain: "example.org", Path: "/daughter"}, + ByDownstream: wrapperspb.Bool(false), + }, + }, + serverErr: status.Error(codes.Internal, "internal server error"), + expErr: "Error: rpc error: code = Internal desc = internal server error\n", + }, + { + name: "Count by Federates With: default matcher", + args: []string{"-federatesWith", "spiffe://domain.test"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + ByFederatesWith: &types.FederatesWithMatch{ + TrustDomains: []string{"spiffe://domain.test"}, + Match: types.FederatesWithMatch_MATCH_SUPERSET, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by Federates With: exact matcher", + args: []string{"-federatesWith", "spiffe://domain.test", "-matchFederatesWithOn", "exact"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + ByFederatesWith: &types.FederatesWithMatch{ + TrustDomains: []string{"spiffe://domain.test"}, + Match: types.FederatesWithMatch_MATCH_EXACT, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by Federates With: Any matcher", + args: []string{"-federatesWith", "spiffe://domain.test", "-matchFederatesWithOn", "any"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + ByFederatesWith: &types.FederatesWithMatch{ + TrustDomains: []string{"spiffe://domain.test"}, + Match: types.FederatesWithMatch_MATCH_ANY, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by Federates With: superset matcher", + args: []string{"-federatesWith", "spiffe://domain.test", "-matchFederatesWithOn", "superset"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + ByFederatesWith: &types.FederatesWithMatch{ + TrustDomains: []string{"spiffe://domain.test"}, + Match: types.FederatesWithMatch_MATCH_SUPERSET, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by Federates With: subset matcher", + args: []string{"-federatesWith", "spiffe://domain.test", "-matchFederatesWithOn", "subset"}, + expCountReq: &entryv1.CountEntriesRequest{ + Filter: &entryv1.CountEntriesRequest_Filter{ + ByFederatesWith: &types.FederatesWithMatch{ + TrustDomains: []string{"spiffe://domain.test"}, + Match: types.FederatesWithMatch_MATCH_SUBSET, + }, + ByDownstream: wrapperspb.Bool(false), + }, + }, + fakeCountResp: fakeResp1, + expOutPretty: "1 registration entry", + expOutJSON: `{"count":1}`, + }, + { + name: "Count by Federates With: Invalid matcher", + args: []string{"-federatesWith", "spiffe://domain.test", "-matchFederatesWithOn", "NO-MATCHER"}, + expErr: "Error: match behavior \"NO-MATCHER\" unknown\n", + }, { name: "4 entries", fakeCountResp: fakeResp4, @@ -73,13 +325,16 @@ func TestCount(t *testing.T) { test.server.err = tt.serverErr test.server.countEntriesResp = tt.fakeCountResp - rc := test.client.Run(test.args(tt.args...)) + args := tt.args + args = append(args, "-output", format) + + rc := test.client.Run(test.args(args...)) if tt.expErr != "" { require.Equal(t, 1, rc) require.Equal(t, tt.expErr, test.stderr.String()) return } - requireOutputBasedOnFormat(t, test.stdout.String(), format, tt.expOutPretty, tt.expOutJSON) + requireOutputBasedOnFormat(t, format, test.stdout.String(), tt.expOutPretty, tt.expOutJSON) require.Equal(t, 0, rc) }) } diff --git a/doc/spire_server.md b/doc/spire_server.md index 3190d3eb86..949f91b149 100644 --- a/doc/spire_server.md +++ b/doc/spire_server.md @@ -361,7 +361,6 @@ Displays the total number of registration entries. | `-socketPath` | Path to the SPIRE Server API socket | /tmp/spire-server/private/api.sock | | `-spiffeID` | The SPIFFE ID of the records to count. | | - ### `spire-server entry delete` Deletes a specified registration entry. @@ -518,7 +517,7 @@ Displays the total number of attested nodes. | `-canReattest` | Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all | | | `-banned` | Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all | | | `-expiresBefore` | A date that indicates the time it should expired before. (format: YYYY-MM-DD) | | -| `-spiffeID` | The SPIFFE ID of the records to count. +| `-spiffeID` | The SPIFFE ID of the records to count. | | ### `spire-server agent evict` From 3ed4f7e041d24dda5363ee4954d9222af1d63a79 Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Tue, 27 Feb 2024 16:07:23 -0300 Subject: [PATCH 06/17] Change count & list for entries Signed-off-by: FedeNQ --- .../cli/agent/agent_posix_test.go | 8 +- cmd/spire-server/cli/agent/agent_test.go | 28 +- cmd/spire-server/cli/agent/count.go | 8 +- cmd/spire-server/cli/agent/list.go | 7 +- cmd/spire-server/cli/agent/util.go | 20 - doc/spire_server.md | 6 +- pkg/common/cli/flags.go | 14 +- pkg/server/api/agent/v1/service.go | 18 +- pkg/server/datastore/sqlstore/sqlstore.go | 902 ++++++++++++------ 9 files changed, 639 insertions(+), 372 deletions(-) delete mode 100644 cmd/spire-server/cli/agent/util.go diff --git a/cmd/spire-server/cli/agent/agent_posix_test.go b/cmd/spire-server/cli/agent/agent_posix_test.go index 8697fa6607..43ac27f7ad 100644 --- a/cmd/spire-server/cli/agent/agent_posix_test.go +++ b/cmd/spire-server/cli/agent/agent_posix_test.go @@ -15,13 +15,13 @@ var ( ` listUsage = `Usage of agent list: -attestationType string - The SPIFFE ID of the nodes to list + Filter by attestation type, like join_token or x509pop. -banned value Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all. -canReattest value Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all. -expiresBefore string - A date that indicates the time it should expired before, (format: YYYY-MM-DD) + Filter by expiration time (format: "2006-01-02 15:04:05 -0700 -07") -matchSelectorsOn string The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") -output value @@ -49,13 +49,13 @@ var ( ` countUsage = `Usage of agent count: -attestationType string - The SPIFFE ID of the nodes to count + Filter by attestation type, like join_token or x509pop. -banned value Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all. -canReattest value Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all. -expiresBefore string - A date that indicates the time it should expired before, (format: YYYY-MM-DD) + Filter by expiration time (format: "2006-01-02 15:04:05 -0700 -07") -matchSelectorsOn string The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") -output value diff --git a/cmd/spire-server/cli/agent/agent_test.go b/cmd/spire-server/cli/agent/agent_test.go index 2eab7544d3..6110a88bfd 100644 --- a/cmd/spire-server/cli/agent/agent_test.go +++ b/cmd/spire-server/cli/agent/agent_test.go @@ -225,18 +225,6 @@ func TestCount(t *testing.T) { expectedReturnCode: 1, expectedStderr: common.AddrError, }, - { - name: "Count by expiresBefore: date is too long", - args: []string{"-expiresBefore", "2001-01-011"}, - expectedReturnCode: 1, - expectedStderr: "Error: date is not valid: date is too long\n", - }, - { - name: "Count by expiresBefore: date is too long", - args: []string{"-expiresBefore", "2001-01-0"}, - expectedReturnCode: 1, - expectedStderr: "Error: date is not valid: date is too short\n", - }, { name: "Count by expiresBefore: month out of range", args: []string{"-expiresBefore", "2001-13-05"}, @@ -409,10 +397,10 @@ func TestList(t *testing.T) { }, { name: "by expiresBefore", - args: []string{"-expiresBefore", "2000-01-01"}, + args: []string{"-expiresBefore", "2000-01-01 15:04:05 -0700 -07"}, expectReq: &agentv1.ListAgentsRequest{ Filter: &agentv1.ListAgentsRequest_Filter{ - ByExpiresBefore: "2000-01-01", + ByExpiresBefore: "2000-01-01 15:04:05 -0700 -07", }, PageSize: 1000, }, @@ -464,18 +452,6 @@ func TestList(t *testing.T) { expectedReturnCode: 1, expectedStderr: common.AddrError, }, - { - name: "List by expiresBefore: date is too long", - args: []string{"-expiresBefore", "2001-01-011"}, - expectedReturnCode: 1, - expectedStderr: "Error: date is not valid: date is too long\n", - }, - { - name: "List by expiresBefore: date is too long", - args: []string{"-expiresBefore", "2001-01-0"}, - expectedReturnCode: 1, - expectedStderr: "Error: date is not valid: date is too short\n", - }, { name: "List by expiresBefore: month out of range", args: []string{"-expiresBefore", "2001-13-05"}, diff --git a/cmd/spire-server/cli/agent/count.go b/cmd/spire-server/cli/agent/count.go index a3e3649c45..ef59d0de7b 100644 --- a/cmd/spire-server/cli/agent/count.go +++ b/cmd/spire-server/cli/agent/count.go @@ -5,6 +5,7 @@ import ( "errors" "flag" "fmt" + "time" "github.com/mitchellh/cli" agentv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/agent/v1" @@ -83,7 +84,8 @@ func (c *countCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient u } if c.expiresBefore != "" { - err := validate(c.expiresBefore) + // Parse the time string into a time.Time object + _, err := time.Parse("2006-01-02 15:04:05 -0700 -07", c.expiresBefore) if err != nil { return fmt.Errorf("date is not valid: %w", err) } @@ -124,10 +126,10 @@ func (c *countCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient u func (c *countCommand) AppendFlags(fs *flag.FlagSet) { fs.Var(&c.selectors, "selector", "A colon-delimited type:value selector. Can be used more than once") - fs.StringVar(&c.attestationType, "attestationType", "", "The SPIFFE ID of the nodes to count") + fs.StringVar(&c.attestationType, "attestationType", "", "Filter by attestation type, like join_token or x509pop.") fs.Var(&c.canReattest, "canReattest", "Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all.") fs.Var(&c.banned, "banned", "Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all.") - fs.StringVar(&c.expiresBefore, "expiresBefore", "", "A date that indicates the time it should expired before, (format: YYYY-MM-DD)") + fs.StringVar(&c.expiresBefore, "expiresBefore", "", "Filter by expiration time (format: \"2006-01-02 15:04:05 -0700 -07\")") fs.StringVar(&c.matchSelectorsOn, "matchSelectorsOn", "superset", "The match mode used when filtering by selectors. Options: exact, any, superset and subset") cliprinter.AppendFlagWithCustomPretty(&c.printer, fs, c.env, prettyPrintCount) } diff --git a/cmd/spire-server/cli/agent/list.go b/cmd/spire-server/cli/agent/list.go index d891a69077..8062294c43 100644 --- a/cmd/spire-server/cli/agent/list.go +++ b/cmd/spire-server/cli/agent/list.go @@ -85,7 +85,8 @@ func (c *listCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient ut } if c.expiresBefore != "" { - err := validate(c.expiresBefore) + // Parse the time string into a time.Time object + _, err := time.Parse("2006-01-02 15:04:05 -0700 -07", c.expiresBefore) if err != nil { return fmt.Errorf("date is not valid: %w", err) } @@ -136,10 +137,10 @@ func (c *listCommand) Run(ctx context.Context, _ *commoncli.Env, serverClient ut func (c *listCommand) AppendFlags(fs *flag.FlagSet) { fs.Var(&c.selectors, "selector", "A colon-delimited type:value selector. Can be used more than once") - fs.StringVar(&c.attestationType, "attestationType", "", "The SPIFFE ID of the nodes to list") + fs.StringVar(&c.attestationType, "attestationType", "", "Filter by attestation type, like join_token or x509pop.") fs.Var(&c.canReattest, "canReattest", "Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all.") fs.Var(&c.banned, "banned", "Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all.") - fs.StringVar(&c.expiresBefore, "expiresBefore", "", "A date that indicates the time it should expired before, (format: YYYY-MM-DD)") + fs.StringVar(&c.expiresBefore, "expiresBefore", "", "Filter by expiration time (format: \"2006-01-02 15:04:05 -0700 -07\")") fs.StringVar(&c.matchSelectorsOn, "matchSelectorsOn", "superset", "The match mode used when filtering by selectors. Options: exact, any, superset and subset") cliprinter.AppendFlagWithCustomPretty(&c.printer, fs, c.env, prettyPrintAgents) } diff --git a/cmd/spire-server/cli/agent/util.go b/cmd/spire-server/cli/agent/util.go deleted file mode 100644 index e26ceb8b1c..0000000000 --- a/cmd/spire-server/cli/agent/util.go +++ /dev/null @@ -1,20 +0,0 @@ -package agent - -import ( - "fmt" - "time" -) - -func validate(s string) error { - if len(s) < 10 { - return fmt.Errorf("date is too short") - } - if len(s) > 10 { - return fmt.Errorf("date is too long") - } - _, err := time.Parse("2006-01-02", s) - if err != nil { - return err - } - return nil -} diff --git a/doc/spire_server.md b/doc/spire_server.md index 949f91b149..a30256573b 100644 --- a/doc/spire_server.md +++ b/doc/spire_server.md @@ -516,7 +516,7 @@ Displays the total number of attested nodes. | `-selector` | A colon-delimited type:value selector. Can be used more than once to specify multiple selectors. | | | `-canReattest` | Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all | | | `-banned` | Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all | | -| `-expiresBefore` | A date that indicates the time it should expired before. (format: YYYY-MM-DD) | | +| `-expiresBefore` | Filter by expiration time (format: "2006-01-02 15:04:05 -0700 -07") | | | `-spiffeID` | The SPIFFE ID of the records to count. | | ### `spire-server agent evict` @@ -539,8 +539,8 @@ Displays attested nodes. | `-selector` | A colon-delimited type:value selector. Can be used more than once to specify multiple selectors. | | | `-canReattest` | Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all | | | `-banned` | Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all | | -| `-expiresBefore` | A date that indicates the time it should expired before. (format: YYYY-MM-DD)| | -| `-attestationType` | Filters agents to those matching the attestation type. | | +| `-expiresBefore` | Filter by expiration time (format: "2006-01-02 15:04:05 -0700 -07")| | +| `-attestationType` | Filters agents to those matching the attestation type, like join_token or x509pop. | | ### `spire-server agent show` diff --git a/pkg/common/cli/flags.go b/pkg/common/cli/flags.go index 13a1132791..9770f821a2 100644 --- a/pkg/common/cli/flags.go +++ b/pkg/common/cli/flags.go @@ -46,24 +46,28 @@ func (s *StringsFlag) Set(val string) error { return nil } -// BoolFlag is used to define 3 possible states: true, false, or all -// take care that false=1, and true=2 +// BoolFlag is used to define 3 possible states: true, false, or all. +// Take care that false=1, and true=2 type BoolFlag int +const BoolFlagAll = 0 +const BoolFlagFalse = 1 +const BoolFlagTrue = 2 + func (b *BoolFlag) String() string { return "" } func (b *BoolFlag) Set(val string) error { if val == "false" { - *b = 1 + *b = BoolFlagFalse return nil } if val == "true" { - *b = 2 + *b = BoolFlagTrue return nil } // if the value received isn't true or false, it will set the default value - *b = 0 + *b = BoolFlagAll return nil } diff --git a/pkg/server/api/agent/v1/service.go b/pkg/server/api/agent/v1/service.go index ce14ca841a..4aa62cb6f5 100644 --- a/pkg/server/api/agent/v1/service.go +++ b/pkg/server/api/agent/v1/service.go @@ -88,10 +88,14 @@ func (s *Service) CountAgents(ctx context.Context, req *agentv1.CountAgentsReque countReq.ByCanReattest = &req.Filter.ByCanReattest.Value } - countReq.ByAttestationType = filter.ByAttestationType + if filter.ByAttestationType != "" { + countReq.ByAttestationType = filter.ByAttestationType + } // err is verified previously - countReq.ByExpiresBefore, _ = time.Parse("2006-01-02", filter.ByExpiresBefore) + if filter.ByExpiresBefore != "" { + countReq.ByExpiresBefore, _ = time.Parse("2006-01-02 15:04:05 -0700 -07", filter.ByExpiresBefore) + } if filter.BySelectorMatch != nil { selectors, err := api.SelectorsFromProto(filter.BySelectorMatch.Selectors) @@ -136,9 +140,15 @@ func (s *Service) ListAgents(ctx context.Context, req *agentv1.ListAgentsRequest listReq.ByCanReattest = &req.Filter.ByCanReattest.Value } - listReq.ByAttestationType = filter.ByAttestationType + if filter.ByAttestationType != "" { + listReq.ByAttestationType = filter.ByAttestationType + } + // err is verified previously - listReq.ByExpiresBefore, _ = time.Parse("2006-01-02", filter.ByExpiresBefore) + // countReq.ByExpiresBefore, _ = time.Parse("2006-01-02", filter.ByExpiresBefore) + if filter.ByExpiresBefore != "" { + listReq.ByExpiresBefore, _ = time.Parse("2006-01-02 15:04:05 -0700 -07", filter.ByExpiresBefore) + } if filter.BySelectorMatch != nil { selectors, err := api.SelectorsFromProto(filter.BySelectorMatch.Selectors) diff --git a/pkg/server/datastore/sqlstore/sqlstore.go b/pkg/server/datastore/sqlstore/sqlstore.go index 02485c1958..1a15d467d2 100644 --- a/pkg/server/datastore/sqlstore/sqlstore.go +++ b/pkg/server/datastore/sqlstore/sqlstore.go @@ -458,32 +458,13 @@ func (ds *Plugin) FetchRegistrationEntry(ctx context.Context, // CountRegistrationEntries counts all registrations (pagination available) func (ds *Plugin) CountRegistrationEntries(ctx context.Context, req *datastore.CountRegistrationEntriesRequest) (count int32, err error) { - if countRegistrationEntriesHasFilters(req) { - var actDb = ds.db - if req.DataConsistency == datastore.TolerateStale && ds.roDb != nil { - actDb = ds.roDb - } - - resp, err := listRegistrationEntries(ctx, actDb, ds.log, &datastore.ListRegistrationEntriesRequest{ - DataConsistency: req.DataConsistency, - ByParentID: req.ByParentID, - BySelectors: req.BySelectors, - BySpiffeID: req.BySpiffeID, - ByFederatesWith: req.ByFederatesWith, - ByHint: req.ByHint, - ByDownstream: req.ByDownstream, - }) - return int32(len(resp.Entries)), err - } - - if err = ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { - count, err = countRegistrationEntries(tx, req) - return err - }); err != nil { - return 0, err + var actDb = ds.db + if req.DataConsistency == datastore.TolerateStale && ds.roDb != nil { + actDb = ds.roDb } - return count, nil + resp, err := countRegistrationEntries(ctx, actDb, ds.log, req) + return resp, err } // ListRegistrationEntries lists all registrations (pagination available) @@ -2658,25 +2639,6 @@ ORDER BY selector_id, dns_name_id return query, []any{entryID}, nil } -func countRegistrationEntries(tx *gorm.DB, _ *datastore.CountRegistrationEntriesRequest) (int32, error) { - var count int - if err := tx.Model(&RegisteredEntry{}).Count(&count).Error; err != nil { - return 0, sqlError.Wrap(err) - } - - return int32(count), nil -} - -func countRegistrationEntriesHasFilters(req *datastore.CountRegistrationEntriesRequest) bool { - if req.ByParentID != "" || req.ByHint != "" || req.BySpiffeID != "" { - return true - } - if req.ByFederatesWith != nil || req.BySelectors != nil || req.ByDownstream != nil { - return true - } - return false -} - func listRegistrationEntries(ctx context.Context, db *sqlDB, log logrus.FieldLogger, req *datastore.ListRegistrationEntriesRequest) (*datastore.ListRegistrationEntriesResponse, error) { if req.Pagination != nil && req.Pagination.PageSize == 0 { return nil, status.Error(codes.InvalidArgument, "cannot paginate with pagesize = 0") @@ -2771,7 +2733,6 @@ func listRegistrationEntriesOnce(ctx context.Context, db queryContext, databaseT return nil, sqlError.Wrap(err) } defer rows.Close() - var entries []*common.RegistrationEntry if req.Pagination != nil { entries = make([]*common.RegistrationEntry, 0, req.Pagination.PageSize) @@ -2864,85 +2825,8 @@ func buildListRegistrationEntriesQuerySQLite3(req *datastore.ListRegistrationEnt if err != nil { return "", nil, err } - if filtered { - builder.WriteString(")") - } - - builder.WriteString(` -SELECT - id AS e_id, - entry_id, - spiffe_id, - parent_id, - ttl AS reg_ttl, - admin, - downstream, - expiry, - store_svid, - hint, - created_at, - NULL AS selector_id, - NULL AS selector_type, - NULL AS selector_value, - NULL AS trust_domain, - NULL AS dns_name_id, - NULL AS dns_name, - revision_number, - jwt_svid_ttl AS reg_jwt_svid_ttl -FROM - registered_entries -`) - - if filtered { - builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") - } - if downstream { - if !filtered { - builder.WriteString("\t\tWHERE downstream = true\n") - } else { - builder.WriteString("\t\tAND downstream = true\n") - } - } - builder.WriteString(` -UNION - -SELECT - F.registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, B.trust_domain, NULL, NULL, NULL, NULL -FROM - bundles B -INNER JOIN - federated_registration_entries F -ON - B.id = F.bundle_id -`) - if filtered { - builder.WriteString("WHERE\n\tF.registered_entry_id IN (SELECT e_id FROM listing)\n") - } - builder.WriteString(` -UNION - -SELECT - registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, value, NULL, NULL -FROM - dns_names -`) - if filtered { - builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") - } - builder.WriteString(` -UNION -SELECT - registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, type, value, NULL, NULL, NULL, NULL, NULL -FROM - selectors -`) - if filtered { - builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") - } - builder.WriteString(` -ORDER BY e_id, selector_id, dns_name_id -;`) + buildQuerySQLite3(builder, filtered, downstream) return builder.String(), args, nil } @@ -2959,84 +2843,8 @@ func buildListRegistrationEntriesQueryPostgreSQL(req *datastore.ListRegistration if err != nil { return "", nil, err } - if filtered { - builder.WriteString(")") - } - - builder.WriteString(` -SELECT - id AS e_id, - entry_id, - spiffe_id, - parent_id, - ttl AS reg_ttl, - admin, - downstream, - expiry, - store_svid, - hint, - created_at, - NULL ::integer AS selector_id, - NULL AS selector_type, - NULL AS selector_value, - NULL AS trust_domain, - NULL ::integer AS dns_name_id, - NULL AS dns_name, - revision_number, - jwt_svid_ttl AS reg_jwt_svid_ttl -FROM - registered_entries -`) - if filtered { - builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") - } - if downstream { - if !filtered { - builder.WriteString("\t\tWHERE downstream = true\n") - } else { - builder.WriteString("\t\tAND downstream = true\n") - } - } - builder.WriteString(` -UNION ALL - -SELECT - F.registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, B.trust_domain, NULL, NULL, NULL, NULL -FROM - bundles B -INNER JOIN - federated_registration_entries F -ON - B.id = F.bundle_id -`) - if filtered { - builder.WriteString("WHERE\n\tF.registered_entry_id IN (SELECT e_id FROM listing)\n") - } - builder.WriteString(` -UNION ALL - -SELECT - registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, value, NULL, NULL -FROM - dns_names -`) - if filtered { - builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") - } - builder.WriteString(` -UNION ALL -SELECT - registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, type, value, NULL, NULL, NULL, NULL, NULL -FROM - selectors -`) - if filtered { - builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") - } - builder.WriteString(` -ORDER BY e_id, selector_id, dns_name_id -;`) + buildQueryPostgreSQL(builder, filtered, downstream) return postgreSQLRebind(builder.String()), args, nil } @@ -3056,38 +2864,6 @@ func postgreSQLRebind(s string) string { func buildListRegistrationEntriesQueryMySQL(req *datastore.ListRegistrationEntriesRequest) (string, []any, error) { builder := new(strings.Builder) - builder.WriteString(` -SELECT - E.id AS e_id, - E.entry_id AS entry_id, - E.spiffe_id, - E.parent_id, - E.ttl AS reg_ttl, - E.admin, - E.downstream, - E.expiry, - E.store_svid, - E.hint, - E.created_at, - S.id AS selector_id, - S.type AS selector_type, - S.value AS selector_value, - B.trust_domain, - D.id AS dns_name_id, - D.value AS dns_name, - E.revision_number, - E.jwt_svid_ttl AS reg_jwt_svid_ttl -FROM - registered_entries E -LEFT JOIN - (SELECT 1 AS joinItem UNION SELECT 2 UNION SELECT 3) AS joinItems ON TRUE -LEFT JOIN - selectors S ON joinItem=1 AND E.id=S.registered_entry_id -LEFT JOIN - dns_names D ON joinItem=2 AND E.id=D.registered_entry_id -LEFT JOIN - (federated_registration_entries F INNER JOIN bundles B ON F.bundle_id=B.id) ON joinItem=3 AND E.id=F.registered_entry_id -`) filtered, args, err := appendListRegistrationEntriesFilterQuery("WHERE E.id IN (\n", builder, MySQL, req) var downstream = false @@ -3099,17 +2875,7 @@ LEFT JOIN return "", nil, err } - if filtered { - builder.WriteString(")") - } - if downstream { - if !filtered { - builder.WriteString("\t\tWHERE downstream = true\n") - } else { - builder.WriteString("\t\tAND downstream = true\n") - } - } - builder.WriteString("\nORDER BY e_id, selector_id, dns_name_id\n;") + buildQueryMySQL(builder, filtered, downstream) return builder.String(), args, nil } @@ -3126,86 +2892,614 @@ func buildListRegistrationEntriesQueryMySQLCTE(req *datastore.ListRegistrationEn if err != nil { return "", nil, err } - if filtered { - builder.WriteString(")") - } + buildQueryMySQLCTE(builder, filtered, downstream) + return builder.String(), args, nil +} - builder.WriteString(` -SELECT - id AS e_id, - entry_id, - spiffe_id, - parent_id, - ttl AS reg_ttl, - admin, - downstream, - expiry, - store_svid, - hint, - created_at, - NULL AS selector_id, - NULL AS selector_type, - NULL AS selector_value, - NULL AS trust_domain, - NULL AS dns_name_id, - NULL AS dns_name, - revision_number, - jwt_svid_ttl AS reg_jwt_svid_ttl -FROM - registered_entries -`) - if filtered { - builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") - } - if downstream { - if !filtered { - builder.WriteString("\t\tWHERE downstream = true\n") - } else { - builder.WriteString("\t\tAND downstream = true\n") - } +// Count Registration Entries +func countRegistrationEntries(ctx context.Context, db *sqlDB, log logrus.FieldLogger, req *datastore.CountRegistrationEntriesRequest) (int32, error) { + if req.BySelectors != nil && len(req.BySelectors.Selectors) == 0 { + return 0, status.Error(codes.InvalidArgument, "cannot list by empty selector set") } - builder.WriteString(` -UNION -SELECT - F.registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, B.trust_domain, NULL, NULL, NULL, NULL -FROM - bundles B -INNER JOIN - federated_registration_entries F -ON - B.id = F.bundle_id -`) + // Exact/subset selector matching requires filtering out all registration + // entries returned by the query whose selectors are not fully represented + // in the request selectors. For this reason, it's possible that a paged + // query returns rows that are completely filtered out. If that happens, + // keep querying until a page gets at least one result. + resp, err := countRegistrationEntriesOnce(ctx, db.raw, db.databaseType, db.supportsCTE, req) + + return resp, err +} + +func countRegistrationEntriesOnce(ctx context.Context, db queryContext, databaseType string, supportsCTE bool, req *datastore.CountRegistrationEntriesRequest) (int32, error) { + query, args, err := buildCountRegistrationEntriesQuery(databaseType, supportsCTE, req) + if err != nil { + return 0, sqlError.Wrap(err) + } + + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return 0, sqlError.Wrap(err) + } + defer rows.Close() + + var resp int32 = -1 + if rows.Next() { + rows.Scan(&resp) + } + + if err := rows.Err(); err != nil { + return 0, sqlError.Wrap(err) + } + return resp, nil +} + +func buildCountRegistrationEntriesQuery(dbType string, supportsCTE bool, req *datastore.CountRegistrationEntriesRequest) (string, []any, error) { + switch dbType { + case SQLite: + // The SQLite3 queries unconditionally leverage CTE since the + // embedded version of SQLite3 supports CTE. + return buildCountRegistrationEntriesQuerySQLite3(req) + case PostgreSQL: + // The PostgreSQL queries unconditionally leverage CTE since all versions + // of PostgreSQL supported by the plugin support CTE. + return buildCountRegistrationEntriesQueryPostgreSQL(req) + case MySQL: + if supportsCTE { + return buildCountRegistrationEntriesQueryMySQLCTE(req) + } + return buildCountRegistrationEntriesQueryMySQL(req) + default: + return "", nil, sqlError.New("unsupported db type: %q", dbType) + } +} + +func buildCountRegistrationEntriesQuerySQLite3(req *datastore.CountRegistrationEntriesRequest) (string, []any, error) { + builder := new(strings.Builder) + builder.WriteString(`SELECT COUNT(*) FROM (`) + + filtered, args, err := appendCountRegistrationEntriesFilterQuery("\nWITH listing AS (\n", builder, SQLite, req) + var downstream = false + if req.ByDownstream != nil { + downstream = *req.ByDownstream + } + + if err != nil { + return "", nil, err + } + + buildQuerySQLite3(builder, filtered, downstream) + + builder.WriteString(` + ) AS query_result + WHERE entry_id != '' + `) + + return builder.String(), args, nil +} + +func buildCountRegistrationEntriesQueryPostgreSQL(req *datastore.CountRegistrationEntriesRequest) (string, []any, error) { + builder := new(strings.Builder) + builder.WriteString(`SELECT COUNT(*) FROM (`) + + filtered, args, err := appendCountRegistrationEntriesFilterQuery("\nWITH listing AS (\n", builder, PostgreSQL, req) + var downstream = false + if req.ByDownstream != nil { + downstream = *req.ByDownstream + } + + if err != nil { + return "", nil, err + } + + buildQueryPostgreSQL(builder, filtered, downstream) + + builder.WriteString(` + ) AS query_result + WHERE entry_id != '' + `) + + return postgreSQLRebind(builder.String()), args, nil +} + +func buildCountRegistrationEntriesQueryMySQL(req *datastore.CountRegistrationEntriesRequest) (string, []any, error) { + builder := new(strings.Builder) + builder.WriteString(`SELECT COUNT(*) FROM (`) + + filtered, args, err := appendCountRegistrationEntriesFilterQuery("WHERE E.id IN (\n", builder, MySQL, req) + var downstream = false + if req.ByDownstream != nil { + downstream = *req.ByDownstream + } + + if err != nil { + return "", nil, err + } + + buildQueryMySQL(builder, filtered, downstream) + + builder.WriteString(` + ) AS query_result + WHERE entry_id != '' + `) + + return builder.String(), args, nil +} + +func buildCountRegistrationEntriesQueryMySQLCTE(req *datastore.CountRegistrationEntriesRequest) (string, []any, error) { + builder := new(strings.Builder) + builder.WriteString(`SELECT COUNT(*) FROM (`) + + filtered, args, err := appendCountRegistrationEntriesFilterQuery("\nWITH listing AS (\n", builder, MySQL, req) + var downstream = false + if req.ByDownstream != nil { + downstream = *req.ByDownstream + } + + if err != nil { + return "", nil, err + } + + buildQueryMySQLCTE(builder, filtered, downstream) + + builder.WriteString(` + ) AS query_result + WHERE entry_id != '' + `) + + return builder.String(), args, nil +} + +func appendCountRegistrationEntriesFilterQuery(filterExp string, builder *strings.Builder, dbType string, req *datastore.CountRegistrationEntriesRequest) (bool, []any, error) { + var args []any + + root := idFilterNode{idColumn: "id"} + if req.ByParentID != "" || req.BySpiffeID != "" { + subquery := new(strings.Builder) + subquery.WriteString("SELECT id AS e_id FROM registered_entries WHERE ") + if req.ByParentID != "" { + subquery.WriteString("parent_id = ?") + args = append(args, req.ByParentID) + } + if req.BySpiffeID != "" { + if req.ByParentID != "" { + subquery.WriteString(" AND ") + } + subquery.WriteString("spiffe_id = ?") + args = append(args, req.BySpiffeID) + } + root.children = append(root.children, idFilterNode{ + idColumn: "id", + query: []string{subquery.String()}, + }) + } + + if req.ByHint != "" { + root.children = append(root.children, idFilterNode{ + idColumn: "id", + query: []string{"SELECT id AS e_id FROM registered_entries WHERE hint = ?"}, + }) + args = append(args, req.ByHint) + } + + if req.BySelectors != nil && len(req.BySelectors.Selectors) > 0 { + switch req.BySelectors.Match { + case datastore.Subset, datastore.MatchAny: + // subset needs a union, so we need to group them and add the group + // as a child to the root. + if len(req.BySelectors.Selectors) < 2 { + root.children = append(root.children, idFilterNode{ + idColumn: "registered_entry_id", + query: []string{"SELECT registered_entry_id AS e_id FROM selectors WHERE type = ? AND value = ?"}, + }) + } else { + group := idFilterNode{ + idColumn: "e_id", + union: true, + } + for range req.BySelectors.Selectors { + group.children = append(group.children, idFilterNode{ + idColumn: "registered_entry_id", + query: []string{"SELECT registered_entry_id AS e_id FROM selectors WHERE type = ? AND value = ?"}, + }) + } + root.children = append(root.children, group) + } + case datastore.Exact, datastore.Superset: + // exact match does use an intersection, so we can just add these + // directly to the root idFilterNode, since it is already an intersection + for range req.BySelectors.Selectors { + root.children = append(root.children, idFilterNode{ + idColumn: "registered_entry_id", + query: []string{"SELECT registered_entry_id AS e_id FROM selectors WHERE type = ? AND value = ?"}, + }) + } + default: + return false, nil, errs.New("unhandled selectors match behavior %q", req.BySelectors.Match) + } + for _, selector := range req.BySelectors.Selectors { + args = append(args, selector.Type, selector.Value) + } + } + + if req.ByFederatesWith != nil && len(req.ByFederatesWith.TrustDomains) > 0 { + // Take the trust domains from the request without duplicates + tdSet := make(map[string]struct{}) + for _, td := range req.ByFederatesWith.TrustDomains { + tdSet[td] = struct{}{} + } + trustDomains := make([]string, 0, len(tdSet)) + for td := range tdSet { + trustDomains = append(trustDomains, td) + } + + // Exact/subset federates-with matching requires filtering out all registration + // entries whose federated trust domains are not fully represented in the request + filterNode := idFilterNode{ + idColumn: "E.id", + } + filterNode.query = append(filterNode.query, "SELECT E.id AS e_id") + filterNode.query = append(filterNode.query, "FROM registered_entries E") + filterNode.query = append(filterNode.query, "INNER JOIN federated_registration_entries FE ON FE.registered_entry_id = E.id") + filterNode.query = append(filterNode.query, "INNER JOIN bundles B ON B.id = FE.bundle_id") + filterNode.query = append(filterNode.query, "GROUP BY E.id") + filterNode.query = append(filterNode.query, "HAVING") + + sliceArg := buildSliceArg(len(trustDomains)) + addIsSubset := func() { + filterNode.query = append(filterNode.query, "\tCOUNT(CASE WHEN B.trust_domain NOT IN "+sliceArg+" THEN B.trust_domain ELSE NULL END) = 0 AND") + for _, td := range trustDomains { + args = append(args, td) + } + } + + switch req.ByFederatesWith.Match { + case datastore.Subset: + // Subset federates-with matching requires filtering out all registration + // entries that don't federate with even one trust domain in the request + addIsSubset() + filterNode.query = append(filterNode.query, "\tCOUNT(CASE WHEN B.trust_domain IN "+sliceArg+" THEN B.trust_domain ELSE NULL END) > 0") + for _, td := range trustDomains { + args = append(args, td) + } + case datastore.Exact: + // Exact federates-with matching requires filtering out all registration + // entries that don't federate with all the trust domains in the request + addIsSubset() + filterNode.query = append(filterNode.query, "\tCOUNT(DISTINCT CASE WHEN B.trust_domain IN "+sliceArg+" THEN B.trust_domain ELSE NULL END) = ?") + for _, td := range trustDomains { + args = append(args, td) + } + args = append(args, len(trustDomains)) + case datastore.MatchAny: + // MatchAny federates-with matching requires filtering out all registration + // entries that has at least one trust domain in the request + filterNode.query = append(filterNode.query, "\tCOUNT(CASE WHEN B.trust_domain IN "+sliceArg+" THEN B.trust_domain ELSE NULL END) > 0") + for _, td := range trustDomains { + args = append(args, td) + } + case datastore.Superset: + // SuperSet federates-with matching requires filtering out all registration + // entries has all trustdomains + filterNode.query = append(filterNode.query, "\tCOUNT(DISTINCT CASE WHEN B.trust_domain IN "+sliceArg+" THEN B.trust_domain ELSE NULL END) = ?") + for _, td := range trustDomains { + args = append(args, td) + } + args = append(args, len(trustDomains)) + + default: + return false, nil, errs.New("unhandled federates with match behavior %q", req.ByFederatesWith.Match) + } + root.children = append(root.children, filterNode) + } + + filtered := false + filter := func() { + if !filtered { + builder.WriteString(filterExp) + } + filtered = true + } + indentation := 1 + if len(root.children) > 0 { + filter() + root.Render(builder, dbType, indentation, true) + } + + return filtered, args, nil +} + +func buildQuerySQLite3(builder *strings.Builder, filtered bool, downstream bool) { + if filtered { + builder.WriteString(")") + } + + builder.WriteString(` + SELECT + id AS e_id, + entry_id, + spiffe_id, + parent_id, + ttl AS reg_ttl, + admin, + downstream, + expiry, + store_svid, + hint, + created_at, + NULL AS selector_id, + NULL AS selector_type, + NULL AS selector_value, + NULL AS trust_domain, + NULL AS dns_name_id, + NULL AS dns_name, + revision_number, + jwt_svid_ttl AS reg_jwt_svid_ttl + FROM + registered_entries + `) + + if filtered { + builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") + } + if downstream { + if !filtered { + builder.WriteString("\t\tWHERE downstream = true\n") + } else { + builder.WriteString("\t\tAND downstream = true\n") + } + } + builder.WriteString(` + UNION + + SELECT + F.registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, B.trust_domain, NULL, NULL, NULL, NULL + FROM + bundles B + INNER JOIN + federated_registration_entries F + ON + B.id = F.bundle_id + `) if filtered { builder.WriteString("WHERE\n\tF.registered_entry_id IN (SELECT e_id FROM listing)\n") } builder.WriteString(` -UNION + UNION -SELECT - registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, value, NULL, NULL -FROM - dns_names -`) + SELECT + registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, value, NULL, NULL + FROM + dns_names + `) if filtered { builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") } builder.WriteString(` -UNION + UNION -SELECT - registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, type, value, NULL, NULL, NULL, NULL, NULL -FROM - selectors -`) + SELECT + registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, type, value, NULL, NULL, NULL, NULL, NULL + FROM + selectors + `) if filtered { builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") } builder.WriteString(` -ORDER BY e_id, selector_id, dns_name_id -;`) + ORDER BY e_id, selector_id, dns_name_id + `) +} - return builder.String(), args, nil +func buildQueryPostgreSQL(builder *strings.Builder, filtered bool, downstream bool) { + if filtered { + builder.WriteString(")") + } + + builder.WriteString(` + SELECT + id AS e_id, + entry_id, + spiffe_id, + parent_id, + ttl AS reg_ttl, + admin, + downstream, + expiry, + store_svid, + hint, + created_at, + NULL ::integer AS selector_id, + NULL AS selector_type, + NULL AS selector_value, + NULL AS trust_domain, + NULL ::integer AS dns_name_id, + NULL AS dns_name, + revision_number, + jwt_svid_ttl AS reg_jwt_svid_ttl + FROM + registered_entries + `) + if filtered { + builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") + } + if downstream { + if !filtered { + builder.WriteString("\t\tWHERE downstream = true\n") + } else { + builder.WriteString("\t\tAND downstream = true\n") + } + } + builder.WriteString(` + UNION ALL + + SELECT + F.registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, B.trust_domain, NULL, NULL, NULL, NULL + FROM + bundles B + INNER JOIN + federated_registration_entries F + ON + B.id = F.bundle_id + `) + if filtered { + builder.WriteString("WHERE\n\tF.registered_entry_id IN (SELECT e_id FROM listing)\n") + } + builder.WriteString(` + UNION ALL + + SELECT + registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, value, NULL, NULL + FROM + dns_names + `) + if filtered { + builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") + } + builder.WriteString(` + UNION ALL + + SELECT + registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, type, value, NULL, NULL, NULL, NULL, NULL + FROM + selectors + `) + if filtered { + builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") + } + builder.WriteString(` + ORDER BY e_id, selector_id, dns_name_id + ;`) +} + +func buildQueryMySQLCTE(builder *strings.Builder, filtered bool, downstream bool) { + if filtered { + builder.WriteString(")") + } + + builder.WriteString(` + SELECT + id AS e_id, + entry_id, + spiffe_id, + parent_id, + ttl AS reg_ttl, + admin, + downstream, + expiry, + store_svid, + hint, + created_at, + NULL AS selector_id, + NULL AS selector_type, + NULL AS selector_value, + NULL AS trust_domain, + NULL AS dns_name_id, + NULL AS dns_name, + revision_number, + jwt_svid_ttl AS reg_jwt_svid_ttl + FROM + registered_entries + `) + if filtered { + builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") + } + if downstream { + if !filtered { + builder.WriteString("\t\tWHERE downstream = true\n") + } else { + builder.WriteString("\t\tAND downstream = true\n") + } + } + builder.WriteString(` + UNION + + SELECT + F.registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, B.trust_domain, NULL, NULL, NULL, NULL + FROM + bundles B + INNER JOIN + federated_registration_entries F + ON + B.id = F.bundle_id + `) + if filtered { + builder.WriteString("WHERE\n\tF.registered_entry_id IN (SELECT e_id FROM listing)\n") + } + builder.WriteString(` + UNION + + SELECT + registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, value, NULL, NULL + FROM + dns_names + `) + if filtered { + builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") + } + builder.WriteString(` + UNION + + SELECT + registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, type, value, NULL, NULL, NULL, NULL, NULL + FROM + selectors + `) + if filtered { + builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") + } + builder.WriteString(` + ORDER BY e_id, selector_id, dns_name_id + ;`) +} + +func buildQueryMySQL(builder *strings.Builder, filtered bool, downstream bool) { + + builder.WriteString(` + SELECT + E.id AS e_id, + E.entry_id AS entry_id, + E.spiffe_id, + E.parent_id, + E.ttl AS reg_ttl, + E.admin, + E.downstream, + E.expiry, + E.store_svid, + E.hint, + E.created_at, + S.id AS selector_id, + S.type AS selector_type, + S.value AS selector_value, + B.trust_domain, + D.id AS dns_name_id, + D.value AS dns_name, + E.revision_number, + E.jwt_svid_ttl AS reg_jwt_svid_ttl + FROM + registered_entries E + LEFT JOIN + (SELECT 1 AS joinItem UNION SELECT 2 UNION SELECT 3) AS joinItems ON TRUE + LEFT JOIN + selectors S ON joinItem=1 AND E.id=S.registered_entry_id + LEFT JOIN + dns_names D ON joinItem=2 AND E.id=D.registered_entry_id + LEFT JOIN + (federated_registration_entries F INNER JOIN bundles B ON F.bundle_id=B.id) ON joinItem=3 AND E.id=F.registered_entry_id + `) + + if filtered { + builder.WriteString(")") + } + if downstream { + if !filtered { + builder.WriteString("\t\tWHERE downstream = true\n") + } else { + builder.WriteString("\t\tAND downstream = true\n") + } + } + builder.WriteString("\nORDER BY e_id, selector_id, dns_name_id\n;") } type idFilterNode struct { From 6df58c7fd557a032120f335b7eb0a44b18c2259c Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Thu, 29 Feb 2024 16:51:29 -0300 Subject: [PATCH 07/17] rollback Signed-off-by: FedeNQ --- pkg/server/datastore/sqlstore/sqlstore.go | 344 ++-------------------- 1 file changed, 29 insertions(+), 315 deletions(-) diff --git a/pkg/server/datastore/sqlstore/sqlstore.go b/pkg/server/datastore/sqlstore/sqlstore.go index 1a15d467d2..32c1517ee4 100644 --- a/pkg/server/datastore/sqlstore/sqlstore.go +++ b/pkg/server/datastore/sqlstore/sqlstore.go @@ -2865,6 +2865,7 @@ func postgreSQLRebind(s string) string { func buildListRegistrationEntriesQueryMySQL(req *datastore.ListRegistrationEntriesRequest) (string, []any, error) { builder := new(strings.Builder) + buildQueryMySQL(builder) filtered, args, err := appendListRegistrationEntriesFilterQuery("WHERE E.id IN (\n", builder, MySQL, req) var downstream = false if req.ByDownstream != nil { @@ -2875,7 +2876,19 @@ func buildListRegistrationEntriesQueryMySQL(req *datastore.ListRegistrationEntri return "", nil, err } - buildQueryMySQL(builder, filtered, downstream) + if filtered { + builder.WriteString(")") + } + if downstream { + if !filtered { + builder.WriteString("\t\tWHERE downstream = true\n") + } else { + builder.WriteString("\t\tAND downstream = true\n") + } + } + builder.WriteString(` + ORDER BY e_id, selector_id, dns_name_id + `) return builder.String(), args, nil } @@ -2907,307 +2920,19 @@ func countRegistrationEntries(ctx context.Context, db *sqlDB, log logrus.FieldLo // in the request selectors. For this reason, it's possible that a paged // query returns rows that are completely filtered out. If that happens, // keep querying until a page gets at least one result. - resp, err := countRegistrationEntriesOnce(ctx, db.raw, db.databaseType, db.supportsCTE, req) - - return resp, err -} - -func countRegistrationEntriesOnce(ctx context.Context, db queryContext, databaseType string, supportsCTE bool, req *datastore.CountRegistrationEntriesRequest) (int32, error) { - query, args, err := buildCountRegistrationEntriesQuery(databaseType, supportsCTE, req) - if err != nil { - return 0, sqlError.Wrap(err) - } - - rows, err := db.QueryContext(ctx, query, args...) - if err != nil { - return 0, sqlError.Wrap(err) - } - defer rows.Close() - - var resp int32 = -1 - if rows.Next() { - rows.Scan(&resp) - } - - if err := rows.Err(); err != nil { - return 0, sqlError.Wrap(err) - } - return resp, nil -} - -func buildCountRegistrationEntriesQuery(dbType string, supportsCTE bool, req *datastore.CountRegistrationEntriesRequest) (string, []any, error) { - switch dbType { - case SQLite: - // The SQLite3 queries unconditionally leverage CTE since the - // embedded version of SQLite3 supports CTE. - return buildCountRegistrationEntriesQuerySQLite3(req) - case PostgreSQL: - // The PostgreSQL queries unconditionally leverage CTE since all versions - // of PostgreSQL supported by the plugin support CTE. - return buildCountRegistrationEntriesQueryPostgreSQL(req) - case MySQL: - if supportsCTE { - return buildCountRegistrationEntriesQueryMySQLCTE(req) - } - return buildCountRegistrationEntriesQueryMySQL(req) - default: - return "", nil, sqlError.New("unsupported db type: %q", dbType) - } -} - -func buildCountRegistrationEntriesQuerySQLite3(req *datastore.CountRegistrationEntriesRequest) (string, []any, error) { - builder := new(strings.Builder) - builder.WriteString(`SELECT COUNT(*) FROM (`) - - filtered, args, err := appendCountRegistrationEntriesFilterQuery("\nWITH listing AS (\n", builder, SQLite, req) - var downstream = false - if req.ByDownstream != nil { - downstream = *req.ByDownstream - } - - if err != nil { - return "", nil, err - } - - buildQuerySQLite3(builder, filtered, downstream) - - builder.WriteString(` - ) AS query_result - WHERE entry_id != '' - `) - - return builder.String(), args, nil -} - -func buildCountRegistrationEntriesQueryPostgreSQL(req *datastore.CountRegistrationEntriesRequest) (string, []any, error) { - builder := new(strings.Builder) - builder.WriteString(`SELECT COUNT(*) FROM (`) - - filtered, args, err := appendCountRegistrationEntriesFilterQuery("\nWITH listing AS (\n", builder, PostgreSQL, req) - var downstream = false - if req.ByDownstream != nil { - downstream = *req.ByDownstream - } - - if err != nil { - return "", nil, err - } - - buildQueryPostgreSQL(builder, filtered, downstream) - - builder.WriteString(` - ) AS query_result - WHERE entry_id != '' - `) - - return postgreSQLRebind(builder.String()), args, nil -} - -func buildCountRegistrationEntriesQueryMySQL(req *datastore.CountRegistrationEntriesRequest) (string, []any, error) { - builder := new(strings.Builder) - builder.WriteString(`SELECT COUNT(*) FROM (`) - - filtered, args, err := appendCountRegistrationEntriesFilterQuery("WHERE E.id IN (\n", builder, MySQL, req) - var downstream = false - if req.ByDownstream != nil { - downstream = *req.ByDownstream - } - - if err != nil { - return "", nil, err - } - - buildQueryMySQL(builder, filtered, downstream) - - builder.WriteString(` - ) AS query_result - WHERE entry_id != '' - `) - - return builder.String(), args, nil -} - -func buildCountRegistrationEntriesQueryMySQLCTE(req *datastore.CountRegistrationEntriesRequest) (string, []any, error) { - builder := new(strings.Builder) - builder.WriteString(`SELECT COUNT(*) FROM (`) - - filtered, args, err := appendCountRegistrationEntriesFilterQuery("\nWITH listing AS (\n", builder, MySQL, req) - var downstream = false - if req.ByDownstream != nil { - downstream = *req.ByDownstream - } - - if err != nil { - return "", nil, err - } - - buildQueryMySQLCTE(builder, filtered, downstream) - - builder.WriteString(` - ) AS query_result - WHERE entry_id != '' - `) - - return builder.String(), args, nil -} - -func appendCountRegistrationEntriesFilterQuery(filterExp string, builder *strings.Builder, dbType string, req *datastore.CountRegistrationEntriesRequest) (bool, []any, error) { - var args []any - - root := idFilterNode{idColumn: "id"} - if req.ByParentID != "" || req.BySpiffeID != "" { - subquery := new(strings.Builder) - subquery.WriteString("SELECT id AS e_id FROM registered_entries WHERE ") - if req.ByParentID != "" { - subquery.WriteString("parent_id = ?") - args = append(args, req.ByParentID) - } - if req.BySpiffeID != "" { - if req.ByParentID != "" { - subquery.WriteString(" AND ") - } - subquery.WriteString("spiffe_id = ?") - args = append(args, req.BySpiffeID) - } - root.children = append(root.children, idFilterNode{ - idColumn: "id", - query: []string{subquery.String()}, - }) - } - - if req.ByHint != "" { - root.children = append(root.children, idFilterNode{ - idColumn: "id", - query: []string{"SELECT id AS e_id FROM registered_entries WHERE hint = ?"}, - }) - args = append(args, req.ByHint) - } - - if req.BySelectors != nil && len(req.BySelectors.Selectors) > 0 { - switch req.BySelectors.Match { - case datastore.Subset, datastore.MatchAny: - // subset needs a union, so we need to group them and add the group - // as a child to the root. - if len(req.BySelectors.Selectors) < 2 { - root.children = append(root.children, idFilterNode{ - idColumn: "registered_entry_id", - query: []string{"SELECT registered_entry_id AS e_id FROM selectors WHERE type = ? AND value = ?"}, - }) - } else { - group := idFilterNode{ - idColumn: "e_id", - union: true, - } - for range req.BySelectors.Selectors { - group.children = append(group.children, idFilterNode{ - idColumn: "registered_entry_id", - query: []string{"SELECT registered_entry_id AS e_id FROM selectors WHERE type = ? AND value = ?"}, - }) - } - root.children = append(root.children, group) - } - case datastore.Exact, datastore.Superset: - // exact match does use an intersection, so we can just add these - // directly to the root idFilterNode, since it is already an intersection - for range req.BySelectors.Selectors { - root.children = append(root.children, idFilterNode{ - idColumn: "registered_entry_id", - query: []string{"SELECT registered_entry_id AS e_id FROM selectors WHERE type = ? AND value = ?"}, - }) - } - default: - return false, nil, errs.New("unhandled selectors match behavior %q", req.BySelectors.Match) - } - for _, selector := range req.BySelectors.Selectors { - args = append(args, selector.Type, selector.Value) - } - } - - if req.ByFederatesWith != nil && len(req.ByFederatesWith.TrustDomains) > 0 { - // Take the trust domains from the request without duplicates - tdSet := make(map[string]struct{}) - for _, td := range req.ByFederatesWith.TrustDomains { - tdSet[td] = struct{}{} - } - trustDomains := make([]string, 0, len(tdSet)) - for td := range tdSet { - trustDomains = append(trustDomains, td) - } - - // Exact/subset federates-with matching requires filtering out all registration - // entries whose federated trust domains are not fully represented in the request - filterNode := idFilterNode{ - idColumn: "E.id", - } - filterNode.query = append(filterNode.query, "SELECT E.id AS e_id") - filterNode.query = append(filterNode.query, "FROM registered_entries E") - filterNode.query = append(filterNode.query, "INNER JOIN federated_registration_entries FE ON FE.registered_entry_id = E.id") - filterNode.query = append(filterNode.query, "INNER JOIN bundles B ON B.id = FE.bundle_id") - filterNode.query = append(filterNode.query, "GROUP BY E.id") - filterNode.query = append(filterNode.query, "HAVING") - - sliceArg := buildSliceArg(len(trustDomains)) - addIsSubset := func() { - filterNode.query = append(filterNode.query, "\tCOUNT(CASE WHEN B.trust_domain NOT IN "+sliceArg+" THEN B.trust_domain ELSE NULL END) = 0 AND") - for _, td := range trustDomains { - args = append(args, td) - } - } - - switch req.ByFederatesWith.Match { - case datastore.Subset: - // Subset federates-with matching requires filtering out all registration - // entries that don't federate with even one trust domain in the request - addIsSubset() - filterNode.query = append(filterNode.query, "\tCOUNT(CASE WHEN B.trust_domain IN "+sliceArg+" THEN B.trust_domain ELSE NULL END) > 0") - for _, td := range trustDomains { - args = append(args, td) - } - case datastore.Exact: - // Exact federates-with matching requires filtering out all registration - // entries that don't federate with all the trust domains in the request - addIsSubset() - filterNode.query = append(filterNode.query, "\tCOUNT(DISTINCT CASE WHEN B.trust_domain IN "+sliceArg+" THEN B.trust_domain ELSE NULL END) = ?") - for _, td := range trustDomains { - args = append(args, td) - } - args = append(args, len(trustDomains)) - case datastore.MatchAny: - // MatchAny federates-with matching requires filtering out all registration - // entries that has at least one trust domain in the request - filterNode.query = append(filterNode.query, "\tCOUNT(CASE WHEN B.trust_domain IN "+sliceArg+" THEN B.trust_domain ELSE NULL END) > 0") - for _, td := range trustDomains { - args = append(args, td) - } - case datastore.Superset: - // SuperSet federates-with matching requires filtering out all registration - // entries has all trustdomains - filterNode.query = append(filterNode.query, "\tCOUNT(DISTINCT CASE WHEN B.trust_domain IN "+sliceArg+" THEN B.trust_domain ELSE NULL END) = ?") - for _, td := range trustDomains { - args = append(args, td) - } - args = append(args, len(trustDomains)) - - default: - return false, nil, errs.New("unhandled federates with match behavior %q", req.ByFederatesWith.Match) - } - root.children = append(root.children, filterNode) - } - - filtered := false - filter := func() { - if !filtered { - builder.WriteString(filterExp) - } - filtered = true - } - indentation := 1 - if len(root.children) > 0 { - filter() - root.Render(builder, dbType, indentation, true) - } + resp, err := listRegistrationEntriesOnce(ctx, db.raw, db.databaseType, db.supportsCTE, + &datastore.ListRegistrationEntriesRequest{ + DataConsistency: req.DataConsistency, + ByParentID: req.ByParentID, + BySelectors: req.BySelectors, + BySpiffeID: req.BySpiffeID, + ByFederatesWith: req.ByFederatesWith, + ByHint: req.ByHint, + ByDownstream: req.ByDownstream, + }, + ) - return filtered, args, nil + return int32(len(resp.Entries)), err } func buildQuerySQLite3(builder *strings.Builder, filtered bool, downstream bool) { @@ -3370,7 +3095,7 @@ func buildQueryPostgreSQL(builder *strings.Builder, filtered bool, downstream bo } builder.WriteString(` ORDER BY e_id, selector_id, dns_name_id - ;`) + `) } func buildQueryMySQLCTE(builder *strings.Builder, filtered bool, downstream bool) { @@ -3451,10 +3176,10 @@ func buildQueryMySQLCTE(builder *strings.Builder, filtered bool, downstream bool } builder.WriteString(` ORDER BY e_id, selector_id, dns_name_id - ;`) + `) } -func buildQueryMySQL(builder *strings.Builder, filtered bool, downstream bool) { +func buildQueryMySQL(builder *strings.Builder) { builder.WriteString(` SELECT @@ -3489,17 +3214,6 @@ func buildQueryMySQL(builder *strings.Builder, filtered bool, downstream bool) { (federated_registration_entries F INNER JOIN bundles B ON F.bundle_id=B.id) ON joinItem=3 AND E.id=F.registered_entry_id `) - if filtered { - builder.WriteString(")") - } - if downstream { - if !filtered { - builder.WriteString("\t\tWHERE downstream = true\n") - } else { - builder.WriteString("\t\tAND downstream = true\n") - } - } - builder.WriteString("\nORDER BY e_id, selector_id, dns_name_id\n;") } type idFilterNode struct { From a102dc25774367d9d91f780eb2e04638a70c4991 Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Fri, 1 Mar 2024 10:31:04 -0300 Subject: [PATCH 08/17] fix Signed-off-by: FedeNQ --- pkg/server/datastore/sqlstore/sqlstore.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pkg/server/datastore/sqlstore/sqlstore.go b/pkg/server/datastore/sqlstore/sqlstore.go index 1cc07eaee8..12b258ad00 100644 --- a/pkg/server/datastore/sqlstore/sqlstore.go +++ b/pkg/server/datastore/sqlstore/sqlstore.go @@ -2886,6 +2886,7 @@ func buildListRegistrationEntriesQueryPostgreSQL(req *datastore.ListRegistration func maybeRebind(dbType, query string) string { if isPostgresDbType(dbType) { return postgreSQLRebind(query) + } return query } @@ -3053,6 +3054,7 @@ func buildQuerySQLite3(builder *strings.Builder, filtered bool, downstream bool) func buildQueryPostgreSQL(builder *strings.Builder, filtered bool, downstream bool) { if filtered { builder.WriteString(")") + } builder.WriteString(` SELECT From c8ad89f285d00ea0d3aacad3d92df7e239edfe24 Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Fri, 1 Mar 2024 11:00:05 -0300 Subject: [PATCH 09/17] fix lint Signed-off-by: FedeNQ --- pkg/server/datastore/sqlstore/sqlstore.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pkg/server/datastore/sqlstore/sqlstore.go b/pkg/server/datastore/sqlstore/sqlstore.go index 12b258ad00..c4d69d30ea 100644 --- a/pkg/server/datastore/sqlstore/sqlstore.go +++ b/pkg/server/datastore/sqlstore/sqlstore.go @@ -2944,7 +2944,7 @@ func buildListRegistrationEntriesQueryMySQLCTE(req *datastore.ListRegistrationEn } // Count Registration Entries -func countRegistrationEntries(ctx context.Context, db *sqlDB, log logrus.FieldLogger, req *datastore.CountRegistrationEntriesRequest) (int32, error) { +func countRegistrationEntries(ctx context.Context, db *sqlDB, _ logrus.FieldLogger, req *datastore.CountRegistrationEntriesRequest) (int32, error) { if req.BySelectors != nil && len(req.BySelectors.Selectors) == 0 { return 0, status.Error(codes.InvalidArgument, "cannot list by empty selector set") } @@ -3214,7 +3214,6 @@ func buildQueryMySQLCTE(builder *strings.Builder, filtered bool, downstream bool } func buildQueryMySQL(builder *strings.Builder) { - builder.WriteString(` SELECT E.id AS e_id, @@ -3247,7 +3246,6 @@ func buildQueryMySQL(builder *strings.Builder) { LEFT JOIN (federated_registration_entries F INNER JOIN bundles B ON F.bundle_id=B.id) ON joinItem=3 AND E.id=F.registered_entry_id `) - } type idFilterNode struct { From e8547ffc690a40c3d2ea10120138289f429a7f0d Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Mon, 4 Mar 2024 09:47:32 -0300 Subject: [PATCH 10/17] update go.mod & go.sum Signed-off-by: FedeNQ --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 621ab08957..556195ebdd 100644 --- a/go.mod +++ b/go.mod @@ -69,7 +69,7 @@ require ( github.com/sigstore/sigstore v1.8.2 github.com/sirupsen/logrus v1.9.3 github.com/spiffe/go-spiffe/v2 v2.1.7 - github.com/spiffe/spire-api-sdk v1.2.5-0.20231107161112-ba57e0e943a2 + github.com/spiffe/spire-api-sdk v1.2.5-0.20240301205221-967353a5c821 github.com/spiffe/spire-plugin-sdk v1.4.4-0.20230721151831-bf67dde4721d github.com/stretchr/testify v1.8.4 github.com/uber-go/tally/v4 v4.1.11 diff --git a/go.sum b/go.sum index aeed5f9f89..ff54524386 100644 --- a/go.sum +++ b/go.sum @@ -1395,6 +1395,8 @@ github.com/spiffe/go-spiffe/v2 v2.1.7 h1:VUkM1yIyg/x8X7u1uXqSRVRCdMdfRIEdFBzpqoe github.com/spiffe/go-spiffe/v2 v2.1.7/go.mod h1:QJDGdhXllxjxvd5B+2XnhhXB/+rC8gr+lNrtOryiWeE= github.com/spiffe/spire-api-sdk v1.2.5-0.20231107161112-ba57e0e943a2 h1:EKSBig+9oEvyLUi80aE/88UHjoNCqlNGTFTjm02F+fk= github.com/spiffe/spire-api-sdk v1.2.5-0.20231107161112-ba57e0e943a2/go.mod h1:4uuhFlN6KBWjACRP3xXwrOTNnvaLp1zJs8Lribtr4fI= +github.com/spiffe/spire-api-sdk v1.2.5-0.20240301205221-967353a5c821 h1:ws5/mYxmiZtw/67nymx5hnSJo8Kx2Q1UkQqiSt8TU74= +github.com/spiffe/spire-api-sdk v1.2.5-0.20240301205221-967353a5c821/go.mod h1:4uuhFlN6KBWjACRP3xXwrOTNnvaLp1zJs8Lribtr4fI= github.com/spiffe/spire-plugin-sdk v1.4.4-0.20230721151831-bf67dde4721d h1:LCRQGU6vOqKLfRrG+GJQrwMwDILcAddAEIf4/1PaSVc= github.com/spiffe/spire-plugin-sdk v1.4.4-0.20230721151831-bf67dde4721d/go.mod h1:GA6o2PVLwyJdevT6KKt5ZXCY/ziAPna13y/seGk49Ik= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= From f9ebb672551b1e87aad196f7d5d74fa2d5d78e25 Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Mon, 4 Mar 2024 10:47:40 -0300 Subject: [PATCH 11/17] fix windows message Signed-off-by: FedeNQ --- .../cli/agent/agent_windows_test.go | 28 ++++++++++++++++--- go.sum | 2 -- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/cmd/spire-server/cli/agent/agent_windows_test.go b/cmd/spire-server/cli/agent/agent_windows_test.go index 965ab5c3a7..45ba7c7921 100644 --- a/cmd/spire-server/cli/agent/agent_windows_test.go +++ b/cmd/spire-server/cli/agent/agent_windows_test.go @@ -14,14 +14,22 @@ var ( Desired output format (pretty, json); default: pretty. ` listUsage = `Usage of agent list: + -attestationType string + Filter by attestation type, like join_token or x509pop. + -banned value + Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all. + -canReattest value + Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all. + -expiresBefore string + Filter by expiration time (format: "2006-01-02 15:04:05 -0700 -07") -matchSelectorsOn string The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") - -namedPipeName string - Pipe name of the SPIRE Server API named pipe (default "\\spire-server\\private\\api") -output value Desired output format (pretty, json); default: pretty. -selector value A colon-delimited type:value selector. Can be used more than once + -socketPath string + Path to the SPIRE Server API socket (default "/tmp/spire-server/private/api.sock") ` banUsage = `Usage of agent ban: -namedPipeName string @@ -40,10 +48,22 @@ var ( The SPIFFE ID of the agent to evict (agent identity) ` countUsage = `Usage of agent count: - -namedPipeName string - Pipe name of the SPIRE Server API named pipe (default "\\spire-server\\private\\api") + -attestationType string + Filter by attestation type, like join_token or x509pop. + -banned value + Filter based on string received, 'true': banned agents, 'false': not banned agents, other value will return all. + -canReattest value + Filter based on string received, 'true': agents that can reattest, 'false': agents that can't reattest, other value will return all. + -expiresBefore string + Filter by expiration time (format: "2006-01-02 15:04:05 -0700 -07") + -matchSelectorsOn string + The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") -output value Desired output format (pretty, json); default: pretty. + -selector value + A colon-delimited type:value selector. Can be used more than once + -socketPath string + Path to the SPIRE Server API socket (default "/tmp/spire-server/private/api.sock") ` showUsage = `Usage of agent show: -namedPipeName string diff --git a/go.sum b/go.sum index ff54524386..3612d5aa40 100644 --- a/go.sum +++ b/go.sum @@ -1393,8 +1393,6 @@ github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMV github.com/spiffe/go-spiffe/v2 v2.1.6/go.mod h1:eVDqm9xFvyqao6C+eQensb9ZPkyNEeaUbqbBpOhBnNk= github.com/spiffe/go-spiffe/v2 v2.1.7 h1:VUkM1yIyg/x8X7u1uXqSRVRCdMdfRIEdFBzpqoeASGk= github.com/spiffe/go-spiffe/v2 v2.1.7/go.mod h1:QJDGdhXllxjxvd5B+2XnhhXB/+rC8gr+lNrtOryiWeE= -github.com/spiffe/spire-api-sdk v1.2.5-0.20231107161112-ba57e0e943a2 h1:EKSBig+9oEvyLUi80aE/88UHjoNCqlNGTFTjm02F+fk= -github.com/spiffe/spire-api-sdk v1.2.5-0.20231107161112-ba57e0e943a2/go.mod h1:4uuhFlN6KBWjACRP3xXwrOTNnvaLp1zJs8Lribtr4fI= github.com/spiffe/spire-api-sdk v1.2.5-0.20240301205221-967353a5c821 h1:ws5/mYxmiZtw/67nymx5hnSJo8Kx2Q1UkQqiSt8TU74= github.com/spiffe/spire-api-sdk v1.2.5-0.20240301205221-967353a5c821/go.mod h1:4uuhFlN6KBWjACRP3xXwrOTNnvaLp1zJs8Lribtr4fI= github.com/spiffe/spire-plugin-sdk v1.4.4-0.20230721151831-bf67dde4721d h1:LCRQGU6vOqKLfRrG+GJQrwMwDILcAddAEIf4/1PaSVc= From 7751c40d87c6b732346c6717f4f095d7a14a2d3e Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Mon, 4 Mar 2024 11:59:34 -0300 Subject: [PATCH 12/17] update agent & entry message Signed-off-by: FedeNQ --- cmd/spire-server/cli/agent/agent_windows_test.go | 8 ++++---- cmd/spire-server/cli/entry/util_windows_test.go | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/cmd/spire-server/cli/agent/agent_windows_test.go b/cmd/spire-server/cli/agent/agent_windows_test.go index 45ba7c7921..924ecfb5e9 100644 --- a/cmd/spire-server/cli/agent/agent_windows_test.go +++ b/cmd/spire-server/cli/agent/agent_windows_test.go @@ -24,12 +24,12 @@ var ( Filter by expiration time (format: "2006-01-02 15:04:05 -0700 -07") -matchSelectorsOn string The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") + -namedPipeName string + Pipe name of the SPIRE Server API named pipe (default "\\spire-server\\private\\api") -output value Desired output format (pretty, json); default: pretty. -selector value A colon-delimited type:value selector. Can be used more than once - -socketPath string - Path to the SPIRE Server API socket (default "/tmp/spire-server/private/api.sock") ` banUsage = `Usage of agent ban: -namedPipeName string @@ -58,12 +58,12 @@ var ( Filter by expiration time (format: "2006-01-02 15:04:05 -0700 -07") -matchSelectorsOn string The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") + -namedPipeName string + Pipe name of the SPIRE Server API named pipe (default "\\spire-server\\private\\api") -output value Desired output format (pretty, json); default: pretty. -selector value A colon-delimited type:value selector. Can be used more than once - -socketPath string - Path to the SPIRE Server API socket (default "/tmp/spire-server/private/api.sock") ` showUsage = `Usage of agent show: -namedPipeName string diff --git a/cmd/spire-server/cli/entry/util_windows_test.go b/cmd/spire-server/cli/entry/util_windows_test.go index 75e1d1929b..18f5c88af4 100644 --- a/cmd/spire-server/cli/entry/util_windows_test.go +++ b/cmd/spire-server/cli/entry/util_windows_test.go @@ -112,9 +112,25 @@ const ( Desired output format (pretty, json); default: pretty. ` countUsage = `Usage of entry count: + -downstream + A boolean value that, when set, indicates that the entry describes a downstream SPIRE server + -federatesWith value + SPIFFE ID of a trust domain an entry is federate with. Can be used more than once + -hint string + The Hint of the records to count (optional) + -matchFederatesWithOn string + The match mode used when filtering by federates with. Options: exact, any, superset and subset (default "superset") + -matchSelectorsOn string + The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") -namedPipeName string Pipe name of the SPIRE Server API named pipe (default "\\spire-server\\private\\api") -output value Desired output format (pretty, json); default: pretty. + -parentID string + The Parent ID of the records to count + -selector value + A colon-delimited type:value selector. Can be used more than once + -spiffeID string + The SPIFFE ID of the records to count ` ) From 268ccd0415cbfae902218273253dd9e33a0d9032 Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Mon, 4 Mar 2024 12:38:44 -0300 Subject: [PATCH 13/17] update agent message Signed-off-by: FedeNQ --- cmd/spire-server/cli/agent/agent_windows_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/spire-server/cli/agent/agent_windows_test.go b/cmd/spire-server/cli/agent/agent_windows_test.go index 924ecfb5e9..7b98b75005 100644 --- a/cmd/spire-server/cli/agent/agent_windows_test.go +++ b/cmd/spire-server/cli/agent/agent_windows_test.go @@ -25,7 +25,7 @@ var ( -matchSelectorsOn string The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") -namedPipeName string - Pipe name of the SPIRE Server API named pipe (default "\\spire-server\\private\\api") + Pipe name of the SPIRE Server API named pipe (default "\\spire-server\\private\\api") -output value Desired output format (pretty, json); default: pretty. -selector value @@ -59,7 +59,7 @@ var ( -matchSelectorsOn string The match mode used when filtering by selectors. Options: exact, any, superset and subset (default "superset") -namedPipeName string - Pipe name of the SPIRE Server API named pipe (default "\\spire-server\\private\\api") + Pipe name of the SPIRE Server API named pipe (default "\\spire-server\\private\\api") -output value Desired output format (pretty, json); default: pretty. -selector value From efd0548992590437dfe865d39c167401e423eddc Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Wed, 27 Mar 2024 12:07:18 -0300 Subject: [PATCH 14/17] count entries & agent now uses pagination Signed-off-by: FedeNQ --- pkg/server/api/agent/v1/service.go | 1 - pkg/server/datastore/sqlstore/sqlstore.go | 107 ++++++++++++++++------ 2 files changed, 81 insertions(+), 27 deletions(-) diff --git a/pkg/server/api/agent/v1/service.go b/pkg/server/api/agent/v1/service.go index 408b63678e..1efcebf14f 100644 --- a/pkg/server/api/agent/v1/service.go +++ b/pkg/server/api/agent/v1/service.go @@ -143,7 +143,6 @@ func (s *Service) ListAgents(ctx context.Context, req *agentv1.ListAgentsRequest } // err is verified previously - // countReq.ByExpiresBefore, _ = time.Parse("2006-01-02", filter.ByExpiresBefore) if filter.ByExpiresBefore != "" { listReq.ByExpiresBefore, _ = time.Parse("2006-01-02 15:04:05 -0700 -07", filter.ByExpiresBefore) } diff --git a/pkg/server/datastore/sqlstore/sqlstore.go b/pkg/server/datastore/sqlstore/sqlstore.go index c4d69d30ea..d26d0b905d 100644 --- a/pkg/server/datastore/sqlstore/sqlstore.go +++ b/pkg/server/datastore/sqlstore/sqlstore.go @@ -312,15 +312,8 @@ func (ds *Plugin) FetchAttestedNode(ctx context.Context, spiffeID string) (attes // CountAttestedNodes counts all attested nodes func (ds *Plugin) CountAttestedNodes(ctx context.Context, req *datastore.CountAttestedNodesRequest) (count int32, err error) { if countAttestedNodesHasFilters(req) { - resp, err := listAttestedNodes(ctx, ds.db, ds.log, &datastore.ListAttestedNodesRequest{ - ByAttestationType: req.ByAttestationType, - ByBanned: req.ByBanned, - ByExpiresBefore: req.ByExpiresBefore, - BySelectorMatch: req.BySelectorMatch, - FetchSelectors: req.FetchSelectors, - ByCanReattest: req.ByCanReattest, - }) - return int32(len(resp.Nodes)), err + resp, err := countAttestedNodesWithFilters(ctx, ds.db, ds.log, req) + return resp, err } if err = ds.withReadTx(ctx, func(tx *gorm.DB) (err error) { count, err = countAttestedNodes(tx) @@ -1620,6 +1613,48 @@ func listAttestedNodes(ctx context.Context, db *sqlDB, log logrus.FieldLogger, r } } +func countAttestedNodesWithFilters(ctx context.Context, db *sqlDB, log logrus.FieldLogger, req *datastore.CountAttestedNodesRequest) (int32, error) { + if req.BySelectorMatch != nil && len(req.BySelectorMatch.Selectors) == 0 { + return -1, status.Error(codes.InvalidArgument, "cannot list by empty selectors set") + } + + var val int32 = 0 + listReq := &datastore.ListAttestedNodesRequest{ + ByAttestationType: req.ByAttestationType, + ByBanned: req.ByBanned, + ByExpiresBefore: req.ByExpiresBefore, + BySelectorMatch: req.BySelectorMatch, + FetchSelectors: req.FetchSelectors, + ByCanReattest: req.ByCanReattest, + Pagination: &datastore.Pagination{ + Token: "", + PageSize: 1000, + }, + } + for { + resp, err := listAttestedNodesOnce(ctx, db, listReq) + if err != nil { + return -1, err + } + + if len(resp.Nodes) == 0 { + return val, nil + } + + if req.BySelectorMatch != nil { + switch req.BySelectorMatch.Match { + case datastore.Exact, datastore.Subset: + resp.Nodes = filterNodesBySelectorSet(resp.Nodes, req.BySelectorMatch.Selectors) + default: + } + } + + val += int32(len(resp.Nodes)) + + listReq.Pagination = resp.Pagination + } +} + func createAttestedNodeEvent(tx *gorm.DB, spiffeID string) error { newAttestedNodeEvent := AttestedNodeEvent{ SpiffeID: spiffeID, @@ -2944,29 +2979,49 @@ func buildListRegistrationEntriesQueryMySQLCTE(req *datastore.ListRegistrationEn } // Count Registration Entries -func countRegistrationEntries(ctx context.Context, db *sqlDB, _ logrus.FieldLogger, req *datastore.CountRegistrationEntriesRequest) (int32, error) { +func countRegistrationEntries(ctx context.Context, db *sqlDB, log logrus.FieldLogger, req *datastore.CountRegistrationEntriesRequest) (int32, error) { if req.BySelectors != nil && len(req.BySelectors.Selectors) == 0 { return 0, status.Error(codes.InvalidArgument, "cannot list by empty selector set") } - // Exact/subset selector matching requires filtering out all registration - // entries returned by the query whose selectors are not fully represented - // in the request selectors. For this reason, it's possible that a paged - // query returns rows that are completely filtered out. If that happens, - // keep querying until a page gets at least one result. - resp, err := listRegistrationEntriesOnce(ctx, db.raw, db.databaseType, db.supportsCTE, - &datastore.ListRegistrationEntriesRequest{ - DataConsistency: req.DataConsistency, - ByParentID: req.ByParentID, - BySelectors: req.BySelectors, - BySpiffeID: req.BySpiffeID, - ByFederatesWith: req.ByFederatesWith, - ByHint: req.ByHint, - ByDownstream: req.ByDownstream, + var val int32 = 0 + listReq := &datastore.ListRegistrationEntriesRequest{ + DataConsistency: req.DataConsistency, + ByParentID: req.ByParentID, + BySelectors: req.BySelectors, + BySpiffeID: req.BySpiffeID, + ByFederatesWith: req.ByFederatesWith, + ByHint: req.ByHint, + ByDownstream: req.ByDownstream, + Pagination: &datastore.Pagination{ + Token: "", + PageSize: 1000, }, - ) + } + + for { + resp, err := listRegistrationEntriesOnce(ctx, db.raw, db.databaseType, db.supportsCTE, listReq) - return int32(len(resp.Entries)), err + if err != nil { + return -1, err + } + + if len(resp.Entries) == 0 { + return val, nil + } + + if req.BySelectors != nil { + switch req.BySelectors.Match { + case datastore.Exact, datastore.Subset: + resp.Entries = filterEntriesBySelectorSet(resp.Entries, req.BySelectors.Selectors) + default: + } + } + + val += int32(len(resp.Entries)) + + listReq.Pagination = resp.Pagination + } } func buildQuerySQLite3(builder *strings.Builder, filtered bool, downstream bool) { From 1475ee41f902ca0e55377bf923ed302013f7a830 Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Wed, 27 Mar 2024 12:08:29 -0300 Subject: [PATCH 15/17] remove comment Signed-off-by: FedeNQ --- pkg/server/api/agent/v1/service.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/pkg/server/api/agent/v1/service.go b/pkg/server/api/agent/v1/service.go index 1efcebf14f..b9889c056b 100644 --- a/pkg/server/api/agent/v1/service.go +++ b/pkg/server/api/agent/v1/service.go @@ -90,7 +90,6 @@ func (s *Service) CountAgents(ctx context.Context, req *agentv1.CountAgentsReque countReq.ByAttestationType = filter.ByAttestationType } - // err is verified previously if filter.ByExpiresBefore != "" { countReq.ByExpiresBefore, _ = time.Parse("2006-01-02 15:04:05 -0700 -07", filter.ByExpiresBefore) } @@ -142,7 +141,6 @@ func (s *Service) ListAgents(ctx context.Context, req *agentv1.ListAgentsRequest listReq.ByAttestationType = filter.ByAttestationType } - // err is verified previously if filter.ByExpiresBefore != "" { listReq.ByExpiresBefore, _ = time.Parse("2006-01-02 15:04:05 -0700 -07", filter.ByExpiresBefore) } From a3bb1fe817c0ed3ba43d8d36cf16c9fd26dbf572 Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Wed, 27 Mar 2024 12:48:59 -0300 Subject: [PATCH 16/17] fix lint Signed-off-by: FedeNQ --- pkg/server/datastore/sqlstore/sqlstore.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/server/datastore/sqlstore/sqlstore.go b/pkg/server/datastore/sqlstore/sqlstore.go index d26d0b905d..85c39752d0 100644 --- a/pkg/server/datastore/sqlstore/sqlstore.go +++ b/pkg/server/datastore/sqlstore/sqlstore.go @@ -1613,12 +1613,12 @@ func listAttestedNodes(ctx context.Context, db *sqlDB, log logrus.FieldLogger, r } } -func countAttestedNodesWithFilters(ctx context.Context, db *sqlDB, log logrus.FieldLogger, req *datastore.CountAttestedNodesRequest) (int32, error) { +func countAttestedNodesWithFilters(ctx context.Context, db *sqlDB, _ logrus.FieldLogger, req *datastore.CountAttestedNodesRequest) (int32, error) { if req.BySelectorMatch != nil && len(req.BySelectorMatch.Selectors) == 0 { return -1, status.Error(codes.InvalidArgument, "cannot list by empty selectors set") } - var val int32 = 0 + var val int32 listReq := &datastore.ListAttestedNodesRequest{ ByAttestationType: req.ByAttestationType, ByBanned: req.ByBanned, @@ -2979,12 +2979,12 @@ func buildListRegistrationEntriesQueryMySQLCTE(req *datastore.ListRegistrationEn } // Count Registration Entries -func countRegistrationEntries(ctx context.Context, db *sqlDB, log logrus.FieldLogger, req *datastore.CountRegistrationEntriesRequest) (int32, error) { +func countRegistrationEntries(ctx context.Context, db *sqlDB, _ logrus.FieldLogger, req *datastore.CountRegistrationEntriesRequest) (int32, error) { if req.BySelectors != nil && len(req.BySelectors.Selectors) == 0 { return 0, status.Error(codes.InvalidArgument, "cannot list by empty selector set") } - var val int32 = 0 + var val int32 listReq := &datastore.ListRegistrationEntriesRequest{ DataConsistency: req.DataConsistency, ByParentID: req.ByParentID, From 296a77fe77c582d75d0e0add66c29082b50e723a Mon Sep 17 00:00:00 2001 From: FedeNQ Date: Wed, 27 Mar 2024 14:46:05 -0300 Subject: [PATCH 17/17] rollback Signed-off-by: FedeNQ --- pkg/server/datastore/sqlstore/sqlstore.go | 553 +++++++++++----------- 1 file changed, 267 insertions(+), 286 deletions(-) diff --git a/pkg/server/datastore/sqlstore/sqlstore.go b/pkg/server/datastore/sqlstore/sqlstore.go index 85c39752d0..096a59717d 100644 --- a/pkg/server/datastore/sqlstore/sqlstore.go +++ b/pkg/server/datastore/sqlstore/sqlstore.go @@ -2894,8 +2894,85 @@ func buildListRegistrationEntriesQuerySQLite3(req *datastore.ListRegistrationEnt if err != nil { return "", nil, err } + if filtered { + builder.WriteString(")") + } + + builder.WriteString(` +SELECT + id AS e_id, + entry_id, + spiffe_id, + parent_id, + ttl AS reg_ttl, + admin, + downstream, + expiry, + store_svid, + hint, + created_at, + NULL AS selector_id, + NULL AS selector_type, + NULL AS selector_value, + NULL AS trust_domain, + NULL AS dns_name_id, + NULL AS dns_name, + revision_number, + jwt_svid_ttl AS reg_jwt_svid_ttl +FROM + registered_entries +`) + + if filtered { + builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") + } + if downstream { + if !filtered { + builder.WriteString("\t\tWHERE downstream = true\n") + } else { + builder.WriteString("\t\tAND downstream = true\n") + } + } + builder.WriteString(` +UNION + +SELECT + F.registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, B.trust_domain, NULL, NULL, NULL, NULL +FROM + bundles B +INNER JOIN + federated_registration_entries F +ON + B.id = F.bundle_id +`) + if filtered { + builder.WriteString("WHERE\n\tF.registered_entry_id IN (SELECT e_id FROM listing)\n") + } + builder.WriteString(` +UNION - buildQuerySQLite3(builder, filtered, downstream) +SELECT + registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, value, NULL, NULL +FROM + dns_names +`) + if filtered { + builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") + } + builder.WriteString(` +UNION + +SELECT + registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, type, value, NULL, NULL, NULL, NULL, NULL +FROM + selectors +`) + if filtered { + builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") + } + builder.WriteString(` +ORDER BY e_id, selector_id, dns_name_id +;`) return builder.String(), args, nil } @@ -2912,8 +2989,84 @@ func buildListRegistrationEntriesQueryPostgreSQL(req *datastore.ListRegistration if err != nil { return "", nil, err } + if filtered { + builder.WriteString(")") + } - buildQueryPostgreSQL(builder, filtered, downstream) + builder.WriteString(` +SELECT + id AS e_id, + entry_id, + spiffe_id, + parent_id, + ttl AS reg_ttl, + admin, + downstream, + expiry, + store_svid, + hint, + created_at, + NULL ::integer AS selector_id, + NULL AS selector_type, + NULL AS selector_value, + NULL AS trust_domain, + NULL ::integer AS dns_name_id, + NULL AS dns_name, + revision_number, + jwt_svid_ttl AS reg_jwt_svid_ttl +FROM + registered_entries +`) + if filtered { + builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") + } + if downstream { + if !filtered { + builder.WriteString("\t\tWHERE downstream = true\n") + } else { + builder.WriteString("\t\tAND downstream = true\n") + } + } + builder.WriteString(` +UNION ALL + +SELECT + F.registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, B.trust_domain, NULL, NULL, NULL, NULL +FROM + bundles B +INNER JOIN + federated_registration_entries F +ON + B.id = F.bundle_id +`) + if filtered { + builder.WriteString("WHERE\n\tF.registered_entry_id IN (SELECT e_id FROM listing)\n") + } + builder.WriteString(` +UNION ALL + +SELECT + registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, value, NULL, NULL +FROM + dns_names +`) + if filtered { + builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") + } + builder.WriteString(` +UNION ALL + +SELECT + registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, type, value, NULL, NULL, NULL, NULL, NULL +FROM + selectors +`) + if filtered { + builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") + } + builder.WriteString(` +ORDER BY e_id, selector_id, dns_name_id +;`) return postgreSQLRebind(builder.String()), args, nil } @@ -2933,8 +3086,39 @@ func postgreSQLRebind(s string) string { func buildListRegistrationEntriesQueryMySQL(req *datastore.ListRegistrationEntriesRequest) (string, []any, error) { builder := new(strings.Builder) + builder.WriteString(` +SELECT + E.id AS e_id, + E.entry_id AS entry_id, + E.spiffe_id, + E.parent_id, + E.ttl AS reg_ttl, + E.admin, + E.downstream, + E.expiry, + E.store_svid, + E.hint, + E.created_at, + S.id AS selector_id, + S.type AS selector_type, + S.value AS selector_value, + B.trust_domain, + D.id AS dns_name_id, + D.value AS dns_name, + E.revision_number, + E.jwt_svid_ttl AS reg_jwt_svid_ttl +FROM + registered_entries E +LEFT JOIN + (SELECT 1 AS joinItem UNION SELECT 2 UNION SELECT 3) AS joinItems ON TRUE +LEFT JOIN + selectors S ON joinItem=1 AND E.id=S.registered_entry_id +LEFT JOIN + dns_names D ON joinItem=2 AND E.id=D.registered_entry_id +LEFT JOIN + (federated_registration_entries F INNER JOIN bundles B ON F.bundle_id=B.id) ON joinItem=3 AND E.id=F.registered_entry_id +`) - buildQueryMySQL(builder) filtered, args, err := appendListRegistrationEntriesFilterQuery("WHERE E.id IN (\n", builder, MySQL, req) var downstream = false if req.ByDownstream != nil { @@ -2955,9 +3139,7 @@ func buildListRegistrationEntriesQueryMySQL(req *datastore.ListRegistrationEntri builder.WriteString("\t\tAND downstream = true\n") } } - builder.WriteString(` - ORDER BY e_id, selector_id, dns_name_id - `) + builder.WriteString("\nORDER BY e_id, selector_id, dns_name_id\n;") return builder.String(), args, nil } @@ -2974,7 +3156,85 @@ func buildListRegistrationEntriesQueryMySQLCTE(req *datastore.ListRegistrationEn if err != nil { return "", nil, err } - buildQueryMySQLCTE(builder, filtered, downstream) + if filtered { + builder.WriteString(")") + } + + builder.WriteString(` +SELECT + id AS e_id, + entry_id, + spiffe_id, + parent_id, + ttl AS reg_ttl, + admin, + downstream, + expiry, + store_svid, + hint, + created_at, + NULL AS selector_id, + NULL AS selector_type, + NULL AS selector_value, + NULL AS trust_domain, + NULL AS dns_name_id, + NULL AS dns_name, + revision_number, + jwt_svid_ttl AS reg_jwt_svid_ttl +FROM + registered_entries +`) + if filtered { + builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") + } + if downstream { + if !filtered { + builder.WriteString("\t\tWHERE downstream = true\n") + } else { + builder.WriteString("\t\tAND downstream = true\n") + } + } + builder.WriteString(` +UNION + +SELECT + F.registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, B.trust_domain, NULL, NULL, NULL, NULL +FROM + bundles B +INNER JOIN + federated_registration_entries F +ON + B.id = F.bundle_id +`) + if filtered { + builder.WriteString("WHERE\n\tF.registered_entry_id IN (SELECT e_id FROM listing)\n") + } + builder.WriteString(` +UNION + +SELECT + registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, value, NULL, NULL +FROM + dns_names +`) + if filtered { + builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") + } + builder.WriteString(` +UNION + +SELECT + registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, type, value, NULL, NULL, NULL, NULL, NULL +FROM + selectors +`) + if filtered { + builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") + } + builder.WriteString(` +ORDER BY e_id, selector_id, dns_name_id +;`) + return builder.String(), args, nil } @@ -3024,285 +3284,6 @@ func countRegistrationEntries(ctx context.Context, db *sqlDB, _ logrus.FieldLogg } } -func buildQuerySQLite3(builder *strings.Builder, filtered bool, downstream bool) { - if filtered { - builder.WriteString(")") - } - - builder.WriteString(` - SELECT - id AS e_id, - entry_id, - spiffe_id, - parent_id, - ttl AS reg_ttl, - admin, - downstream, - expiry, - store_svid, - hint, - created_at, - NULL AS selector_id, - NULL AS selector_type, - NULL AS selector_value, - NULL AS trust_domain, - NULL AS dns_name_id, - NULL AS dns_name, - revision_number, - jwt_svid_ttl AS reg_jwt_svid_ttl - FROM - registered_entries - `) - - if filtered { - builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") - } - if downstream { - if !filtered { - builder.WriteString("\t\tWHERE downstream = true\n") - } else { - builder.WriteString("\t\tAND downstream = true\n") - } - } - builder.WriteString(` - UNION - - SELECT - F.registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, B.trust_domain, NULL, NULL, NULL, NULL - FROM - bundles B - INNER JOIN - federated_registration_entries F - ON - B.id = F.bundle_id - `) - if filtered { - builder.WriteString("WHERE\n\tF.registered_entry_id IN (SELECT e_id FROM listing)\n") - } - builder.WriteString(` - UNION - - SELECT - registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, value, NULL, NULL - FROM - dns_names - `) - if filtered { - builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") - } - builder.WriteString(` - UNION - - SELECT - registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, type, value, NULL, NULL, NULL, NULL, NULL - FROM - selectors - `) - if filtered { - builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") - } - builder.WriteString(` - ORDER BY e_id, selector_id, dns_name_id - `) -} - -func buildQueryPostgreSQL(builder *strings.Builder, filtered bool, downstream bool) { - if filtered { - builder.WriteString(")") - } - - builder.WriteString(` - SELECT - id AS e_id, - entry_id, - spiffe_id, - parent_id, - ttl AS reg_ttl, - admin, - downstream, - expiry, - store_svid, - hint, - created_at, - NULL ::integer AS selector_id, - NULL AS selector_type, - NULL AS selector_value, - NULL AS trust_domain, - NULL ::integer AS dns_name_id, - NULL AS dns_name, - revision_number, - jwt_svid_ttl AS reg_jwt_svid_ttl - FROM - registered_entries - `) - if filtered { - builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") - } - if downstream { - if !filtered { - builder.WriteString("\t\tWHERE downstream = true\n") - } else { - builder.WriteString("\t\tAND downstream = true\n") - } - } - builder.WriteString(` - UNION ALL - - SELECT - F.registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, B.trust_domain, NULL, NULL, NULL, NULL - FROM - bundles B - INNER JOIN - federated_registration_entries F - ON - B.id = F.bundle_id - `) - if filtered { - builder.WriteString("WHERE\n\tF.registered_entry_id IN (SELECT e_id FROM listing)\n") - } - builder.WriteString(` - UNION ALL - - SELECT - registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, value, NULL, NULL - FROM - dns_names - `) - if filtered { - builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") - } - builder.WriteString(` - UNION ALL - - SELECT - registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, type, value, NULL, NULL, NULL, NULL, NULL - FROM - selectors - `) - if filtered { - builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") - } - builder.WriteString(` - ORDER BY e_id, selector_id, dns_name_id - `) -} - -func buildQueryMySQLCTE(builder *strings.Builder, filtered bool, downstream bool) { - if filtered { - builder.WriteString(")") - } - - builder.WriteString(` - SELECT - id AS e_id, - entry_id, - spiffe_id, - parent_id, - ttl AS reg_ttl, - admin, - downstream, - expiry, - store_svid, - hint, - created_at, - NULL AS selector_id, - NULL AS selector_type, - NULL AS selector_value, - NULL AS trust_domain, - NULL AS dns_name_id, - NULL AS dns_name, - revision_number, - jwt_svid_ttl AS reg_jwt_svid_ttl - FROM - registered_entries - `) - if filtered { - builder.WriteString("WHERE id IN (SELECT e_id FROM listing)\n") - } - if downstream { - if !filtered { - builder.WriteString("\t\tWHERE downstream = true\n") - } else { - builder.WriteString("\t\tAND downstream = true\n") - } - } - builder.WriteString(` - UNION - - SELECT - F.registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, B.trust_domain, NULL, NULL, NULL, NULL - FROM - bundles B - INNER JOIN - federated_registration_entries F - ON - B.id = F.bundle_id - `) - if filtered { - builder.WriteString("WHERE\n\tF.registered_entry_id IN (SELECT e_id FROM listing)\n") - } - builder.WriteString(` - UNION - - SELECT - registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, value, NULL, NULL - FROM - dns_names - `) - if filtered { - builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") - } - builder.WriteString(` - UNION - - SELECT - registered_entry_id, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, id, type, value, NULL, NULL, NULL, NULL, NULL - FROM - selectors - `) - if filtered { - builder.WriteString("WHERE registered_entry_id IN (SELECT e_id FROM listing)\n") - } - builder.WriteString(` - ORDER BY e_id, selector_id, dns_name_id - `) -} - -func buildQueryMySQL(builder *strings.Builder) { - builder.WriteString(` - SELECT - E.id AS e_id, - E.entry_id AS entry_id, - E.spiffe_id, - E.parent_id, - E.ttl AS reg_ttl, - E.admin, - E.downstream, - E.expiry, - E.store_svid, - E.hint, - E.created_at, - S.id AS selector_id, - S.type AS selector_type, - S.value AS selector_value, - B.trust_domain, - D.id AS dns_name_id, - D.value AS dns_name, - E.revision_number, - E.jwt_svid_ttl AS reg_jwt_svid_ttl - FROM - registered_entries E - LEFT JOIN - (SELECT 1 AS joinItem UNION SELECT 2 UNION SELECT 3) AS joinItems ON TRUE - LEFT JOIN - selectors S ON joinItem=1 AND E.id=S.registered_entry_id - LEFT JOIN - dns_names D ON joinItem=2 AND E.id=D.registered_entry_id - LEFT JOIN - (federated_registration_entries F INNER JOIN bundles B ON F.bundle_id=B.id) ON joinItem=3 AND E.id=F.registered_entry_id - `) -} - type idFilterNode struct { idColumn string