diff --git a/cmd/serve.go b/cmd/serve.go index 2e5edb7e3..6e5415b1d 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -116,6 +116,8 @@ func writePidFile(pidFile string) error { func serve(ctx context.Context) error { var resolverOpts []graphapi.Option + config.AppConfig.LoadBalancerLimit = viper.GetInt("load-balancer-limit") + if serveDevMode { enablePlayground = true config.AppConfig.Logging.Debug = true diff --git a/internal/config/config.go b/internal/config/config.go index a5569cee6..4c7c9f1d8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -17,17 +17,18 @@ import ( // AppConfig stores all the config values for our application var AppConfig struct { - OIDC echojwtx.AuthConfig `mapstructure:"oidc"` - OIDCClient OIDCClientConfig `mapstructure:"oidc"` - CRDB crdbx.Config - Logging loggingx.Config - Server echox.Config - Tracing otelx.Config - Events events.Config - Permissions permissions.Config - Metadata MetadataConfig - RestrictedPorts []int - Supergraph SupergraphConfig + OIDC echojwtx.AuthConfig `mapstructure:"oidc"` + OIDCClient OIDCClientConfig `mapstructure:"oidc"` + CRDB crdbx.Config + Logging loggingx.Config + Server echox.Config + Tracing otelx.Config + Events events.Config + Permissions permissions.Config + LoadBalancerLimit int + Metadata MetadataConfig + RestrictedPorts []int + Supergraph SupergraphConfig } // MetadataConfig stores the configuration for metadata diff --git a/internal/graphapi/errors.go b/internal/graphapi/errors.go index a6f33a4a9..0e9b1e55b 100644 --- a/internal/graphapi/errors.go +++ b/internal/graphapi/errors.go @@ -17,4 +17,7 @@ var ( // ErrInternalServerError is returned when an internal error occurs. ErrInternalServerError = errors.New("internal server error") + + // ErrLoadBalancerLimitReached is returned when the load balancer limit has been reached for an owner. + ErrLoadBalancerLimitReached = errors.New("load balancer limit reached") ) diff --git a/internal/graphapi/loadbalancer.resolvers.go b/internal/graphapi/loadbalancer.resolvers.go index 3dfa6d9bd..57e8cfc02 100644 --- a/internal/graphapi/loadbalancer.resolvers.go +++ b/internal/graphapi/loadbalancer.resolvers.go @@ -14,7 +14,9 @@ import ( "go.infratographer.com/load-balancer-api/pkg/metadata" + "go.infratographer.com/load-balancer-api/internal/config" "go.infratographer.com/load-balancer-api/internal/ent/generated" + "go.infratographer.com/load-balancer-api/internal/ent/generated/loadbalancer" "go.infratographer.com/load-balancer-api/internal/ent/generated/port" "go.infratographer.com/load-balancer-api/internal/ent/generated/predicate" ) @@ -25,6 +27,18 @@ func (r *mutationResolver) LoadBalancerCreate(ctx context.Context, input generat return nil, err } + if config.AppConfig.LoadBalancerLimit > 0 { + count, err := r.client.LoadBalancer.Query().Where(predicate.LoadBalancer(loadbalancer.OwnerIDEQ(input.OwnerID))).Count(ctx) + + if err != nil { + r.logger.Errorw("failed to query loadbalancer count", "error", err) + } + + if count >= config.AppConfig.LoadBalancerLimit { + return nil, ErrLoadBalancerLimitReached + } + } + lb, err := r.client.LoadBalancer.Create().SetInput(input).Save(ctx) if err != nil { if generated.IsValidationError(err) { diff --git a/internal/graphapi/loadbalancer_test.go b/internal/graphapi/loadbalancer_test.go index abd91dbcb..f4249724c 100644 --- a/internal/graphapi/loadbalancer_test.go +++ b/internal/graphapi/loadbalancer_test.go @@ -12,6 +12,7 @@ import ( "go.infratographer.com/permissions-api/pkg/permissions/mockpermissions" "go.infratographer.com/x/gidx" + "go.infratographer.com/load-balancer-api/internal/config" ent "go.infratographer.com/load-balancer-api/internal/ent/generated" "go.infratographer.com/load-balancer-api/internal/graphclient" "go.infratographer.com/load-balancer-api/internal/testutils" @@ -158,6 +159,65 @@ func TestCreate_loadBalancer(t *testing.T) { } } +func TestCreate_loadBalancer_limit(t *testing.T) { + ctx := context.Background() + perms := new(mockpermissions.MockPermissions) + perms.On("CreateAuthRelationships", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + ctx = perms.ContextWithHandler(ctx) + + // Permit request + ctx = context.WithValue(ctx, permissions.CheckerCtxKey, permissions.DefaultAllowChecker) + + prov := (&testutils.ProviderBuilder{}).MustNew(ctx) + locationID := gidx.MustNewID(locationPrefix) + name := gofakeit.DomainName() + + config.AppConfig.LoadBalancerLimit = 3 + + testCases := []struct { + TestName string + lbCount int + Input graphclient.CreateLoadBalancerInput + errorMsg string + }{ + { + TestName: "creates loadbalancers - under limit", + Input: graphclient.CreateLoadBalancerInput{Name: name, ProviderID: prov.ID, OwnerID: gidx.MustNewID(ownerPrefix), LocationID: locationID}, + lbCount: 2, + }, + { + TestName: "fails to create loadbalancers - over limit", + Input: graphclient.CreateLoadBalancerInput{Name: name, ProviderID: prov.ID, OwnerID: gidx.MustNewID(ownerPrefix), LocationID: locationID}, + lbCount: 5, + errorMsg: "load balancer limit reached", + }, + } + + for _, tt := range testCases { + t.Run(tt.TestName, func(t *testing.T) { + tt := tt + t.Parallel() + var err error + + for i := 1; i < tt.lbCount; i++ { + _, err = graphTestClient().LoadBalancerCreate(ctx, tt.Input) + if err != nil { + return + } + } + + if tt.errorMsg != "" { + require.Error(t, err) + assert.ErrorContains(t, err, tt.errorMsg) + return + } + + require.NoError(t, err) + }) + } +} + func TestUpdate_loadBalancer(t *testing.T) { ctx := context.Background() perms := new(mockpermissions.MockPermissions)