From 438d3c2416a0ed6abe3407e57b262b20193bab83 Mon Sep 17 00:00:00 2001 From: kakcy Date: Tue, 28 Jan 2025 18:43:23 +0900 Subject: [PATCH 1/6] refactor: changed transaction handling related to Experiment package --- pkg/experiment/api/api.go | 31 +-- pkg/experiment/api/api_test.go | 12 +- pkg/experiment/api/experiment.go | 81 ++------ pkg/experiment/api/experiment_test.go | 188 +++++++++++-------- pkg/experiment/storage/v2/experiment.go | 16 +- pkg/experiment/storage/v2/experiment_test.go | 55 ++++-- 6 files changed, 205 insertions(+), 178 deletions(-) diff --git a/pkg/experiment/api/api.go b/pkg/experiment/api/api.go index 8148550f1..bec6fa6fa 100644 --- a/pkg/experiment/api/api.go +++ b/pkg/experiment/api/api.go @@ -25,6 +25,7 @@ import ( accountclient "github.com/bucketeer-io/bucketeer/pkg/account/client" autoopsclient "github.com/bucketeer-io/bucketeer/pkg/autoops/client" + storage "github.com/bucketeer-io/bucketeer/pkg/experiment/storage/v2" featureclient "github.com/bucketeer-io/bucketeer/pkg/feature/client" "github.com/bucketeer-io/bucketeer/pkg/locale" "github.com/bucketeer-io/bucketeer/pkg/log" @@ -50,13 +51,14 @@ func WithLogger(l *zap.Logger) Option { } type experimentService struct { - featureClient featureclient.Client - accountClient accountclient.Client - autoOpsClient autoopsclient.Client - mysqlClient mysql.Client - publisher publisher.Publisher - opts *options - logger *zap.Logger + featureClient featureclient.Client + accountClient accountclient.Client + autoOpsClient autoopsclient.Client + mysqlClient mysql.Client + experimentStorage storage.ExperimentStorage + publisher publisher.Publisher + opts *options + logger *zap.Logger } func NewExperimentService( @@ -74,13 +76,14 @@ func NewExperimentService( opt(dopts) } return &experimentService{ - featureClient: featureClient, - accountClient: accountClient, - autoOpsClient: autoOpsClient, - mysqlClient: mysqlClient, - publisher: publisher, - opts: dopts, - logger: dopts.logger.Named("api"), + featureClient: featureClient, + accountClient: accountClient, + autoOpsClient: autoOpsClient, + mysqlClient: mysqlClient, + experimentStorage: storage.NewExperimentStorage(mysqlClient), + publisher: publisher, + opts: dopts, + logger: dopts.logger.Named("api"), } } diff --git a/pkg/experiment/api/api_test.go b/pkg/experiment/api/api_test.go index 6f1cdc4a3..3bdf8b218 100644 --- a/pkg/experiment/api/api_test.go +++ b/pkg/experiment/api/api_test.go @@ -26,6 +26,7 @@ import ( accountclientmock "github.com/bucketeer-io/bucketeer/pkg/account/client/mock" autoopsclientmock "github.com/bucketeer-io/bucketeer/pkg/autoops/client/mock" + storagemock "github.com/bucketeer-io/bucketeer/pkg/experiment/storage/v2/mock" featureclientmock "github.com/bucketeer-io/bucketeer/pkg/feature/client/mock" publishermock "github.com/bucketeer-io/bucketeer/pkg/pubsub/publisher/mock" "github.com/bucketeer-io/bucketeer/pkg/rpc" @@ -111,8 +112,15 @@ func createExperimentService(c *gomock.Controller, specifiedEnvironmentId *strin mysqlClient := mysqlmock.NewMockClient(c) p := publishermock.NewMockPublisher(c) p.EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - es := NewExperimentService(featureClientMock, accountClientMock, autoOpsClientMock, mysqlClient, p) - return es.(*experimentService) + return &experimentService{ + featureClient: featureClientMock, + accountClient: accountClientMock, + autoOpsClient: autoOpsClientMock, + mysqlClient: mysqlClient, + experimentStorage: storagemock.NewMockExperimentStorage(c), + publisher: p, + logger: zap.NewNop().Named("api"), + } } func createContextWithToken() context.Context { diff --git a/pkg/experiment/api/experiment.go b/pkg/experiment/api/experiment.go index 836458fd9..83c22f15d 100644 --- a/pkg/experiment/api/experiment.go +++ b/pkg/experiment/api/experiment.go @@ -58,8 +58,7 @@ func (s *experimentService) GetExperiment( if err := validateGetExperimentRequest(req, localizer); err != nil { return nil, err } - experimentStorage := v2es.NewExperimentStorage(s.mysqlClient) - experiment, err := experimentStorage.GetExperiment(ctx, req.Id, req.EnvironmentId) + experiment, err := s.experimentStorage.GetExperiment(ctx, req.Id, req.EnvironmentId) if err != nil { if errors.Is(err, v2es.ErrExperimentNotFound) { dt, err := statusNotFound.WithDetails(&errdetails.LocalizedMessage{ @@ -168,8 +167,7 @@ func (s *experimentService) ListExperiments( } return nil, dt.Err() } - experimentStorage := v2es.NewExperimentStorage(s.mysqlClient) - experiments, nextCursor, totalCount, summary, err := experimentStorage.ListExperiments( + experiments, nextCursor, totalCount, summary, err := s.experimentStorage.ListExperiments( ctx, whereParts, orders, @@ -332,25 +330,7 @@ func (s *experimentService) CreateExperiment( } return nil, dt.Err() } - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - experimentStorage := v2es.NewExperimentStorage(tx) + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { handler, err := command.NewExperimentCommandHandler( editor, experiment, @@ -363,7 +343,7 @@ func (s *experimentService) CreateExperiment( if err := handler.Handle(ctx, req.Command); err != nil { return err } - return experimentStorage.CreateExperiment(ctx, experiment, req.EnvironmentId) + return s.experimentStorage.CreateExperiment(contextWithTx, experiment, req.EnvironmentId) }) if err != nil { if errors.Is(err, v2es.ErrExperimentAlreadyExists) { @@ -477,7 +457,6 @@ func (s *experimentService) createExperimentNoCommand( return statusGoalTypeMismatch.Err() } } - experimentStorage := v2es.NewExperimentStorage(s.mysqlClient) prev := &domain.Experiment{} if err = copier.Copy(prev, experiment); err != nil { return err @@ -513,7 +492,7 @@ func (s *experimentService) createExperimentNoCommand( if err != nil { return err } - return experimentStorage.CreateExperiment(ctxWithTx, experiment, req.EnvironmentId) + return s.experimentStorage.CreateExperiment(ctxWithTx, experiment, req.EnvironmentId) }) if err != nil { if errors.Is(err, v2es.ErrGoalNotFound) { @@ -693,27 +672,9 @@ func (s *experimentService) UpdateExperiment( if err := validateUpdateExperimentRequest(req, localizer); err != nil { return nil, err } - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } var experimentPb *proto.Experiment - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - experimentStorage := v2es.NewExperimentStorage(tx) - experiment, err := experimentStorage.GetExperiment(ctx, req.Id, req.EnvironmentId) + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { + experiment, err := s.experimentStorage.GetExperiment(contextWithTx, req.Id, req.EnvironmentId) if err != nil { return err } @@ -737,7 +698,7 @@ func (s *experimentService) UpdateExperiment( ) return err } - return experimentStorage.UpdateExperiment(ctx, experiment, req.EnvironmentId) + return s.experimentStorage.UpdateExperiment(contextWithTx, experiment, req.EnvironmentId) } if req.ChangeNameCommand != nil { if err = handler.Handle(ctx, req.ChangeNameCommand); err != nil { @@ -764,7 +725,7 @@ func (s *experimentService) UpdateExperiment( } } experimentPb = experiment.Experiment - return experimentStorage.UpdateExperiment(ctx, experiment, req.EnvironmentId) + return s.experimentStorage.UpdateExperiment(contextWithTx, experiment, req.EnvironmentId) }) if err != nil { if errors.Is(err, v2es.ErrExperimentNotFound) || errors.Is(err, v2es.ErrExperimentUnexpectedAffectedRows) { @@ -1209,26 +1170,8 @@ func (s *experimentService) updateExperiment( id, environmentId string, localizer locale.Localizer, ) error { - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return statusInternal.Err() - } - return dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - experimentStorage := v2es.NewExperimentStorage(tx) - experiment, err := experimentStorage.GetExperiment(ctx, id, environmentId) + err := s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { + experiment, err := s.experimentStorage.GetExperiment(contextWithTx, id, environmentId) if err != nil { s.logger.Error( "Failed to get experiment", @@ -1253,7 +1196,7 @@ func (s *experimentService) updateExperiment( ) return err } - return experimentStorage.UpdateExperiment(ctx, experiment, environmentId) + return s.experimentStorage.UpdateExperiment(contextWithTx, experiment, environmentId) }) if err != nil { if errors.Is(err, v2es.ErrExperimentNotFound) || errors.Is(err, v2es.ErrExperimentUnexpectedAffectedRows) { diff --git a/pkg/experiment/api/experiment_test.go b/pkg/experiment/api/experiment_test.go index ca52baf52..46cf7f896 100644 --- a/pkg/experiment/api/experiment_test.go +++ b/pkg/experiment/api/experiment_test.go @@ -27,7 +27,9 @@ import ( gstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/wrapperspb" + "github.com/bucketeer-io/bucketeer/pkg/experiment/domain" v2es "github.com/bucketeer-io/bucketeer/pkg/experiment/storage/v2" + storagemock "github.com/bucketeer-io/bucketeer/pkg/experiment/storage/v2/mock" "github.com/bucketeer-io/bucketeer/pkg/locale" "github.com/bucketeer-io/bucketeer/pkg/storage/v2/mysql" mysqlmock "github.com/bucketeer-io/bucketeer/pkg/storage/v2/mysql/mock" @@ -68,11 +70,7 @@ func TestGetExperimentMySQL(t *testing.T) { }, { setup: func(s *experimentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(mysql.ErrNoRows) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().GetExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, v2es.ErrExperimentNotFound) }, id: "id-0", environmentId: "ns0", @@ -80,11 +78,9 @@ func TestGetExperimentMySQL(t *testing.T) { }, { setup: func(s *experimentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().GetExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(&domain.Experiment{ + Experiment: &experimentproto.Experiment{Id: "id-1"}, + }, nil) }, id: "id-1", environmentId: "ns0", @@ -140,18 +136,11 @@ func TestListExperimentsMySQL(t *testing.T) { orgRole: toPtr(accountproto.AccountV2_Role_Organization_MEMBER), envRole: toPtr(accountproto.AccountV2_Role_Environment_VIEWER), setup: func(s *experimentService) { - rows := mysqlmock.NewMockRows(mockController) - rows.EXPECT().Close().Return(nil) - rows.EXPECT().Next().Return(false) - rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(rows, nil) - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().ListExperiments( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return([]*experimentproto.Experiment{ + {Id: "id-1"}, + }, 0, int64(0), nil, nil) }, req: &experimentproto.ListExperimentsRequest{FeatureId: "id-0", EnvironmentId: "ns0"}, expectedErr: nil, @@ -181,15 +170,19 @@ func TestCreateExperimentMySQL(t *testing.T) { }{ { setup: func(s *experimentService) { + // for goal storage row := mysqlmock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(nil) s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(nil) + + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().CreateExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) }, input: &experimentproto.CreateExperimentRequest{ Command: &experimentproto.CreateExperimentCommand{ @@ -291,7 +284,10 @@ func TestCreateExperimentNoCommandMySQL(t *testing.T) { setup: func(s *experimentService) { s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( gomock.Any(), gomock.Any(), - ).Return(nil) + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().CreateExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) }, input: &experimentproto.CreateExperimentRequest{ FeatureId: "fid", @@ -460,10 +456,12 @@ func TestUpdateExperimentMySQL(t *testing.T) { }, { setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(v2es.ErrExperimentNotFound) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().GetExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, v2es.ErrExperimentNotFound) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(v2es.ErrExperimentNotFound) }, req: &experimentproto.UpdateExperimentRequest{ Id: "id-0", @@ -476,10 +474,17 @@ func TestUpdateExperimentMySQL(t *testing.T) { }, { setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().GetExperiment( gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(nil) + ).Return(&domain.Experiment{ + Experiment: &experimentproto.Experiment{Id: "id-1"}, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().UpdateExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) }, req: &experimentproto.UpdateExperimentRequest{ Id: "id-1", @@ -638,10 +643,12 @@ func TestStartExperimentMySQL(t *testing.T) { { desc: "error not found", setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(v2es.ErrExperimentNotFound) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().GetExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, v2es.ErrExperimentNotFound) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(v2es.ErrExperimentNotFound) }, req: &experimentproto.StartExperimentRequest{ Id: "noop", @@ -653,10 +660,15 @@ func TestStartExperimentMySQL(t *testing.T) { { desc: "success", setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(nil) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().GetExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(&domain.Experiment{ + Experiment: &experimentproto.Experiment{Id: "id-1"}, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().UpdateExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) }, req: &experimentproto.StartExperimentRequest{ Id: "eid", @@ -722,10 +734,12 @@ func TestFinishExperimentMySQL(t *testing.T) { { desc: "error not found", setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(v2es.ErrExperimentNotFound) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().GetExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, v2es.ErrExperimentNotFound) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(v2es.ErrExperimentNotFound) }, req: &experimentproto.FinishExperimentRequest{ Id: "noop", @@ -737,10 +751,15 @@ func TestFinishExperimentMySQL(t *testing.T) { { desc: "success", setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(nil) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().GetExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(&domain.Experiment{ + Experiment: &experimentproto.Experiment{Id: "id-1"}, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().UpdateExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) }, req: &experimentproto.FinishExperimentRequest{ Id: "eid", @@ -802,10 +821,12 @@ func TestStopExperimentMySQL(t *testing.T) { }, { setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(v2es.ErrExperimentNotFound) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().GetExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, v2es.ErrExperimentNotFound) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(v2es.ErrExperimentNotFound) }, req: &experimentproto.StopExperimentRequest{ Id: "id-0", @@ -816,10 +837,15 @@ func TestStopExperimentMySQL(t *testing.T) { }, { setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(nil) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().GetExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(&domain.Experiment{ + Experiment: &experimentproto.Experiment{Id: "id-1"}, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().UpdateExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) }, req: &experimentproto.StopExperimentRequest{ Id: "id-1", @@ -879,10 +905,12 @@ func TestArchiveExperimentMySQL(t *testing.T) { }, { setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(v2es.ErrExperimentNotFound) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().GetExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, v2es.ErrExperimentNotFound) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(v2es.ErrExperimentNotFound) }, req: &experimentproto.ArchiveExperimentRequest{ Id: "id-0", @@ -893,10 +921,15 @@ func TestArchiveExperimentMySQL(t *testing.T) { }, { setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(nil) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().GetExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(&domain.Experiment{ + Experiment: &experimentproto.Experiment{Id: "id-1"}, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().UpdateExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) }, req: &experimentproto.ArchiveExperimentRequest{ Id: "id-1", @@ -956,10 +989,12 @@ func TestDeleteExperimentMySQL(t *testing.T) { }, { setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(v2es.ErrExperimentNotFound) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().GetExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, v2es.ErrExperimentNotFound) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(v2es.ErrExperimentNotFound) }, req: &experimentproto.DeleteExperimentRequest{ Id: "id-0", @@ -970,10 +1005,15 @@ func TestDeleteExperimentMySQL(t *testing.T) { }, { setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(nil) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().GetExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(&domain.Experiment{ + Experiment: &experimentproto.Experiment{Id: "id-1"}, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.experimentStorage.(*storagemock.MockExperimentStorage).EXPECT().UpdateExperiment(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) }, req: &experimentproto.DeleteExperimentRequest{ Id: "id-1", diff --git a/pkg/experiment/storage/v2/experiment.go b/pkg/experiment/storage/v2/experiment.go index f9c1e4f2a..5714bf514 100644 --- a/pkg/experiment/storage/v2/experiment.go +++ b/pkg/experiment/storage/v2/experiment.go @@ -57,11 +57,11 @@ type ExperimentStorage interface { } type experimentStorage struct { - qe mysql.QueryExecer + client mysql.Client } -func NewExperimentStorage(qe mysql.QueryExecer) ExperimentStorage { - return &experimentStorage{qe: qe} +func NewExperimentStorage(client mysql.Client) ExperimentStorage { + return &experimentStorage{client: client} } func (s *experimentStorage) CreateExperiment( @@ -69,7 +69,7 @@ func (s *experimentStorage) CreateExperiment( e *domain.Experiment, environmentId string, ) error { - _, err := s.qe.ExecContext( + _, err := s.client.Qe(ctx).ExecContext( ctx, insertExperimentSQL, e.Id, @@ -107,7 +107,7 @@ func (s *experimentStorage) UpdateExperiment( e *domain.Experiment, environmentId string, ) error { - result, err := s.qe.ExecContext( + result, err := s.client.Qe(ctx).ExecContext( ctx, updateExperimentSQL, e.GoalId, @@ -150,7 +150,7 @@ func (s *experimentStorage) GetExperiment( ) (*domain.Experiment, error) { experiment := proto.Experiment{} var status int32 - err := s.qe.QueryRowContext( + err := s.client.Qe(ctx).QueryRowContext( ctx, selectExperimentSQL, id, @@ -197,7 +197,7 @@ func (s *experimentStorage) ListExperiments( orderBySQL := mysql.ConstructOrderBySQLString(orders) limitOffsetSQL := mysql.ConstructLimitOffsetSQLString(limit, offset) query := fmt.Sprintf(selectExperimentsSQL, whereSQL, orderBySQL, limitOffsetSQL) - rows, err := s.qe.QueryContext(ctx, query, whereArgs...) + rows, err := s.client.Qe(ctx).QueryContext(ctx, query, whereArgs...) if err != nil { return nil, 0, 0, nil, err } @@ -241,7 +241,7 @@ func (s *experimentStorage) ListExperiments( var totalCount int64 summary := &proto.ListExperimentsResponse_Summary{} countQuery := fmt.Sprintf(countExperimentSQL, whereSQL) - err = s.qe.QueryRowContext(ctx, countQuery, whereArgs...).Scan( + err = s.client.Qe(ctx).QueryRowContext(ctx, countQuery, whereArgs...).Scan( &totalCount, &summary.TotalWaitingCount, &summary.TotalRunningCount, diff --git a/pkg/experiment/storage/v2/experiment_test.go b/pkg/experiment/storage/v2/experiment_test.go index f04ea9472..4eab75210 100644 --- a/pkg/experiment/storage/v2/experiment_test.go +++ b/pkg/experiment/storage/v2/experiment_test.go @@ -33,7 +33,7 @@ func TestNewExperimentStorage(t *testing.T) { t.Parallel() mockController := gomock.NewController(t) defer mockController.Finish() - db := NewExperimentStorage(mock.NewMockQueryExecer(mockController)) + db := NewExperimentStorage(mock.NewMockClient(mockController)) assert.IsType(t, &experimentStorage{}, db) } @@ -50,7 +50,11 @@ func TestCreateExperiment(t *testing.T) { }{ { setup: func(s *experimentStorage) { - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, mysql.ErrDuplicateEntry) }, @@ -62,7 +66,11 @@ func TestCreateExperiment(t *testing.T) { }, { setup: func(s *experimentStorage) { - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, nil) }, @@ -100,7 +108,12 @@ func TestUpdateExperiment(t *testing.T) { setup: func(s *experimentStorage) { result := mock.NewMockResult(mockController) result.EXPECT().RowsAffected().Return(int64(0), nil) - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(result, nil) }, @@ -114,7 +127,11 @@ func TestUpdateExperiment(t *testing.T) { setup: func(s *experimentStorage) { result := mock.NewMockResult(mockController) result.EXPECT().RowsAffected().Return(int64(1), nil) - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(result, nil) }, @@ -151,7 +168,11 @@ func TestGetExperiment(t *testing.T) { setup: func(s *experimentStorage) { row := mock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(mysql.ErrNoRows) - s.qe.(*mock.MockQueryExecer).EXPECT().QueryRowContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -164,7 +185,11 @@ func TestGetExperiment(t *testing.T) { setup: func(s *experimentStorage) { row := mock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(nil) - s.qe.(*mock.MockQueryExecer).EXPECT().QueryRowContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -202,7 +227,11 @@ func TestListExperiments(t *testing.T) { }{ { setup: func(s *experimentStorage) { - s.qe.(*mock.MockQueryExecer).EXPECT().QueryContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe) + qe.EXPECT().QueryContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, errors.New("error")) }, @@ -220,12 +249,16 @@ func TestListExperiments(t *testing.T) { rows.EXPECT().Close().Return(nil) rows.EXPECT().Next().Return(false) rows.EXPECT().Err().Return(nil) - s.qe.(*mock.MockQueryExecer).EXPECT().QueryContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe( + gomock.Any(), + ).Return(qe).AnyTimes() + qe.EXPECT().QueryContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(rows, nil) row := mock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(nil) - s.qe.(*mock.MockQueryExecer).EXPECT().QueryRowContext( + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -262,5 +295,5 @@ func TestListExperiments(t *testing.T) { func newExperimentStorageWithMock(t *testing.T, mockController *gomock.Controller) *experimentStorage { t.Helper() - return &experimentStorage{mock.NewMockQueryExecer(mockController)} + return &experimentStorage{mock.NewMockClient(mockController)} } From ee61df55ad0477e9143a516e1b4f43bc2e3efb1b Mon Sep 17 00:00:00 2001 From: kakcy Date: Wed, 29 Jan 2025 18:34:15 +0900 Subject: [PATCH 2/6] refactor: changed transaction handling related to Experiment package --- pkg/experiment/api/api.go | 2 + pkg/experiment/api/api_test.go | 1 + pkg/experiment/api/experiment_test.go | 9 +- pkg/experiment/api/goal.go | 106 +++------------- pkg/experiment/api/goal_test.go | 162 ++++++++++++++++--------- pkg/experiment/storage/v2/goal.go | 18 +-- pkg/experiment/storage/v2/goal_test.go | 38 ++++-- 7 files changed, 163 insertions(+), 173 deletions(-) diff --git a/pkg/experiment/api/api.go b/pkg/experiment/api/api.go index bec6fa6fa..786127370 100644 --- a/pkg/experiment/api/api.go +++ b/pkg/experiment/api/api.go @@ -56,6 +56,7 @@ type experimentService struct { autoOpsClient autoopsclient.Client mysqlClient mysql.Client experimentStorage storage.ExperimentStorage + goalStorage storage.GoalStorage publisher publisher.Publisher opts *options logger *zap.Logger @@ -81,6 +82,7 @@ func NewExperimentService( autoOpsClient: autoOpsClient, mysqlClient: mysqlClient, experimentStorage: storage.NewExperimentStorage(mysqlClient), + goalStorage: storage.NewGoalStorage(mysqlClient), publisher: publisher, opts: dopts, logger: dopts.logger.Named("api"), diff --git a/pkg/experiment/api/api_test.go b/pkg/experiment/api/api_test.go index 3bdf8b218..ea5bcc294 100644 --- a/pkg/experiment/api/api_test.go +++ b/pkg/experiment/api/api_test.go @@ -118,6 +118,7 @@ func createExperimentService(c *gomock.Controller, specifiedEnvironmentId *strin autoOpsClient: autoOpsClientMock, mysqlClient: mysqlClient, experimentStorage: storagemock.NewMockExperimentStorage(c), + goalStorage: storagemock.NewMockGoalStorage(c), publisher: p, logger: zap.NewNop().Named("api"), } diff --git a/pkg/experiment/api/experiment_test.go b/pkg/experiment/api/experiment_test.go index 46cf7f896..7ccd3fd3e 100644 --- a/pkg/experiment/api/experiment_test.go +++ b/pkg/experiment/api/experiment_test.go @@ -170,13 +170,7 @@ func TestCreateExperimentMySQL(t *testing.T) { }{ { setup: func(s *experimentService) { - // for goal storage - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) - + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().GetGoal(gomock.Any(), gomock.Any(), gomock.Any()).Return(&domain.Goal{}, nil) s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( gomock.Any(), gomock.Any(), ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { @@ -282,6 +276,7 @@ func TestCreateExperimentNoCommandMySQL(t *testing.T) { { desc: "success", setup: func(s *experimentService) { + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().GetGoal(gomock.Any(), gomock.Any(), gomock.Any()).Return(&domain.Goal{}, nil) s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( gomock.Any(), gomock.Any(), ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { diff --git a/pkg/experiment/api/goal.go b/pkg/experiment/api/goal.go index 6a1a2c03d..2fba01bb4 100644 --- a/pkg/experiment/api/goal.go +++ b/pkg/experiment/api/goal.go @@ -97,8 +97,7 @@ func (s *experimentService) getGoalMySQL( ctx context.Context, goalID, environmentId string, ) (*domain.Goal, error) { - goalStorage := v2es.NewGoalStorage(s.mysqlClient) - goal, err := goalStorage.GetGoal(ctx, goalID, environmentId) + goal, err := s.goalStorage.GetGoal(ctx, goalID, environmentId) if err != nil { s.logger.Error( "Failed to get goal", @@ -164,8 +163,7 @@ func (s *experimentService) ListGoals( if req.IsInUseStatus != nil { isInUseStatus = &req.IsInUseStatus.Value } - goalStorage := v2es.NewGoalStorage(s.mysqlClient) - goals, nextCursor, totalCount, err := goalStorage.ListGoals( + goals, nextCursor, totalCount, err := s.goalStorage.ListGoals( ctx, whereParts, orders, @@ -320,25 +318,7 @@ func (s *experimentService) CreateGoal( } return nil, dt.Err() } - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - goalStorage := v2es.NewGoalStorage(tx) + err = s.mysqlClient.RunInTransactionV2(ctx, func(ctxWithTx context.Context, _ mysql.Transaction) error { handler, err := command.NewGoalCommandHandler(editor, goal, s.publisher, req.EnvironmentId) if err != nil { return err @@ -346,7 +326,7 @@ func (s *experimentService) CreateGoal( if err := handler.Handle(ctx, req.Command); err != nil { return err } - return goalStorage.CreateGoal(ctx, goal, req.EnvironmentId) + return s.goalStorage.CreateGoal(ctxWithTx, goal, req.EnvironmentId) }) if err != nil { if errors.Is(err, v2es.ErrGoalAlreadyExists) { @@ -407,25 +387,7 @@ func (s *experimentService) createGoalNoCommand( } return nil, dt.Err() } - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - goalStorage := v2es.NewGoalStorage(tx) + err = s.mysqlClient.RunInTransactionV2(ctx, func(ctxWithTx context.Context, _ mysql.Transaction) error { prev := &domain.Goal{} e, err := domainevent.NewEvent( editor, @@ -451,7 +413,7 @@ func (s *experimentService) createGoalNoCommand( if err := s.publisher.Publish(ctx, e); err != nil { return err } - return goalStorage.CreateGoal(ctx, goal, req.EnvironmentId) + return s.goalStorage.CreateGoal(ctxWithTx, goal, req.EnvironmentId) }) if err != nil { if errors.Is(err, v2es.ErrGoalAlreadyExists) { @@ -625,28 +587,9 @@ func (s *experimentService) updateGoalNoCommand( if err != nil { return nil, err } - - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } var updatedGoal *proto.Goal - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - goalStorage := v2es.NewGoalStorage(tx) - goal, err := goalStorage.GetGoal(ctx, req.Id, req.EnvironmentId) + err = s.mysqlClient.RunInTransactionV2(ctx, func(ctxWithTx context.Context, _ mysql.Transaction) error { + goal, err := s.goalStorage.GetGoal(ctxWithTx, req.Id, req.EnvironmentId) if err != nil { return err } @@ -686,7 +629,7 @@ func (s *experimentService) updateGoalNoCommand( if err = s.publisher.Publish(ctx, e); err != nil { return err } - return goalStorage.UpdateGoal(ctx, updated, req.EnvironmentId) + return s.goalStorage.UpdateGoal(ctxWithTx, updated, req.EnvironmentId) }) if err != nil { if errors.Is(err, v2es.ErrGoalNotFound) || errors.Is(err, v2es.ErrGoalUnexpectedAffectedRows) { @@ -813,9 +756,8 @@ func (s *experimentService) DeleteGoal( } return nil, dt.Err() } - err = s.mysqlClient.RunInTransactionV2(ctx, func(ctxWithTx context.Context, tx mysql.Transaction) error { - experimentStorage := v2es.NewGoalStorage(s.mysqlClient) - goal, err := experimentStorage.GetGoal(ctxWithTx, req.Id, req.EnvironmentId) + err = s.mysqlClient.RunInTransactionV2(ctx, func(ctxWithTx context.Context, _ mysql.Transaction) error { + goal, err := s.goalStorage.GetGoal(ctxWithTx, req.Id, req.EnvironmentId) if err != nil { return err } @@ -837,7 +779,7 @@ func (s *experimentService) DeleteGoal( if err := s.publisher.Publish(ctxWithTx, e); err != nil { return err } - return experimentStorage.DeleteGoal(ctxWithTx, req.Id, req.EnvironmentId) + return s.goalStorage.DeleteGoal(ctxWithTx, req.Id, req.EnvironmentId) }) if err != nil { if errors.Is(err, v2es.ErrGoalNotFound) || errors.Is(err, v2es.ErrGoalUnexpectedAffectedRows) { @@ -869,26 +811,8 @@ func (s *experimentService) updateGoal( commands []command.Command, localizer locale.Localizer, ) error { - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return statusInternal.Err() - } - return dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - goalStorage := v2es.NewGoalStorage(tx) - goal, err := goalStorage.GetGoal(ctx, goalID, environmentId) + err := s.mysqlClient.RunInTransactionV2(ctx, func(ctxWithTx context.Context, _ mysql.Transaction) error { + goal, err := s.goalStorage.GetGoal(ctxWithTx, goalID, environmentId) if err != nil { return err } @@ -901,7 +825,7 @@ func (s *experimentService) updateGoal( return err } } - return goalStorage.UpdateGoal(ctx, goal, environmentId) + return s.goalStorage.UpdateGoal(ctxWithTx, goal, environmentId) }) if err != nil { if errors.Is(err, v2es.ErrGoalNotFound) || errors.Is(err, v2es.ErrGoalUnexpectedAffectedRows) { diff --git a/pkg/experiment/api/goal_test.go b/pkg/experiment/api/goal_test.go index 5b981f09a..aa8ed2fde 100644 --- a/pkg/experiment/api/goal_test.go +++ b/pkg/experiment/api/goal_test.go @@ -18,22 +18,23 @@ import ( "context" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - "google.golang.org/genproto/googleapis/rpc/errdetails" - "google.golang.org/grpc/metadata" - gstatus "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/wrapperspb" - autoopsclientmock "github.com/bucketeer-io/bucketeer/pkg/autoops/client/mock" + "github.com/bucketeer-io/bucketeer/pkg/experiment/domain" v2es "github.com/bucketeer-io/bucketeer/pkg/experiment/storage/v2" + storagemock "github.com/bucketeer-io/bucketeer/pkg/experiment/storage/v2/mock" "github.com/bucketeer-io/bucketeer/pkg/locale" "github.com/bucketeer-io/bucketeer/pkg/storage/v2/mysql" mysqlmock "github.com/bucketeer-io/bucketeer/pkg/storage/v2/mysql/mock" accountproto "github.com/bucketeer-io/bucketeer/proto/account" autoopsproto "github.com/bucketeer-io/bucketeer/proto/autoops" experimentproto "github.com/bucketeer-io/bucketeer/proto/experiment" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc/metadata" + gstatus "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/wrapperspb" ) func TestGetGoalMySQL(t *testing.T) { @@ -74,11 +75,9 @@ func TestGetGoalMySQL(t *testing.T) { { desc: "error: ErrNotFound", setup: func(s *experimentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(mysql.ErrNoRows) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().GetGoal( gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + ).Return(nil, v2es.ErrGoalNotFound) }, id: "id-0", environmentId: "ns0", @@ -97,16 +96,18 @@ func TestGetGoalMySQL(t *testing.T) { orgRole: toPtr(accountproto.AccountV2_Role_Organization_MEMBER), envRole: toPtr(accountproto.AccountV2_Role_Environment_VIEWER), setup: func(s *experimentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) s.autoOpsClient.(*autoopsclientmock.MockClient).EXPECT().ListAutoOpsRules( gomock.Any(), gomock.Any(), ).Return(&autoopsproto.ListAutoOpsRulesResponse{ AutoOpsRules: []*autoopsproto.AutoOpsRule{}, }, nil) + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().GetGoal( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(&domain.Goal{ + Goal: &experimentproto.Goal{ + Id: "id-1", + }, + }, nil) }, id: "id-1", environmentId: "ns0", @@ -165,18 +166,10 @@ func TestListGoalMySQL(t *testing.T) { orgRole: toPtr(accountproto.AccountV2_Role_Organization_MEMBER), envRole: toPtr(accountproto.AccountV2_Role_Environment_VIEWER), setup: func(s *experimentService) { - rows := mysqlmock.NewMockRows(mockController) - rows.EXPECT().Close().Return(nil) - rows.EXPECT().Next().Return(false) - rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(rows, nil) - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().ListGoals( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + ).Return([]*experimentproto.Goal{}, 0, int64(0), nil) s.autoOpsClient.(*autoopsclientmock.MockClient).EXPECT().ListAutoOpsRules( gomock.Any(), gomock.Any(), ).Return(&autoopsproto.ListAutoOpsRulesResponse{ @@ -248,9 +241,8 @@ func TestCreateGoalMySQL(t *testing.T) { }, { setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrGoalAlreadyExists) }, req: &experimentproto.CreateGoalRequest{ @@ -261,8 +253,12 @@ func TestCreateGoalMySQL(t *testing.T) { }, { setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().CreateGoal( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil) }, @@ -340,9 +336,8 @@ func TestCreateGoalNoCommandMySQL(t *testing.T) { { desc: "error: ErrGoalAlreadyExists", setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrGoalAlreadyExists) }, req: &experimentproto.CreateGoalRequest{ @@ -355,8 +350,12 @@ func TestCreateGoalNoCommandMySQL(t *testing.T) { { desc: "success", setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().CreateGoal( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil) }, @@ -410,9 +409,8 @@ func TestUpdateGoalMySQL(t *testing.T) { }, { setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrGoalNotFound) }, req: &experimentproto.UpdateGoalRequest{ @@ -424,8 +422,19 @@ func TestUpdateGoalMySQL(t *testing.T) { }, { setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().GetGoal( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(&domain.Goal{ + Goal: &experimentproto.Goal{ + Id: "id-1", + }, + }, nil) + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().UpdateGoal( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil) }, @@ -492,9 +501,8 @@ func TestUpdateGoalNoCommandMySQL(t *testing.T) { { desc: "error: not found", setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrGoalNotFound) }, req: &experimentproto.UpdateGoalRequest{ @@ -507,8 +515,19 @@ func TestUpdateGoalNoCommandMySQL(t *testing.T) { { desc: "success", setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().GetGoal( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(&domain.Goal{ + Goal: &experimentproto.Goal{ + Id: "id-1", + }, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().UpdateGoal( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil) }, @@ -523,8 +542,19 @@ func TestUpdateGoalNoCommandMySQL(t *testing.T) { { desc: "success: archived goal", setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().GetGoal( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(&domain.Goal{ + Goal: &experimentproto.Goal{ + Id: "id-1", + }, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().UpdateGoal( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil) }, @@ -586,9 +616,8 @@ func TestArchiveGoalMySQL(t *testing.T) { }, { setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrGoalNotFound) }, req: &experimentproto.ArchiveGoalRequest{ @@ -600,8 +629,19 @@ func TestArchiveGoalMySQL(t *testing.T) { }, { setup: func(s *experimentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().GetGoal( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(&domain.Goal{ + Goal: &experimentproto.Goal{ + Id: "id-1", + }, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().UpdateGoal( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil) }, @@ -667,8 +707,20 @@ func TestDeleteGoalMySQL(t *testing.T) { }, { setup: func(s *experimentService) { + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().GetGoal( + gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(&domain.Goal{ + Goal: &experimentproto.Goal{ + Id: "id-1", + }, + }, nil) s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().DeleteGoal( + gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil) }, req: &experimentproto.DeleteGoalRequest{ diff --git a/pkg/experiment/storage/v2/goal.go b/pkg/experiment/storage/v2/goal.go index ef71fa872..71d07a134 100644 --- a/pkg/experiment/storage/v2/goal.go +++ b/pkg/experiment/storage/v2/goal.go @@ -69,15 +69,15 @@ type GoalStorage interface { } type goalStorage struct { - qe mysql.QueryExecer + client mysql.Client } -func NewGoalStorage(qe mysql.QueryExecer) GoalStorage { - return &goalStorage{qe: qe} +func NewGoalStorage(client mysql.Client) GoalStorage { + return &goalStorage{client: client} } func (s *goalStorage) CreateGoal(ctx context.Context, g *domain.Goal, environmentId string) error { - _, err := s.qe.ExecContext( + _, err := s.client.Qe(ctx).ExecContext( ctx, insertGoalSQL, g.Id, @@ -100,7 +100,7 @@ func (s *goalStorage) CreateGoal(ctx context.Context, g *domain.Goal, environmen } func (s *goalStorage) UpdateGoal(ctx context.Context, g *domain.Goal, environmentId string) error { - result, err := s.qe.ExecContext( + result, err := s.client.Qe(ctx).ExecContext( ctx, updateGoalSQL, g.Name, @@ -129,7 +129,7 @@ func (s *goalStorage) GetGoal(ctx context.Context, id, environmentId string) (*d goal := proto.Goal{} var connectionType int32 var experiments []experimentRef - err := s.qe.QueryRowContext( + err := s.client.Qe(ctx).QueryRowContext( ctx, selectGoalSQL, environmentId, // Case query @@ -191,7 +191,7 @@ func (s *goalStorage) ListGoals( } } query := fmt.Sprintf(selectGoalsSQL, whereSQL, isInUseStatusSQL, orderBySQL, limitOffsetSQL) - rows, err := s.qe.QueryContext(ctx, query, prepareArgs...) + rows, err := s.client.Qe(ctx).QueryContext(ctx, query, prepareArgs...) if err != nil { return nil, 0, 0, err } @@ -246,7 +246,7 @@ func (s *goalStorage) ListGoals( prepareCountArgs = append(prepareCountArgs, environmentId) prepareCountArgs = append(prepareCountArgs, whereArgs...) countQuery := fmt.Sprintf(countGoalSQL, countConditionSQL, whereSQL) - err = s.qe.QueryRowContext(ctx, countQuery, prepareCountArgs...).Scan(&totalCount) + err = s.client.Qe(ctx).QueryRowContext(ctx, countQuery, prepareArgs...).Scan(&totalCount) if err != nil { return nil, 0, 0, err } @@ -254,7 +254,7 @@ func (s *goalStorage) ListGoals( } func (s *goalStorage) DeleteGoal(ctx context.Context, id, environmentId string) error { - result, err := s.qe.ExecContext( + result, err := s.client.Qe(ctx).ExecContext( ctx, deleteGoalSQL, id, diff --git a/pkg/experiment/storage/v2/goal_test.go b/pkg/experiment/storage/v2/goal_test.go index b5494001b..fbe5a4b1b 100644 --- a/pkg/experiment/storage/v2/goal_test.go +++ b/pkg/experiment/storage/v2/goal_test.go @@ -33,7 +33,7 @@ func TestNewGoalStorage(t *testing.T) { t.Parallel() mockController := gomock.NewController(t) defer mockController.Finish() - db := NewGoalStorage(mock.NewMockQueryExecer(mockController)) + db := NewGoalStorage(mock.NewMockClient(mockController)) assert.IsType(t, &goalStorage{}, db) } @@ -50,7 +50,9 @@ func TestCreateGoal(t *testing.T) { }{ { setup: func(s *goalStorage) { - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe(gomock.Any()).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, mysql.ErrDuplicateEntry) }, @@ -62,7 +64,9 @@ func TestCreateGoal(t *testing.T) { }, { setup: func(s *goalStorage) { - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe(gomock.Any()).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, nil) }, @@ -100,7 +104,9 @@ func TestUpdateGoal(t *testing.T) { setup: func(s *goalStorage) { result := mock.NewMockResult(mockController) result.EXPECT().RowsAffected().Return(int64(0), nil) - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe(gomock.Any()).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(result, nil) }, @@ -114,7 +120,9 @@ func TestUpdateGoal(t *testing.T) { setup: func(s *goalStorage) { result := mock.NewMockResult(mockController) result.EXPECT().RowsAffected().Return(int64(1), nil) - s.qe.(*mock.MockQueryExecer).EXPECT().ExecContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe(gomock.Any()).Return(qe) + qe.EXPECT().ExecContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(result, nil) }, @@ -151,7 +159,9 @@ func TestGetGoal(t *testing.T) { setup: func(s *goalStorage) { row := mock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(mysql.ErrNoRows) - s.qe.(*mock.MockQueryExecer).EXPECT().QueryRowContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe(gomock.Any()).Return(qe) + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -164,7 +174,9 @@ func TestGetGoal(t *testing.T) { setup: func(s *goalStorage) { row := mock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(nil) - s.qe.(*mock.MockQueryExecer).EXPECT().QueryRowContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe(gomock.Any()).Return(qe) + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -204,7 +216,9 @@ func TestListGoals(t *testing.T) { }{ { setup: func(s *goalStorage) { - s.qe.(*mock.MockQueryExecer).EXPECT().QueryContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe(gomock.Any()).Return(qe).AnyTimes() + qe.EXPECT().QueryContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(nil, errors.New("error")) }, @@ -224,12 +238,14 @@ func TestListGoals(t *testing.T) { rows.EXPECT().Close().Return(nil) rows.EXPECT().Next().Return(false) rows.EXPECT().Err().Return(nil) - s.qe.(*mock.MockQueryExecer).EXPECT().QueryContext( + qe := mock.NewMockQueryExecer(mockController) + s.client.(*mock.MockClient).EXPECT().Qe(gomock.Any()).Return(qe).AnyTimes() + qe.EXPECT().QueryContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(rows, nil) row := mock.NewMockRow(mockController) row.EXPECT().Scan(gomock.Any()).Return(nil) - s.qe.(*mock.MockQueryExecer).EXPECT().QueryRowContext( + qe.EXPECT().QueryRowContext( gomock.Any(), gomock.Any(), gomock.Any(), ).Return(row) }, @@ -270,5 +286,5 @@ func TestListGoals(t *testing.T) { func newGoalStorageWithMock(t *testing.T, mockController *gomock.Controller) *goalStorage { t.Helper() - return &goalStorage{mock.NewMockQueryExecer(mockController)} + return &goalStorage{mock.NewMockClient(mockController)} } From 4d73a7ed8e712ada630522da8a97ab615b59c8c2 Mon Sep 17 00:00:00 2001 From: kakcy Date: Wed, 29 Jan 2025 18:36:35 +0900 Subject: [PATCH 3/6] fix: add GoalID check on request with NoCommand in CreateExperimentAPI --- pkg/experiment/api/experiment.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pkg/experiment/api/experiment.go b/pkg/experiment/api/experiment.go index 83c22f15d..325436b92 100644 --- a/pkg/experiment/api/experiment.go +++ b/pkg/experiment/api/experiment.go @@ -418,6 +418,29 @@ func (s *experimentService) createExperimentNoCommand( } return nil, dt.Err() } + for _, gid := range req.GoalIds { + _, err := s.getGoalMySQL(ctx, gid, req.EnvironmentId) + if err != nil { + if errors.Is(err, v2es.ErrGoalNotFound) { + dt, err := statusGoalNotFound.WithDetails(&errdetails.LocalizedMessage{ + Locale: localizer.GetLocale(), + Message: localizer.MustLocalize(locale.NotFoundError), + }) + if err != nil { + return nil, statusInternal.Err() + } + return nil, dt.Err() + } + dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ + Locale: localizer.GetLocale(), + Message: localizer.MustLocalize(locale.InternalServerError), + }) + if err != nil { + return nil, statusInternal.Err() + } + return nil, dt.Err() + } + } experiment, err := domain.NewExperiment( req.FeatureId, getFeatureResp.Feature.Version, From ce0966b016d0dd625e4c9cd1fd89d7b66872614a Mon Sep 17 00:00:00 2001 From: kakcy Date: Fri, 31 Jan 2025 11:48:55 +0900 Subject: [PATCH 4/6] refactor: make mockgen --- pkg/experiment/api/goal_test.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pkg/experiment/api/goal_test.go b/pkg/experiment/api/goal_test.go index aa8ed2fde..72eea3cc5 100644 --- a/pkg/experiment/api/goal_test.go +++ b/pkg/experiment/api/goal_test.go @@ -18,6 +18,14 @@ import ( "context" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc/metadata" + gstatus "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/wrapperspb" + autoopsclientmock "github.com/bucketeer-io/bucketeer/pkg/autoops/client/mock" "github.com/bucketeer-io/bucketeer/pkg/experiment/domain" v2es "github.com/bucketeer-io/bucketeer/pkg/experiment/storage/v2" @@ -28,13 +36,6 @@ import ( accountproto "github.com/bucketeer-io/bucketeer/proto/account" autoopsproto "github.com/bucketeer-io/bucketeer/proto/autoops" experimentproto "github.com/bucketeer-io/bucketeer/proto/experiment" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - "google.golang.org/genproto/googleapis/rpc/errdetails" - "google.golang.org/grpc/metadata" - gstatus "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/wrapperspb" ) func TestGetGoalMySQL(t *testing.T) { From 1d1071a583283ff32289199e20918d15da19f76e Mon Sep 17 00:00:00 2001 From: kakcy Date: Thu, 6 Feb 2025 15:21:07 +0900 Subject: [PATCH 5/6] refactor: revert GoalIDs check on request with NoCommand --- pkg/experiment/api/experiment.go | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/pkg/experiment/api/experiment.go b/pkg/experiment/api/experiment.go index 325436b92..83c22f15d 100644 --- a/pkg/experiment/api/experiment.go +++ b/pkg/experiment/api/experiment.go @@ -418,29 +418,6 @@ func (s *experimentService) createExperimentNoCommand( } return nil, dt.Err() } - for _, gid := range req.GoalIds { - _, err := s.getGoalMySQL(ctx, gid, req.EnvironmentId) - if err != nil { - if errors.Is(err, v2es.ErrGoalNotFound) { - dt, err := statusGoalNotFound.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.NotFoundError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - } experiment, err := domain.NewExperiment( req.FeatureId, getFeatureResp.Feature.Version, From 7f915f4aaeac7e90fce809ea82764dab17aa93ca Mon Sep 17 00:00:00 2001 From: kakcy Date: Thu, 6 Feb 2025 18:47:33 +0900 Subject: [PATCH 6/6] test: fix test --- pkg/experiment/api/experiment_test.go | 4 +++- pkg/experiment/storage/v2/goal.go | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/experiment/api/experiment_test.go b/pkg/experiment/api/experiment_test.go index 7ccd3fd3e..56d07de20 100644 --- a/pkg/experiment/api/experiment_test.go +++ b/pkg/experiment/api/experiment_test.go @@ -276,7 +276,9 @@ func TestCreateExperimentNoCommandMySQL(t *testing.T) { { desc: "success", setup: func(s *experimentService) { - s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().GetGoal(gomock.Any(), gomock.Any(), gomock.Any()).Return(&domain.Goal{}, nil) + s.goalStorage.(*storagemock.MockGoalStorage).EXPECT().GetGoal(gomock.Any(), gomock.Any(), gomock.Any()).Return(&domain.Goal{ + Goal: &experimentproto.Goal{Id: "goalId", ConnectionType: experimentproto.Goal_EXPERIMENT}, + }, nil) s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( gomock.Any(), gomock.Any(), ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { diff --git a/pkg/experiment/storage/v2/goal.go b/pkg/experiment/storage/v2/goal.go index 71d07a134..46fa22bf2 100644 --- a/pkg/experiment/storage/v2/goal.go +++ b/pkg/experiment/storage/v2/goal.go @@ -246,7 +246,7 @@ func (s *goalStorage) ListGoals( prepareCountArgs = append(prepareCountArgs, environmentId) prepareCountArgs = append(prepareCountArgs, whereArgs...) countQuery := fmt.Sprintf(countGoalSQL, countConditionSQL, whereSQL) - err = s.client.Qe(ctx).QueryRowContext(ctx, countQuery, prepareArgs...).Scan(&totalCount) + err = s.client.Qe(ctx).QueryRowContext(ctx, countQuery, prepareCountArgs...).Scan(&totalCount) if err != nil { return nil, 0, 0, err }