From cf552c48a2b22aafe66c015151f3c7dd08f28a2b Mon Sep 17 00:00:00 2001 From: Gabriele Gerbino Date: Tue, 11 Jul 2023 14:51:42 +0200 Subject: [PATCH] feat: support scoping plugins to consumer-groups --- file/builder.go | 27 +- file/builder_test.go | 42 ++- file/codegen/main.go | 1 + file/kong_json_schema.json | 4 +- file/types.go | 38 ++- go.mod | 2 +- go.sum | 4 +- state/builder.go | 21 ++ state/plugin.go | 48 ++- state/plugin_test.go | 55 ++- state/types.go | 10 + tests/integration/sync_test.go | 320 ++++++++++++++++++ tests/integration/test_utils.go | 6 + .../kong3x.yaml | 76 +++++ types/plugin.go | 22 +- utils/utils.go | 10 + 16 files changed, 622 insertions(+), 64 deletions(-) create mode 100644 tests/integration/testdata/sync/023-consumer-groups-scoped-plugins/kong3x.yaml diff --git a/file/builder.go b/file/builder.go index a0286e215..9c2908a17 100644 --- a/file/builder.go +++ b/file/builder.go @@ -116,6 +116,12 @@ func (b *stateBuilder) consumerGroups() { ConsumerGroup: &cg.ConsumerGroup, } + err := b.intermediate.ConsumerGroups.Add(state.ConsumerGroup{ConsumerGroup: cg.ConsumerGroup}) + if err != nil { + b.err = err + return + } + for _, plugin := range cg.Plugins { if utils.Empty(plugin.ID) { current, err := b.currentState.ConsumerGroupPlugins.Get( @@ -882,6 +888,18 @@ func (b *stateBuilder) plugins() { } p.Route = utils.GetRouteReference(r.Route) } + if p.ConsumerGroup != nil && !utils.Empty(p.ConsumerGroup.ID) { + cg, err := b.intermediate.ConsumerGroups.Get(*p.ConsumerGroup.ID) + if errors.Is(err, state.ErrNotFound) { + b.err = fmt.Errorf("consumer-group %v for plugin %v: %w", + p.ConsumerGroup.FriendlyName(), *p.Name, err) + return + } else if err != nil { + b.err = err + return + } + p.ConsumerGroup = utils.GetConsumerGroupReference(cg.ConsumerGroup) + } plugins = append(plugins, p) } if err := b.ingestPlugins(plugins); err != nil { @@ -997,9 +1015,9 @@ func (b *stateBuilder) ingestPlugins(plugins []FPlugin) error { for _, p := range plugins { p := p if utils.Empty(p.ID) { - cID, rID, sID := pluginRelations(&p.Plugin) + cID, rID, sID, cgID := pluginRelations(&p.Plugin) plugin, err := b.currentState.Plugins.GetByProp(*p.Name, - sID, rID, cID) + sID, rID, cID, cgID) if errors.Is(err, state.ErrNotFound) { p.ID = uuid() } else if err != nil { @@ -1044,7 +1062,7 @@ func (b *stateBuilder) fillPluginConfig(plugin *FPlugin) error { return nil } -func pluginRelations(plugin *kong.Plugin) (cID, rID, sID string) { +func pluginRelations(plugin *kong.Plugin) (cID, rID, sID, cgID string) { if plugin.Consumer != nil && !utils.Empty(plugin.Consumer.ID) { cID = *plugin.Consumer.ID } @@ -1054,6 +1072,9 @@ func pluginRelations(plugin *kong.Plugin) (cID, rID, sID string) { if plugin.Service != nil && !utils.Empty(plugin.Service.ID) { sID = *plugin.Service.ID } + if plugin.ConsumerGroup != nil && !utils.Empty(plugin.ConsumerGroup.ID) { + cgID = *plugin.ConsumerGroup.ID + } return } diff --git a/file/builder_test.go b/file/builder_test.go index 0d3557786..1b79c2886 100644 --- a/file/builder_test.go +++ b/file/builder_test.go @@ -293,6 +293,9 @@ func existingPluginState() *state.KongState { Route: &kong.Route{ ID: kong.String("700bc504-b2b1-4abd-bd38-cec92779659e"), }, + ConsumerGroup: &kong.ConsumerGroup{ + ID: kong.String("69ed4618-a653-4b54-8bb6-dc33bd6fe048"), + }, }, }) return s @@ -751,6 +754,9 @@ func Test_stateBuilder_ingestPlugins(t *testing.T) { Route: &kong.Route{ ID: kong.String("700bc504-b2b1-4abd-bd38-cec92779659e"), }, + ConsumerGroup: &kong.ConsumerGroup{ + ID: kong.String("69ed4618-a653-4b54-8bb6-dc33bd6fe048"), + }, }, }, }, @@ -780,6 +786,9 @@ func Test_stateBuilder_ingestPlugins(t *testing.T) { Route: &kong.Route{ ID: kong.String("700bc504-b2b1-4abd-bd38-cec92779659e"), }, + ConsumerGroup: &kong.ConsumerGroup{ + ID: kong.String("69ed4618-a653-4b54-8bb6-dc33bd6fe048"), + }, Config: kong.Configuration{}, }, }, @@ -805,11 +814,12 @@ func Test_pluginRelations(t *testing.T) { plugin *kong.Plugin } tests := []struct { - name string - args args - wantCID string - wantRID string - wantSID string + name string + args args + wantCID string + wantRID string + wantSID string + wantCGID string }{ { args: args{ @@ -817,9 +827,10 @@ func Test_pluginRelations(t *testing.T) { Name: kong.String("foo"), }, }, - wantCID: "", - wantRID: "", - wantSID: "", + wantCID: "", + wantRID: "", + wantSID: "", + wantCGID: "", }, { args: args{ @@ -834,16 +845,20 @@ func Test_pluginRelations(t *testing.T) { Service: &kong.Service{ ID: kong.String("sID"), }, + ConsumerGroup: &kong.ConsumerGroup{ + ID: kong.String("cgID"), + }, }, }, - wantCID: "cID", - wantRID: "rID", - wantSID: "sID", + wantCID: "cID", + wantRID: "rID", + wantSID: "sID", + wantCGID: "cgID", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotCID, gotRID, gotSID := pluginRelations(tt.args.plugin) + gotCID, gotRID, gotSID, gotCGID := pluginRelations(tt.args.plugin) if gotCID != tt.wantCID { t.Errorf("pluginRelations() gotCID = %v, want %v", gotCID, tt.wantCID) } @@ -853,6 +868,9 @@ func Test_pluginRelations(t *testing.T) { if gotSID != tt.wantSID { t.Errorf("pluginRelations() gotSID = %v, want %v", gotSID, tt.wantSID) } + if gotCGID != tt.wantCGID { + t.Errorf("pluginRelations() gotCGID = %v, want %v", gotCGID, tt.wantCGID) + } }) } } diff --git a/file/codegen/main.go b/file/codegen/main.go index be480d51f..552b70d0b 100644 --- a/file/codegen/main.go +++ b/file/codegen/main.go @@ -97,6 +97,7 @@ func main() { schema.Definitions["FPlugin"].Properties["consumer"] = stringType schema.Definitions["FPlugin"].Properties["service"] = stringType schema.Definitions["FPlugin"].Properties["route"] = stringType + schema.Definitions["FPlugin"].Properties["consumer_group"] = stringType schema.Definitions["FService"].Properties["client_certificate"] = stringType diff --git a/file/kong_json_schema.json b/file/kong_json_schema.json index b287a13ff..5e68bbe5f 100644 --- a/file/kong_json_schema.json +++ b/file/kong_json_schema.json @@ -444,7 +444,6 @@ }, "groups": { "items": { - "$schema": "http://json-schema.org/draft-04/schema#", "$ref": "#/definitions/ConsumerGroup" }, "type": "array" @@ -600,6 +599,9 @@ "consumer": { "type": "string" }, + "consumer_group": { + "type": "string" + }, "created_at": { "type": "integer" }, diff --git a/file/types.go b/file/types.go index 7c6aa44c5..ffe364eb1 100644 --- a/file/types.go +++ b/file/types.go @@ -321,19 +321,20 @@ type FPlugin struct { // foo is a shadow type of Plugin. // It is used for custom marshalling of plugin. type foo struct { - CreatedAt *int `json:"created_at,omitempty" yaml:"created_at,omitempty"` - ID *string `json:"id,omitempty" yaml:"id,omitempty"` - Name *string `json:"name,omitempty" yaml:"name,omitempty"` - InstanceName *string `json:"instance_name,omitempty" yaml:"instance_name,omitempty"` - Config kong.Configuration `json:"config,omitempty" yaml:"config,omitempty"` - Service string `json:"service,omitempty" yaml:",omitempty"` - Consumer string `json:"consumer,omitempty" yaml:",omitempty"` - Route string `json:"route,omitempty" yaml:",omitempty"` - Enabled *bool `json:"enabled,omitempty" yaml:"enabled,omitempty"` - RunOn *string `json:"run_on,omitempty" yaml:"run_on,omitempty"` - Ordering *kong.PluginOrdering `json:"ordering,omitempty" yaml:"ordering,omitempty"` - Protocols []*string `json:"protocols,omitempty" yaml:"protocols,omitempty"` - Tags []*string `json:"tags,omitempty" yaml:"tags,omitempty"` + CreatedAt *int `json:"created_at,omitempty" yaml:"created_at,omitempty"` + ID *string `json:"id,omitempty" yaml:"id,omitempty"` + Name *string `json:"name,omitempty" yaml:"name,omitempty"` + InstanceName *string `json:"instance_name,omitempty" yaml:"instance_name,omitempty"` + Config kong.Configuration `json:"config,omitempty" yaml:"config,omitempty"` + Service string `json:"service,omitempty" yaml:",omitempty"` + Consumer string `json:"consumer,omitempty" yaml:",omitempty"` + ConsumerGroup string `json:"consumer_group,omitempty" yaml:",omitempty"` + Route string `json:"route,omitempty" yaml:",omitempty"` + Enabled *bool `json:"enabled,omitempty" yaml:"enabled,omitempty"` + RunOn *string `json:"run_on,omitempty" yaml:"run_on,omitempty"` + Ordering *kong.PluginOrdering `json:"ordering,omitempty" yaml:"ordering,omitempty"` + Protocols []*string `json:"protocols,omitempty" yaml:"protocols,omitempty"` + Tags []*string `json:"tags,omitempty" yaml:"tags,omitempty"` ConfigSource *string `json:"_config,omitempty" yaml:"_config,omitempty"` } @@ -379,6 +380,9 @@ func copyToFoo(p FPlugin) foo { if p.Plugin.Service != nil { f.Service = *p.Plugin.Service.ID } + if p.Plugin.ConsumerGroup != nil { + f.ConsumerGroup = *p.Plugin.ConsumerGroup.ID + } return f } @@ -428,6 +432,11 @@ func copyFromFoo(f foo, p *FPlugin) { ID: kong.String(f.Service), } } + if f.ConsumerGroup != "" { + p.ConsumerGroup = &kong.ConsumerGroup{ + ID: kong.String(f.ConsumerGroup), + } + } } // MarshalYAML is a custom marshal method to handle @@ -480,6 +489,9 @@ func (p FPlugin) sortKey() string { if p.Service != nil { key += *p.Service.ID } + if p.ConsumerGroup != nil { + key += *p.ConsumerGroup.ID + } return key } if p.ID != nil { diff --git a/go.mod b/go.mod index 9cbc1e044..78073707e 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/hashicorp/go-retryablehttp v0.7.4 github.com/hexops/gotextdiff v1.0.3 github.com/imdario/mergo v0.3.16 - github.com/kong/go-kong v0.44.0 + github.com/kong/go-kong v0.45.1-0.20230707124609-5236d86ec5d6 github.com/mitchellh/go-homedir v1.1.0 github.com/shirou/gopsutil/v3 v3.23.6 github.com/spf13/cobra v1.7.0 diff --git a/go.sum b/go.sum index 8cbab3c6d..208fab1e4 100644 --- a/go.sum +++ b/go.sum @@ -202,8 +202,8 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/kong/go-kong v0.44.0 h1:1x3w/TYdJjIZ6c1j9HiYP8755c923XN2O6j3kEaUkTA= -github.com/kong/go-kong v0.44.0/go.mod h1:41Sot1N/n8UHBp+gE/6nOw3vuzoHbhMSyU/zOS7VzPE= +github.com/kong/go-kong v0.45.1-0.20230707124609-5236d86ec5d6 h1:VFPUX0r0dW+lEPJ0ytTQcY4WD2d+HMcW/lPXFFWyo9w= +github.com/kong/go-kong v0.45.1-0.20230707124609-5236d86ec5d6/go.mod h1:41Sot1N/n8UHBp+gE/6nOw3vuzoHbhMSyU/zOS7VzPE= github.com/kong/semver/v4 v4.0.1 h1:DIcNR8W3gfx0KabFBADPalxxsp+q/5COwIFkkhrFQ2Y= github.com/kong/semver/v4 v4.0.1/go.mod h1:LImQ0oT15pJvSns/hs2laLca2zcYoHu5EsSNY0J6/QA= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= diff --git a/state/builder.go b/state/builder.go index cf7e1c20a..0e7b43cbd 100644 --- a/state/builder.go +++ b/state/builder.go @@ -57,6 +57,18 @@ func ensureConsumer(kongState *KongState, consumerID string) (bool, *kong.Consum return true, utils.GetConsumerReference(c.Consumer), nil } +func ensureConsumerGroup(kongState *KongState, consumerGroupID string) (bool, *kong.ConsumerGroup, error) { + c, err := kongState.ConsumerGroups.Get(consumerGroupID) + if err != nil { + if errors.Is(err, ErrNotFound) { + return false, nil, nil + } + return false, nil, fmt.Errorf("looking up consumer-group %q: %w", consumerGroupID, err) + + } + return true, utils.GetConsumerGroupReference(c.ConsumerGroup), nil +} + func buildKong(kongState *KongState, raw *utils.KongRawState) error { for _, s := range raw.Services { err := kongState.Services.Add(Service{Service: *s}) @@ -282,6 +294,15 @@ func buildKong(kongState *KongState, raw *utils.KongRawState) error { p.Consumer = c } } + if p.ConsumerGroup != nil && !utils.Empty(p.ConsumerGroup.ID) { + ok, cg, err := ensureConsumerGroup(kongState, *p.ConsumerGroup.ID) + if err != nil { + return err + } + if ok { + p.ConsumerGroup = cg + } + } err := kongState.Plugins.Add(Plugin{Plugin: *p}) if err != nil { return fmt.Errorf("inserting plugins into state: %w", err) diff --git a/state/plugin.go b/state/plugin.go index 841ac1d6f..099893e95 100644 --- a/state/plugin.go +++ b/state/plugin.go @@ -12,10 +12,11 @@ import ( var errPluginNameRequired = fmt.Errorf("name of plugin required") const ( - pluginTableName = "plugin" - pluginsByServiceID = "pluginsByServiceID" - pluginsByRouteID = "pluginsByRouteID" - pluginsByConsumerID = "pluginsByConsumerID" + pluginTableName = "plugin" + pluginsByServiceID = "pluginsByServiceID" + pluginsByRouteID = "pluginsByRouteID" + pluginsByConsumerID = "pluginsByConsumerID" + pluginsByConsumerGroupID = "pluginsByConsumerGroupID" ) var pluginTableSchema = &memdb.TableSchema{ @@ -68,6 +69,18 @@ var pluginTableSchema = &memdb.TableSchema{ }, AllowMissing: true, }, + pluginsByConsumerGroupID: { + Name: pluginsByConsumerGroupID, + Indexer: &indexers.SubFieldIndexer{ + Fields: []indexers.Field{ + { + Struct: "ConsumerGroup", + Sub: "ID", + }, + }, + }, + AllowMissing: true, + }, // combined foreign fields // FIXME bug: collision if svc/route/consumer has the same ID // and same type of plugin is created. Consider the case when only @@ -92,6 +105,10 @@ var pluginTableSchema = &memdb.TableSchema{ Struct: "Consumer", Sub: "ID", }, + { + Struct: "ConsumerGroup", + Sub: "ID", + }, }, }, }, @@ -133,7 +150,7 @@ func insertPlugin(txn *memdb.Txn, plugin Plugin) error { } // err out if another plugin with exact same combination is present - sID, rID, cID := "", "", "" + sID, rID, cID, cgID := "", "", "", "" if plugin.Service != nil && !utils.Empty(plugin.Service.ID) { sID = *plugin.Service.ID } @@ -143,7 +160,10 @@ func insertPlugin(txn *memdb.Txn, plugin Plugin) error { if plugin.Consumer != nil && !utils.Empty(plugin.Consumer.ID) { cID = *plugin.Consumer.ID } - _, err = getPluginBy(txn, *plugin.Name, sID, rID, cID) + if plugin.ConsumerGroup != nil && !utils.Empty(plugin.ConsumerGroup.ID) { + cgID = *plugin.ConsumerGroup.ID + } + _, err = getPluginBy(txn, *plugin.Name, sID, rID, cID, cgID) if err == nil { return fmt.Errorf("inserting plugin %v: %w", plugin.Console(), ErrAlreadyExists) } else if !errors.Is(err, ErrNotFound) { @@ -194,7 +214,7 @@ func (k *PluginsCollection) GetAllByName(name string) ([]*Plugin, error) { return k.getAllPluginsBy("name", name) } -func getPluginBy(txn *memdb.Txn, name, svcID, routeID, consumerID string) ( +func getPluginBy(txn *memdb.Txn, name, svcID, routeID, consumerID, consumerGroupID string) ( *Plugin, error, ) { if name == "" { @@ -202,7 +222,7 @@ func getPluginBy(txn *memdb.Txn, name, svcID, routeID, consumerID string) ( } res, err := txn.First(pluginTableName, "fields", - name, svcID, routeID, consumerID) + name, svcID, routeID, consumerID, consumerGroupID) if err != nil { return nil, err } @@ -217,18 +237,18 @@ func getPluginBy(txn *memdb.Txn, name, svcID, routeID, consumerID string) ( } // GetByProp returns a plugin which matches all the properties passed in -// the arguments. If serviceID, routeID and consumerID are empty strings, then -// a global plugin is searched. +// the arguments. If serviceID, routeID, consumerID and consumerGroupID +// are empty strings, then a global plugin is searched. // Otherwise, a plugin with name and the supplied foreign references is // searched. // name is required. -func (k *PluginsCollection) GetByProp(name, serviceID, - routeID string, consumerID string, +func (k *PluginsCollection) GetByProp( + name, serviceID, routeID, consumerID, consumerGroupID string, ) (*Plugin, error) { txn := k.db.Txn(false) defer txn.Abort() - return getPluginBy(txn, name, serviceID, routeID, consumerID) + return getPluginBy(txn, name, serviceID, routeID, consumerID, consumerGroupID) } func (k *PluginsCollection) getAllPluginsBy(index, identifier string) ( @@ -264,7 +284,7 @@ func (k *PluginsCollection) GetAllByServiceID(id string) ([]*Plugin, return k.getAllPluginsBy(pluginsByServiceID, id) } -// GetAllByRouteID returns all plugins referencing a service +// GetAllByRouteID returns all plugins referencing a route // by its id. func (k *PluginsCollection) GetAllByRouteID(id string) ([]*Plugin, error, diff --git a/state/plugin_test.go b/state/plugin_test.go index 5ba22171b..f808806c8 100644 --- a/state/plugin_test.go +++ b/state/plugin_test.go @@ -271,9 +271,25 @@ func TestPluginsCollection_Update(t *testing.T) { }, }, } + plugin4 := Plugin{ + Plugin: kong.Plugin{ + ID: kong.String("id4"), + Name: kong.String("key-auth"), + Route: &kong.Route{ + ID: kong.String("route1"), + }, + Service: &kong.Service{ + ID: kong.String("svc1"), + }, + ConsumerGroup: &kong.ConsumerGroup{ + ID: kong.String("cg1"), + }, + }, + } k.Add(plugin1) k.Add(plugin2) k.Add(plugin3) + k.Add(plugin4) for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { @@ -362,6 +378,18 @@ func TestGetPluginByProp(t *testing.T) { }, }, }, + { + Plugin: kong.Plugin{ + ID: kong.String("5"), + Name: kong.String("key-auth"), + ConsumerGroup: &kong.ConsumerGroup{ + ID: kong.String("cg1"), + }, + Config: map[string]interface{}{ + "key5": "value5", + }, + }, + }, } assert := assert.New(t) collection := pluginsCollection() @@ -370,33 +398,38 @@ func TestGetPluginByProp(t *testing.T) { assert.Nil(collection.Add(p)) } - plugin, err := collection.GetByProp("", "", "", "") + plugin, err := collection.GetByProp("", "", "", "", "") assert.Nil(plugin) - assert.NotNil(err) + assert.Error(err) - plugin, err = collection.GetByProp("foo", "", "", "") + plugin, err = collection.GetByProp("foo", "", "", "", "") assert.Nil(plugin) assert.Equal(ErrNotFound, err) - plugin, err = collection.GetByProp("key-auth", "", "", "") - assert.Nil(err) + plugin, err = collection.GetByProp("key-auth", "", "", "", "") + assert.NoError(err) assert.NotNil(plugin) assert.Equal("value1", plugin.Config["key1"]) - plugin, err = collection.GetByProp("key-auth", "svc1", "", "") - assert.Nil(err) + plugin, err = collection.GetByProp("key-auth", "svc1", "", "", "") + assert.NoError(err) assert.NotNil(plugin) assert.Equal("value2", plugin.Config["key2"]) - plugin, err = collection.GetByProp("key-auth", "", "route1", "") - assert.Nil(err) + plugin, err = collection.GetByProp("key-auth", "", "route1", "", "") + assert.NoError(err) assert.NotNil(plugin) assert.Equal("value3", plugin.Config["key3"]) - plugin, err = collection.GetByProp("key-auth", "", "", "consumer1") - assert.Nil(err) + plugin, err = collection.GetByProp("key-auth", "", "", "consumer1", "") + assert.NoError(err) assert.NotNil(plugin) assert.Equal("value4", plugin.Config["key4"]) + + plugin, err = collection.GetByProp("key-auth", "", "", "", "cg1") + assert.NoError(err) + assert.NotNil(plugin) + assert.Equal("value5", plugin.Config["key5"]) } func TestPluginsInvalidType(t *testing.T) { diff --git a/state/types.go b/state/types.go index 88dcb5795..9bc923a6e 100644 --- a/state/types.go +++ b/state/types.go @@ -425,6 +425,9 @@ func (p1 *Plugin) Console() string { if p1.Consumer != nil { associations = append(associations, "consumer "+p1.Consumer.FriendlyName()) } + if p1.ConsumerGroup != nil { + associations = append(associations, "consumer-group "+p1.ConsumerGroup.FriendlyName()) + } if len(associations) > 0 { res += "for " } @@ -470,6 +473,7 @@ func (p1 *Plugin) EqualWithOpts(p2 *Plugin, ignoreID, p2Copy.Service = nil p2Copy.Route = nil p2Copy.Consumer = nil + p2Copy.ConsumerGroup = nil } if p1Copy.Service != nil { @@ -490,6 +494,12 @@ func (p1 *Plugin) EqualWithOpts(p2 *Plugin, ignoreID, if p2Copy.Consumer != nil { p2Copy.Consumer.Username = nil } + if p1Copy.ConsumerGroup != nil { + p1Copy.ConsumerGroup.Name = nil + } + if p2Copy.ConsumerGroup != nil { + p2Copy.ConsumerGroup.Name = nil + } return reflect.DeepEqual(p1Copy, p2Copy) } diff --git a/tests/integration/sync_test.go b/tests/integration/sync_test.go index 598266d7e..3a43eed1b 100644 --- a/tests/integration/sync_test.go +++ b/tests/integration/sync_test.go @@ -821,6 +821,167 @@ var ( Protocols: []*string{kong.String("http"), kong.String("https")}, }, } + + consumerGroupScopedPlugins = []*kong.Plugin{ + { + Name: kong.String("rate-limiting-advanced"), + ConsumerGroup: &kong.ConsumerGroup{ + ID: kong.String("77e6691d-67c0-446a-9401-27be2b141aae"), + }, + Config: kong.Configuration{ + "consumer_groups": nil, + "dictionary_name": string("kong_rate_limiting_counters"), + "disable_penalty": bool(false), + "enforce_consumer_groups": bool(false), + "error_code": float64(429), + "error_message": string("API rate limit exceeded"), + "header_name": nil, + "hide_client_headers": bool(false), + "identifier": string("consumer"), + "limit": []any{float64(10)}, + "namespace": string("gold"), + "path": nil, + "redis": map[string]any{ + "cluster_addresses": nil, + "connect_timeout": nil, + "database": float64(0), + "host": nil, + "keepalive_backlog": nil, + "keepalive_pool_size": float64(30), + "password": nil, + "port": nil, + "read_timeout": nil, + "send_timeout": nil, + "sentinel_addresses": nil, + "sentinel_master": nil, + "sentinel_password": nil, + "sentinel_role": nil, + "sentinel_username": nil, + "server_name": nil, + "ssl": false, + "ssl_verify": false, + "timeout": float64(2000), + "username": nil, + }, + "retry_after_jitter_max": float64(1), + "strategy": string("local"), + "sync_rate": float64(-1), + "window_size": []any{float64(60)}, + "window_type": string("sliding"), + }, + Enabled: kong.Bool(true), + Protocols: []*string{kong.String("grpc"), kong.String("grpcs"), kong.String("http"), kong.String("https")}, + }, + { + Name: kong.String("rate-limiting-advanced"), + ConsumerGroup: &kong.ConsumerGroup{ + ID: kong.String("5bcbd3a7-030b-4310-bd1d-2721ff85d236"), + }, + Config: kong.Configuration{ + "consumer_groups": nil, + "dictionary_name": string("kong_rate_limiting_counters"), + "disable_penalty": bool(false), + "enforce_consumer_groups": bool(false), + "error_code": float64(429), + "error_message": string("API rate limit exceeded"), + "header_name": nil, + "hide_client_headers": bool(false), + "identifier": string("consumer"), + "limit": []any{float64(7)}, + "namespace": string("silver"), + "path": nil, + "redis": map[string]any{ + "cluster_addresses": nil, + "connect_timeout": nil, + "database": float64(0), + "host": nil, + "keepalive_backlog": nil, + "keepalive_pool_size": float64(30), + "password": nil, + "port": nil, + "read_timeout": nil, + "send_timeout": nil, + "sentinel_addresses": nil, + "sentinel_master": nil, + "sentinel_password": nil, + "sentinel_role": nil, + "sentinel_username": nil, + "server_name": nil, + "ssl": false, + "ssl_verify": false, + "timeout": float64(2000), + "username": nil, + }, + "retry_after_jitter_max": float64(1), + "strategy": string("local"), + "sync_rate": float64(-1), + "window_size": []any{float64(60)}, + "window_type": string("sliding"), + }, + Enabled: kong.Bool(true), + Protocols: []*string{kong.String("grpc"), kong.String("grpcs"), kong.String("http"), kong.String("https")}, + }, + { + Name: kong.String("rate-limiting-advanced"), + Config: kong.Configuration{ + "consumer_groups": nil, + "dictionary_name": string("kong_rate_limiting_counters"), + "disable_penalty": bool(false), + "enforce_consumer_groups": bool(false), + "error_code": float64(429), + "error_message": string("API rate limit exceeded"), + "header_name": nil, + "hide_client_headers": bool(false), + "identifier": string("consumer"), + "limit": []any{float64(5)}, + "namespace": string("silver"), + "path": nil, + "redis": map[string]any{ + "cluster_addresses": nil, + "connect_timeout": nil, + "database": float64(0), + "host": nil, + "keepalive_backlog": nil, + "keepalive_pool_size": float64(30), + "password": nil, + "port": nil, + "read_timeout": nil, + "send_timeout": nil, + "sentinel_addresses": nil, + "sentinel_master": nil, + "sentinel_password": nil, + "sentinel_role": nil, + "sentinel_username": nil, + "server_name": nil, + "ssl": false, + "ssl_verify": false, + "timeout": float64(2000), + "username": nil, + }, + "retry_after_jitter_max": float64(1), + "strategy": string("local"), + "sync_rate": float64(-1), + "window_size": []any{float64(60)}, + "window_type": string("sliding"), + }, + Enabled: kong.Bool(true), + Protocols: []*string{kong.String("grpc"), kong.String("grpcs"), kong.String("http"), kong.String("https")}, + }, + { + Name: kong.String("key-auth"), + Config: kong.Configuration{ + "anonymous": nil, + "hide_credentials": false, + "key_in_body": false, + "key_in_header": true, + "key_in_query": true, + "key_names": []interface{}{"apikey"}, + "run_on_preflight": true, + }, + Enabled: kong.Bool(true), + Protocols: []*string{kong.String("http"), kong.String("https")}, + }, + } ) // test scope: @@ -3458,3 +3619,162 @@ func Test_Sync_UpdateWithExplicitIDsWithNoNames(t *testing.T) { }, }, ignoreFieldsIrrelevantForIDsTests) } + +// This test has 2 goals: +// - make sure consumer groups scoped plugins can be configured correctly in Kong +// - the actual consumer groups functionality works once set +// +// This is achieved via configuring: +// - 3 consumers: +// - 1 belonging to Gold Consumer Group +// - 1 belonging to Silver Consumer Group +// - 1 not belonging to any Consumer Group +// +// - 3 key-auths, one for each consumer +// - 1 global key-auth plugin +// - 2 consumer group +// - 1 global RLA plugin +// - 2 RLA plugins, scoped to the related consumer groups +// - 1 service pointing to mockbin.org +// - 1 route proxying the above service +// +// Once the configuration is verified to be matching in Kong, +// we then check whether the specific RLA configuration is correctly applied: consumers +// not belonging to the consumer group should be limited to 5 requests +// every 30s, while consumers belonging to the 'gold' and 'silver' consumer groups +// should be allowed to run respectively 10 and 7 requests in the same timeframe. +// In order to make sure this is the case, we run requests in a loop +// for all consumers and then check at what point they start to receive 429. +func Test_Sync_ConsumerGroupsScopedPlugins(t *testing.T) { + const ( + maxGoldRequestsNumber = 10 + maxSilverRequestsNumber = 7 + maxRegularRequestsNumber = 5 + ) + client, err := getTestClient() + if err != nil { + t.Errorf(err.Error()) + } + tests := []struct { + name string + kongFile string + expectedState utils.KongRawState + }{ + { + name: "creates consumer groups scoped plugins", + kongFile: "testdata/sync/023-consumer-groups-scoped-plugins/kong3x.yaml", + expectedState: utils.KongRawState{ + Consumers: consumerGroupsConsumers, + ConsumerGroups: []*kong.ConsumerGroupObject{ + { + ConsumerGroup: &kong.ConsumerGroup{ + Name: kong.String("silver"), + }, + Consumers: []*kong.Consumer{ + { + Username: kong.String("bar"), + }, + }, + }, + { + ConsumerGroup: &kong.ConsumerGroup{ + Name: kong.String("gold"), + }, + Consumers: []*kong.Consumer{ + { + Username: kong.String("foo"), + }, + }, + }, + }, + Plugins: consumerGroupScopedPlugins, + Services: svc1_207, + Routes: route1_20x, + KeyAuths: []*kong.KeyAuth{ + { + Consumer: &kong.Consumer{ + ID: kong.String("87095815-5395-454e-8c18-a11c9bc0ef04"), + }, + Key: kong.String("i-am-special"), + }, + { + Consumer: &kong.Consumer{ + ID: kong.String("5a5b9369-baeb-4faa-a902-c40ccdc2928e"), + }, + Key: kong.String("i-am-not-so-special"), + }, + { + Consumer: &kong.Consumer{ + ID: kong.String("e894ea9e-ad08-4acf-a960-5a23aa7701c7"), + }, + Key: kong.String("i-am-just-average"), + }, + }, + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + runWhen(t, "enterprise", ">=3.4.0") + teardown := setup(t) + defer teardown(t) + + sync(tc.kongFile) + testKongState(t, client, false, tc.expectedState, nil) + + // Kong proxy may need a bit to be ready. + time.Sleep(time.Second * 10) + + // build simple http client + client := &http.Client{} + + // test 'foo' consumer (part of 'gold' group) + req, err := http.NewRequest("GET", "http://localhost:8000/r1", nil) + assert.NoError(t, err) + req.Header.Add("apikey", "i-am-special") + n := 0 + for n < 11 { + resp, err := client.Do(req) + assert.NoError(t, err) + defer resp.Body.Close() + if resp.StatusCode == http.StatusTooManyRequests { + break + } + n++ + } + assert.Equal(t, maxGoldRequestsNumber, n) + + // test 'bar' consumer (part of 'silver' group) + req, err = http.NewRequest("GET", "http://localhost:8000/r1", nil) + assert.NoError(t, err) + req.Header.Add("apikey", "i-am-not-so-special") + n = 0 + for n < 11 { + resp, err := client.Do(req) + assert.NoError(t, err) + defer resp.Body.Close() + if resp.StatusCode == http.StatusTooManyRequests { + break + } + n++ + } + assert.Equal(t, maxSilverRequestsNumber, n) + + // test 'baz' consumer (not part of any group) + req, err = http.NewRequest("GET", "http://localhost:8000/r1", nil) + assert.NoError(t, err) + req.Header.Add("apikey", "i-am-just-average") + n = 0 + for n < 11 { + resp, err := client.Do(req) + assert.NoError(t, err) + defer resp.Body.Close() + if resp.StatusCode == http.StatusTooManyRequests { + break + } + n++ + } + assert.Equal(t, maxRegularRequestsNumber, n) + }) + } +} diff --git a/tests/integration/test_utils.go b/tests/integration/test_utils.go index 228db1e28..628b1fea9 100644 --- a/tests/integration/test_utils.go +++ b/tests/integration/test_utils.go @@ -129,6 +129,9 @@ func sortSlices(x, y interface{}) bool { if xEntity.Consumer != nil { xName += *xEntity.Consumer.ID } + if xEntity.ConsumerGroup != nil { + xName += *xEntity.ConsumerGroup.ID + } if yEntity.Route != nil { yName += *yEntity.Route.ID } @@ -138,6 +141,9 @@ func sortSlices(x, y interface{}) bool { if yEntity.Consumer != nil { yName += *yEntity.Consumer.ID } + if yEntity.ConsumerGroup != nil { + yName += *yEntity.ConsumerGroup.ID + } } return xName < yName } diff --git a/tests/integration/testdata/sync/023-consumer-groups-scoped-plugins/kong3x.yaml b/tests/integration/testdata/sync/023-consumer-groups-scoped-plugins/kong3x.yaml new file mode 100644 index 000000000..3f0af53ff --- /dev/null +++ b/tests/integration/testdata/sync/023-consumer-groups-scoped-plugins/kong3x.yaml @@ -0,0 +1,76 @@ +_format_version: "3.0" +services: +- connect_timeout: 60000 + id: 58076db2-28b6-423b-ba39-a797193017f7 + host: mockbin.org + name: svc1 + port: 80 + protocol: http + read_timeout: 60000 + retries: 5 + routes: + - name: r1 + id: 87b6a97e-f3f7-4c47-857a-7464cb9e202b + https_redirect_status_code: 301 + paths: + - /r1 + +consumer_groups: +- id: 5bcbd3a7-030b-4310-bd1d-2721ff85d236 + name: silver + consumers: + - username: bar + - username: baz +- id: 77e6691d-67c0-446a-9401-27be2b141aae + name: gold + consumers: + - username: foo +consumers: +- username: foo + keyauth_credentials: + - key: i-am-special + groups: + - name: gold +- username: bar + keyauth_credentials: + - key: i-am-not-so-special + groups: + - name: silver +- username: baz + keyauth_credentials: + - key: i-am-just-average +plugins: +- name: key-auth + enabled: true + protocols: + - http + - https +- name: rate-limiting-advanced + config: + namespace: silver + limit: + - 5 + retry_after_jitter_max: 1 + window_size: + - 60 + window_type: sliding +- name: rate-limiting-advanced + consumer_group: silver + config: + namespace: silver + limit: + - 7 + retry_after_jitter_max: 1 + window_size: + - 60 + window_type: sliding +- name: rate-limiting-advanced + consumer_group: gold + config: + namespace: gold + limit: + - 10 + retry_after_jitter_max: 1 + window_size: + - 60 + window_type: sliding diff --git a/types/plugin.go b/types/plugin.go index 44339c0f2..fc3a59525 100644 --- a/types/plugin.go +++ b/types/plugin.go @@ -26,6 +26,9 @@ func stripPluginReferencesName(plugin *state.Plugin) { if plugin.Plugin.Consumer != nil && plugin.Plugin.Consumer.Username != nil { plugin.Plugin.Consumer.Username = nil } + if plugin.Plugin.ConsumerGroup != nil && plugin.Plugin.ConsumerGroup.Name != nil { + plugin.Plugin.ConsumerGroup.Name = nil + } } func pluginFromStruct(arg crud.Event) *state.Plugin { @@ -111,9 +114,10 @@ func (d *pluginDiffer) Deletes(handler func(crud.Event) error) error { func (d *pluginDiffer) deletePlugin(plugin *state.Plugin) (*crud.Event, error) { plugin = &state.Plugin{Plugin: *plugin.DeepCopy()} name := *plugin.Name - serviceID, routeID, consumerID := foreignNames(plugin) - _, err := d.targetState.Plugins.GetByProp(name, serviceID, routeID, - consumerID) + serviceID, routeID, consumerID, consumerGroupID := foreignNames(plugin) + _, err := d.targetState.Plugins.GetByProp( + name, serviceID, routeID, consumerID, consumerGroupID, + ) if errors.Is(err, state.ErrNotFound) { return &crud.Event{ Op: crud.Delete, @@ -151,9 +155,10 @@ func (d *pluginDiffer) CreateAndUpdates(handler func(crud.Event) error) error { func (d *pluginDiffer) createUpdatePlugin(plugin *state.Plugin) (*crud.Event, error) { plugin = &state.Plugin{Plugin: *plugin.DeepCopy()} name := *plugin.Name - serviceID, routeID, consumerID := foreignNames(plugin) - currentPlugin, err := d.currentState.Plugins.GetByProp(name, - serviceID, routeID, consumerID) + serviceID, routeID, consumerID, consumerGroupID := foreignNames(plugin) + currentPlugin, err := d.currentState.Plugins.GetByProp( + name, serviceID, routeID, consumerID, consumerGroupID, + ) if errors.Is(err, state.ErrNotFound) { // plugin not present, create it @@ -181,7 +186,7 @@ func (d *pluginDiffer) createUpdatePlugin(plugin *state.Plugin) (*crud.Event, er return nil, nil } -func foreignNames(p *state.Plugin) (serviceID, routeID, consumerID string) { +func foreignNames(p *state.Plugin) (serviceID, routeID, consumerID, consumerGroupID string) { if p == nil { return } @@ -194,5 +199,8 @@ func foreignNames(p *state.Plugin) (serviceID, routeID, consumerID string) { if p.Consumer != nil && p.Consumer.ID != nil { consumerID = *p.Consumer.ID } + if p.ConsumerGroup != nil && p.ConsumerGroup.ID != nil { + consumerGroupID = *p.ConsumerGroup.ID + } return } diff --git a/utils/utils.go b/utils/utils.go index e00f37c6b..bdb81864d 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -148,6 +148,16 @@ func GetConsumerReference(c kong.Consumer) *kong.Consumer { return consumer } +// GetConsumerGroupReference returns a name+ID only copy of the input consumer-group, +// for use in references from other objects +func GetConsumerGroupReference(c kong.ConsumerGroup) *kong.ConsumerGroup { + consumerGroup := &kong.ConsumerGroup{ID: kong.String(*c.ID)} + if c.Name != nil { + consumerGroup.Name = kong.String(*c.Name) + } + return consumerGroup +} + // GetServiceReference returns a name+ID only copy of the input service, // for use in references from other objects func GetServiceReference(s kong.Service) *kong.Service {