From 74ca90e45e1621aa77950b44d5c3a4a94f71f760 Mon Sep 17 00:00:00 2001 From: Andrew Richardson Date: Wed, 22 Jun 2022 12:45:06 -0400 Subject: [PATCH 1/9] Add namespace to subscription database queries Signed-off-by: Andrew Richardson --- .../apiserver/route_delete_subscription.go | 2 +- .../route_delete_subscription_test.go | 2 +- .../apiserver/route_get_subscription_by_id.go | 2 +- .../route_get_subscription_by_id_test.go | 2 +- internal/apiserver/route_get_subscriptions.go | 2 +- .../apiserver/route_get_subscriptions_test.go | 2 +- .../apiserver/route_post_new_subscription.go | 2 +- .../route_post_new_subscription_test.go | 2 +- internal/apiserver/route_put_subscription.go | 2 +- .../apiserver/route_put_subscription_test.go | 2 +- .../database/sqlcommon/subscription_sql.go | 18 ++-- .../sqlcommon/subscription_sql_test.go | 20 ++--- internal/events/event_manager.go | 2 +- internal/events/event_manager_test.go | 6 +- internal/events/subscription_manager.go | 4 +- internal/events/subscription_manager_test.go | 30 +++---- internal/orchestrator/data_query.go | 4 - internal/orchestrator/orchestrator.go | 10 +-- internal/orchestrator/subscriptions.go | 27 +++--- internal/orchestrator/subscriptions_test.go | 86 +++++++++++-------- mocks/databasemocks/plugin.go | 42 ++++----- mocks/orchestratormocks/orchestrator.go | 70 +++++++-------- pkg/database/plugin.go | 6 +- 23 files changed, 179 insertions(+), 166 deletions(-) diff --git a/internal/apiserver/route_delete_subscription.go b/internal/apiserver/route_delete_subscription.go index 85bb1d77c..cc21ab61c 100644 --- a/internal/apiserver/route_delete_subscription.go +++ b/internal/apiserver/route_delete_subscription.go @@ -37,7 +37,7 @@ var deleteSubscription = &ffapi.Route{ JSONOutputCodes: []int{http.StatusNoContent}, // Sync operation, no output Extensions: &coreExtensions{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - err = cr.or.DeleteSubscription(cr.ctx, extractNamespace(r.PP), r.PP["subid"]) + err = cr.or.DeleteSubscription(cr.ctx, r.PP["subid"]) return nil, err }, }, diff --git a/internal/apiserver/route_delete_subscription_test.go b/internal/apiserver/route_delete_subscription_test.go index 6f5a66cff..7b1cb602d 100644 --- a/internal/apiserver/route_delete_subscription_test.go +++ b/internal/apiserver/route_delete_subscription_test.go @@ -39,7 +39,7 @@ func TestDeleteSubscription(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - o.On("DeleteSubscription", mock.Anything, "ns1", u.String()). + o.On("DeleteSubscription", mock.Anything, u.String()). Return(nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_get_subscription_by_id.go b/internal/apiserver/route_get_subscription_by_id.go index 549b04f27..e099ce7b4 100644 --- a/internal/apiserver/route_get_subscription_by_id.go +++ b/internal/apiserver/route_get_subscription_by_id.go @@ -38,7 +38,7 @@ var getSubscriptionByID = &ffapi.Route{ JSONOutputCodes: []int{http.StatusOK}, Extensions: &coreExtensions{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - output, err = cr.or.GetSubscriptionByID(cr.ctx, extractNamespace(r.PP), r.PP["subid"]) + output, err = cr.or.GetSubscriptionByID(cr.ctx, r.PP["subid"]) return output, err }, }, diff --git a/internal/apiserver/route_get_subscription_by_id_test.go b/internal/apiserver/route_get_subscription_by_id_test.go index cf695d174..0a84a35e6 100644 --- a/internal/apiserver/route_get_subscription_by_id_test.go +++ b/internal/apiserver/route_get_subscription_by_id_test.go @@ -31,7 +31,7 @@ func TestGetSubscriptionByID(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - o.On("GetSubscriptionByID", mock.Anything, "mynamespace", "abcd12345"). + o.On("GetSubscriptionByID", mock.Anything, "abcd12345"). Return(&core.Subscription{}, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_get_subscriptions.go b/internal/apiserver/route_get_subscriptions.go index bc1118fde..ce72f5dbf 100644 --- a/internal/apiserver/route_get_subscriptions.go +++ b/internal/apiserver/route_get_subscriptions.go @@ -38,7 +38,7 @@ var getSubscriptions = &ffapi.Route{ Extensions: &coreExtensions{ FilterFactory: database.SubscriptionQueryFactory, CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - return filterResult(cr.or.GetSubscriptions(cr.ctx, extractNamespace(r.PP), cr.filter)) + return filterResult(cr.or.GetSubscriptions(cr.ctx, cr.filter)) }, }, } diff --git a/internal/apiserver/route_get_subscriptions_test.go b/internal/apiserver/route_get_subscriptions_test.go index 64a7c6a91..68801187b 100644 --- a/internal/apiserver/route_get_subscriptions_test.go +++ b/internal/apiserver/route_get_subscriptions_test.go @@ -31,7 +31,7 @@ func TestGetSubscriptions(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - o.On("GetSubscriptions", mock.Anything, "mynamespace", mock.Anything). + o.On("GetSubscriptions", mock.Anything, mock.Anything). Return([]*core.Subscription{}, nil, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_post_new_subscription.go b/internal/apiserver/route_post_new_subscription.go index f165ce761..f32767911 100644 --- a/internal/apiserver/route_post_new_subscription.go +++ b/internal/apiserver/route_post_new_subscription.go @@ -36,7 +36,7 @@ var postNewSubscription = &ffapi.Route{ JSONOutputCodes: []int{http.StatusCreated}, // Sync operation Extensions: &coreExtensions{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - output, err = cr.or.CreateSubscription(cr.ctx, extractNamespace(r.PP), r.Input.(*core.Subscription)) + output, err = cr.or.CreateSubscription(cr.ctx, r.Input.(*core.Subscription)) return output, err }, }, diff --git a/internal/apiserver/route_post_new_subscription_test.go b/internal/apiserver/route_post_new_subscription_test.go index 0b4dfea2b..00ce58e61 100644 --- a/internal/apiserver/route_post_new_subscription_test.go +++ b/internal/apiserver/route_post_new_subscription_test.go @@ -36,7 +36,7 @@ func TestPostNewSubscription(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - o.On("CreateSubscription", mock.Anything, "ns1", mock.AnythingOfType("*core.Subscription")). + o.On("CreateSubscription", mock.Anything, mock.AnythingOfType("*core.Subscription")). Return(&core.Subscription{}, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_put_subscription.go b/internal/apiserver/route_put_subscription.go index 047b832e0..ff7e661e0 100644 --- a/internal/apiserver/route_put_subscription.go +++ b/internal/apiserver/route_put_subscription.go @@ -36,7 +36,7 @@ var putSubscription = &ffapi.Route{ JSONOutputCodes: []int{http.StatusOK}, // Sync operation Extensions: &coreExtensions{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - output, err = cr.or.CreateUpdateSubscription(cr.ctx, extractNamespace(r.PP), r.Input.(*core.Subscription)) + output, err = cr.or.CreateUpdateSubscription(cr.ctx, r.Input.(*core.Subscription)) return output, err }, }, diff --git a/internal/apiserver/route_put_subscription_test.go b/internal/apiserver/route_put_subscription_test.go index bdba24e07..0cb858b01 100644 --- a/internal/apiserver/route_put_subscription_test.go +++ b/internal/apiserver/route_put_subscription_test.go @@ -36,7 +36,7 @@ func TestPutSubscription(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - o.On("CreateUpdateSubscription", mock.Anything, "ns1", mock.AnythingOfType("*core.Subscription")). + o.On("CreateUpdateSubscription", mock.Anything, mock.AnythingOfType("*core.Subscription")). Return(&core.Subscription{}, nil) r.ServeHTTP(res, req) diff --git a/internal/database/sqlcommon/subscription_sql.go b/internal/database/sqlcommon/subscription_sql.go index ffac7cd6d..3c2c12e58 100644 --- a/internal/database/sqlcommon/subscription_sql.go +++ b/internal/database/sqlcommon/subscription_sql.go @@ -178,17 +178,19 @@ func (s *SQLCommon) getSubscriptionEq(ctx context.Context, eq sq.Eq, textName st return subscription, nil } -func (s *SQLCommon) GetSubscriptionByID(ctx context.Context, id *fftypes.UUID) (message *core.Subscription, err error) { - return s.getSubscriptionEq(ctx, sq.Eq{"id": id}, id.String()) +func (s *SQLCommon) GetSubscriptionByID(ctx context.Context, namespace string, id *fftypes.UUID) (message *core.Subscription, err error) { + return s.getSubscriptionEq(ctx, sq.Eq{"id": id, "namespace": namespace}, id.String()) } -func (s *SQLCommon) GetSubscriptionByName(ctx context.Context, ns, name string) (message *core.Subscription, err error) { - return s.getSubscriptionEq(ctx, sq.Eq{"namespace": ns, "name": name}, fmt.Sprintf("%s:%s", ns, name)) +func (s *SQLCommon) GetSubscriptionByName(ctx context.Context, namespace, name string) (message *core.Subscription, err error) { + return s.getSubscriptionEq(ctx, sq.Eq{"namespace": namespace, "name": name}, fmt.Sprintf("%s:%s", namespace, name)) } -func (s *SQLCommon) GetSubscriptions(ctx context.Context, filter database.Filter) (message []*core.Subscription, fr *database.FilterResult, err error) { +func (s *SQLCommon) GetSubscriptions(ctx context.Context, namespace string, filter database.Filter) (message []*core.Subscription, fr *database.FilterResult, err error) { - query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(subscriptionColumns...).From(subscriptionsTable), filter, subscriptionFilterFieldMap, []interface{}{"sequence"}) + query, fop, fi, err := s.filterSelect( + ctx, "", sq.Select(subscriptionColumns...).From(subscriptionsTable), + filter, subscriptionFilterFieldMap, []interface{}{"sequence"}, sq.Eq{"namespace": namespace}) if err != nil { return nil, nil, err } @@ -245,7 +247,7 @@ func (s *SQLCommon) UpdateSubscription(ctx context.Context, namespace, name stri return s.commitTx(ctx, tx, autoCommit) } -func (s *SQLCommon) DeleteSubscriptionByID(ctx context.Context, id *fftypes.UUID) (err error) { +func (s *SQLCommon) DeleteSubscriptionByID(ctx context.Context, namespace string, id *fftypes.UUID) (err error) { ctx, tx, autoCommit, err := s.beginOrUseTx(ctx) if err != nil { @@ -253,7 +255,7 @@ func (s *SQLCommon) DeleteSubscriptionByID(ctx context.Context, id *fftypes.UUID } defer s.rollbackTx(ctx, tx, autoCommit) - subscription, err := s.GetSubscriptionByID(ctx, id) + subscription, err := s.GetSubscriptionByID(ctx, namespace, id) if err == nil && subscription != nil { err = s.deleteTx(ctx, subscriptionsTable, tx, sq.Delete(subscriptionsTable).Where(sq.Eq{ "id": id, diff --git a/internal/database/sqlcommon/subscription_sql_test.go b/internal/database/sqlcommon/subscription_sql_test.go index a4b448135..c8e12dc69 100644 --- a/internal/database/sqlcommon/subscription_sql_test.go +++ b/internal/database/sqlcommon/subscription_sql_test.go @@ -101,7 +101,7 @@ func TestSubscriptionsE2EWithDB(t *testing.T) { assert.NoError(t, err) // Check we get the exact same data back - note the removal of one of the subscription elements - subscriptionRead, err = s.GetSubscriptionByID(ctx, subscription.ID) + subscriptionRead, err = s.GetSubscriptionByID(ctx, "ns1", subscription.ID) assert.NoError(t, err) subscriptionJson, _ = json.Marshal(&subscriptionUpdated) subscriptionReadJson, _ = json.Marshal(&subscriptionRead) @@ -114,7 +114,7 @@ func TestSubscriptionsE2EWithDB(t *testing.T) { fb.Eq("namespace", subscriptionUpdated.Namespace), fb.Eq("name", subscriptionUpdated.Name), ) - subscriptionRes, res, err := s.GetSubscriptions(ctx, filter.Count(true)) + subscriptionRes, res, err := s.GetSubscriptions(ctx, "ns1", filter.Count(true)) assert.NoError(t, err) assert.Equal(t, 1, len(subscriptionRes)) assert.Equal(t, int64(1), *res.TotalCount) @@ -132,15 +132,15 @@ func TestSubscriptionsE2EWithDB(t *testing.T) { fb.Eq("name", subscriptionUpdated.Name), fb.Eq("created", updateTime.String()), ) - subscriptions, _, err := s.GetSubscriptions(ctx, filter) + subscriptions, _, err := s.GetSubscriptions(ctx, "ns1", filter) assert.NoError(t, err) assert.Equal(t, 1, len(subscriptions)) // Test delete, and refind no return s.callbacks.On("UUIDCollectionNSEvent", database.CollectionSubscriptions, core.ChangeEventTypeDeleted, "ns1", subscription.ID).Return() - err = s.DeleteSubscriptionByID(ctx, subscriptionUpdated.ID) + err = s.DeleteSubscriptionByID(ctx, "ns1", subscriptionUpdated.ID) assert.NoError(t, err) - subscriptions, _, err = s.GetSubscriptions(ctx, filter) + subscriptions, _, err = s.GetSubscriptions(ctx, "ns1", filter) assert.NoError(t, err) assert.Equal(t, 0, len(subscriptions)) @@ -228,7 +228,7 @@ func TestGetSubscriptionQueryFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) f := database.SubscriptionQueryFactory.NewFilter(context.Background()).Eq("name", "") - _, _, err := s.GetSubscriptions(context.Background(), f) + _, _, err := s.GetSubscriptions(context.Background(), "ns1", f) assert.Regexp(t, "FF10115", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -236,7 +236,7 @@ func TestGetSubscriptionQueryFail(t *testing.T) { func TestGetSubscriptionBuildQueryFail(t *testing.T) { s, _ := newMockProvider().init() f := database.SubscriptionQueryFactory.NewFilter(context.Background()).Eq("name", map[bool]bool{true: false}) - _, _, err := s.GetSubscriptions(context.Background(), f) + _, _, err := s.GetSubscriptions(context.Background(), "ns1", f) assert.Regexp(t, "FF00143.*type", err) } @@ -244,7 +244,7 @@ func TestGetSubscriptionReadMessageFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"ntype"}).AddRow("only one")) f := database.SubscriptionQueryFactory.NewFilter(context.Background()).Eq("name", "") - _, _, err := s.GetSubscriptions(context.Background(), f) + _, _, err := s.GetSubscriptions(context.Background(), "ns1", f) assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -304,7 +304,7 @@ func TestSubscriptionUpdateFail(t *testing.T) { func TestSubscriptionDeleteBeginFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectBegin().WillReturnError(fmt.Errorf("pop")) - err := s.DeleteSubscriptionByID(context.Background(), fftypes.NewUUID()) + err := s.DeleteSubscriptionByID(context.Background(), "ns1", fftypes.NewUUID()) assert.Regexp(t, "FF10114", err) } @@ -315,6 +315,6 @@ func TestSubscriptionDeleteFail(t *testing.T) { fftypes.NewUUID(), "ns1", "sub1", "websockets", `{}`, `{}`, fftypes.Now(), fftypes.Now()), ) mock.ExpectExec("DELETE .*").WillReturnError(fmt.Errorf("pop")) - err := s.DeleteSubscriptionByID(context.Background(), fftypes.NewUUID()) + err := s.DeleteSubscriptionByID(context.Background(), "ns1", fftypes.NewUUID()) assert.Regexp(t, "FF10118", err) } diff --git a/internal/events/event_manager.go b/internal/events/event_manager.go index d1b5325fc..65be588df 100644 --- a/internal/events/event_manager.go +++ b/internal/events/event_manager.go @@ -245,7 +245,7 @@ func (em *eventManager) CreateUpdateDurableSubscription(ctx context.Context, sub func (em *eventManager) DeleteDurableSubscription(ctx context.Context, subDef *core.Subscription) (err error) { // The event in the database for the deletion of the susbscription, will asynchronously update the submanager - return em.database.DeleteSubscriptionByID(ctx, subDef.ID) + return em.database.DeleteSubscriptionByID(ctx, em.namespace, subDef.ID) } func (em *eventManager) AddSystemEventListener(ns string, el system.EventListener) error { diff --git a/internal/events/event_manager_test.go b/internal/events/event_manager_test.go index 218aed94c..c06d68471 100644 --- a/internal/events/event_manager_test.go +++ b/internal/events/event_manager_test.go @@ -176,7 +176,7 @@ func TestEmitSubscriptionEventsNoops(t *testing.T) { getSubCallReady := make(chan bool, 1) getSubCalled := make(chan bool) - getSub := mdi.On("GetSubscriptionByID", mock.Anything, mock.Anything).Return(nil, nil) + getSub := mdi.On("GetSubscriptionByID", mock.Anything, "ns1", mock.Anything).Return(nil, nil) getSub.RunFn = func(a mock.Arguments) { <-getSubCallReady getSubCalled <- true @@ -393,8 +393,8 @@ func TestCreateDeleteDurableSubscriptionOk(t *testing.T) { mdi := em.database.(*databasemocks.Plugin) subId := fftypes.NewUUID() sub := &core.Subscription{SubscriptionRef: core.SubscriptionRef{ID: subId, Namespace: "ns1"}} - mdi.On("GetSubscriptionByID", mock.Anything, subId).Return(sub, nil) - mdi.On("DeleteSubscriptionByID", mock.Anything, subId).Return(nil) + mdi.On("GetSubscriptionByID", mock.Anything, "ns1", subId).Return(sub, nil) + mdi.On("DeleteSubscriptionByID", mock.Anything, "ns1", subId).Return(nil) err := em.DeleteDurableSubscription(em.ctx, sub) assert.NoError(t, err) } diff --git a/internal/events/subscription_manager.go b/internal/events/subscription_manager.go index 15b6f462d..50da22db0 100644 --- a/internal/events/subscription_manager.go +++ b/internal/events/subscription_manager.go @@ -128,7 +128,7 @@ func newSubscriptionManager(ctx context.Context, ns string, di database.Plugin, func (sm *subscriptionManager) start() error { fb := database.SubscriptionQueryFactory.NewFilter(sm.ctx) filter := fb.And().Limit(sm.maxSubs) - persistedSubs, _, err := sm.database.GetSubscriptions(sm.ctx, filter) + persistedSubs, _, err := sm.database.GetSubscriptions(sm.ctx, sm.namespace, filter) if err != nil { return err } @@ -167,7 +167,7 @@ func (sm *subscriptionManager) subscriptionEventListener() { func (sm *subscriptionManager) newOrUpdatedDurableSubscription(id *fftypes.UUID) { var subDef *core.Subscription err := sm.retry.Do(sm.ctx, "retrieve subscription", func(attempt int) (retry bool, err error) { - subDef, err = sm.database.GetSubscriptionByID(sm.ctx, id) + subDef, err = sm.database.GetSubscriptionByID(sm.ctx, sm.namespace, id) return err != nil, err // indefinite retry }) if err != nil || subDef == nil { diff --git a/internal/events/subscription_manager_test.go b/internal/events/subscription_manager_test.go index e0213f253..5d05e6662 100644 --- a/internal/events/subscription_manager_test.go +++ b/internal/events/subscription_manager_test.go @@ -76,7 +76,7 @@ func TestRegisterDurableSubscriptions(t *testing.T) { defer cancel() mdi := sm.database.(*databasemocks.Plugin) - mdi.On("GetSubscriptions", mock.Anything, mock.Anything).Return([]*core.Subscription{ + mdi.On("GetSubscriptions", mock.Anything, "ns1", mock.Anything).Return([]*core.Subscription{ {SubscriptionRef: core.SubscriptionRef{ ID: sub1, }, Transport: "ut"}, @@ -135,7 +135,7 @@ func TestReloadDurableSubscription(t *testing.T) { } mdi := sm.database.(*databasemocks.Plugin) - mdi.On("GetSubscriptions", mock.Anything, mock.Anything).Return([]*core.Subscription{ + mdi.On("GetSubscriptions", mock.Anything, "ns1", mock.Anything).Return([]*core.Subscription{ {SubscriptionRef: core.SubscriptionRef{ ID: sub1, Namespace: "ns1", @@ -158,7 +158,7 @@ func TestRegisterEphemeralSubscriptions(t *testing.T) { defer cancel() mdi := sm.database.(*databasemocks.Plugin) - mdi.On("GetSubscriptions", mock.Anything, mock.Anything).Return([]*core.Subscription{}, nil, nil) + mdi.On("GetSubscriptions", mock.Anything, "ns1", mock.Anything).Return([]*core.Subscription{}, nil, nil) mei.On("ValidateOptions", mock.Anything).Return(nil) err := sm.start() @@ -187,7 +187,7 @@ func TestRegisterEphemeralSubscriptionsFail(t *testing.T) { defer cancel() mdi := sm.database.(*databasemocks.Plugin) - mdi.On("GetSubscriptions", mock.Anything, mock.Anything).Return([]*core.Subscription{}, nil, nil) + mdi.On("GetSubscriptions", mock.Anything, "ns1", mock.Anything).Return([]*core.Subscription{}, nil, nil) mei.On("ValidateOptions", mock.Anything).Return(nil) err := sm.start() assert.NoError(t, err) @@ -209,7 +209,7 @@ func TestStartSubRestoreFail(t *testing.T) { defer cancel() mdi := sm.database.(*databasemocks.Plugin) - mdi.On("GetSubscriptions", mock.Anything, mock.Anything).Return(nil, nil, fmt.Errorf("pop")) + mdi.On("GetSubscriptions", mock.Anything, "ns1", mock.Anything).Return(nil, nil, fmt.Errorf("pop")) err := sm.start() assert.EqualError(t, err, "pop") } @@ -220,7 +220,7 @@ func TestStartSubRestoreOkSubsFail(t *testing.T) { defer cancel() mdi := sm.database.(*databasemocks.Plugin) - mdi.On("GetSubscriptions", mock.Anything, mock.Anything).Return([]*core.Subscription{ + mdi.On("GetSubscriptions", mock.Anything, "ns1", mock.Anything).Return([]*core.Subscription{ {SubscriptionRef: core.SubscriptionRef{ ID: fftypes.NewUUID(), }, @@ -238,7 +238,7 @@ func TestStartSubRestoreOkSubsOK(t *testing.T) { defer cancel() mdi := sm.database.(*databasemocks.Plugin) - mdi.On("GetSubscriptions", mock.Anything, mock.Anything).Return([]*core.Subscription{ + mdi.On("GetSubscriptions", mock.Anything, "ns1", mock.Anything).Return([]*core.Subscription{ {SubscriptionRef: core.SubscriptionRef{ ID: fftypes.NewUUID(), }, @@ -538,7 +538,7 @@ func TestDispatchDeliveryResponseOK(t *testing.T) { sm, cancel := newTestSubManager(t, mei) defer cancel() mdi := sm.database.(*databasemocks.Plugin) - mdi.On("GetSubscriptions", mock.Anything, mock.Anything).Return([]*core.Subscription{}, nil, nil) + mdi.On("GetSubscriptions", mock.Anything, "ns1", mock.Anything).Return([]*core.Subscription{}, nil, nil) mei.On("ValidateOptions", mock.Anything).Return(nil) err := sm.start() assert.NoError(t, err) @@ -568,7 +568,7 @@ func TestDispatchDeliveryResponseInvalidSubscription(t *testing.T) { sm, cancel := newTestSubManager(t, mei) defer cancel() mdi := sm.database.(*databasemocks.Plugin) - mdi.On("GetSubscriptions", mock.Anything, mock.Anything).Return([]*core.Subscription{}, nil, nil) + mdi.On("GetSubscriptions", mock.Anything, "ns1", mock.Anything).Return([]*core.Subscription{}, nil, nil) err := sm.start() assert.NoError(t, err) be := &boundCallbacks{sm: sm, ei: mei} @@ -618,7 +618,7 @@ func TestNewDurableSubscriptionBadSub(t *testing.T) { mdi := sm.database.(*databasemocks.Plugin) subID := fftypes.NewUUID() - mdi.On("GetSubscriptionByID", mock.Anything, subID).Return(&core.Subscription{ + mdi.On("GetSubscriptionByID", mock.Anything, "ns1", subID).Return(&core.Subscription{ Filter: core.SubscriptionFilter{ Events: "![[[[badness", }, @@ -645,7 +645,7 @@ func TestNewDurableSubscriptionUnknownTransport(t *testing.T) { } subID := fftypes.NewUUID() - mdi.On("GetSubscriptionByID", mock.Anything, subID).Return(&core.Subscription{ + mdi.On("GetSubscriptionByID", mock.Anything, "ns1", subID).Return(&core.Subscription{ SubscriptionRef: core.SubscriptionRef{ ID: subID, Namespace: "ns1", @@ -677,7 +677,7 @@ func TestNewDurableSubscriptionOK(t *testing.T) { } subID := fftypes.NewUUID() - mdi.On("GetSubscriptionByID", mock.Anything, subID).Return(&core.Subscription{ + mdi.On("GetSubscriptionByID", mock.Anything, "ns1", subID).Return(&core.Subscription{ SubscriptionRef: core.SubscriptionRef{ ID: subID, Namespace: "ns1", @@ -726,7 +726,7 @@ func TestUpdatedDurableSubscriptionNoOp(t *testing.T) { }, } - mdi.On("GetSubscriptionByID", mock.Anything, subID).Return(sub, nil) + mdi.On("GetSubscriptionByID", mock.Anything, "ns1", subID).Return(sub, nil) sm.newOrUpdatedDurableSubscription(subID) assert.Equal(t, ed, sm.connections["conn1"].dispatchers[*subID]) @@ -771,7 +771,7 @@ func TestUpdatedDurableSubscriptionOK(t *testing.T) { }, } - mdi.On("GetSubscriptionByID", mock.Anything, subID).Return(&sub2, nil) + mdi.On("GetSubscriptionByID", mock.Anything, "ns1", subID).Return(&sub2, nil) sm.newOrUpdatedDurableSubscription(subID) assert.NotEqual(t, ed, sm.connections["conn1"].dispatchers[*subID]) @@ -838,7 +838,7 @@ func TestDeleteDurableSubscriptionOk(t *testing.T) { }, } - mdi.On("GetSubscriptionByID", mock.Anything, subID).Return(subDef, nil) + mdi.On("GetSubscriptionByID", mock.Anything, "ns1", subID).Return(subDef, nil) mdi.On("DeleteOffset", mock.Anything, fftypes.FFEnum("subscription"), subID.String()).Return(fmt.Errorf("this error is logged and swallowed")) sm.deletedDurableSubscription(subID) diff --git a/internal/orchestrator/data_query.go b/internal/orchestrator/data_query.go index 767d527b1..bb0c83daf 100644 --- a/internal/orchestrator/data_query.go +++ b/internal/orchestrator/data_query.go @@ -135,10 +135,6 @@ func (or *orchestrator) GetEventByID(ctx context.Context, id string) (*core.Even return or.database().GetEventByID(ctx, or.namespace, u) } -func (or *orchestrator) scopeNS(ns string, filter database.AndFilter) database.AndFilter { - return filter.Condition(filter.Builder().Eq("namespace", ns)) -} - func (or *orchestrator) GetTransactions(ctx context.Context, filter database.AndFilter) ([]*core.Transaction, *database.FilterResult, error) { return or.database().GetTransactions(ctx, or.namespace, filter) } diff --git a/internal/orchestrator/orchestrator.go b/internal/orchestrator/orchestrator.go index e0c90ca7d..dc2dc6f72 100644 --- a/internal/orchestrator/orchestrator.go +++ b/internal/orchestrator/orchestrator.go @@ -70,11 +70,11 @@ type Orchestrator interface { GetStatus(ctx context.Context) (*core.NodeStatus, error) // Subscription management - GetSubscriptions(ctx context.Context, ns string, filter database.AndFilter) ([]*core.Subscription, *database.FilterResult, error) - GetSubscriptionByID(ctx context.Context, ns, id string) (*core.Subscription, error) - CreateSubscription(ctx context.Context, ns string, subDef *core.Subscription) (*core.Subscription, error) - CreateUpdateSubscription(ctx context.Context, ns string, subDef *core.Subscription) (*core.Subscription, error) - DeleteSubscription(ctx context.Context, ns, id string) error + GetSubscriptions(ctx context.Context, filter database.AndFilter) ([]*core.Subscription, *database.FilterResult, error) + GetSubscriptionByID(ctx context.Context, id string) (*core.Subscription, error) + CreateSubscription(ctx context.Context, subDef *core.Subscription) (*core.Subscription, error) + CreateUpdateSubscription(ctx context.Context, subDef *core.Subscription) (*core.Subscription, error) + DeleteSubscription(ctx context.Context, id string) error // Data Query GetNamespace(ctx context.Context, ns string) (*core.Namespace, error) diff --git a/internal/orchestrator/subscriptions.go b/internal/orchestrator/subscriptions.go index a1477a944..a459da622 100644 --- a/internal/orchestrator/subscriptions.go +++ b/internal/orchestrator/subscriptions.go @@ -27,18 +27,18 @@ import ( "github.com/hyperledger/firefly/pkg/database" ) -func (or *orchestrator) CreateSubscription(ctx context.Context, ns string, subDef *core.Subscription) (*core.Subscription, error) { - return or.createUpdateSubscription(ctx, ns, subDef, true) +func (or *orchestrator) CreateSubscription(ctx context.Context, subDef *core.Subscription) (*core.Subscription, error) { + return or.createUpdateSubscription(ctx, subDef, true) } -func (or *orchestrator) CreateUpdateSubscription(ctx context.Context, ns string, subDef *core.Subscription) (*core.Subscription, error) { - return or.createUpdateSubscription(ctx, ns, subDef, false) +func (or *orchestrator) CreateUpdateSubscription(ctx context.Context, subDef *core.Subscription) (*core.Subscription, error) { + return or.createUpdateSubscription(ctx, subDef, false) } -func (or *orchestrator) createUpdateSubscription(ctx context.Context, ns string, subDef *core.Subscription, mustNew bool) (*core.Subscription, error) { +func (or *orchestrator) createUpdateSubscription(ctx context.Context, subDef *core.Subscription, mustNew bool) (*core.Subscription, error) { subDef.ID = fftypes.NewUUID() subDef.Created = fftypes.Now() - subDef.Namespace = ns + subDef.Namespace = or.namespace subDef.Ephemeral = false if err := or.data.VerifyNamespaceExists(ctx, subDef.Namespace); err != nil { return nil, err @@ -53,30 +53,29 @@ func (or *orchestrator) createUpdateSubscription(ctx context.Context, ns string, return subDef, or.events.CreateUpdateDurableSubscription(ctx, subDef, mustNew) } -func (or *orchestrator) DeleteSubscription(ctx context.Context, ns, id string) error { +func (or *orchestrator) DeleteSubscription(ctx context.Context, id string) error { u, err := fftypes.ParseUUID(ctx, id) if err != nil { return err } - sub, err := or.database().GetSubscriptionByID(ctx, u) + sub, err := or.database().GetSubscriptionByID(ctx, or.namespace, u) if err != nil { return err } - if sub == nil || sub.Namespace != ns { + if sub == nil { return i18n.NewError(ctx, coremsgs.Msg404NotFound) } return or.events.DeleteDurableSubscription(ctx, sub) } -func (or *orchestrator) GetSubscriptions(ctx context.Context, ns string, filter database.AndFilter) ([]*core.Subscription, *database.FilterResult, error) { - filter = or.scopeNS(ns, filter) - return or.database().GetSubscriptions(ctx, filter) +func (or *orchestrator) GetSubscriptions(ctx context.Context, filter database.AndFilter) ([]*core.Subscription, *database.FilterResult, error) { + return or.database().GetSubscriptions(ctx, or.namespace, filter) } -func (or *orchestrator) GetSubscriptionByID(ctx context.Context, ns, id string) (*core.Subscription, error) { +func (or *orchestrator) GetSubscriptionByID(ctx context.Context, id string) (*core.Subscription, error) { u, err := fftypes.ParseUUID(ctx, id) if err != nil { return nil, err } - return or.database().GetSubscriptionByID(ctx, u) + return or.database().GetSubscriptionByID(ctx, or.namespace, u) } diff --git a/internal/orchestrator/subscriptions_test.go b/internal/orchestrator/subscriptions_test.go index 6a4d60e43..208620156 100644 --- a/internal/orchestrator/subscriptions_test.go +++ b/internal/orchestrator/subscriptions_test.go @@ -31,19 +31,23 @@ import ( func TestCreateSubscriptionBadNamespace(t *testing.T) { or := newTestOrchestrator() - or.mdm.On("VerifyNamespaceExists", mock.Anything, "!wrong").Return(fmt.Errorf("pop")) - _, err := or.CreateSubscription(or.ctx, "!wrong", &core.Subscription{ + defer or.cleanup(t) + + or.mdm.On("VerifyNamespaceExists", mock.Anything, "ns").Return(fmt.Errorf("pop")) + _, err := or.CreateSubscription(or.ctx, &core.Subscription{ SubscriptionRef: core.SubscriptionRef{ - Name: "sub1", + Name: "!sub1", }, }) - assert.Regexp(t, "pop", err) + assert.EqualError(t, err, "pop") } func TestCreateSubscriptionBadName(t *testing.T) { or := newTestOrchestrator() - or.mdm.On("VerifyNamespaceExists", mock.Anything, "ns1").Return(nil) - _, err := or.CreateSubscription(or.ctx, "ns1", &core.Subscription{ + defer or.cleanup(t) + + or.mdm.On("VerifyNamespaceExists", mock.Anything, "ns").Return(nil) + _, err := or.CreateSubscription(or.ctx, &core.Subscription{ SubscriptionRef: core.SubscriptionRef{ Name: "!sub1", }, @@ -53,8 +57,10 @@ func TestCreateSubscriptionBadName(t *testing.T) { func TestCreateSubscriptionSystemTransport(t *testing.T) { or := newTestOrchestrator() - or.mdm.On("VerifyNamespaceExists", mock.Anything, "ns1").Return(nil) - _, err := or.CreateSubscription(or.ctx, "ns1", &core.Subscription{ + defer or.cleanup(t) + + or.mdm.On("VerifyNamespaceExists", mock.Anything, "ns").Return(nil) + _, err := or.CreateSubscription(or.ctx, &core.Subscription{ Transport: system.SystemEventsTransport, SubscriptionRef: core.SubscriptionRef{ Name: "sub1", @@ -65,63 +71,67 @@ func TestCreateSubscriptionSystemTransport(t *testing.T) { func TestCreateSubscriptionOk(t *testing.T) { or := newTestOrchestrator() + defer or.cleanup(t) + sub := &core.Subscription{ SubscriptionRef: core.SubscriptionRef{ Name: "sub1", }, } - or.mdm.On("VerifyNamespaceExists", mock.Anything, "ns1").Return(nil) + or.mdm.On("VerifyNamespaceExists", mock.Anything, "ns").Return(nil) or.mem.On("CreateUpdateDurableSubscription", mock.Anything, mock.Anything, true).Return(nil) - s1, err := or.CreateSubscription(or.ctx, "ns1", sub) + s1, err := or.CreateSubscription(or.ctx, sub) assert.NoError(t, err) assert.Equal(t, s1, sub) - assert.Equal(t, "ns1", sub.Namespace) + assert.Equal(t, "ns", sub.Namespace) } func TestCreateUpdateSubscriptionOk(t *testing.T) { or := newTestOrchestrator() + defer or.cleanup(t) + sub := &core.Subscription{ SubscriptionRef: core.SubscriptionRef{ Name: "sub1", }, } - or.mdm.On("VerifyNamespaceExists", mock.Anything, "ns1").Return(nil) + or.mdm.On("VerifyNamespaceExists", mock.Anything, "ns").Return(nil) or.mem.On("CreateUpdateDurableSubscription", mock.Anything, mock.Anything, false).Return(nil) - s1, err := or.CreateUpdateSubscription(or.ctx, "ns1", sub) + s1, err := or.CreateUpdateSubscription(or.ctx, sub) assert.NoError(t, err) assert.Equal(t, s1, sub) - assert.Equal(t, "ns1", sub.Namespace) + assert.Equal(t, "ns", sub.Namespace) } func TestDeleteSubscriptionBadUUID(t *testing.T) { or := newTestOrchestrator() - or.mdi.On("GetSubscriptionByID", mock.Anything, mock.Anything).Return(nil, fmt.Errorf("pop")) - err := or.DeleteSubscription(or.ctx, "ns2", "! a UUID") + defer or.cleanup(t) + + err := or.DeleteSubscription(or.ctx, "! a UUID") assert.Regexp(t, "FF00138", err) } func TestDeleteSubscriptionLookupError(t *testing.T) { or := newTestOrchestrator() - or.mdi.On("GetSubscriptionByID", mock.Anything, mock.Anything).Return(nil, fmt.Errorf("pop")) - err := or.DeleteSubscription(or.ctx, "ns2", fftypes.NewUUID().String()) + defer or.cleanup(t) + + or.mdi.On("GetSubscriptionByID", mock.Anything, "ns", mock.Anything).Return(nil, fmt.Errorf("pop")) + err := or.DeleteSubscription(or.ctx, fftypes.NewUUID().String()) assert.EqualError(t, err, "pop") } -func TestDeleteSubscriptionNSMismatch(t *testing.T) { +func TestDeleteSubscriptionNotFound(t *testing.T) { or := newTestOrchestrator() - sub := &core.Subscription{ - SubscriptionRef: core.SubscriptionRef{ - ID: fftypes.NewUUID(), - Name: "sub1", - Namespace: "ns1", - }, - } - or.mdi.On("GetSubscriptionByID", mock.Anything, sub.ID).Return(sub, nil) - err := or.DeleteSubscription(or.ctx, "ns2", sub.ID.String()) + defer or.cleanup(t) + + or.mdi.On("GetSubscriptionByID", mock.Anything, "ns", mock.Anything).Return(nil, nil) + err := or.DeleteSubscription(or.ctx, fftypes.NewUUID().String()) assert.Regexp(t, "FF10109", err) } func TestDeleteSubscription(t *testing.T) { or := newTestOrchestrator() + defer or.cleanup(t) + sub := &core.Subscription{ SubscriptionRef: core.SubscriptionRef{ ID: fftypes.NewUUID(), @@ -129,32 +139,38 @@ func TestDeleteSubscription(t *testing.T) { Namespace: "ns1", }, } - or.mdi.On("GetSubscriptionByID", mock.Anything, sub.ID).Return(sub, nil) + or.mdi.On("GetSubscriptionByID", mock.Anything, "ns", sub.ID).Return(sub, nil) or.mem.On("DeleteDurableSubscription", mock.Anything, sub).Return(nil) - err := or.DeleteSubscription(or.ctx, "ns1", sub.ID.String()) + err := or.DeleteSubscription(or.ctx, sub.ID.String()) assert.NoError(t, err) } func TestGetSubscriptions(t *testing.T) { or := newTestOrchestrator() + defer or.cleanup(t) + u := fftypes.NewUUID() - or.mdi.On("GetSubscriptions", mock.Anything, mock.Anything).Return([]*core.Subscription{}, nil, nil) + or.mdi.On("GetSubscriptions", mock.Anything, "ns", mock.Anything).Return([]*core.Subscription{}, nil, nil) fb := database.SubscriptionQueryFactory.NewFilter(context.Background()) f := fb.And(fb.Eq("id", u)) - _, _, err := or.GetSubscriptions(context.Background(), "ns1", f) + _, _, err := or.GetSubscriptions(context.Background(), f) assert.NoError(t, err) } func TestGetSGetSubscriptionsByID(t *testing.T) { or := newTestOrchestrator() + defer or.cleanup(t) + u := fftypes.NewUUID() - or.mdi.On("GetSubscriptionByID", mock.Anything, u).Return(nil, nil) - _, err := or.GetSubscriptionByID(context.Background(), "ns1", u.String()) + or.mdi.On("GetSubscriptionByID", mock.Anything, "ns", u).Return(nil, nil) + _, err := or.GetSubscriptionByID(context.Background(), u.String()) assert.NoError(t, err) } func TestGetSubscriptionDefsByIDBadID(t *testing.T) { or := newTestOrchestrator() - _, err := or.GetSubscriptionByID(context.Background(), "", "") + defer or.cleanup(t) + + _, err := or.GetSubscriptionByID(context.Background(), "") assert.Regexp(t, "FF00138", err) } diff --git a/mocks/databasemocks/plugin.go b/mocks/databasemocks/plugin.go index a70229605..ebe9ff266 100644 --- a/mocks/databasemocks/plugin.go +++ b/mocks/databasemocks/plugin.go @@ -121,13 +121,13 @@ func (_m *Plugin) DeleteOffset(ctx context.Context, t fftypes.FFEnum, name strin return r0 } -// DeleteSubscriptionByID provides a mock function with given fields: ctx, id -func (_m *Plugin) DeleteSubscriptionByID(ctx context.Context, id *fftypes.UUID) error { - ret := _m.Called(ctx, id) +// DeleteSubscriptionByID provides a mock function with given fields: ctx, namespace, id +func (_m *Plugin) DeleteSubscriptionByID(ctx context.Context, namespace string, id *fftypes.UUID) error { + ret := _m.Called(ctx, namespace, id) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *fftypes.UUID) error); ok { - r0 = rf(ctx, id) + if rf, ok := ret.Get(0).(func(context.Context, string, *fftypes.UUID) error); ok { + r0 = rf(ctx, namespace, id) } else { r0 = ret.Error(0) } @@ -1566,13 +1566,13 @@ func (_m *Plugin) GetPins(ctx context.Context, namespace string, filter database return r0, r1, r2 } -// GetSubscriptionByID provides a mock function with given fields: ctx, id -func (_m *Plugin) GetSubscriptionByID(ctx context.Context, id *fftypes.UUID) (*core.Subscription, error) { - ret := _m.Called(ctx, id) +// GetSubscriptionByID provides a mock function with given fields: ctx, namespace, id +func (_m *Plugin) GetSubscriptionByID(ctx context.Context, namespace string, id *fftypes.UUID) (*core.Subscription, error) { + ret := _m.Called(ctx, namespace, id) var r0 *core.Subscription - if rf, ok := ret.Get(0).(func(context.Context, *fftypes.UUID) *core.Subscription); ok { - r0 = rf(ctx, id) + if rf, ok := ret.Get(0).(func(context.Context, string, *fftypes.UUID) *core.Subscription); ok { + r0 = rf(ctx, namespace, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.Subscription) @@ -1580,8 +1580,8 @@ func (_m *Plugin) GetSubscriptionByID(ctx context.Context, id *fftypes.UUID) (*c } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *fftypes.UUID) error); ok { - r1 = rf(ctx, id) + if rf, ok := ret.Get(1).(func(context.Context, string, *fftypes.UUID) error); ok { + r1 = rf(ctx, namespace, id) } else { r1 = ret.Error(1) } @@ -1612,13 +1612,13 @@ func (_m *Plugin) GetSubscriptionByName(ctx context.Context, namespace string, n return r0, r1 } -// GetSubscriptions provides a mock function with given fields: ctx, filter -func (_m *Plugin) GetSubscriptions(ctx context.Context, filter database.Filter) ([]*core.Subscription, *database.FilterResult, error) { - ret := _m.Called(ctx, filter) +// GetSubscriptions provides a mock function with given fields: ctx, namespace, filter +func (_m *Plugin) GetSubscriptions(ctx context.Context, namespace string, filter database.Filter) ([]*core.Subscription, *database.FilterResult, error) { + ret := _m.Called(ctx, namespace, filter) var r0 []*core.Subscription - if rf, ok := ret.Get(0).(func(context.Context, database.Filter) []*core.Subscription); ok { - r0 = rf(ctx, filter) + if rf, ok := ret.Get(0).(func(context.Context, string, database.Filter) []*core.Subscription); ok { + r0 = rf(ctx, namespace, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.Subscription) @@ -1626,8 +1626,8 @@ func (_m *Plugin) GetSubscriptions(ctx context.Context, filter database.Filter) } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, database.Filter) *database.FilterResult); ok { - r1 = rf(ctx, filter) + if rf, ok := ret.Get(1).(func(context.Context, string, database.Filter) *database.FilterResult); ok { + r1 = rf(ctx, namespace, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -1635,8 +1635,8 @@ func (_m *Plugin) GetSubscriptions(ctx context.Context, filter database.Filter) } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, database.Filter) error); ok { - r2 = rf(ctx, filter) + if rf, ok := ret.Get(2).(func(context.Context, string, database.Filter) error); ok { + r2 = rf(ctx, namespace, filter) } else { r2 = ret.Error(2) } diff --git a/mocks/orchestratormocks/orchestrator.go b/mocks/orchestratormocks/orchestrator.go index 6e0e4140e..fce2e6f54 100644 --- a/mocks/orchestratormocks/orchestrator.go +++ b/mocks/orchestratormocks/orchestrator.go @@ -100,13 +100,13 @@ func (_m *Orchestrator) Contracts() contracts.Manager { return r0 } -// CreateSubscription provides a mock function with given fields: ctx, ns, subDef -func (_m *Orchestrator) CreateSubscription(ctx context.Context, ns string, subDef *core.Subscription) (*core.Subscription, error) { - ret := _m.Called(ctx, ns, subDef) +// CreateSubscription provides a mock function with given fields: ctx, subDef +func (_m *Orchestrator) CreateSubscription(ctx context.Context, subDef *core.Subscription) (*core.Subscription, error) { + ret := _m.Called(ctx, subDef) var r0 *core.Subscription - if rf, ok := ret.Get(0).(func(context.Context, string, *core.Subscription) *core.Subscription); ok { - r0 = rf(ctx, ns, subDef) + if rf, ok := ret.Get(0).(func(context.Context, *core.Subscription) *core.Subscription); ok { + r0 = rf(ctx, subDef) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.Subscription) @@ -114,8 +114,8 @@ func (_m *Orchestrator) CreateSubscription(ctx context.Context, ns string, subDe } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, *core.Subscription) error); ok { - r1 = rf(ctx, ns, subDef) + if rf, ok := ret.Get(1).(func(context.Context, *core.Subscription) error); ok { + r1 = rf(ctx, subDef) } else { r1 = ret.Error(1) } @@ -123,13 +123,13 @@ func (_m *Orchestrator) CreateSubscription(ctx context.Context, ns string, subDe return r0, r1 } -// CreateUpdateSubscription provides a mock function with given fields: ctx, ns, subDef -func (_m *Orchestrator) CreateUpdateSubscription(ctx context.Context, ns string, subDef *core.Subscription) (*core.Subscription, error) { - ret := _m.Called(ctx, ns, subDef) +// CreateUpdateSubscription provides a mock function with given fields: ctx, subDef +func (_m *Orchestrator) CreateUpdateSubscription(ctx context.Context, subDef *core.Subscription) (*core.Subscription, error) { + ret := _m.Called(ctx, subDef) var r0 *core.Subscription - if rf, ok := ret.Get(0).(func(context.Context, string, *core.Subscription) *core.Subscription); ok { - r0 = rf(ctx, ns, subDef) + if rf, ok := ret.Get(0).(func(context.Context, *core.Subscription) *core.Subscription); ok { + r0 = rf(ctx, subDef) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.Subscription) @@ -137,8 +137,8 @@ func (_m *Orchestrator) CreateUpdateSubscription(ctx context.Context, ns string, } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, *core.Subscription) error); ok { - r1 = rf(ctx, ns, subDef) + if rf, ok := ret.Get(1).(func(context.Context, *core.Subscription) error); ok { + r1 = rf(ctx, subDef) } else { r1 = ret.Error(1) } @@ -162,13 +162,13 @@ func (_m *Orchestrator) Data() data.Manager { return r0 } -// DeleteSubscription provides a mock function with given fields: ctx, ns, id -func (_m *Orchestrator) DeleteSubscription(ctx context.Context, ns string, id string) error { - ret := _m.Called(ctx, ns, id) +// DeleteSubscription provides a mock function with given fields: ctx, id +func (_m *Orchestrator) DeleteSubscription(ctx context.Context, id string) error { + ret := _m.Called(ctx, id) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, ns, id) + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -898,13 +898,13 @@ func (_m *Orchestrator) GetStatus(ctx context.Context) (*core.NodeStatus, error) return r0, r1 } -// GetSubscriptionByID provides a mock function with given fields: ctx, ns, id -func (_m *Orchestrator) GetSubscriptionByID(ctx context.Context, ns string, id string) (*core.Subscription, error) { - ret := _m.Called(ctx, ns, id) +// GetSubscriptionByID provides a mock function with given fields: ctx, id +func (_m *Orchestrator) GetSubscriptionByID(ctx context.Context, id string) (*core.Subscription, error) { + ret := _m.Called(ctx, id) var r0 *core.Subscription - if rf, ok := ret.Get(0).(func(context.Context, string, string) *core.Subscription); ok { - r0 = rf(ctx, ns, id) + if rf, ok := ret.Get(0).(func(context.Context, string) *core.Subscription); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.Subscription) @@ -912,8 +912,8 @@ func (_m *Orchestrator) GetSubscriptionByID(ctx context.Context, ns string, id s } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, ns, id) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -921,13 +921,13 @@ func (_m *Orchestrator) GetSubscriptionByID(ctx context.Context, ns string, id s return r0, r1 } -// GetSubscriptions provides a mock function with given fields: ctx, ns, filter -func (_m *Orchestrator) GetSubscriptions(ctx context.Context, ns string, filter database.AndFilter) ([]*core.Subscription, *database.FilterResult, error) { - ret := _m.Called(ctx, ns, filter) +// GetSubscriptions provides a mock function with given fields: ctx, filter +func (_m *Orchestrator) GetSubscriptions(ctx context.Context, filter database.AndFilter) ([]*core.Subscription, *database.FilterResult, error) { + ret := _m.Called(ctx, filter) var r0 []*core.Subscription - if rf, ok := ret.Get(0).(func(context.Context, string, database.AndFilter) []*core.Subscription); ok { - r0 = rf(ctx, ns, filter) + if rf, ok := ret.Get(0).(func(context.Context, database.AndFilter) []*core.Subscription); ok { + r0 = rf(ctx, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.Subscription) @@ -935,8 +935,8 @@ func (_m *Orchestrator) GetSubscriptions(ctx context.Context, ns string, filter } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, string, database.AndFilter) *database.FilterResult); ok { - r1 = rf(ctx, ns, filter) + if rf, ok := ret.Get(1).(func(context.Context, database.AndFilter) *database.FilterResult); ok { + r1 = rf(ctx, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -944,8 +944,8 @@ func (_m *Orchestrator) GetSubscriptions(ctx context.Context, ns string, filter } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, string, database.AndFilter) error); ok { - r2 = rf(ctx, ns, filter) + if rf, ok := ret.Get(2).(func(context.Context, database.AndFilter) error); ok { + r2 = rf(ctx, filter) } else { r2 = ret.Error(2) } diff --git a/pkg/database/plugin.go b/pkg/database/plugin.go index 4da1fc2a1..d38765805 100644 --- a/pkg/database/plugin.go +++ b/pkg/database/plugin.go @@ -237,13 +237,13 @@ type iSubscriptionCollection interface { GetSubscriptionByName(ctx context.Context, namespace, name string) (offset *core.Subscription, err error) // GetSubscriptionByID - Get an subscription by id - GetSubscriptionByID(ctx context.Context, id *fftypes.UUID) (offset *core.Subscription, err error) + GetSubscriptionByID(ctx context.Context, namespace string, id *fftypes.UUID) (offset *core.Subscription, err error) // GetSubscriptions - Get subscriptions - GetSubscriptions(ctx context.Context, filter Filter) (offset []*core.Subscription, res *FilterResult, err error) + GetSubscriptions(ctx context.Context, namespace string, filter Filter) (offset []*core.Subscription, res *FilterResult, err error) // DeleteSubscriptionByID - Delete a subscription - DeleteSubscriptionByID(ctx context.Context, id *fftypes.UUID) (err error) + DeleteSubscriptionByID(ctx context.Context, namespace string, id *fftypes.UUID) (err error) } type iEventCollection interface { From fff8c01b93db38aba1793ddfac76f7185f07db22 Mon Sep 17 00:00:00 2001 From: Andrew Richardson Date: Wed, 22 Jun 2022 13:31:05 -0400 Subject: [PATCH 2/9] Add namespace to token pool database queries Signed-off-by: Andrew Richardson --- .../apiserver/route_get_token_connectors.go | 2 +- .../route_get_token_connectors_test.go | 2 +- internal/apiserver/route_get_token_pools.go | 2 +- .../apiserver/route_get_token_pools_test.go | 2 +- internal/apiserver/route_post_token_pool.go | 2 +- .../apiserver/route_post_token_pool_test.go | 2 +- internal/assets/manager.go | 21 ++- internal/assets/manager_test.go | 6 +- internal/assets/operations.go | 6 +- internal/assets/operations_test.go | 18 +-- internal/assets/token_approval_test.go | 7 +- internal/assets/token_pool.go | 27 ++-- internal/assets/token_pool_test.go | 130 ++++-------------- internal/assets/token_transfer_test.go | 15 +- internal/database/sqlcommon/tokenpool_sql.go | 16 ++- .../database/sqlcommon/tokenpool_sql_test.go | 20 +-- .../definition_handler_tokenpool.go | 2 +- .../definition_handler_tokenpool_test.go | 14 +- internal/events/token_pool_created.go | 2 +- internal/events/token_pool_created_test.go | 28 ++-- internal/events/tokens_approved.go | 2 +- internal/events/tokens_approved_test.go | 22 +-- internal/events/tokens_transferred.go | 2 +- internal/events/tokens_transferred_test.go | 26 ++-- internal/orchestrator/orchestrator.go | 2 +- internal/orchestrator/txn_status.go | 2 +- internal/orchestrator/txn_status_test.go | 8 +- internal/syncasync/sync_async_bridge.go | 2 +- internal/syncasync/sync_async_bridge_test.go | 6 +- internal/txcommon/event_enrich.go | 2 +- internal/txcommon/event_enrich_test.go | 4 +- mocks/assetmocks/manager.go | 56 ++++---- mocks/databasemocks/plugin.go | 46 +++---- pkg/database/plugin.go | 8 +- 34 files changed, 206 insertions(+), 306 deletions(-) diff --git a/internal/apiserver/route_get_token_connectors.go b/internal/apiserver/route_get_token_connectors.go index c45a070a2..1f264f892 100644 --- a/internal/apiserver/route_get_token_connectors.go +++ b/internal/apiserver/route_get_token_connectors.go @@ -36,7 +36,7 @@ var getTokenConnectors = &ffapi.Route{ JSONOutputCodes: []int{http.StatusOK}, Extensions: &coreExtensions{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - return cr.or.Assets().GetTokenConnectors(cr.ctx, extractNamespace(r.PP)), nil + return cr.or.Assets().GetTokenConnectors(cr.ctx), nil }, }, } diff --git a/internal/apiserver/route_get_token_connectors_test.go b/internal/apiserver/route_get_token_connectors_test.go index ad0931ae3..73c035e50 100644 --- a/internal/apiserver/route_get_token_connectors_test.go +++ b/internal/apiserver/route_get_token_connectors_test.go @@ -34,7 +34,7 @@ func TestGetTokenConnectors(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mam.On("GetTokenConnectors", mock.Anything, "ns1", mock.Anything). + mam.On("GetTokenConnectors", mock.Anything, mock.Anything). Return([]*core.TokenConnector{}, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_get_token_pools.go b/internal/apiserver/route_get_token_pools.go index 6d8028223..67caf0f27 100644 --- a/internal/apiserver/route_get_token_pools.go +++ b/internal/apiserver/route_get_token_pools.go @@ -38,7 +38,7 @@ var getTokenPools = &ffapi.Route{ Extensions: &coreExtensions{ FilterFactory: database.TokenPoolQueryFactory, CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - return filterResult(cr.or.Assets().GetTokenPools(cr.ctx, extractNamespace(r.PP), cr.filter)) + return filterResult(cr.or.Assets().GetTokenPools(cr.ctx, cr.filter)) }, }, } diff --git a/internal/apiserver/route_get_token_pools_test.go b/internal/apiserver/route_get_token_pools_test.go index 41ca3ca1a..47699732d 100644 --- a/internal/apiserver/route_get_token_pools_test.go +++ b/internal/apiserver/route_get_token_pools_test.go @@ -34,7 +34,7 @@ func TestGetTokenPools(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mam.On("GetTokenPools", mock.Anything, "ns1", mock.Anything). + mam.On("GetTokenPools", mock.Anything, mock.Anything). Return([]*core.TokenPool{}, nil, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_post_token_pool.go b/internal/apiserver/route_post_token_pool.go index b17050835..2914b08b3 100644 --- a/internal/apiserver/route_post_token_pool.go +++ b/internal/apiserver/route_post_token_pool.go @@ -41,7 +41,7 @@ var postTokenPool = &ffapi.Route{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { waitConfirm := strings.EqualFold(r.QP["confirm"], "true") r.SuccessStatus = syncRetcode(waitConfirm) - return cr.or.Assets().CreateTokenPool(cr.ctx, extractNamespace(r.PP), r.Input.(*core.TokenPool), waitConfirm) + return cr.or.Assets().CreateTokenPool(cr.ctx, r.Input.(*core.TokenPool), waitConfirm) }, }, } diff --git a/internal/apiserver/route_post_token_pool_test.go b/internal/apiserver/route_post_token_pool_test.go index 1c57c5277..7fcb89a5d 100644 --- a/internal/apiserver/route_post_token_pool_test.go +++ b/internal/apiserver/route_post_token_pool_test.go @@ -39,7 +39,7 @@ func TestPostTokenPool(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mam.On("CreateTokenPool", mock.Anything, "ns1", mock.AnythingOfType("*core.TokenPool"), false). + mam.On("CreateTokenPool", mock.Anything, mock.AnythingOfType("*core.TokenPool"), false). Return(&core.TokenPool{}, nil) r.ServeHTTP(res, req) diff --git a/internal/assets/manager.go b/internal/assets/manager.go index bcfa899d5..29e11c619 100644 --- a/internal/assets/manager.go +++ b/internal/assets/manager.go @@ -25,7 +25,6 @@ import ( "github.com/hyperledger/firefly/internal/broadcast" "github.com/hyperledger/firefly/internal/coreconfig" "github.com/hyperledger/firefly/internal/coremsgs" - "github.com/hyperledger/firefly/internal/data" "github.com/hyperledger/firefly/internal/identity" "github.com/hyperledger/firefly/internal/metrics" "github.com/hyperledger/firefly/internal/operations" @@ -41,10 +40,10 @@ import ( type Manager interface { core.Named - CreateTokenPool(ctx context.Context, ns string, pool *core.TokenPool, waitConfirm bool) (*core.TokenPool, error) + CreateTokenPool(ctx context.Context, pool *core.TokenPool, waitConfirm bool) (*core.TokenPool, error) ActivateTokenPool(ctx context.Context, pool *core.TokenPool) error - GetTokenPools(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenPool, *database.FilterResult, error) - GetTokenPool(ctx context.Context, ns, connector, poolName string) (*core.TokenPool, error) + GetTokenPools(ctx context.Context, filter database.AndFilter) ([]*core.TokenPool, *database.FilterResult, error) + GetTokenPool(ctx context.Context, connector, poolName string) (*core.TokenPool, error) GetTokenPoolByNameOrID(ctx context.Context, poolNameOrID string) (*core.TokenPool, error) GetTokenBalances(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenBalance, *database.FilterResult, error) @@ -59,7 +58,7 @@ type Manager interface { BurnTokens(ctx context.Context, transfer *core.TokenTransferInput, waitConfirm bool) (*core.TokenTransfer, error) TransferTokens(ctx context.Context, transfer *core.TokenTransferInput, waitConfirm bool) (*core.TokenTransfer, error) - GetTokenConnectors(ctx context.Context, ns string) []*core.TokenConnector + GetTokenConnectors(ctx context.Context) []*core.TokenConnector NewApproval(approve *core.TokenApprovalInput) sysmessaging.MessageSender TokenApproval(ctx context.Context, approval *core.TokenApprovalInput, waitConfirm bool) (*core.TokenApproval, error) @@ -76,7 +75,6 @@ type assetManager struct { database database.Plugin txHelper txcommon.Helper identity identity.Manager - data data.Manager syncasync syncasync.Bridge broadcast broadcast.Manager messaging privatemessaging.Manager @@ -86,7 +84,7 @@ type assetManager struct { keyNormalization int } -func NewAssetManager(ctx context.Context, ns string, di database.Plugin, im identity.Manager, dm data.Manager, sa syncasync.Bridge, bm broadcast.Manager, pm privatemessaging.Manager, ti map[string]tokens.Plugin, mm metrics.Manager, om operations.Manager, txHelper txcommon.Helper) (Manager, error) { +func NewAssetManager(ctx context.Context, ns string, di database.Plugin, im identity.Manager, sa syncasync.Bridge, bm broadcast.Manager, pm privatemessaging.Manager, ti map[string]tokens.Plugin, mm metrics.Manager, om operations.Manager, txHelper txcommon.Helper) (Manager, error) { if di == nil || im == nil || sa == nil || bm == nil || pm == nil || ti == nil || mm == nil || om == nil { return nil, i18n.NewError(ctx, coremsgs.MsgInitializationNilDepError, "AssetManager") } @@ -96,7 +94,6 @@ func NewAssetManager(ctx context.Context, ns string, di database.Plugin, im iden database: di, txHelper: txHelper, identity: im, - data: dm, syncasync: sa, broadcast: bm, messaging: pm, @@ -143,7 +140,7 @@ func (am *assetManager) GetTokenAccountPools(ctx context.Context, ns, key string return am.database.GetTokenAccountPools(ctx, key, am.scopeNS(ns, filter)) } -func (am *assetManager) GetTokenConnectors(ctx context.Context, ns string) []*core.TokenConnector { +func (am *assetManager) GetTokenConnectors(ctx context.Context) []*core.TokenConnector { connectors := []*core.TokenConnector{} for token := range am.tokens { connectors = append( @@ -156,8 +153,8 @@ func (am *assetManager) GetTokenConnectors(ctx context.Context, ns string) []*co return connectors } -func (am *assetManager) getDefaultTokenConnector(ctx context.Context, ns string) (string, error) { - tokenConnectors := am.GetTokenConnectors(ctx, ns) +func (am *assetManager) getDefaultTokenConnector(ctx context.Context) (string, error) { + tokenConnectors := am.GetTokenConnectors(ctx) if len(tokenConnectors) != 1 { return "", i18n.NewError(ctx, coremsgs.MsgFieldNotSpecified, "connector") } @@ -167,7 +164,7 @@ func (am *assetManager) getDefaultTokenConnector(ctx context.Context, ns string) func (am *assetManager) getDefaultTokenPool(ctx context.Context) (*core.TokenPool, error) { f := database.TokenPoolQueryFactory.NewFilter(ctx).And() f.Limit(1).Count(true) - tokenPools, fr, err := am.GetTokenPools(ctx, am.namespace, f) + tokenPools, fr, err := am.GetTokenPools(ctx, f) if err != nil { return nil, err } diff --git a/internal/assets/manager_test.go b/internal/assets/manager_test.go index d7fb34136..df0d12db7 100644 --- a/internal/assets/manager_test.go +++ b/internal/assets/manager_test.go @@ -63,7 +63,7 @@ func newTestAssetsCommon(t *testing.T, metrics bool) (*assetManager, func()) { mom.On("RegisterHandler", mock.Anything, mock.Anything, mock.Anything) mti.On("Name").Return("ut").Maybe() ctx, cancel := context.WithCancel(context.Background()) - a, err := NewAssetManager(ctx, "ns1", mdi, mim, mdm, msa, mbm, mpm, map[string]tokens.Plugin{"magic-tokens": mti}, mm, mom, txHelper) + a, err := NewAssetManager(ctx, "ns1", mdi, mim, msa, mbm, mpm, map[string]tokens.Plugin{"magic-tokens": mti}, mm, mom, txHelper) rag := mdi.On("RunAsGroup", mock.Anything, mock.Anything).Maybe() rag.RunFn = func(a mock.Arguments) { rag.ReturnArguments = mock.Arguments{a[1].(func(context.Context) error)(a[0].(context.Context))} @@ -75,7 +75,7 @@ func newTestAssetsCommon(t *testing.T, metrics bool) (*assetManager, func()) { } func TestInitFail(t *testing.T) { - _, err := NewAssetManager(context.Background(), "", nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + _, err := NewAssetManager(context.Background(), "", nil, nil, nil, nil, nil, nil, nil, nil, nil) assert.Regexp(t, "FF10128", err) } @@ -125,7 +125,7 @@ func TestGetTokenConnectors(t *testing.T) { am, cancel := newTestAssets(t) defer cancel() - connectors := am.GetTokenConnectors(context.Background(), "ns1") + connectors := am.GetTokenConnectors(context.Background()) assert.Equal(t, 1, len(connectors)) assert.Equal(t, "magic-tokens", connectors[0].Name) } diff --git a/internal/assets/operations.go b/internal/assets/operations.go index e36090955..92bb2ec34 100644 --- a/internal/assets/operations.go +++ b/internal/assets/operations.go @@ -61,7 +61,7 @@ func (am *assetManager) PrepareOperation(ctx context.Context, op *core.Operation if err != nil { return nil, err } - pool, err := am.database.GetTokenPoolByID(ctx, poolID) + pool, err := am.database.GetTokenPoolByID(ctx, am.namespace, poolID) if err != nil { return nil, err } else if pool == nil { @@ -74,7 +74,7 @@ func (am *assetManager) PrepareOperation(ctx context.Context, op *core.Operation if err != nil { return nil, err } - pool, err := am.database.GetTokenPoolByID(ctx, transfer.Pool) + pool, err := am.database.GetTokenPoolByID(ctx, am.namespace, transfer.Pool) if err != nil { return nil, err } else if pool == nil { @@ -87,7 +87,7 @@ func (am *assetManager) PrepareOperation(ctx context.Context, op *core.Operation if err != nil { return nil, err } - pool, err := am.database.GetTokenPoolByID(ctx, approval.Pool) + pool, err := am.database.GetTokenPoolByID(ctx, am.namespace, approval.Pool) if err != nil { return nil, err } else if pool == nil { diff --git a/internal/assets/operations_test.go b/internal/assets/operations_test.go index 5a6c18ab9..ef6d0ee9f 100644 --- a/internal/assets/operations_test.go +++ b/internal/assets/operations_test.go @@ -80,7 +80,7 @@ func TestPrepareAndRunActivatePool(t *testing.T) { mti := am.tokens["magic-tokens"].(*tokenmocks.Plugin) mdi := am.database.(*databasemocks.Plugin) mti.On("ActivateTokenPool", context.Background(), "ns1:"+op.ID.String(), pool).Return(true, nil) - mdi.On("GetTokenPoolByID", context.Background(), pool.ID).Return(pool, nil) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", pool.ID).Return(pool, nil) po, err := am.PrepareOperation(context.Background(), op) assert.NoError(t, err) @@ -118,7 +118,7 @@ func TestPrepareAndRunTransfer(t *testing.T) { mti := am.tokens["magic-tokens"].(*tokenmocks.Plugin) mdi := am.database.(*databasemocks.Plugin) mti.On("TransferTokens", context.Background(), "ns1:"+op.ID.String(), "F1", transfer).Return(nil) - mdi.On("GetTokenPoolByID", context.Background(), pool.ID).Return(pool, nil) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", pool.ID).Return(pool, nil) po, err := am.PrepareOperation(context.Background(), op) assert.NoError(t, err) @@ -157,7 +157,7 @@ func TestPrepareAndRunApproval(t *testing.T) { mti := am.tokens["magic-tokens"].(*tokenmocks.Plugin) mdi := am.database.(*databasemocks.Plugin) mti.On("TokensApproval", context.Background(), "ns1:"+op.ID.String(), "F1", approval).Return(nil) - mdi.On("GetTokenPoolByID", context.Background(), pool.ID).Return(pool, nil) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", pool.ID).Return(pool, nil) po, err := am.PrepareOperation(context.Background(), op) assert.NoError(t, err) @@ -220,7 +220,7 @@ func TestPrepareOperationActivatePoolError(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdi.On("GetTokenPoolByID", context.Background(), poolID).Return(nil, fmt.Errorf("pop")) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", poolID).Return(nil, fmt.Errorf("pop")) _, err := am.PrepareOperation(context.Background(), op) assert.EqualError(t, err, "pop") @@ -239,7 +239,7 @@ func TestPrepareOperationActivatePoolNotFound(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdi.On("GetTokenPoolByID", context.Background(), poolID).Return(nil, nil) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", poolID).Return(nil, nil) _, err := am.PrepareOperation(context.Background(), op) assert.Regexp(t, "FF10109", err) @@ -271,7 +271,7 @@ func TestPrepareOperationTransferError(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdi.On("GetTokenPoolByID", context.Background(), poolID).Return(nil, fmt.Errorf("pop")) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", poolID).Return(nil, fmt.Errorf("pop")) _, err := am.PrepareOperation(context.Background(), op) assert.EqualError(t, err, "pop") @@ -290,7 +290,7 @@ func TestPrepareOperationTransferNotFound(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdi.On("GetTokenPoolByID", context.Background(), poolID).Return(nil, nil) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", poolID).Return(nil, nil) _, err := am.PrepareOperation(context.Background(), op) assert.Regexp(t, "FF10109", err) @@ -322,7 +322,7 @@ func TestPrepareOperationApprovalError(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdi.On("GetTokenPoolByID", context.Background(), poolID).Return(nil, fmt.Errorf("pop")) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", poolID).Return(nil, fmt.Errorf("pop")) _, err := am.PrepareOperation(context.Background(), op) assert.EqualError(t, err, "pop") @@ -341,7 +341,7 @@ func TestPrepareOperationApprovalNotFound(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdi.On("GetTokenPoolByID", context.Background(), poolID).Return(nil, nil) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", poolID).Return(nil, nil) _, err := am.PrepareOperation(context.Background(), op) assert.Regexp(t, "FF10109", err) diff --git a/internal/assets/token_approval_test.go b/internal/assets/token_approval_test.go index 82fb45369..7974d3444 100644 --- a/internal/assets/token_approval_test.go +++ b/internal/assets/token_approval_test.go @@ -23,7 +23,6 @@ import ( "github.com/hyperledger/firefly/internal/identity" "github.com/hyperledger/firefly/internal/syncasync" "github.com/hyperledger/firefly/mocks/databasemocks" - "github.com/hyperledger/firefly/mocks/datamocks" "github.com/hyperledger/firefly/mocks/identitymanagermocks" "github.com/hyperledger/firefly/mocks/operationmocks" "github.com/hyperledger/firefly/mocks/syncasyncmocks" @@ -187,7 +186,7 @@ func TestApprovalDefaultPoolSuccess(t *testing.T) { TotalCount: &totalCount, } mim.On("NormalizeSigningKey", context.Background(), "key", identity.KeyNormalizationBlockchainPlugin).Return("0x12345", nil) - mdi.On("GetTokenPools", context.Background(), mock.MatchedBy((func(f database.AndFilter) bool { + mdi.On("GetTokenPools", context.Background(), "ns1", mock.MatchedBy((func(f database.AndFilter) bool { info, _ := f.Finalize() return info.Count && info.Limit == 1 }))).Return(tokenPools, filterResult, nil) @@ -228,7 +227,7 @@ func TestApprovalDefaultPoolNoPool(t *testing.T) { filterResult := &database.FilterResult{ TotalCount: &totalCount, } - mdi.On("GetTokenPools", context.Background(), mock.MatchedBy((func(f database.AndFilter) bool { + mdi.On("GetTokenPools", context.Background(), "ns1", mock.MatchedBy((func(f database.AndFilter) bool { info, _ := f.Finalize() return info.Count && info.Limit == 1 }))).Return(tokenPools, filterResult, nil) @@ -436,7 +435,6 @@ func TestTokenApprovalConfirm(t *testing.T) { mdi := am.database.(*databasemocks.Plugin) mim := am.identity.(*identitymanagermocks.Manager) - mdm := am.data.(*datamocks.Manager) msa := am.syncasync.(*syncasyncmocks.Bridge) mth := am.txHelper.(*txcommonmocks.Helper) mom := am.operations.(*operationmocks.Manager) @@ -461,7 +459,6 @@ func TestTokenApprovalConfirm(t *testing.T) { mdi.AssertExpectations(t) mim.AssertExpectations(t) - mdm.AssertExpectations(t) msa.AssertExpectations(t) mom.AssertExpectations(t) } diff --git a/internal/assets/token_pool.go b/internal/assets/token_pool.go index 58353db70..1a38c6071 100644 --- a/internal/assets/token_pool.go +++ b/internal/assets/token_pool.go @@ -28,23 +28,20 @@ import ( "github.com/hyperledger/firefly/pkg/database" ) -func (am *assetManager) CreateTokenPool(ctx context.Context, ns string, pool *core.TokenPool, waitConfirm bool) (*core.TokenPool, error) { - if err := am.data.VerifyNamespaceExists(ctx, ns); err != nil { - return nil, err - } +func (am *assetManager) CreateTokenPool(ctx context.Context, pool *core.TokenPool, waitConfirm bool) (*core.TokenPool, error) { if err := core.ValidateFFNameFieldNoUUID(ctx, pool.Name, "name"); err != nil { return nil, err } - if existing, err := am.database.GetTokenPool(ctx, ns, pool.Name); err != nil { + if existing, err := am.database.GetTokenPool(ctx, am.namespace, pool.Name); err != nil { return nil, err } else if existing != nil { return nil, i18n.NewError(ctx, coremsgs.MsgTokenPoolDuplicate, pool.Name) } pool.ID = fftypes.NewUUID() - pool.Namespace = ns + pool.Namespace = am.namespace if pool.Connector == "" { - connector, err := am.getDefaultTokenConnector(ctx, ns) + connector, err := am.getDefaultTokenConnector(ctx) if err != nil { return nil, err } @@ -136,24 +133,18 @@ func (am *assetManager) ActivateTokenPool(ctx context.Context, pool *core.TokenP return err } -func (am *assetManager) GetTokenPools(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenPool, *database.FilterResult, error) { - if err := core.ValidateFFNameField(ctx, ns, "namespace"); err != nil { - return nil, nil, err - } - return am.database.GetTokenPools(ctx, am.scopeNS(ns, filter)) +func (am *assetManager) GetTokenPools(ctx context.Context, filter database.AndFilter) ([]*core.TokenPool, *database.FilterResult, error) { + return am.database.GetTokenPools(ctx, am.namespace, filter) } -func (am *assetManager) GetTokenPool(ctx context.Context, ns, connector, poolName string) (*core.TokenPool, error) { +func (am *assetManager) GetTokenPool(ctx context.Context, connector, poolName string) (*core.TokenPool, error) { if _, err := am.selectTokenPlugin(ctx, connector); err != nil { return nil, err } - if err := core.ValidateFFNameField(ctx, ns, "namespace"); err != nil { - return nil, err - } if err := core.ValidateFFNameFieldNoUUID(ctx, poolName, "name"); err != nil { return nil, err } - pool, err := am.database.GetTokenPool(ctx, ns, poolName) + pool, err := am.database.GetTokenPool(ctx, am.namespace, poolName) if err != nil { return nil, err } @@ -174,7 +165,7 @@ func (am *assetManager) GetTokenPoolByNameOrID(ctx context.Context, poolNameOrID if pool, err = am.database.GetTokenPool(ctx, am.namespace, poolNameOrID); err != nil { return nil, err } - } else if pool, err = am.database.GetTokenPoolByID(ctx, poolID); err != nil { + } else if pool, err = am.database.GetTokenPoolByID(ctx, am.namespace, poolID); err != nil { return nil, err } if pool == nil { diff --git a/internal/assets/token_pool_test.go b/internal/assets/token_pool_test.go index 3e0908b4b..d8e0a567e 100644 --- a/internal/assets/token_pool_test.go +++ b/internal/assets/token_pool_test.go @@ -23,7 +23,6 @@ import ( "github.com/hyperledger/firefly/internal/identity" "github.com/hyperledger/firefly/internal/syncasync" "github.com/hyperledger/firefly/mocks/databasemocks" - "github.com/hyperledger/firefly/mocks/datamocks" "github.com/hyperledger/firefly/mocks/identitymanagermocks" "github.com/hyperledger/firefly/mocks/operationmocks" "github.com/hyperledger/firefly/mocks/syncasyncmocks" @@ -41,10 +40,7 @@ func TestCreateTokenPoolBadName(t *testing.T) { pool := &core.TokenPool{} - mdm := am.data.(*datamocks.Manager) - mdm.On("VerifyNamespaceExists", context.Background(), "ns1").Return(nil) - - _, err := am.CreateTokenPool(context.Background(), "ns1", pool, false) + _, err := am.CreateTokenPool(context.Background(), pool, false) assert.Regexp(t, "FF00140", err) } @@ -57,11 +53,9 @@ func TestCreateTokenPoolGetError(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) mdi.On("GetTokenPool", context.Background(), "ns1", "testpool").Return(nil, fmt.Errorf("pop")) - mdm.On("VerifyNamespaceExists", context.Background(), "ns1").Return(nil) - _, err := am.CreateTokenPool(context.Background(), "ns1", pool, false) + _, err := am.CreateTokenPool(context.Background(), pool, false) assert.EqualError(t, err, "pop") mdi.AssertExpectations(t) @@ -76,11 +70,9 @@ func TestCreateTokenPoolDuplicateName(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) mdi.On("GetTokenPool", context.Background(), "ns1", "testpool").Return(&core.TokenPool{}, nil) - mdm.On("VerifyNamespaceExists", context.Background(), "ns1").Return(nil) - _, err := am.CreateTokenPool(context.Background(), "ns1", pool, false) + _, err := am.CreateTokenPool(context.Background(), pool, false) assert.Regexp(t, "FF10275.*testpool", err) mdi.AssertExpectations(t) @@ -95,13 +87,11 @@ func TestCreateTokenPoolDefaultConnectorSuccess(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) mim := am.identity.(*identitymanagermocks.Manager) mth := am.txHelper.(*txcommonmocks.Helper) mom := am.operations.(*operationmocks.Manager) mdi.On("GetTokenPool", context.Background(), "ns1", "testpool").Return(nil, nil) mim.On("NormalizeSigningKey", context.Background(), "", identity.KeyNormalizationBlockchainPlugin).Return("resolved-key", nil) - mdm.On("VerifyNamespaceExists", context.Background(), "ns1").Return(nil) mth.On("SubmitNewTransaction", context.Background(), core.TransactionTypeTokenPool).Return(fftypes.NewUUID(), nil) mdi.On("InsertOperation", context.Background(), mock.Anything).Return(nil) mom.On("RunOperation", context.Background(), mock.MatchedBy(func(op *core.PreparedOperation) bool { @@ -109,11 +99,10 @@ func TestCreateTokenPoolDefaultConnectorSuccess(t *testing.T) { return op.Type == core.OpTypeTokenCreatePool && data.Pool == pool })).Return(nil, nil) - _, err := am.CreateTokenPool(context.Background(), "ns1", pool, false) + _, err := am.CreateTokenPool(context.Background(), pool, false) assert.NoError(t, err) mdi.AssertExpectations(t) - mdm.AssertExpectations(t) mim.AssertExpectations(t) mth.AssertExpectations(t) mom.AssertExpectations(t) @@ -130,15 +119,12 @@ func TestCreateTokenPoolDefaultConnectorNoConnectors(t *testing.T) { am.tokens = make(map[string]tokens.Plugin) mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) mdi.On("GetTokenPool", context.Background(), "ns1", "testpool").Return(nil, nil) - mdm.On("VerifyNamespaceExists", context.Background(), "ns1").Return(nil) - _, err := am.CreateTokenPool(context.Background(), "ns1", pool, false) + _, err := am.CreateTokenPool(context.Background(), pool, false) assert.Regexp(t, "FF10292", err) mdi.AssertExpectations(t) - mdm.AssertExpectations(t) } func TestCreateTokenPoolDefaultConnectorMultipleConnectors(t *testing.T) { @@ -153,32 +139,12 @@ func TestCreateTokenPoolDefaultConnectorMultipleConnectors(t *testing.T) { am.tokens["magic-tokens2"] = nil mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) mdi.On("GetTokenPool", context.Background(), "ns1", "testpool").Return(nil, nil) - mdm.On("VerifyNamespaceExists", context.Background(), "ns1").Return(nil) - _, err := am.CreateTokenPool(context.Background(), "ns1", pool, false) + _, err := am.CreateTokenPool(context.Background(), pool, false) assert.Regexp(t, "FF10292", err) mdi.AssertExpectations(t) - mdm.AssertExpectations(t) -} - -func TestCreateTokenPoolMissingNamespace(t *testing.T) { - am, cancel := newTestAssets(t) - defer cancel() - - pool := &core.TokenPool{ - Name: "testpool", - } - - mdm := am.data.(*datamocks.Manager) - mdm.On("VerifyNamespaceExists", context.Background(), "ns1").Return(fmt.Errorf("pop")) - - _, err := am.CreateTokenPool(context.Background(), "ns1", pool, false) - assert.EqualError(t, err, "pop") - - mdm.AssertExpectations(t) } func TestCreateTokenPoolNoConnectors(t *testing.T) { @@ -191,15 +157,12 @@ func TestCreateTokenPoolNoConnectors(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) mdi.On("GetTokenPool", context.Background(), "ns1", "testpool").Return(nil, nil) - mdm.On("VerifyNamespaceExists", context.Background(), "ns1").Return(nil) - _, err := am.CreateTokenPool(context.Background(), "ns1", pool, false) + _, err := am.CreateTokenPool(context.Background(), pool, false) assert.Regexp(t, "FF10292", err) mdi.AssertExpectations(t) - mdm.AssertExpectations(t) } func TestCreateTokenPoolIdentityFail(t *testing.T) { @@ -211,17 +174,14 @@ func TestCreateTokenPoolIdentityFail(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) mim := am.identity.(*identitymanagermocks.Manager) mdi.On("GetTokenPool", context.Background(), "ns1", "testpool").Return(nil, nil) - mdm.On("VerifyNamespaceExists", context.Background(), "ns1").Return(nil) mim.On("NormalizeSigningKey", context.Background(), "", identity.KeyNormalizationBlockchainPlugin).Return("", fmt.Errorf("pop")) - _, err := am.CreateTokenPool(context.Background(), "ns1", pool, false) + _, err := am.CreateTokenPool(context.Background(), pool, false) assert.EqualError(t, err, "pop") mdi.AssertExpectations(t) - mdm.AssertExpectations(t) mim.AssertExpectations(t) } @@ -235,17 +195,14 @@ func TestCreateTokenPoolWrongConnector(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) mim := am.identity.(*identitymanagermocks.Manager) mdi.On("GetTokenPool", context.Background(), "ns1", "testpool").Return(nil, nil) mim.On("NormalizeSigningKey", context.Background(), "", identity.KeyNormalizationBlockchainPlugin).Return("0x12345", nil) - mdm.On("VerifyNamespaceExists", context.Background(), "ns1").Return(nil) - _, err := am.CreateTokenPool(context.Background(), "ns1", pool, false) + _, err := am.CreateTokenPool(context.Background(), pool, false) assert.Regexp(t, "FF10272", err) mdi.AssertExpectations(t) - mdm.AssertExpectations(t) mim.AssertExpectations(t) } @@ -259,13 +216,11 @@ func TestCreateTokenPoolFail(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) mim := am.identity.(*identitymanagermocks.Manager) mth := am.txHelper.(*txcommonmocks.Helper) mom := am.operations.(*operationmocks.Manager) mdi.On("GetTokenPool", context.Background(), "ns1", "testpool").Return(nil, nil) mim.On("NormalizeSigningKey", context.Background(), "", identity.KeyNormalizationBlockchainPlugin).Return("0x12345", nil) - mdm.On("VerifyNamespaceExists", context.Background(), "ns1").Return(nil) mth.On("SubmitNewTransaction", context.Background(), core.TransactionTypeTokenPool).Return(fftypes.NewUUID(), nil) mdi.On("InsertOperation", context.Background(), mock.Anything).Return(nil) mom.On("RunOperation", context.Background(), mock.MatchedBy(func(op *core.PreparedOperation) bool { @@ -273,11 +228,10 @@ func TestCreateTokenPoolFail(t *testing.T) { return op.Type == core.OpTypeTokenCreatePool && data.Pool == pool })).Return(nil, fmt.Errorf("pop")) - _, err := am.CreateTokenPool(context.Background(), "ns1", pool, false) + _, err := am.CreateTokenPool(context.Background(), pool, false) assert.Regexp(t, "pop", err) mdi.AssertExpectations(t) - mdm.AssertExpectations(t) mim.AssertExpectations(t) mth.AssertExpectations(t) mom.AssertExpectations(t) @@ -293,19 +247,16 @@ func TestCreateTokenPoolTransactionFail(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) mim := am.identity.(*identitymanagermocks.Manager) mth := am.txHelper.(*txcommonmocks.Helper) mdi.On("GetTokenPool", context.Background(), "ns1", "testpool").Return(nil, nil) mim.On("NormalizeSigningKey", context.Background(), "", identity.KeyNormalizationBlockchainPlugin).Return("0x12345", nil) - mdm.On("VerifyNamespaceExists", context.Background(), "ns1").Return(nil) mth.On("SubmitNewTransaction", context.Background(), core.TransactionTypeTokenPool).Return(nil, fmt.Errorf("pop")) - _, err := am.CreateTokenPool(context.Background(), "ns1", pool, false) + _, err := am.CreateTokenPool(context.Background(), pool, false) assert.Regexp(t, "pop", err) mdi.AssertExpectations(t) - mdm.AssertExpectations(t) mim.AssertExpectations(t) mth.AssertExpectations(t) } @@ -320,20 +271,17 @@ func TestCreateTokenPoolOpInsertFail(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) mim := am.identity.(*identitymanagermocks.Manager) mth := am.txHelper.(*txcommonmocks.Helper) mdi.On("GetTokenPool", context.Background(), "ns1", "testpool").Return(nil, nil) mim.On("NormalizeSigningKey", context.Background(), "", identity.KeyNormalizationBlockchainPlugin).Return("0x12345", nil) - mdm.On("VerifyNamespaceExists", context.Background(), "ns1").Return(nil) mth.On("SubmitNewTransaction", context.Background(), core.TransactionTypeTokenPool).Return(fftypes.NewUUID(), nil) mdi.On("InsertOperation", context.Background(), mock.Anything).Return(fmt.Errorf("pop")) - _, err := am.CreateTokenPool(context.Background(), "ns1", pool, false) + _, err := am.CreateTokenPool(context.Background(), pool, false) assert.Regexp(t, "pop", err) mdi.AssertExpectations(t) - mdm.AssertExpectations(t) mim.AssertExpectations(t) mth.AssertExpectations(t) } @@ -348,13 +296,11 @@ func TestCreateTokenPoolSyncSuccess(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) mim := am.identity.(*identitymanagermocks.Manager) mth := am.txHelper.(*txcommonmocks.Helper) mom := am.operations.(*operationmocks.Manager) mdi.On("GetTokenPool", context.Background(), "ns1", "testpool").Return(nil, nil) mim.On("NormalizeSigningKey", context.Background(), "", identity.KeyNormalizationBlockchainPlugin).Return("0x12345", nil) - mdm.On("VerifyNamespaceExists", context.Background(), "ns1").Return(nil) mth.On("SubmitNewTransaction", context.Background(), core.TransactionTypeTokenPool).Return(fftypes.NewUUID(), nil) mdi.On("InsertOperation", context.Background(), mock.Anything).Return(nil) mom.On("RunOperation", context.Background(), mock.MatchedBy(func(op *core.PreparedOperation) bool { @@ -362,11 +308,10 @@ func TestCreateTokenPoolSyncSuccess(t *testing.T) { return op.Type == core.OpTypeTokenCreatePool && data.Pool == pool })).Return(nil, nil) - _, err := am.CreateTokenPool(context.Background(), "ns1", pool, false) + _, err := am.CreateTokenPool(context.Background(), pool, false) assert.NoError(t, err) mdi.AssertExpectations(t) - mdm.AssertExpectations(t) mim.AssertExpectations(t) mth.AssertExpectations(t) mom.AssertExpectations(t) @@ -382,13 +327,11 @@ func TestCreateTokenPoolAsyncSuccess(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) mim := am.identity.(*identitymanagermocks.Manager) mth := am.txHelper.(*txcommonmocks.Helper) mom := am.operations.(*operationmocks.Manager) mdi.On("GetTokenPool", context.Background(), "ns1", "testpool").Return(nil, nil) mim.On("NormalizeSigningKey", context.Background(), "", identity.KeyNormalizationBlockchainPlugin).Return("0x12345", nil) - mdm.On("VerifyNamespaceExists", context.Background(), "ns1").Return(nil) mth.On("SubmitNewTransaction", context.Background(), core.TransactionTypeTokenPool).Return(fftypes.NewUUID(), nil) mdi.On("InsertOperation", context.Background(), mock.Anything).Return(nil) mom.On("RunOperation", context.Background(), mock.MatchedBy(func(op *core.PreparedOperation) bool { @@ -396,11 +339,10 @@ func TestCreateTokenPoolAsyncSuccess(t *testing.T) { return op.Type == core.OpTypeTokenCreatePool && data.Pool == pool })).Return(nil, nil) - _, err := am.CreateTokenPool(context.Background(), "ns1", pool, false) + _, err := am.CreateTokenPool(context.Background(), pool, false) assert.NoError(t, err) mdi.AssertExpectations(t) - mdm.AssertExpectations(t) mim.AssertExpectations(t) mth.AssertExpectations(t) mom.AssertExpectations(t) @@ -416,13 +358,11 @@ func TestCreateTokenPoolConfirm(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) msa := am.syncasync.(*syncasyncmocks.Bridge) mim := am.identity.(*identitymanagermocks.Manager) mth := am.txHelper.(*txcommonmocks.Helper) mom := am.operations.(*operationmocks.Manager) mdi.On("GetTokenPool", context.Background(), "ns1", "testpool").Return(nil, nil) - mdm.On("VerifyNamespaceExists", context.Background(), "ns1").Return(nil) mim.On("NormalizeSigningKey", context.Background(), "", identity.KeyNormalizationBlockchainPlugin).Return("0x12345", nil) mth.On("SubmitNewTransaction", context.Background(), core.TransactionTypeTokenPool).Return(fftypes.NewUUID(), nil) mdi.On("InsertOperation", context.Background(), mock.Anything).Return(nil) @@ -437,11 +377,10 @@ func TestCreateTokenPoolConfirm(t *testing.T) { return op.Type == core.OpTypeTokenCreatePool && data.Pool == pool })).Return(nil, nil) - _, err := am.CreateTokenPool(context.Background(), "ns1", pool, true) + _, err := am.CreateTokenPool(context.Background(), pool, true) assert.NoError(t, err) mdi.AssertExpectations(t) - mdm.AssertExpectations(t) mim.AssertExpectations(t) mth.AssertExpectations(t) msa.AssertExpectations(t) @@ -608,7 +547,7 @@ func TestGetTokenPool(t *testing.T) { mdi := am.database.(*databasemocks.Plugin) mdi.On("GetTokenPool", context.Background(), "ns1", "abc").Return(&core.TokenPool{}, nil) - _, err := am.GetTokenPool(context.Background(), "ns1", "magic-tokens", "abc") + _, err := am.GetTokenPool(context.Background(), "magic-tokens", "abc") assert.NoError(t, err) mdi.AssertExpectations(t) @@ -620,7 +559,7 @@ func TestGetTokenPoolNotFound(t *testing.T) { mdi := am.database.(*databasemocks.Plugin) mdi.On("GetTokenPool", context.Background(), "ns1", "abc").Return(nil, nil) - _, err := am.GetTokenPool(context.Background(), "ns1", "magic-tokens", "abc") + _, err := am.GetTokenPool(context.Background(), "magic-tokens", "abc") assert.Regexp(t, "FF10109", err) mdi.AssertExpectations(t) @@ -632,7 +571,7 @@ func TestGetTokenPoolFailed(t *testing.T) { mdi := am.database.(*databasemocks.Plugin) mdi.On("GetTokenPool", context.Background(), "ns1", "abc").Return(nil, fmt.Errorf("pop")) - _, err := am.GetTokenPool(context.Background(), "ns1", "magic-tokens", "abc") + _, err := am.GetTokenPool(context.Background(), "magic-tokens", "abc") assert.Regexp(t, "pop", err) mdi.AssertExpectations(t) @@ -642,23 +581,15 @@ func TestGetTokenPoolBadPlugin(t *testing.T) { am, cancel := newTestAssets(t) defer cancel() - _, err := am.GetTokenPool(context.Background(), "", "", "") + _, err := am.GetTokenPool(context.Background(), "", "") assert.Regexp(t, "FF10272", err) } -func TestGetTokenPoolBadNamespace(t *testing.T) { - am, cancel := newTestAssets(t) - defer cancel() - - _, err := am.GetTokenPool(context.Background(), "", "magic-tokens", "") - assert.Regexp(t, "FF00140", err) -} - func TestGetTokenPoolBadName(t *testing.T) { am, cancel := newTestAssets(t) defer cancel() - _, err := am.GetTokenPool(context.Background(), "ns1", "magic-tokens", "") + _, err := am.GetTokenPool(context.Background(), "magic-tokens", "") assert.Regexp(t, "FF00140", err) } @@ -668,7 +599,7 @@ func TestGetTokenPoolByID(t *testing.T) { u := fftypes.NewUUID() mdi := am.database.(*databasemocks.Plugin) - mdi.On("GetTokenPoolByID", context.Background(), u).Return(&core.TokenPool{}, nil) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", u).Return(&core.TokenPool{}, nil) _, err := am.GetTokenPoolByNameOrID(context.Background(), u.String()) assert.NoError(t, err) @@ -689,7 +620,7 @@ func TestGetTokenPoolByIDBadID(t *testing.T) { u := fftypes.NewUUID() mdi := am.database.(*databasemocks.Plugin) - mdi.On("GetTokenPoolByID", context.Background(), u).Return(nil, fmt.Errorf("pop")) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", u).Return(nil, fmt.Errorf("pop")) _, err := am.GetTokenPoolByNameOrID(context.Background(), u.String()) assert.EqualError(t, err, "pop") @@ -702,7 +633,7 @@ func TestGetTokenPoolByIDNilPool(t *testing.T) { u := fftypes.NewUUID() mdi := am.database.(*databasemocks.Plugin) - mdi.On("GetTokenPoolByID", context.Background(), u).Return(nil, nil) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", u).Return(nil, nil) _, err := am.GetTokenPoolByNameOrID(context.Background(), u.String()) assert.Regexp(t, "FF10109", err) @@ -749,20 +680,9 @@ func TestGetTokenPools(t *testing.T) { mdi := am.database.(*databasemocks.Plugin) fb := database.TokenPoolQueryFactory.NewFilter(context.Background()) f := fb.And(fb.Eq("id", u)) - mdi.On("GetTokenPools", context.Background(), f).Return([]*core.TokenPool{}, nil, nil) - _, _, err := am.GetTokenPools(context.Background(), "ns1", f) + mdi.On("GetTokenPools", context.Background(), "ns1", f).Return([]*core.TokenPool{}, nil, nil) + _, _, err := am.GetTokenPools(context.Background(), f) assert.NoError(t, err) mdi.AssertExpectations(t) } - -func TestGetTokenPoolsBadNamespace(t *testing.T) { - am, cancel := newTestAssets(t) - defer cancel() - - u := fftypes.NewUUID() - fb := database.TokenPoolQueryFactory.NewFilter(context.Background()) - f := fb.And(fb.Eq("id", u)) - _, _, err := am.GetTokenPools(context.Background(), "", f) - assert.Regexp(t, "FF00140", err) -} diff --git a/internal/assets/token_transfer_test.go b/internal/assets/token_transfer_test.go index 2fd539ae4..f06462bfb 100644 --- a/internal/assets/token_transfer_test.go +++ b/internal/assets/token_transfer_test.go @@ -24,7 +24,6 @@ import ( "github.com/hyperledger/firefly/internal/syncasync" "github.com/hyperledger/firefly/mocks/broadcastmocks" "github.com/hyperledger/firefly/mocks/databasemocks" - "github.com/hyperledger/firefly/mocks/datamocks" "github.com/hyperledger/firefly/mocks/identitymanagermocks" "github.com/hyperledger/firefly/mocks/operationmocks" "github.com/hyperledger/firefly/mocks/privatemessagingmocks" @@ -165,7 +164,7 @@ func TestMintTokenDefaultPoolSuccess(t *testing.T) { TotalCount: &totalCount, } mim.On("NormalizeSigningKey", context.Background(), "", identity.KeyNormalizationBlockchainPlugin).Return("0x12345", nil) - mdi.On("GetTokenPools", context.Background(), mock.MatchedBy((func(f database.AndFilter) bool { + mdi.On("GetTokenPools", context.Background(), "ns1", mock.MatchedBy((func(f database.AndFilter) bool { info, _ := f.Finalize() return info.Count && info.Limit == 1 }))).Return(tokenPools, filterResult, nil) @@ -204,7 +203,7 @@ func TestMintTokenDefaultPoolNoPools(t *testing.T) { filterResult := &database.FilterResult{ TotalCount: &totalCount, } - mdi.On("GetTokenPools", context.Background(), mock.MatchedBy((func(f database.AndFilter) bool { + mdi.On("GetTokenPools", context.Background(), "ns1", mock.MatchedBy((func(f database.AndFilter) bool { info, _ := f.Finalize() return info.Count && info.Limit == 1 }))).Return(tokenPools, filterResult, nil) @@ -241,7 +240,7 @@ func TestMintTokenDefaultPoolMultiplePools(t *testing.T) { filterResult := &database.FilterResult{ TotalCount: &totalCount, } - mdi.On("GetTokenPools", context.Background(), mock.MatchedBy((func(f database.AndFilter) bool { + mdi.On("GetTokenPools", context.Background(), "ns1", mock.MatchedBy((func(f database.AndFilter) bool { info, _ := f.Finalize() return info.Count && info.Limit == 1 }))).Return(tokenPools, filterResult, nil) @@ -263,7 +262,7 @@ func TestMintTokensGetPoolsError(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdi.On("GetTokenPools", context.Background(), mock.Anything).Return(nil, nil, fmt.Errorf("pop")) + mdi.On("GetTokenPools", context.Background(), "ns1", mock.Anything).Return(nil, nil, fmt.Errorf("pop")) _, err := am.MintTokens(context.Background(), mint, false) assert.EqualError(t, err, "pop") @@ -403,7 +402,6 @@ func TestMintTokensConfirm(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) msa := am.syncasync.(*syncasyncmocks.Bridge) mim := am.identity.(*identitymanagermocks.Manager) mth := am.txHelper.(*txcommonmocks.Helper) @@ -427,7 +425,6 @@ func TestMintTokensConfirm(t *testing.T) { assert.NoError(t, err) mdi.AssertExpectations(t) - mdm.AssertExpectations(t) msa.AssertExpectations(t) mom.AssertExpectations(t) } @@ -512,7 +509,6 @@ func TestBurnTokensConfirm(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) msa := am.syncasync.(*syncasyncmocks.Bridge) mim := am.identity.(*identitymanagermocks.Manager) mth := am.txHelper.(*txcommonmocks.Helper) @@ -536,7 +532,6 @@ func TestBurnTokensConfirm(t *testing.T) { assert.NoError(t, err) mdi.AssertExpectations(t) - mdm.AssertExpectations(t) msa.AssertExpectations(t) mth.AssertExpectations(t) mom.AssertExpectations(t) @@ -959,7 +954,6 @@ func TestTransferTokensConfirm(t *testing.T) { } mdi := am.database.(*databasemocks.Plugin) - mdm := am.data.(*datamocks.Manager) msa := am.syncasync.(*syncasyncmocks.Bridge) mim := am.identity.(*identitymanagermocks.Manager) mth := am.txHelper.(*txcommonmocks.Helper) @@ -983,7 +977,6 @@ func TestTransferTokensConfirm(t *testing.T) { assert.NoError(t, err) mdi.AssertExpectations(t) - mdm.AssertExpectations(t) msa.AssertExpectations(t) mim.AssertExpectations(t) mth.AssertExpectations(t) diff --git a/internal/database/sqlcommon/tokenpool_sql.go b/internal/database/sqlcommon/tokenpool_sql.go index 914506abb..d66a7e735 100644 --- a/internal/database/sqlcommon/tokenpool_sql.go +++ b/internal/database/sqlcommon/tokenpool_sql.go @@ -189,23 +189,25 @@ func (s *SQLCommon) getTokenPoolPred(ctx context.Context, desc string, pred inte return pool, nil } -func (s *SQLCommon) GetTokenPool(ctx context.Context, ns string, name string) (message *core.TokenPool, err error) { - return s.getTokenPoolPred(ctx, ns+":"+name, sq.And{sq.Eq{"namespace": ns}, sq.Eq{"name": name}}) +func (s *SQLCommon) GetTokenPool(ctx context.Context, namespace string, name string) (message *core.TokenPool, err error) { + return s.getTokenPoolPred(ctx, namespace+":"+name, sq.Eq{"namespace": namespace, "name": name}) } -func (s *SQLCommon) GetTokenPoolByID(ctx context.Context, id *fftypes.UUID) (message *core.TokenPool, err error) { - return s.getTokenPoolPred(ctx, id.String(), sq.Eq{"id": id}) +func (s *SQLCommon) GetTokenPoolByID(ctx context.Context, namespace string, id *fftypes.UUID) (message *core.TokenPool, err error) { + return s.getTokenPoolPred(ctx, id.String(), sq.Eq{"id": id, "namespace": namespace}) } -func (s *SQLCommon) GetTokenPoolByLocator(ctx context.Context, connector, locator string) (*core.TokenPool, error) { +func (s *SQLCommon) GetTokenPoolByLocator(ctx context.Context, namespace, connector, locator string) (*core.TokenPool, error) { return s.getTokenPoolPred(ctx, locator, sq.And{ + sq.Eq{"namespace": namespace}, sq.Eq{"connector": connector}, sq.Eq{"locator": locator}, }) } -func (s *SQLCommon) GetTokenPools(ctx context.Context, filter database.Filter) (message []*core.TokenPool, fr *database.FilterResult, err error) { - query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(tokenPoolColumns...).From("tokenpool"), filter, tokenPoolFilterFieldMap, []interface{}{"seq"}) +func (s *SQLCommon) GetTokenPools(ctx context.Context, namespace string, filter database.Filter) (message []*core.TokenPool, fr *database.FilterResult, err error) { + query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(tokenPoolColumns...).From("tokenpool"), + filter, tokenPoolFilterFieldMap, []interface{}{"seq"}, sq.Eq{"namespace": namespace}) if err != nil { return nil, nil, err } diff --git a/internal/database/sqlcommon/tokenpool_sql_test.go b/internal/database/sqlcommon/tokenpool_sql_test.go index 140f52a6f..6d84f39f3 100644 --- a/internal/database/sqlcommon/tokenpool_sql_test.go +++ b/internal/database/sqlcommon/tokenpool_sql_test.go @@ -71,7 +71,7 @@ func TestTokenPoolE2EWithDB(t *testing.T) { poolJson, _ := json.Marshal(&pool) // Query back the token pool (by ID) - poolRead, err := s.GetTokenPoolByID(ctx, pool.ID) + poolRead, err := s.GetTokenPoolByID(ctx, "ns1", pool.ID) assert.NoError(t, err) assert.NotNil(t, poolRead) poolReadJson, _ := json.Marshal(&poolRead) @@ -85,7 +85,7 @@ func TestTokenPoolE2EWithDB(t *testing.T) { assert.Equal(t, string(poolJson), string(poolReadJson)) // Query back the token pool (by locator) - poolRead, err = s.GetTokenPoolByLocator(ctx, pool.Connector, pool.Locator) + poolRead, err = s.GetTokenPoolByLocator(ctx, "ns1", pool.Connector, pool.Locator) assert.NoError(t, err) assert.NotNil(t, poolRead) poolReadJson, _ = json.Marshal(&poolRead) @@ -101,7 +101,7 @@ func TestTokenPoolE2EWithDB(t *testing.T) { fb.Eq("message", pool.Message), fb.Eq("created", pool.Created), ) - pools, res, err := s.GetTokenPools(ctx, filter.Count(true)) + pools, res, err := s.GetTokenPools(ctx, "ns1", filter.Count(true)) assert.NoError(t, err) assert.Equal(t, 1, len(pools)) assert.Equal(t, int64(1), *res.TotalCount) @@ -115,7 +115,7 @@ func TestTokenPoolE2EWithDB(t *testing.T) { assert.NoError(t, err) // Query back the token pool (by ID) - poolRead, err = s.GetTokenPoolByID(ctx, pool.ID) + poolRead, err = s.GetTokenPoolByID(ctx, "ns1", pool.ID) assert.NoError(t, err) assert.NotNil(t, poolRead) poolJson, _ = json.Marshal(&pool) @@ -195,7 +195,7 @@ func TestGetTokenPoolByIDSelectFail(t *testing.T) { s, mock := newMockProvider().init() poolID := fftypes.NewUUID() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) - _, err := s.GetTokenPoolByID(context.Background(), poolID) + _, err := s.GetTokenPoolByID(context.Background(), "ns1", poolID) assert.Regexp(t, "FF10115", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -204,7 +204,7 @@ func TestGetTokenPoolByIDNotFound(t *testing.T) { s, mock := newMockProvider().init() poolID := fftypes.NewUUID() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"id"})) - msg, err := s.GetTokenPoolByID(context.Background(), poolID) + msg, err := s.GetTokenPoolByID(context.Background(), "ns1", poolID) assert.NoError(t, err) assert.Nil(t, msg) assert.NoError(t, mock.ExpectationsWereMet()) @@ -214,7 +214,7 @@ func TestGetTokenPoolByIDScanFail(t *testing.T) { s, mock := newMockProvider().init() poolID := fftypes.NewUUID() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("only one")) - _, err := s.GetTokenPoolByID(context.Background(), poolID) + _, err := s.GetTokenPoolByID(context.Background(), "ns1", poolID) assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -223,7 +223,7 @@ func TestGetTokenPoolsQueryFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) f := database.TokenPoolQueryFactory.NewFilter(context.Background()).Eq("id", "") - _, _, err := s.GetTokenPools(context.Background(), f) + _, _, err := s.GetTokenPools(context.Background(), "ns1", f) assert.Regexp(t, "FF10115", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -231,7 +231,7 @@ func TestGetTokenPoolsQueryFail(t *testing.T) { func TestGetTokenPoolsBuildQueryFail(t *testing.T) { s, _ := newMockProvider().init() f := database.TokenPoolQueryFactory.NewFilter(context.Background()).Eq("id", map[bool]bool{true: false}) - _, _, err := s.GetTokenPools(context.Background(), f) + _, _, err := s.GetTokenPools(context.Background(), "ns1", f) assert.Regexp(t, "FF00143.*id", err) } @@ -239,7 +239,7 @@ func TestGetTokenPoolsScanFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("only one")) f := database.TokenPoolQueryFactory.NewFilter(context.Background()).Eq("id", "") - _, _, err := s.GetTokenPools(context.Background(), f) + _, _, err := s.GetTokenPools(context.Background(), "ns1", f) assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } diff --git a/internal/definitions/definition_handler_tokenpool.go b/internal/definitions/definition_handler_tokenpool.go index c3607c70f..a8e12f36d 100644 --- a/internal/definitions/definition_handler_tokenpool.go +++ b/internal/definitions/definition_handler_tokenpool.go @@ -44,7 +44,7 @@ func (dh *definitionHandlers) handleTokenPoolBroadcast(ctx context.Context, stat } // Check if pool has already been confirmed on chain (and confirm the message if so) - if existingPool, err := dh.database.GetTokenPoolByID(ctx, pool.ID); err != nil { + if existingPool, err := dh.database.GetTokenPoolByID(ctx, dh.namespace, pool.ID); err != nil { return HandlerResult{Action: ActionRetry}, err } else if existingPool != nil && existingPool.State == core.TokenPoolStateConfirmed { return HandlerResult{Action: ActionConfirm, CustomCorrelator: correlator}, nil diff --git a/internal/definitions/definition_handler_tokenpool_test.go b/internal/definitions/definition_handler_tokenpool_test.go index 2628a9afa..d96361aa1 100644 --- a/internal/definitions/definition_handler_tokenpool_test.go +++ b/internal/definitions/definition_handler_tokenpool_test.go @@ -76,7 +76,7 @@ func TestHandleDefinitionBroadcastTokenPoolActivateOK(t *testing.T) { mdi := sh.database.(*databasemocks.Plugin) mam := sh.assets.(*assetmocks.Manager) - mdi.On("GetTokenPoolByID", context.Background(), pool.ID).Return(nil, nil) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", pool.ID).Return(nil, nil) mdi.On("UpsertTokenPool", context.Background(), mock.MatchedBy(func(p *core.TokenPool) bool { return *p.ID == *pool.ID && p.Message == msg.Header.ID })).Return(nil) @@ -101,7 +101,7 @@ func TestHandleDefinitionBroadcastTokenPoolGetPoolFail(t *testing.T) { assert.NoError(t, err) mdi := sh.database.(*databasemocks.Plugin) - mdi.On("GetTokenPoolByID", context.Background(), pool.ID).Return(nil, fmt.Errorf("pop")) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", pool.ID).Return(nil, fmt.Errorf("pop")) action, err := sh.HandleDefinitionBroadcast(context.Background(), bs, msg, data, fftypes.NewUUID()) assert.Equal(t, HandlerResult{Action: ActionRetry}, action) @@ -121,7 +121,7 @@ func TestHandleDefinitionBroadcastTokenPoolExisting(t *testing.T) { mdi := sh.database.(*databasemocks.Plugin) mam := sh.assets.(*assetmocks.Manager) - mdi.On("GetTokenPoolByID", context.Background(), pool.ID).Return(&core.TokenPool{}, nil) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", pool.ID).Return(&core.TokenPool{}, nil) mdi.On("UpsertTokenPool", context.Background(), mock.MatchedBy(func(p *core.TokenPool) bool { return *p.ID == *pool.ID && p.Message == msg.Header.ID })).Return(nil) @@ -148,7 +148,7 @@ func TestHandleDefinitionBroadcastTokenPoolExistingConfirmed(t *testing.T) { } mdi := sh.database.(*databasemocks.Plugin) - mdi.On("GetTokenPoolByID", context.Background(), pool.ID).Return(existing, nil) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", pool.ID).Return(existing, nil) action, err := sh.HandleDefinitionBroadcast(context.Background(), bs, msg, data, fftypes.NewUUID()) assert.Equal(t, HandlerResult{Action: ActionConfirm, CustomCorrelator: pool.ID}, action) @@ -166,7 +166,7 @@ func TestHandleDefinitionBroadcastTokenPoolIDMismatch(t *testing.T) { assert.NoError(t, err) mdi := sh.database.(*databasemocks.Plugin) - mdi.On("GetTokenPoolByID", context.Background(), pool.ID).Return(nil, nil) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", pool.ID).Return(nil, nil) mdi.On("UpsertTokenPool", context.Background(), mock.MatchedBy(func(p *core.TokenPool) bool { return *p.ID == *pool.ID && p.Message == msg.Header.ID })).Return(database.IDMismatch) @@ -188,7 +188,7 @@ func TestHandleDefinitionBroadcastTokenPoolFailUpsert(t *testing.T) { assert.NoError(t, err) mdi := sh.database.(*databasemocks.Plugin) - mdi.On("GetTokenPoolByID", context.Background(), pool.ID).Return(nil, nil) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", pool.ID).Return(nil, nil) mdi.On("UpsertTokenPool", context.Background(), mock.MatchedBy(func(p *core.TokenPool) bool { return *p.ID == *pool.ID && p.Message == msg.Header.ID })).Return(fmt.Errorf("pop")) @@ -211,7 +211,7 @@ func TestHandleDefinitionBroadcastTokenPoolActivateFail(t *testing.T) { mdi := sh.database.(*databasemocks.Plugin) mam := sh.assets.(*assetmocks.Manager) - mdi.On("GetTokenPoolByID", context.Background(), pool.ID).Return(nil, nil) + mdi.On("GetTokenPoolByID", context.Background(), "ns1", pool.ID).Return(nil, nil) mdi.On("UpsertTokenPool", context.Background(), mock.MatchedBy(func(p *core.TokenPool) bool { return *p.ID == *pool.ID && p.Message == msg.Header.ID })).Return(nil) diff --git a/internal/events/token_pool_created.go b/internal/events/token_pool_created.go index f793dfde8..911127d3d 100644 --- a/internal/events/token_pool_created.go +++ b/internal/events/token_pool_created.go @@ -90,7 +90,7 @@ func (em *eventManager) findTXOperation(ctx context.Context, tx *fftypes.UUID, o } func (em *eventManager) shouldConfirm(ctx context.Context, pool *tokens.TokenPool) (existingPool *core.TokenPool, err error) { - if existingPool, err = em.database.GetTokenPoolByLocator(ctx, pool.Connector, pool.PoolLocator); err != nil || existingPool == nil { + if existingPool, err = em.database.GetTokenPoolByLocator(ctx, em.namespace, pool.Connector, pool.PoolLocator); err != nil || existingPool == nil { return existingPool, err } if err = addPoolDetailsFromPlugin(existingPool, pool); err != nil { diff --git a/internal/events/token_pool_created_test.go b/internal/events/token_pool_created_test.go index a3e7bbbd7..d46ff9369 100644 --- a/internal/events/token_pool_created_test.go +++ b/internal/events/token_pool_created_test.go @@ -58,7 +58,7 @@ func TestTokenPoolCreatedIgnore(t *testing.T) { }, } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "123").Return(nil, nil, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "123").Return(nil, nil, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(operations, nil, nil) err := em.TokenPoolCreated(mti, pool) @@ -85,7 +85,7 @@ func TestTokenPoolCreatedIgnoreNoTX(t *testing.T) { }, } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "123").Return(nil, nil, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "123").Return(nil, nil, nil) err := em.TokenPoolCreated(mti, pool) assert.NoError(t, err) @@ -131,8 +131,8 @@ func TestTokenPoolCreatedConfirm(t *testing.T) { }, } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "123").Return(nil, fmt.Errorf("pop")).Once() - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "123").Return(storedPool, nil).Once() + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "123").Return(nil, fmt.Errorf("pop")).Once() + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "123").Return(storedPool, nil).Once() mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", (*fftypes.UUID)(nil), chainPool.Event.ProtocolID).Return(nil, nil) mth.On("InsertBlockchainEvent", em.ctx, mock.MatchedBy(func(e *core.BlockchainEvent) bool { return e.Name == chainPool.Event.Name @@ -197,8 +197,8 @@ func TestTokenPoolCreatedConfirmWrongNS(t *testing.T) { }, } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "123").Return(nil, fmt.Errorf("pop")).Once() - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "123").Return(storedPool, nil).Once() + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "123").Return(nil, fmt.Errorf("pop")).Once() + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "123").Return(storedPool, nil).Once() mth.On("InsertBlockchainEvent", em.ctx, mock.MatchedBy(func(e *core.BlockchainEvent) bool { return e.Name == chainPool.Event.Name })).Return(nil).Once() @@ -243,7 +243,7 @@ func TestTokenPoolCreatedAlreadyConfirmed(t *testing.T) { }, } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "123").Return(storedPool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "123").Return(storedPool, nil) err := em.TokenPoolCreated(mti, chainPool) assert.NoError(t, err) @@ -286,7 +286,7 @@ func TestTokenPoolCreatedConfirmFailBadSymbol(t *testing.T) { }, } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "123").Return(storedPool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "123").Return(storedPool, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return([]*core.Operation{{ ID: opID, }}, nil, nil) @@ -331,7 +331,7 @@ func TestTokenPoolCreatedMigrate(t *testing.T) { }, } - mdi.On("GetTokenPoolByLocator", em.ctx, "magic-tokens", "123").Return(storedPool, nil).Times(2) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "magic-tokens", "123").Return(storedPool, nil).Times(2) mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", (*fftypes.UUID)(nil), chainPool.Event.ProtocolID).Return(nil, nil) mth.On("InsertBlockchainEvent", em.ctx, mock.MatchedBy(func(e *core.BlockchainEvent) bool { return e.Name == chainPool.Event.Name @@ -502,7 +502,7 @@ func TestTokenPoolCreatedAnnounce(t *testing.T) { }, } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "123").Return(nil, nil).Times(2) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "123").Return(nil, nil).Times(2) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(nil, nil, fmt.Errorf("pop")).Once() mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(operations, nil, nil).Once() mbm.On("BroadcastTokenPool", em.ctx, "ns1", mock.MatchedBy(func(pool *core.TokenPoolAnnouncement) bool { @@ -551,7 +551,7 @@ func TestTokenPoolCreatedAnnounceWrongNS(t *testing.T) { }, } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "123").Return(nil, nil).Times(2) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "123").Return(nil, nil).Times(2) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(nil, nil, fmt.Errorf("pop")).Once() mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(operations, nil, nil).Once() @@ -592,7 +592,7 @@ func TestTokenPoolCreatedAnnounceBadOpInputID(t *testing.T) { }, } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "123").Return(nil, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "123").Return(nil, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(operations, nil, nil) err := em.TokenPoolCreated(mti, pool) @@ -633,7 +633,7 @@ func TestTokenPoolCreatedAnnounceBadOpInputNS(t *testing.T) { }, } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "123").Return(nil, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "123").Return(nil, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(operations, nil, nil) err := em.TokenPoolCreated(mti, pool) @@ -678,7 +678,7 @@ func TestTokenPoolCreatedAnnounceBadSymbol(t *testing.T) { }, } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "123").Return(nil, nil).Times(2) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "123").Return(nil, nil).Times(2) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(nil, nil, fmt.Errorf("pop")).Once() mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(operations, nil, nil).Once() diff --git a/internal/events/tokens_approved.go b/internal/events/tokens_approved.go index 848ec79c3..3caa0e9ac 100644 --- a/internal/events/tokens_approved.go +++ b/internal/events/tokens_approved.go @@ -70,7 +70,7 @@ func (em *eventManager) loadApprovalID(ctx context.Context, tx *fftypes.UUID, ap func (em *eventManager) persistTokenApproval(ctx context.Context, approval *tokens.TokenApproval) (valid bool, err error) { // Check that this is from a known pool // TODO: should cache this lookup for efficiency - pool, err := em.database.GetTokenPoolByLocator(ctx, approval.Connector, approval.PoolLocator) + pool, err := em.database.GetTokenPoolByLocator(ctx, em.namespace, approval.Connector, approval.PoolLocator) if err != nil { return false, err } diff --git a/internal/events/tokens_approved_test.go b/internal/events/tokens_approved_test.go index 76e528b05..49d4595c6 100644 --- a/internal/events/tokens_approved_test.go +++ b/internal/events/tokens_approved_test.go @@ -73,8 +73,8 @@ func TestTokensApprovedSucceedWithRetries(t *testing.T) { Namespace: "ns1", } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(nil, fmt.Errorf("pop")).Once() - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil).Times(4) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(nil, fmt.Errorf("pop")).Once() + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil).Times(4) mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(nil, fmt.Errorf("pop")).Once() mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(nil, nil).Times(3) mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", (*fftypes.UUID)(nil), approval.Event.ProtocolID).Return(nil, nil) @@ -113,7 +113,7 @@ func TestPersistApprovalDuplicate(t *testing.T) { Namespace: "ns1", } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(&core.TokenApproval{}, nil) valid, err := em.persistTokenApproval(em.ctx, approval) @@ -135,7 +135,7 @@ func TestPersistApprovalWrongNS(t *testing.T) { Namespace: "ns2", } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) valid, err := em.persistTokenApproval(em.ctx, approval) assert.False(t, valid) @@ -156,7 +156,7 @@ func TestPersistApprovalOpFail(t *testing.T) { Namespace: "ns1", } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(nil, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(nil, nil, fmt.Errorf("pop")) @@ -186,7 +186,7 @@ func TestPersistApprovalBadOp(t *testing.T) { Transaction: fftypes.NewUUID(), }} - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(nil, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) mth.On("PersistTransaction", mock.Anything, approval.TX.ID, core.TransactionTypeTokenApproval, "0xffffeeee").Return(false, fmt.Errorf("pop")) @@ -220,7 +220,7 @@ func TestPersistApprovalTxFail(t *testing.T) { }, }} - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(nil, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) mdi.On("GetTokenApprovalByID", em.ctx, localID).Return(nil, nil) @@ -255,7 +255,7 @@ func TestPersistApprovalGetApprovalFail(t *testing.T) { }, }} - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(nil, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) mdi.On("GetTokenApprovalByID", em.ctx, localID).Return(nil, fmt.Errorf("pop")) @@ -276,7 +276,7 @@ func TestApprovedBadPool(t *testing.T) { mti := &tokenmocks.Plugin{} approval := newApproval() - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(nil, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(nil, nil) err := em.TokensApproved(mti, approval) assert.NoError(t, err) @@ -307,7 +307,7 @@ func TestApprovedWithTransactionRegenerateLocalID(t *testing.T) { }, }} - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(nil, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) mth.On("PersistTransaction", mock.Anything, approval.TX.ID, core.TransactionTypeTokenApproval, "0xffffeeee").Return(true, nil) @@ -354,7 +354,7 @@ func TestApprovedBlockchainEventFail(t *testing.T) { }, }} - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(nil, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) mth.On("PersistTransaction", mock.Anything, approval.TX.ID, core.TransactionTypeTokenApproval, "0xffffeeee").Return(true, nil) diff --git a/internal/events/tokens_transferred.go b/internal/events/tokens_transferred.go index 98383c5dd..2166e14f3 100644 --- a/internal/events/tokens_transferred.go +++ b/internal/events/tokens_transferred.go @@ -70,7 +70,7 @@ func (em *eventManager) loadTransferID(ctx context.Context, tx *fftypes.UUID, tr func (em *eventManager) persistTokenTransfer(ctx context.Context, transfer *tokens.TokenTransfer) (valid bool, err error) { // Check that this is from a known pool // TODO: should cache this lookup for efficiency - pool, err := em.database.GetTokenPoolByLocator(ctx, transfer.Connector, transfer.PoolLocator) + pool, err := em.database.GetTokenPoolByLocator(ctx, em.namespace, transfer.Connector, transfer.PoolLocator) if err != nil { return false, err } diff --git a/internal/events/tokens_transferred_test.go b/internal/events/tokens_transferred_test.go index fae21408f..c4d90acf4 100644 --- a/internal/events/tokens_transferred_test.go +++ b/internal/events/tokens_transferred_test.go @@ -73,8 +73,8 @@ func TestTokensTransferredSucceedWithRetries(t *testing.T) { Namespace: "ns1", } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(nil, fmt.Errorf("pop")).Once() - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil).Times(4) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(nil, fmt.Errorf("pop")).Once() + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil).Times(4) mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, fmt.Errorf("pop")).Once() mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil).Times(3) mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", (*fftypes.UUID)(nil), transfer.Event.ProtocolID).Return(nil, nil) @@ -114,7 +114,7 @@ func TestTokensTransferredIgnoreExisting(t *testing.T) { } mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(&core.TokenTransfer{}, nil) - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) err := em.TokensTransferred(mti, transfer) assert.NoError(t, err) @@ -135,7 +135,7 @@ func TestPersistTransferWrongNS(t *testing.T) { Namespace: "ns2", } - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) valid, err := em.persistTokenTransfer(em.ctx, transfer) assert.False(t, valid) @@ -157,7 +157,7 @@ func TestPersistTransferOpFail(t *testing.T) { } mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil) - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(nil, nil, fmt.Errorf("pop")) valid, err := em.persistTokenTransfer(em.ctx, transfer) @@ -187,7 +187,7 @@ func TestPersistTransferBadOp(t *testing.T) { }} mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil) - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) mth.On("PersistTransaction", mock.Anything, transfer.TX.ID, core.TransactionTypeTokenTransfer, "0xffffeeee").Return(false, fmt.Errorf("pop")) @@ -219,7 +219,7 @@ func TestPersistTransferTxFail(t *testing.T) { }} mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil) - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) mth.On("PersistTransaction", mock.Anything, transfer.TX.ID, core.TransactionTypeTokenTransfer, "0xffffeeee").Return(false, fmt.Errorf("pop")) @@ -252,7 +252,7 @@ func TestPersistTransferGetTransferFail(t *testing.T) { }} mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil) - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) mth.On("PersistTransaction", mock.Anything, transfer.TX.ID, core.TransactionTypeTokenTransfer, "0xffffeeee").Return(true, nil) mdi.On("GetTokenTransferByID", em.ctx, localID).Return(nil, fmt.Errorf("pop")) @@ -286,7 +286,7 @@ func TestPersistTransferBlockchainEventFail(t *testing.T) { }} mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil) - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) mth.On("PersistTransaction", mock.Anything, transfer.TX.ID, core.TransactionTypeTokenTransfer, "0xffffeeee").Return(true, nil) mdi.On("GetTokenTransferByID", em.ctx, localID).Return(nil, nil) @@ -325,7 +325,7 @@ func TestTokensTransferredWithTransactionRegenerateLocalID(t *testing.T) { }} mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil) - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(operations, nil, nil) mth.On("PersistTransaction", mock.Anything, transfer.TX.ID, core.TransactionTypeTokenTransfer, "0xffffeeee").Return(true, nil) mdi.On("GetTokenTransferByID", em.ctx, localID).Return(&core.TokenTransfer{}, nil) @@ -358,7 +358,7 @@ func TestTokensTransferredBadPool(t *testing.T) { transfer := newTransfer() - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(nil, nil) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(nil, nil) err := em.TokensTransferred(mti, transfer) assert.NoError(t, err) @@ -406,7 +406,7 @@ func TestTokensTransferredWithMessageReceived(t *testing.T) { } mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil).Times(2) - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil).Times(2) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil).Times(2) mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", (*fftypes.UUID)(nil), transfer.Event.ProtocolID).Return(nil, nil) mth.On("InsertBlockchainEvent", em.ctx, mock.MatchedBy(func(e *core.BlockchainEvent) bool { return e.Namespace == pool.Namespace && e.Name == transfer.Event.Name @@ -470,7 +470,7 @@ func TestTokensTransferredWithMessageSend(t *testing.T) { } mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil).Times(2) - mdi.On("GetTokenPoolByLocator", em.ctx, "erc1155", "F1").Return(pool, nil).Times(2) + mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil).Times(2) mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", (*fftypes.UUID)(nil), transfer.Event.ProtocolID).Return(nil, nil) mth.On("InsertBlockchainEvent", em.ctx, mock.MatchedBy(func(e *core.BlockchainEvent) bool { return e.Namespace == pool.Namespace && e.Name == transfer.Event.Name diff --git a/internal/orchestrator/orchestrator.go b/internal/orchestrator/orchestrator.go index dc2dc6f72..7c5d20377 100644 --- a/internal/orchestrator/orchestrator.go +++ b/internal/orchestrator/orchestrator.go @@ -447,7 +447,7 @@ func (or *orchestrator) initComponents(ctx context.Context) (err error) { } if or.assets == nil { - or.assets, err = assets.NewAssetManager(ctx, or.namespace, or.database(), or.identity, or.data, or.syncasync, or.broadcast, or.messaging, or.tokens(), or.metrics, or.operations, or.txHelper) + or.assets, err = assets.NewAssetManager(ctx, or.namespace, or.database(), or.identity, or.syncasync, or.broadcast, or.messaging, or.tokens(), or.metrics, or.operations, or.txHelper) if err != nil { return err } diff --git a/internal/orchestrator/txn_status.go b/internal/orchestrator/txn_status.go index 9753efc2f..3722199de 100644 --- a/internal/orchestrator/txn_status.go +++ b/internal/orchestrator/txn_status.go @@ -120,7 +120,7 @@ func (or *orchestrator) GetTransactionStatus(ctx context.Context, id string) (*c case core.TransactionTypeTokenPool: // Note: no assumptions about blockchain events here (may or may not contain one) f := database.TokenPoolQueryFactory.NewFilter(ctx) - switch pools, _, err := or.database().GetTokenPools(ctx, f.Eq("tx.id", id)); { + switch pools, _, err := or.database().GetTokenPools(ctx, or.namespace, f.Eq("tx.id", id)); { case err != nil: return nil, err case len(pools) == 0: diff --git a/internal/orchestrator/txn_status_test.go b/internal/orchestrator/txn_status_test.go index c60da439c..4a0cbdc00 100644 --- a/internal/orchestrator/txn_status_test.go +++ b/internal/orchestrator/txn_status_test.go @@ -265,7 +265,7 @@ func TestGetTransactionStatusTokenPoolSuccess(t *testing.T) { or.mdi.On("GetTransactionByID", mock.Anything, "ns", txID).Return(tx, nil) or.mdi.On("GetOperations", mock.Anything, "ns", mock.Anything).Return(ops, nil, nil) or.mdi.On("GetBlockchainEvents", mock.Anything, "ns", mock.Anything).Return(events, nil, nil) - or.mdi.On("GetTokenPools", mock.Anything, mock.Anything).Return(pools, nil, nil) + or.mdi.On("GetTokenPools", mock.Anything, "ns", mock.Anything).Return(pools, nil, nil) status, err := or.GetTransactionStatus(context.Background(), txID.String()) assert.NoError(t, err) @@ -327,7 +327,7 @@ func TestGetTransactionStatusTokenPoolPending(t *testing.T) { or.mdi.On("GetTransactionByID", mock.Anything, "ns", txID).Return(tx, nil) or.mdi.On("GetOperations", mock.Anything, "ns", mock.Anything).Return(ops, nil, nil) or.mdi.On("GetBlockchainEvents", mock.Anything, "ns", mock.Anything).Return(events, nil, nil) - or.mdi.On("GetTokenPools", mock.Anything, mock.Anything).Return(pools, nil, nil) + or.mdi.On("GetTokenPools", mock.Anything, "ns", mock.Anything).Return(pools, nil, nil) status, err := or.GetTransactionStatus(context.Background(), txID.String()) assert.NoError(t, err) @@ -385,7 +385,7 @@ func TestGetTransactionStatusTokenPoolUnconfirmed(t *testing.T) { or.mdi.On("GetTransactionByID", mock.Anything, "ns", txID).Return(tx, nil) or.mdi.On("GetOperations", mock.Anything, "ns", mock.Anything).Return(ops, nil, nil) or.mdi.On("GetBlockchainEvents", mock.Anything, "ns", mock.Anything).Return(events, nil, nil) - or.mdi.On("GetTokenPools", mock.Anything, mock.Anything).Return(pools, nil, nil) + or.mdi.On("GetTokenPools", mock.Anything, "ns", mock.Anything).Return(pools, nil, nil) status, err := or.GetTransactionStatus(context.Background(), txID.String()) assert.NoError(t, err) @@ -874,7 +874,7 @@ func TestGetTransactionStatusPoolError(t *testing.T) { or.mdi.On("GetTransactionByID", mock.Anything, "ns", txID).Return(tx, nil) or.mdi.On("GetOperations", mock.Anything, "ns", mock.Anything).Return(nil, nil, nil) or.mdi.On("GetBlockchainEvents", mock.Anything, "ns", mock.Anything).Return(nil, nil, nil) - or.mdi.On("GetTokenPools", mock.Anything, mock.Anything).Return(nil, nil, fmt.Errorf("pop")) + or.mdi.On("GetTokenPools", mock.Anything, "ns", mock.Anything).Return(nil, nil, fmt.Errorf("pop")) _, err := or.GetTransactionStatus(context.Background(), txID.String()) assert.EqualError(t, err, "pop") diff --git a/internal/syncasync/sync_async_bridge.go b/internal/syncasync/sync_async_bridge.go index 7524698fb..9a08b852c 100644 --- a/internal/syncasync/sync_async_bridge.go +++ b/internal/syncasync/sync_async_bridge.go @@ -190,7 +190,7 @@ func (sa *syncAsyncBridge) getIdentityFromEvent(event *core.EventDelivery) (iden } func (sa *syncAsyncBridge) getPoolFromEvent(event *core.EventDelivery) (pool *core.TokenPool, err error) { - if pool, err = sa.database.GetTokenPoolByID(sa.ctx, event.Reference); err != nil { + if pool, err = sa.database.GetTokenPoolByID(sa.ctx, sa.namespace, event.Reference); err != nil { return nil, err } if pool == nil { diff --git a/internal/syncasync/sync_async_bridge_test.go b/internal/syncasync/sync_async_bridge_test.go index 9e6072cb4..5d37b3593 100644 --- a/internal/syncasync/sync_async_bridge_test.go +++ b/internal/syncasync/sync_async_bridge_test.go @@ -341,7 +341,7 @@ func TestEventCallbackTokenPoolLookupFail(t *testing.T) { } mdi := sa.database.(*databasemocks.Plugin) - mdi.On("GetTokenPoolByID", sa.ctx, mock.Anything).Return(nil, fmt.Errorf("pop")) + mdi.On("GetTokenPoolByID", sa.ctx, "ns1", mock.Anything).Return(nil, fmt.Errorf("pop")) err := sa.eventCallback(&core.EventDelivery{ EnrichedEvent: core.EnrichedEvent{ @@ -564,7 +564,7 @@ func TestEventCallbackTokenPoolNotFound(t *testing.T) { } mdi := sa.database.(*databasemocks.Plugin) - mdi.On("GetTokenPoolByID", sa.ctx, mock.Anything).Return(nil, nil) + mdi.On("GetTokenPoolByID", sa.ctx, "ns1", mock.Anything).Return(nil, nil) err := sa.eventCallback(&core.EventDelivery{ EnrichedEvent: core.EnrichedEvent{ @@ -766,7 +766,7 @@ func TestAwaitTokenPoolConfirmation(t *testing.T) { mse.On("AddSystemEventListener", "ns1", mock.Anything).Return(nil) mdi := sa.database.(*databasemocks.Plugin) - gmid := mdi.On("GetTokenPoolByID", sa.ctx, mock.Anything) + gmid := mdi.On("GetTokenPoolByID", sa.ctx, "ns1", mock.Anything) gmid.RunFn = func(a mock.Arguments) { pool := &core.TokenPool{ ID: requestID, diff --git a/internal/txcommon/event_enrich.go b/internal/txcommon/event_enrich.go index c73d665ab..8bead76da 100644 --- a/internal/txcommon/event_enrich.go +++ b/internal/txcommon/event_enrich.go @@ -77,7 +77,7 @@ func (t *transactionHelper) EnrichEvent(ctx context.Context, event *core.Event) } e.NamespaceDetails = ns case core.EventTypePoolConfirmed: - tokenPool, err := t.database.GetTokenPoolByID(ctx, event.Reference) + tokenPool, err := t.database.GetTokenPoolByID(ctx, t.namespace, event.Reference) if err != nil { return nil, err } diff --git a/internal/txcommon/event_enrich_test.go b/internal/txcommon/event_enrich_test.go index 5f2fb475b..29afdf9d7 100644 --- a/internal/txcommon/event_enrich_test.go +++ b/internal/txcommon/event_enrich_test.go @@ -460,7 +460,7 @@ func TestEnrichTokenPoolConfirmed(t *testing.T) { ev1 := fftypes.NewUUID() // Setup enrichment - mdi.On("GetTokenPoolByID", mock.Anything, ref1).Return(&core.TokenPool{ + mdi.On("GetTokenPoolByID", mock.Anything, "ns1", ref1).Return(&core.TokenPool{ ID: ref1, }, nil) @@ -486,7 +486,7 @@ func TestEnrichTokenPoolConfirmedFail(t *testing.T) { ev1 := fftypes.NewUUID() // Setup enrichment - mdi.On("GetTokenPoolByID", mock.Anything, ref1).Return(nil, fmt.Errorf("pop")) + mdi.On("GetTokenPoolByID", mock.Anything, "ns1", ref1).Return(nil, fmt.Errorf("pop")) event := &core.Event{ ID: ev1, diff --git a/mocks/assetmocks/manager.go b/mocks/assetmocks/manager.go index f0938dfd4..478ea8ea7 100644 --- a/mocks/assetmocks/manager.go +++ b/mocks/assetmocks/manager.go @@ -57,13 +57,13 @@ func (_m *Manager) BurnTokens(ctx context.Context, transfer *core.TokenTransferI return r0, r1 } -// CreateTokenPool provides a mock function with given fields: ctx, ns, pool, waitConfirm -func (_m *Manager) CreateTokenPool(ctx context.Context, ns string, pool *core.TokenPool, waitConfirm bool) (*core.TokenPool, error) { - ret := _m.Called(ctx, ns, pool, waitConfirm) +// CreateTokenPool provides a mock function with given fields: ctx, pool, waitConfirm +func (_m *Manager) CreateTokenPool(ctx context.Context, pool *core.TokenPool, waitConfirm bool) (*core.TokenPool, error) { + ret := _m.Called(ctx, pool, waitConfirm) var r0 *core.TokenPool - if rf, ok := ret.Get(0).(func(context.Context, string, *core.TokenPool, bool) *core.TokenPool); ok { - r0 = rf(ctx, ns, pool, waitConfirm) + if rf, ok := ret.Get(0).(func(context.Context, *core.TokenPool, bool) *core.TokenPool); ok { + r0 = rf(ctx, pool, waitConfirm) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.TokenPool) @@ -71,8 +71,8 @@ func (_m *Manager) CreateTokenPool(ctx context.Context, ns string, pool *core.To } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, *core.TokenPool, bool) error); ok { - r1 = rf(ctx, ns, pool, waitConfirm) + if rf, ok := ret.Get(1).(func(context.Context, *core.TokenPool, bool) error); ok { + r1 = rf(ctx, pool, waitConfirm) } else { r1 = ret.Error(1) } @@ -208,13 +208,13 @@ func (_m *Manager) GetTokenBalances(ctx context.Context, ns string, filter datab return r0, r1, r2 } -// GetTokenConnectors provides a mock function with given fields: ctx, ns -func (_m *Manager) GetTokenConnectors(ctx context.Context, ns string) []*core.TokenConnector { - ret := _m.Called(ctx, ns) +// GetTokenConnectors provides a mock function with given fields: ctx +func (_m *Manager) GetTokenConnectors(ctx context.Context) []*core.TokenConnector { + ret := _m.Called(ctx) var r0 []*core.TokenConnector - if rf, ok := ret.Get(0).(func(context.Context, string) []*core.TokenConnector); ok { - r0 = rf(ctx, ns) + if rf, ok := ret.Get(0).(func(context.Context) []*core.TokenConnector); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.TokenConnector) @@ -224,13 +224,13 @@ func (_m *Manager) GetTokenConnectors(ctx context.Context, ns string) []*core.To return r0 } -// GetTokenPool provides a mock function with given fields: ctx, ns, connector, poolName -func (_m *Manager) GetTokenPool(ctx context.Context, ns string, connector string, poolName string) (*core.TokenPool, error) { - ret := _m.Called(ctx, ns, connector, poolName) +// GetTokenPool provides a mock function with given fields: ctx, connector, poolName +func (_m *Manager) GetTokenPool(ctx context.Context, connector string, poolName string) (*core.TokenPool, error) { + ret := _m.Called(ctx, connector, poolName) var r0 *core.TokenPool - if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *core.TokenPool); ok { - r0 = rf(ctx, ns, connector, poolName) + if rf, ok := ret.Get(0).(func(context.Context, string, string) *core.TokenPool); ok { + r0 = rf(ctx, connector, poolName) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.TokenPool) @@ -238,8 +238,8 @@ func (_m *Manager) GetTokenPool(ctx context.Context, ns string, connector string } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { - r1 = rf(ctx, ns, connector, poolName) + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, connector, poolName) } else { r1 = ret.Error(1) } @@ -270,13 +270,13 @@ func (_m *Manager) GetTokenPoolByNameOrID(ctx context.Context, poolNameOrID stri return r0, r1 } -// GetTokenPools provides a mock function with given fields: ctx, ns, filter -func (_m *Manager) GetTokenPools(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenPool, *database.FilterResult, error) { - ret := _m.Called(ctx, ns, filter) +// GetTokenPools provides a mock function with given fields: ctx, filter +func (_m *Manager) GetTokenPools(ctx context.Context, filter database.AndFilter) ([]*core.TokenPool, *database.FilterResult, error) { + ret := _m.Called(ctx, filter) var r0 []*core.TokenPool - if rf, ok := ret.Get(0).(func(context.Context, string, database.AndFilter) []*core.TokenPool); ok { - r0 = rf(ctx, ns, filter) + if rf, ok := ret.Get(0).(func(context.Context, database.AndFilter) []*core.TokenPool); ok { + r0 = rf(ctx, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.TokenPool) @@ -284,8 +284,8 @@ func (_m *Manager) GetTokenPools(ctx context.Context, ns string, filter database } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, string, database.AndFilter) *database.FilterResult); ok { - r1 = rf(ctx, ns, filter) + if rf, ok := ret.Get(1).(func(context.Context, database.AndFilter) *database.FilterResult); ok { + r1 = rf(ctx, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -293,8 +293,8 @@ func (_m *Manager) GetTokenPools(ctx context.Context, ns string, filter database } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, string, database.AndFilter) error); ok { - r2 = rf(ctx, ns, filter) + if rf, ok := ret.Get(2).(func(context.Context, database.AndFilter) error); ok { + r2 = rf(ctx, filter) } else { r2 = ret.Error(2) } diff --git a/mocks/databasemocks/plugin.go b/mocks/databasemocks/plugin.go index ebe9ff266..e0b7f3235 100644 --- a/mocks/databasemocks/plugin.go +++ b/mocks/databasemocks/plugin.go @@ -1864,13 +1864,13 @@ func (_m *Plugin) GetTokenPool(ctx context.Context, namespace string, name strin return r0, r1 } -// GetTokenPoolByID provides a mock function with given fields: ctx, id -func (_m *Plugin) GetTokenPoolByID(ctx context.Context, id *fftypes.UUID) (*core.TokenPool, error) { - ret := _m.Called(ctx, id) +// GetTokenPoolByID provides a mock function with given fields: ctx, namespace, id +func (_m *Plugin) GetTokenPoolByID(ctx context.Context, namespace string, id *fftypes.UUID) (*core.TokenPool, error) { + ret := _m.Called(ctx, namespace, id) var r0 *core.TokenPool - if rf, ok := ret.Get(0).(func(context.Context, *fftypes.UUID) *core.TokenPool); ok { - r0 = rf(ctx, id) + if rf, ok := ret.Get(0).(func(context.Context, string, *fftypes.UUID) *core.TokenPool); ok { + r0 = rf(ctx, namespace, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.TokenPool) @@ -1878,8 +1878,8 @@ func (_m *Plugin) GetTokenPoolByID(ctx context.Context, id *fftypes.UUID) (*core } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *fftypes.UUID) error); ok { - r1 = rf(ctx, id) + if rf, ok := ret.Get(1).(func(context.Context, string, *fftypes.UUID) error); ok { + r1 = rf(ctx, namespace, id) } else { r1 = ret.Error(1) } @@ -1887,13 +1887,13 @@ func (_m *Plugin) GetTokenPoolByID(ctx context.Context, id *fftypes.UUID) (*core return r0, r1 } -// GetTokenPoolByLocator provides a mock function with given fields: ctx, connector, locator -func (_m *Plugin) GetTokenPoolByLocator(ctx context.Context, connector string, locator string) (*core.TokenPool, error) { - ret := _m.Called(ctx, connector, locator) +// GetTokenPoolByLocator provides a mock function with given fields: ctx, namespace, connector, locator +func (_m *Plugin) GetTokenPoolByLocator(ctx context.Context, namespace string, connector string, locator string) (*core.TokenPool, error) { + ret := _m.Called(ctx, namespace, connector, locator) var r0 *core.TokenPool - if rf, ok := ret.Get(0).(func(context.Context, string, string) *core.TokenPool); ok { - r0 = rf(ctx, connector, locator) + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *core.TokenPool); ok { + r0 = rf(ctx, namespace, connector, locator) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.TokenPool) @@ -1901,8 +1901,8 @@ func (_m *Plugin) GetTokenPoolByLocator(ctx context.Context, connector string, l } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, connector, locator) + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, namespace, connector, locator) } else { r1 = ret.Error(1) } @@ -1910,13 +1910,13 @@ func (_m *Plugin) GetTokenPoolByLocator(ctx context.Context, connector string, l return r0, r1 } -// GetTokenPools provides a mock function with given fields: ctx, filter -func (_m *Plugin) GetTokenPools(ctx context.Context, filter database.Filter) ([]*core.TokenPool, *database.FilterResult, error) { - ret := _m.Called(ctx, filter) +// GetTokenPools provides a mock function with given fields: ctx, namespace, filter +func (_m *Plugin) GetTokenPools(ctx context.Context, namespace string, filter database.Filter) ([]*core.TokenPool, *database.FilterResult, error) { + ret := _m.Called(ctx, namespace, filter) var r0 []*core.TokenPool - if rf, ok := ret.Get(0).(func(context.Context, database.Filter) []*core.TokenPool); ok { - r0 = rf(ctx, filter) + if rf, ok := ret.Get(0).(func(context.Context, string, database.Filter) []*core.TokenPool); ok { + r0 = rf(ctx, namespace, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.TokenPool) @@ -1924,8 +1924,8 @@ func (_m *Plugin) GetTokenPools(ctx context.Context, filter database.Filter) ([] } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, database.Filter) *database.FilterResult); ok { - r1 = rf(ctx, filter) + if rf, ok := ret.Get(1).(func(context.Context, string, database.Filter) *database.FilterResult); ok { + r1 = rf(ctx, namespace, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -1933,8 +1933,8 @@ func (_m *Plugin) GetTokenPools(ctx context.Context, filter database.Filter) ([] } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, database.Filter) error); ok { - r2 = rf(ctx, filter) + if rf, ok := ret.Get(2).(func(context.Context, string, database.Filter) error); ok { + r2 = rf(ctx, namespace, filter) } else { r2 = ret.Error(2) } diff --git a/pkg/database/plugin.go b/pkg/database/plugin.go index d38765805..06b141952 100644 --- a/pkg/database/plugin.go +++ b/pkg/database/plugin.go @@ -364,13 +364,13 @@ type iTokenPoolCollection interface { GetTokenPool(ctx context.Context, namespace, name string) (*core.TokenPool, error) // GetTokenPoolByID - Get a token pool by pool ID - GetTokenPoolByID(ctx context.Context, id *fftypes.UUID) (*core.TokenPool, error) + GetTokenPoolByID(ctx context.Context, namespace string, id *fftypes.UUID) (*core.TokenPool, error) - // GetTokenPoolByID - Get a token pool by locator - GetTokenPoolByLocator(ctx context.Context, connector, locator string) (*core.TokenPool, error) + // GetTokenPoolByLocator - Get a token pool by locator + GetTokenPoolByLocator(ctx context.Context, namespace, connector, locator string) (*core.TokenPool, error) // GetTokenPools - Get token pools - GetTokenPools(ctx context.Context, filter Filter) ([]*core.TokenPool, *FilterResult, error) + GetTokenPools(ctx context.Context, namespace string, filter Filter) ([]*core.TokenPool, *FilterResult, error) } type iTokenBalanceCollection interface { From 8ef63c90ede2774108a4b0421b9cce7a1e5c683c Mon Sep 17 00:00:00 2001 From: Andrew Richardson Date: Wed, 22 Jun 2022 13:43:13 -0400 Subject: [PATCH 3/9] Add namespace to token transfer database queries Signed-off-by: Andrew Richardson --- .../route_get_token_transfer_by_id.go | 2 +- .../route_get_token_transfer_by_id_test.go | 2 +- .../apiserver/route_get_token_transfers.go | 2 +- .../route_get_token_transfers_test.go | 4 +- internal/assets/manager.go | 4 +- internal/assets/token_transfer.go | 9 ++-- internal/assets/token_transfer_test.go | 10 ++-- .../database/sqlcommon/tokentransfer_sql.go | 12 +++-- .../sqlcommon/tokentransfer_sql_test.go | 20 ++++---- internal/events/aggregator.go | 2 +- internal/events/aggregator_test.go | 6 +-- internal/events/tokens_transferred.go | 4 +- internal/events/tokens_transferred_test.go | 28 +++++------ internal/orchestrator/txn_status.go | 2 +- internal/orchestrator/txn_status_test.go | 8 ++-- internal/syncasync/sync_async_bridge.go | 2 +- internal/syncasync/sync_async_bridge_test.go | 6 +-- internal/txcommon/event_enrich.go | 2 +- internal/txcommon/event_enrich_test.go | 4 +- mocks/assetmocks/manager.go | 32 ++++++------- mocks/databasemocks/plugin.go | 46 +++++++++---------- pkg/database/plugin.go | 6 +-- 22 files changed, 107 insertions(+), 106 deletions(-) diff --git a/internal/apiserver/route_get_token_transfer_by_id.go b/internal/apiserver/route_get_token_transfer_by_id.go index 675908b48..33b65ca1b 100644 --- a/internal/apiserver/route_get_token_transfer_by_id.go +++ b/internal/apiserver/route_get_token_transfer_by_id.go @@ -38,7 +38,7 @@ var getTokenTransferByID = &ffapi.Route{ JSONOutputCodes: []int{http.StatusOK}, Extensions: &coreExtensions{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - output, err = cr.or.Assets().GetTokenTransferByID(cr.ctx, extractNamespace(r.PP), r.PP["transferId"]) + output, err = cr.or.Assets().GetTokenTransferByID(cr.ctx, r.PP["transferId"]) return output, err }, }, diff --git a/internal/apiserver/route_get_token_transfer_by_id_test.go b/internal/apiserver/route_get_token_transfer_by_id_test.go index 3556ac90a..69f027a9f 100644 --- a/internal/apiserver/route_get_token_transfer_by_id_test.go +++ b/internal/apiserver/route_get_token_transfer_by_id_test.go @@ -34,7 +34,7 @@ func TestGetTokenTransferByID(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mam.On("GetTokenTransferByID", mock.Anything, "ns1", "id1"). + mam.On("GetTokenTransferByID", mock.Anything, "id1"). Return(&core.TokenTransfer{}, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_get_token_transfers.go b/internal/apiserver/route_get_token_transfers.go index 106ea7353..fea3048f4 100644 --- a/internal/apiserver/route_get_token_transfers.go +++ b/internal/apiserver/route_get_token_transfers.go @@ -48,7 +48,7 @@ var getTokenTransfers = &ffapi.Route{ Condition(fb.Eq("from", fromOrTo)). Condition(fb.Eq("to", fromOrTo))) } - return filterResult(cr.or.Assets().GetTokenTransfers(cr.ctx, extractNamespace(r.PP), filter)) + return filterResult(cr.or.Assets().GetTokenTransfers(cr.ctx, filter)) }, }, } diff --git a/internal/apiserver/route_get_token_transfers_test.go b/internal/apiserver/route_get_token_transfers_test.go index 89b453cdb..827149efd 100644 --- a/internal/apiserver/route_get_token_transfers_test.go +++ b/internal/apiserver/route_get_token_transfers_test.go @@ -35,7 +35,7 @@ func TestGetTokenTransfers(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mam.On("GetTokenTransfers", mock.Anything, "ns1", mock.Anything). + mam.On("GetTokenTransfers", mock.Anything, mock.Anything). Return([]*core.TokenTransfer{}, nil, nil) r.ServeHTTP(res, req) @@ -50,7 +50,7 @@ func TestGetTokenTransfersFromOrTo(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mam.On("GetTokenTransfers", mock.Anything, "ns1", mock.MatchedBy(func(filter database.AndFilter) bool { + mam.On("GetTokenTransfers", mock.Anything, mock.MatchedBy(func(filter database.AndFilter) bool { info, _ := filter.Finalize() return info.String() == "( ( from == '0x1' ) || ( to == '0x1' ) )" })).Return([]*core.TokenTransfer{}, nil, nil) diff --git a/internal/assets/manager.go b/internal/assets/manager.go index 29e11c619..2faa54f0d 100644 --- a/internal/assets/manager.go +++ b/internal/assets/manager.go @@ -50,8 +50,8 @@ type Manager interface { GetTokenAccounts(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenAccount, *database.FilterResult, error) GetTokenAccountPools(ctx context.Context, ns, key string, filter database.AndFilter) ([]*core.TokenAccountPool, *database.FilterResult, error) - GetTokenTransfers(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenTransfer, *database.FilterResult, error) - GetTokenTransferByID(ctx context.Context, ns, id string) (*core.TokenTransfer, error) + GetTokenTransfers(ctx context.Context, filter database.AndFilter) ([]*core.TokenTransfer, *database.FilterResult, error) + GetTokenTransferByID(ctx context.Context, id string) (*core.TokenTransfer, error) NewTransfer(transfer *core.TokenTransferInput) sysmessaging.MessageSender MintTokens(ctx context.Context, transfer *core.TokenTransferInput, waitConfirm bool) (*core.TokenTransfer, error) diff --git a/internal/assets/token_transfer.go b/internal/assets/token_transfer.go index 34a931800..52ad2e12e 100644 --- a/internal/assets/token_transfer.go +++ b/internal/assets/token_transfer.go @@ -28,17 +28,16 @@ import ( "github.com/hyperledger/firefly/pkg/database" ) -func (am *assetManager) GetTokenTransfers(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenTransfer, *database.FilterResult, error) { - return am.database.GetTokenTransfers(ctx, am.scopeNS(ns, filter)) +func (am *assetManager) GetTokenTransfers(ctx context.Context, filter database.AndFilter) ([]*core.TokenTransfer, *database.FilterResult, error) { + return am.database.GetTokenTransfers(ctx, am.namespace, filter) } -func (am *assetManager) GetTokenTransferByID(ctx context.Context, ns, id string) (*core.TokenTransfer, error) { +func (am *assetManager) GetTokenTransferByID(ctx context.Context, id string) (*core.TokenTransfer, error) { transferID, err := fftypes.ParseUUID(ctx, id) if err != nil { return nil, err } - - return am.database.GetTokenTransferByID(ctx, transferID) + return am.database.GetTokenTransferByID(ctx, am.namespace, transferID) } func (am *assetManager) NewTransfer(transfer *core.TokenTransferInput) sysmessaging.MessageSender { diff --git a/internal/assets/token_transfer_test.go b/internal/assets/token_transfer_test.go index f06462bfb..ac250eb4f 100644 --- a/internal/assets/token_transfer_test.go +++ b/internal/assets/token_transfer_test.go @@ -43,8 +43,8 @@ func TestGetTokenTransfers(t *testing.T) { mdi := am.database.(*databasemocks.Plugin) fb := database.TokenTransferQueryFactory.NewFilter(context.Background()) f := fb.And() - mdi.On("GetTokenTransfers", context.Background(), f).Return([]*core.TokenTransfer{}, nil, nil) - _, _, err := am.GetTokenTransfers(context.Background(), "ns1", f) + mdi.On("GetTokenTransfers", context.Background(), "ns1", f).Return([]*core.TokenTransfer{}, nil, nil) + _, _, err := am.GetTokenTransfers(context.Background(), f) assert.NoError(t, err) mdi.AssertExpectations(t) @@ -56,8 +56,8 @@ func TestGetTokenTransferByID(t *testing.T) { u := fftypes.NewUUID() mdi := am.database.(*databasemocks.Plugin) - mdi.On("GetTokenTransferByID", context.Background(), u).Return(&core.TokenTransfer{}, nil) - _, err := am.GetTokenTransferByID(context.Background(), "ns1", u.String()) + mdi.On("GetTokenTransferByID", context.Background(), "ns1", u).Return(&core.TokenTransfer{}, nil) + _, err := am.GetTokenTransferByID(context.Background(), u.String()) assert.NoError(t, err) mdi.AssertExpectations(t) @@ -67,7 +67,7 @@ func TestGetTokenTransferByIDBadID(t *testing.T) { am, cancel := newTestAssets(t) defer cancel() - _, err := am.GetTokenTransferByID(context.Background(), "ns1", "badUUID") + _, err := am.GetTokenTransferByID(context.Background(), "badUUID") assert.Regexp(t, "FF00138", err) } diff --git a/internal/database/sqlcommon/tokentransfer_sql.go b/internal/database/sqlcommon/tokentransfer_sql.go index c60ff51fb..d66ad13ab 100644 --- a/internal/database/sqlcommon/tokentransfer_sql.go +++ b/internal/database/sqlcommon/tokentransfer_sql.go @@ -200,19 +200,21 @@ func (s *SQLCommon) getTokenTransferPred(ctx context.Context, desc string, pred return transfer, nil } -func (s *SQLCommon) GetTokenTransferByID(ctx context.Context, localID *fftypes.UUID) (*core.TokenTransfer, error) { - return s.getTokenTransferPred(ctx, localID.String(), sq.Eq{"local_id": localID}) +func (s *SQLCommon) GetTokenTransferByID(ctx context.Context, namespace string, localID *fftypes.UUID) (*core.TokenTransfer, error) { + return s.getTokenTransferPred(ctx, localID.String(), sq.Eq{"local_id": localID, "namespace": namespace}) } -func (s *SQLCommon) GetTokenTransferByProtocolID(ctx context.Context, connector, protocolID string) (*core.TokenTransfer, error) { +func (s *SQLCommon) GetTokenTransferByProtocolID(ctx context.Context, namespace, connector, protocolID string) (*core.TokenTransfer, error) { return s.getTokenTransferPred(ctx, protocolID, sq.And{ + sq.Eq{"namespace": namespace}, sq.Eq{"connector": connector}, sq.Eq{"protocol_id": protocolID}, }) } -func (s *SQLCommon) GetTokenTransfers(ctx context.Context, filter database.Filter) (message []*core.TokenTransfer, fr *database.FilterResult, err error) { - query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(tokenTransferColumns...).From(tokentransferTable), filter, tokenTransferFilterFieldMap, []interface{}{"seq"}) +func (s *SQLCommon) GetTokenTransfers(ctx context.Context, namespace string, filter database.Filter) (message []*core.TokenTransfer, fr *database.FilterResult, err error) { + query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(tokenTransferColumns...).From(tokentransferTable), + filter, tokenTransferFilterFieldMap, []interface{}{"seq"}, sq.Eq{"namespace": namespace}) if err != nil { return nil, nil, err } diff --git a/internal/database/sqlcommon/tokentransfer_sql_test.go b/internal/database/sqlcommon/tokentransfer_sql_test.go index e30a242d5..352359869 100644 --- a/internal/database/sqlcommon/tokentransfer_sql_test.go +++ b/internal/database/sqlcommon/tokentransfer_sql_test.go @@ -69,14 +69,14 @@ func TestTokenTransferE2EWithDB(t *testing.T) { transferJson, _ := json.Marshal(&transfer) // Query back the token transfer (by ID) - transferRead, err := s.GetTokenTransferByID(ctx, transfer.LocalID) + transferRead, err := s.GetTokenTransferByID(ctx, "ns1", transfer.LocalID) assert.NoError(t, err) assert.NotNil(t, transferRead) transferReadJson, _ := json.Marshal(&transferRead) assert.Equal(t, string(transferJson), string(transferReadJson)) // Query back the token transfer (by protocol ID) - transferRead, err = s.GetTokenTransferByProtocolID(ctx, transfer.Connector, transfer.ProtocolID) + transferRead, err = s.GetTokenTransferByProtocolID(ctx, "ns1", transfer.Connector, transfer.ProtocolID) assert.NoError(t, err) assert.NotNil(t, transferRead) transferReadJson, _ = json.Marshal(&transferRead) @@ -92,7 +92,7 @@ func TestTokenTransferE2EWithDB(t *testing.T) { fb.Eq("protocolid", transfer.ProtocolID), fb.Eq("created", transfer.Created), ) - transfers, res, err := s.GetTokenTransfers(ctx, filter.Count(true)) + transfers, res, err := s.GetTokenTransfers(ctx, "ns1", filter.Count(true)) assert.NoError(t, err) assert.Equal(t, 1, len(transfers)) assert.Equal(t, int64(1), *res.TotalCount) @@ -107,7 +107,7 @@ func TestTokenTransferE2EWithDB(t *testing.T) { assert.NoError(t, err) // Query back the token transfer (by ID) - transferRead, err = s.GetTokenTransferByID(ctx, transfer.LocalID) + transferRead, err = s.GetTokenTransferByID(ctx, "ns1", transfer.LocalID) assert.NoError(t, err) assert.NotNil(t, transferRead) transferJson, _ = json.Marshal(&transfer) @@ -168,7 +168,7 @@ func TestUpsertTokenTransferFailCommit(t *testing.T) { func TestGetTokenTransferByIDSelectFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) - _, err := s.GetTokenTransferByID(context.Background(), fftypes.NewUUID()) + _, err := s.GetTokenTransferByID(context.Background(), "ns1", fftypes.NewUUID()) assert.Regexp(t, "FF10115", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -176,7 +176,7 @@ func TestGetTokenTransferByIDSelectFail(t *testing.T) { func TestGetTokenTransferByIDNotFound(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"protocolid"})) - msg, err := s.GetTokenTransferByID(context.Background(), fftypes.NewUUID()) + msg, err := s.GetTokenTransferByID(context.Background(), "ns1", fftypes.NewUUID()) assert.NoError(t, err) assert.Nil(t, msg) assert.NoError(t, mock.ExpectationsWereMet()) @@ -185,7 +185,7 @@ func TestGetTokenTransferByIDNotFound(t *testing.T) { func TestGetTokenTransferByIDScanFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"protocolid"}).AddRow("only one")) - _, err := s.GetTokenTransferByID(context.Background(), fftypes.NewUUID()) + _, err := s.GetTokenTransferByID(context.Background(), "ns1", fftypes.NewUUID()) assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -194,7 +194,7 @@ func TestGetTokenTransfersQueryFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) f := database.TokenTransferQueryFactory.NewFilter(context.Background()).Eq("protocolid", "") - _, _, err := s.GetTokenTransfers(context.Background(), f) + _, _, err := s.GetTokenTransfers(context.Background(), "ns1", f) assert.Regexp(t, "FF10115", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -202,7 +202,7 @@ func TestGetTokenTransfersQueryFail(t *testing.T) { func TestGetTokenTransfersBuildQueryFail(t *testing.T) { s, _ := newMockProvider().init() f := database.TokenTransferQueryFactory.NewFilter(context.Background()).Eq("protocolid", map[bool]bool{true: false}) - _, _, err := s.GetTokenTransfers(context.Background(), f) + _, _, err := s.GetTokenTransfers(context.Background(), "ns1", f) assert.Regexp(t, "FF00143.*id", err) } @@ -210,7 +210,7 @@ func TestGetTokenTransfersScanFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"protocolid"}).AddRow("only one")) f := database.TokenTransferQueryFactory.NewFilter(context.Background()).Eq("protocolid", "") - _, _, err := s.GetTokenTransfers(context.Background(), f) + _, _, err := s.GetTokenTransfers(context.Background(), "ns1", f) assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } diff --git a/internal/events/aggregator.go b/internal/events/aggregator.go index ae9a68b69..de3a771c9 100644 --- a/internal/events/aggregator.go +++ b/internal/events/aggregator.go @@ -518,7 +518,7 @@ func (ag *aggregator) attemptMessageDispatch(ctx context.Context, msg *core.Mess filter := fb.And( fb.Eq("message", msg.Header.ID), ) - if transfers, _, err := ag.database.GetTokenTransfers(ctx, filter); err != nil || len(transfers) == 0 { + if transfers, _, err := ag.database.GetTokenTransfers(ctx, ag.namespace, filter); err != nil || len(transfers) == 0 { log.L(ctx).Debugf("Transfer for message %s not yet available", msg.Header.ID) return "", false, err } else if !msg.Hash.Equals(transfers[0].MessageHash) { diff --git a/internal/events/aggregator_test.go b/internal/events/aggregator_test.go index aca5a353f..5cc20c1a9 100644 --- a/internal/events/aggregator_test.go +++ b/internal/events/aggregator_test.go @@ -1424,7 +1424,7 @@ func TestAttemptMessageDispatchMissingTransfers(t *testing.T) { org1 := newTestOrg("org1") mim.On("FindIdentityForVerifier", ag.ctx, mock.Anything, mock.Anything).Return(org1, nil) mdi := ag.database.(*databasemocks.Plugin) - mdi.On("GetTokenTransfers", ag.ctx, mock.Anything).Return([]*core.TokenTransfer{}, nil, nil) + mdi.On("GetTokenTransfers", ag.ctx, "ns1", mock.Anything).Return([]*core.TokenTransfer{}, nil, nil) msg := &core.Message{ Header: core.MessageHeader{ @@ -1454,7 +1454,7 @@ func TestAttemptMessageDispatchGetTransfersFail(t *testing.T) { mim.On("FindIdentityForVerifier", ag.ctx, mock.Anything, mock.Anything).Return(org1, nil) mdi := ag.database.(*databasemocks.Plugin) - mdi.On("GetTokenTransfers", ag.ctx, mock.Anything).Return(nil, nil, fmt.Errorf("pop")) + mdi.On("GetTokenTransfers", ag.ctx, "ns1", mock.Anything).Return(nil, nil, fmt.Errorf("pop")) msg := &core.Message{ Header: core.MessageHeader{ @@ -1495,7 +1495,7 @@ func TestAttemptMessageDispatchTransferMismatch(t *testing.T) { mim.On("FindIdentityForVerifier", ag.ctx, mock.Anything, mock.Anything).Return(org1, nil) mdi := ag.database.(*databasemocks.Plugin) - mdi.On("GetTokenTransfers", ag.ctx, mock.Anything).Return(transfers, nil, nil) + mdi.On("GetTokenTransfers", ag.ctx, "ns1", mock.Anything).Return(transfers, nil, nil) _, dispatched, err := ag.attemptMessageDispatch(ag.ctx, msg, core.DataArray{}, nil, &batchState{}, &core.Pin{Signer: "0x12345"}) assert.NoError(t, err) diff --git a/internal/events/tokens_transferred.go b/internal/events/tokens_transferred.go index 2166e14f3..36bc33e8e 100644 --- a/internal/events/tokens_transferred.go +++ b/internal/events/tokens_transferred.go @@ -55,7 +55,7 @@ func (em *eventManager) loadTransferID(ctx context.Context, tx *fftypes.UUID, tr log.L(ctx).Warnf("Failed to read operation inputs for token transfer '%s': %s", transfer.ProtocolID, err) } else if input != nil && input.Connector == transfer.Connector && input.Pool.Equals(transfer.Pool) { // Check if the LocalID has already been used - if existing, err := em.database.GetTokenTransferByID(ctx, input.LocalID); err != nil { + if existing, err := em.database.GetTokenTransferByID(ctx, em.namespace, input.LocalID); err != nil { return nil, err } else if existing == nil { // Everything matches - use the LocalID that was assigned up-front when the operation was submitted @@ -86,7 +86,7 @@ func (em *eventManager) persistTokenTransfer(ctx context.Context, transfer *toke transfer.Pool = pool.ID // Check that transfer has not already been recorded - if existing, err := em.database.GetTokenTransferByProtocolID(ctx, transfer.Connector, transfer.ProtocolID); err != nil { + if existing, err := em.database.GetTokenTransferByProtocolID(ctx, em.namespace, transfer.Connector, transfer.ProtocolID); err != nil { return false, err } else if existing != nil { log.L(ctx).Warnf("Token transfer '%s' has already been recorded - ignoring", transfer.ProtocolID) diff --git a/internal/events/tokens_transferred_test.go b/internal/events/tokens_transferred_test.go index c4d90acf4..f2f817f94 100644 --- a/internal/events/tokens_transferred_test.go +++ b/internal/events/tokens_transferred_test.go @@ -75,8 +75,8 @@ func TestTokensTransferredSucceedWithRetries(t *testing.T) { mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(nil, fmt.Errorf("pop")).Once() mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil).Times(4) - mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, fmt.Errorf("pop")).Once() - mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil).Times(3) + mdi.On("GetTokenTransferByProtocolID", em.ctx, "ns1", "erc1155", "123").Return(nil, fmt.Errorf("pop")).Once() + mdi.On("GetTokenTransferByProtocolID", em.ctx, "ns1", "erc1155", "123").Return(nil, nil).Times(3) mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", (*fftypes.UUID)(nil), transfer.Event.ProtocolID).Return(nil, nil) mth.On("InsertBlockchainEvent", em.ctx, mock.MatchedBy(func(e *core.BlockchainEvent) bool { return e.Namespace == pool.Namespace && e.Name == transfer.Event.Name @@ -113,7 +113,7 @@ func TestTokensTransferredIgnoreExisting(t *testing.T) { Namespace: "ns1", } - mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(&core.TokenTransfer{}, nil) + mdi.On("GetTokenTransferByProtocolID", em.ctx, "ns1", "erc1155", "123").Return(&core.TokenTransfer{}, nil) mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) err := em.TokensTransferred(mti, transfer) @@ -156,7 +156,7 @@ func TestPersistTransferOpFail(t *testing.T) { Namespace: "ns1", } - mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil) + mdi.On("GetTokenTransferByProtocolID", em.ctx, "ns1", "erc1155", "123").Return(nil, nil) mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(nil, nil, fmt.Errorf("pop")) @@ -186,7 +186,7 @@ func TestPersistTransferBadOp(t *testing.T) { Transaction: fftypes.NewUUID(), }} - mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil) + mdi.On("GetTokenTransferByProtocolID", em.ctx, "ns1", "erc1155", "123").Return(nil, nil) mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) mth.On("PersistTransaction", mock.Anything, transfer.TX.ID, core.TransactionTypeTokenTransfer, "0xffffeeee").Return(false, fmt.Errorf("pop")) @@ -218,7 +218,7 @@ func TestPersistTransferTxFail(t *testing.T) { }, }} - mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil) + mdi.On("GetTokenTransferByProtocolID", em.ctx, "ns1", "erc1155", "123").Return(nil, nil) mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) mth.On("PersistTransaction", mock.Anything, transfer.TX.ID, core.TransactionTypeTokenTransfer, "0xffffeeee").Return(false, fmt.Errorf("pop")) @@ -251,11 +251,11 @@ func TestPersistTransferGetTransferFail(t *testing.T) { }, }} - mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil) + mdi.On("GetTokenTransferByProtocolID", em.ctx, "ns1", "erc1155", "123").Return(nil, nil) mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) mth.On("PersistTransaction", mock.Anything, transfer.TX.ID, core.TransactionTypeTokenTransfer, "0xffffeeee").Return(true, nil) - mdi.On("GetTokenTransferByID", em.ctx, localID).Return(nil, fmt.Errorf("pop")) + mdi.On("GetTokenTransferByID", em.ctx, "ns1", localID).Return(nil, fmt.Errorf("pop")) valid, err := em.persistTokenTransfer(em.ctx, transfer) assert.False(t, valid) @@ -285,11 +285,11 @@ func TestPersistTransferBlockchainEventFail(t *testing.T) { }, }} - mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil) + mdi.On("GetTokenTransferByProtocolID", em.ctx, "ns1", "erc1155", "123").Return(nil, nil) mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) mth.On("PersistTransaction", mock.Anything, transfer.TX.ID, core.TransactionTypeTokenTransfer, "0xffffeeee").Return(true, nil) - mdi.On("GetTokenTransferByID", em.ctx, localID).Return(nil, nil) + mdi.On("GetTokenTransferByID", em.ctx, "ns1", localID).Return(nil, nil) mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", (*fftypes.UUID)(nil), transfer.Event.ProtocolID).Return(nil, nil) mth.On("InsertBlockchainEvent", em.ctx, mock.MatchedBy(func(e *core.BlockchainEvent) bool { return e.Namespace == pool.Namespace && e.Name == transfer.Event.Name @@ -324,11 +324,11 @@ func TestTokensTransferredWithTransactionRegenerateLocalID(t *testing.T) { }, }} - mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil) + mdi.On("GetTokenTransferByProtocolID", em.ctx, "ns1", "erc1155", "123").Return(nil, nil) mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(operations, nil, nil) mth.On("PersistTransaction", mock.Anything, transfer.TX.ID, core.TransactionTypeTokenTransfer, "0xffffeeee").Return(true, nil) - mdi.On("GetTokenTransferByID", em.ctx, localID).Return(&core.TokenTransfer{}, nil) + mdi.On("GetTokenTransferByID", em.ctx, "ns1", localID).Return(&core.TokenTransfer{}, nil) mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", (*fftypes.UUID)(nil), transfer.Event.ProtocolID).Return(nil, nil) mth.On("InsertBlockchainEvent", em.ctx, mock.MatchedBy(func(e *core.BlockchainEvent) bool { return e.Namespace == pool.Namespace && e.Name == transfer.Event.Name @@ -405,7 +405,7 @@ func TestTokensTransferredWithMessageReceived(t *testing.T) { BatchID: fftypes.NewUUID(), } - mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil).Times(2) + mdi.On("GetTokenTransferByProtocolID", em.ctx, "ns1", "erc1155", "123").Return(nil, nil).Times(2) mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil).Times(2) mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", (*fftypes.UUID)(nil), transfer.Event.ProtocolID).Return(nil, nil) mth.On("InsertBlockchainEvent", em.ctx, mock.MatchedBy(func(e *core.BlockchainEvent) bool { @@ -469,7 +469,7 @@ func TestTokensTransferredWithMessageSend(t *testing.T) { State: core.MessageStateStaged, } - mdi.On("GetTokenTransferByProtocolID", em.ctx, "erc1155", "123").Return(nil, nil).Times(2) + mdi.On("GetTokenTransferByProtocolID", em.ctx, "ns1", "erc1155", "123").Return(nil, nil).Times(2) mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil).Times(2) mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", (*fftypes.UUID)(nil), transfer.Event.ProtocolID).Return(nil, nil) mth.On("InsertBlockchainEvent", em.ctx, mock.MatchedBy(func(e *core.BlockchainEvent) bool { diff --git a/internal/orchestrator/txn_status.go b/internal/orchestrator/txn_status.go index 3722199de..82c9e869c 100644 --- a/internal/orchestrator/txn_status.go +++ b/internal/orchestrator/txn_status.go @@ -150,7 +150,7 @@ func (or *orchestrator) GetTransactionStatus(ctx context.Context, id string) (*c updateStatus(result, core.OpStatusPending) } f := database.TokenTransferQueryFactory.NewFilter(ctx) - switch transfers, _, err := or.database().GetTokenTransfers(ctx, f.Eq("tx.id", id)); { + switch transfers, _, err := or.database().GetTokenTransfers(ctx, or.namespace, f.Eq("tx.id", id)); { case err != nil: return nil, err case len(transfers) == 0: diff --git a/internal/orchestrator/txn_status_test.go b/internal/orchestrator/txn_status_test.go index 4a0cbdc00..27603b65f 100644 --- a/internal/orchestrator/txn_status_test.go +++ b/internal/orchestrator/txn_status_test.go @@ -453,7 +453,7 @@ func TestGetTransactionStatusTokenTransferSuccess(t *testing.T) { or.mdi.On("GetTransactionByID", mock.Anything, "ns", txID).Return(tx, nil) or.mdi.On("GetOperations", mock.Anything, "ns", mock.Anything).Return(ops, nil, nil) or.mdi.On("GetBlockchainEvents", mock.Anything, "ns", mock.Anything).Return(events, nil, nil) - or.mdi.On("GetTokenTransfers", mock.Anything, mock.Anything).Return(transfers, nil, nil) + or.mdi.On("GetTokenTransfers", mock.Anything, "ns", mock.Anything).Return(transfers, nil, nil) status, err := or.GetTransactionStatus(context.Background(), txID.String()) assert.NoError(t, err) @@ -589,7 +589,7 @@ func TestGetTransactionStatusTokenTransferPending(t *testing.T) { or.mdi.On("GetTransactionByID", mock.Anything, "ns", txID).Return(tx, nil) or.mdi.On("GetOperations", mock.Anything, "ns", mock.Anything).Return(ops, nil, nil) or.mdi.On("GetBlockchainEvents", mock.Anything, "ns", mock.Anything).Return(events, nil, nil) - or.mdi.On("GetTokenTransfers", mock.Anything, mock.Anything).Return(transfers, nil, nil) + or.mdi.On("GetTokenTransfers", mock.Anything, "ns", mock.Anything).Return(transfers, nil, nil) status, err := or.GetTransactionStatus(context.Background(), txID.String()) assert.NoError(t, err) @@ -651,7 +651,7 @@ func TestGetTransactionStatusTokenTransferRetry(t *testing.T) { or.mdi.On("GetTransactionByID", mock.Anything, "ns", txID).Return(tx, nil) or.mdi.On("GetOperations", mock.Anything, "ns", mock.Anything).Return(ops, nil, nil) or.mdi.On("GetBlockchainEvents", mock.Anything, "ns", mock.Anything).Return(events, nil, nil) - or.mdi.On("GetTokenTransfers", mock.Anything, mock.Anything).Return(transfers, nil, nil) + or.mdi.On("GetTokenTransfers", mock.Anything, "ns", mock.Anything).Return(transfers, nil, nil) status, err := or.GetTransactionStatus(context.Background(), txID.String()) assert.NoError(t, err) @@ -894,7 +894,7 @@ func TestGetTransactionStatusTransferError(t *testing.T) { or.mdi.On("GetTransactionByID", mock.Anything, "ns", txID).Return(tx, nil) or.mdi.On("GetOperations", mock.Anything, "ns", mock.Anything).Return(nil, nil, nil) or.mdi.On("GetBlockchainEvents", mock.Anything, "ns", mock.Anything).Return(nil, nil, nil) - or.mdi.On("GetTokenTransfers", mock.Anything, mock.Anything).Return(nil, nil, fmt.Errorf("pop")) + or.mdi.On("GetTokenTransfers", mock.Anything, "ns", mock.Anything).Return(nil, nil, fmt.Errorf("pop")) _, err := or.GetTransactionStatus(context.Background(), txID.String()) assert.EqualError(t, err, "pop") diff --git a/internal/syncasync/sync_async_bridge.go b/internal/syncasync/sync_async_bridge.go index 9a08b852c..bf2bd370e 100644 --- a/internal/syncasync/sync_async_bridge.go +++ b/internal/syncasync/sync_async_bridge.go @@ -215,7 +215,7 @@ func (sa *syncAsyncBridge) getPoolFromMessage(msg *core.Message) (*core.TokenPoo } func (sa *syncAsyncBridge) getTransferFromEvent(event *core.EventDelivery) (transfer *core.TokenTransfer, err error) { - if transfer, err = sa.database.GetTokenTransferByID(sa.ctx, event.Reference); err != nil { + if transfer, err = sa.database.GetTokenTransferByID(sa.ctx, sa.namespace, event.Reference); err != nil { return nil, err } if transfer == nil { diff --git a/internal/syncasync/sync_async_bridge_test.go b/internal/syncasync/sync_async_bridge_test.go index 5d37b3593..10dd8b8d5 100644 --- a/internal/syncasync/sync_async_bridge_test.go +++ b/internal/syncasync/sync_async_bridge_test.go @@ -434,7 +434,7 @@ func TestEventCallbackTokenTransferLookupFail(t *testing.T) { } mdi := sa.database.(*databasemocks.Plugin) - mdi.On("GetTokenTransferByID", sa.ctx, mock.Anything).Return(nil, fmt.Errorf("pop")) + mdi.On("GetTokenTransferByID", sa.ctx, "ns1", mock.Anything).Return(nil, fmt.Errorf("pop")) err := sa.eventCallback(&core.EventDelivery{ EnrichedEvent: core.EnrichedEvent{ @@ -596,7 +596,7 @@ func TestEventCallbackTokenTransferNotFound(t *testing.T) { } mdi := sa.database.(*databasemocks.Plugin) - mdi.On("GetTokenTransferByID", sa.ctx, mock.Anything).Return(nil, nil) + mdi.On("GetTokenTransferByID", sa.ctx, "ns1", mock.Anything).Return(nil, nil) err := sa.eventCallback(&core.EventDelivery{ EnrichedEvent: core.EnrichedEvent{ @@ -874,7 +874,7 @@ func TestAwaitTokenTransferConfirmation(t *testing.T) { mse.On("AddSystemEventListener", "ns1", mock.Anything).Return(nil) mdi := sa.database.(*databasemocks.Plugin) - gmid := mdi.On("GetTokenTransferByID", sa.ctx, mock.Anything) + gmid := mdi.On("GetTokenTransferByID", sa.ctx, "ns1", mock.Anything) gmid.RunFn = func(a mock.Arguments) { transfer := &core.TokenTransfer{ LocalID: requestID, diff --git a/internal/txcommon/event_enrich.go b/internal/txcommon/event_enrich.go index 8bead76da..f10d8cdfa 100644 --- a/internal/txcommon/event_enrich.go +++ b/internal/txcommon/event_enrich.go @@ -89,7 +89,7 @@ func (t *transactionHelper) EnrichEvent(ctx context.Context, event *core.Event) } e.TokenApproval = approval case core.EventTypeTransferConfirmed: - transfer, err := t.database.GetTokenTransferByID(ctx, event.Reference) + transfer, err := t.database.GetTokenTransferByID(ctx, t.namespace, event.Reference) if err != nil { return nil, err } diff --git a/internal/txcommon/event_enrich_test.go b/internal/txcommon/event_enrich_test.go index 29afdf9d7..7e62d9f34 100644 --- a/internal/txcommon/event_enrich_test.go +++ b/internal/txcommon/event_enrich_test.go @@ -584,7 +584,7 @@ func TestEnrichTokenTransferConfirmed(t *testing.T) { ev1 := fftypes.NewUUID() // Setup enrichment - mdi.On("GetTokenTransferByID", mock.Anything, ref1).Return(&core.TokenTransfer{ + mdi.On("GetTokenTransferByID", mock.Anything, "ns1", ref1).Return(&core.TokenTransfer{ LocalID: ref1, }, nil) @@ -636,7 +636,7 @@ func TestEnrichTokenTransferConfirmedFail(t *testing.T) { ev1 := fftypes.NewUUID() // Setup enrichment - mdi.On("GetTokenTransferByID", mock.Anything, ref1).Return(nil, fmt.Errorf("pop")) + mdi.On("GetTokenTransferByID", mock.Anything, "ns1", ref1).Return(nil, fmt.Errorf("pop")) event := &core.Event{ ID: ev1, diff --git a/mocks/assetmocks/manager.go b/mocks/assetmocks/manager.go index 478ea8ea7..ec83b65da 100644 --- a/mocks/assetmocks/manager.go +++ b/mocks/assetmocks/manager.go @@ -302,13 +302,13 @@ func (_m *Manager) GetTokenPools(ctx context.Context, filter database.AndFilter) return r0, r1, r2 } -// GetTokenTransferByID provides a mock function with given fields: ctx, ns, id -func (_m *Manager) GetTokenTransferByID(ctx context.Context, ns string, id string) (*core.TokenTransfer, error) { - ret := _m.Called(ctx, ns, id) +// GetTokenTransferByID provides a mock function with given fields: ctx, id +func (_m *Manager) GetTokenTransferByID(ctx context.Context, id string) (*core.TokenTransfer, error) { + ret := _m.Called(ctx, id) var r0 *core.TokenTransfer - if rf, ok := ret.Get(0).(func(context.Context, string, string) *core.TokenTransfer); ok { - r0 = rf(ctx, ns, id) + if rf, ok := ret.Get(0).(func(context.Context, string) *core.TokenTransfer); ok { + r0 = rf(ctx, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.TokenTransfer) @@ -316,8 +316,8 @@ func (_m *Manager) GetTokenTransferByID(ctx context.Context, ns string, id strin } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, ns, id) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -325,13 +325,13 @@ func (_m *Manager) GetTokenTransferByID(ctx context.Context, ns string, id strin return r0, r1 } -// GetTokenTransfers provides a mock function with given fields: ctx, ns, filter -func (_m *Manager) GetTokenTransfers(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenTransfer, *database.FilterResult, error) { - ret := _m.Called(ctx, ns, filter) +// GetTokenTransfers provides a mock function with given fields: ctx, filter +func (_m *Manager) GetTokenTransfers(ctx context.Context, filter database.AndFilter) ([]*core.TokenTransfer, *database.FilterResult, error) { + ret := _m.Called(ctx, filter) var r0 []*core.TokenTransfer - if rf, ok := ret.Get(0).(func(context.Context, string, database.AndFilter) []*core.TokenTransfer); ok { - r0 = rf(ctx, ns, filter) + if rf, ok := ret.Get(0).(func(context.Context, database.AndFilter) []*core.TokenTransfer); ok { + r0 = rf(ctx, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.TokenTransfer) @@ -339,8 +339,8 @@ func (_m *Manager) GetTokenTransfers(ctx context.Context, ns string, filter data } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, string, database.AndFilter) *database.FilterResult); ok { - r1 = rf(ctx, ns, filter) + if rf, ok := ret.Get(1).(func(context.Context, database.AndFilter) *database.FilterResult); ok { + r1 = rf(ctx, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -348,8 +348,8 @@ func (_m *Manager) GetTokenTransfers(ctx context.Context, ns string, filter data } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, string, database.AndFilter) error); ok { - r2 = rf(ctx, ns, filter) + if rf, ok := ret.Get(2).(func(context.Context, database.AndFilter) error); ok { + r2 = rf(ctx, filter) } else { r2 = ret.Error(2) } diff --git a/mocks/databasemocks/plugin.go b/mocks/databasemocks/plugin.go index e0b7f3235..5618f44af 100644 --- a/mocks/databasemocks/plugin.go +++ b/mocks/databasemocks/plugin.go @@ -1942,13 +1942,13 @@ func (_m *Plugin) GetTokenPools(ctx context.Context, namespace string, filter da return r0, r1, r2 } -// GetTokenTransferByID provides a mock function with given fields: ctx, localID -func (_m *Plugin) GetTokenTransferByID(ctx context.Context, localID *fftypes.UUID) (*core.TokenTransfer, error) { - ret := _m.Called(ctx, localID) +// GetTokenTransferByID provides a mock function with given fields: ctx, namespace, localID +func (_m *Plugin) GetTokenTransferByID(ctx context.Context, namespace string, localID *fftypes.UUID) (*core.TokenTransfer, error) { + ret := _m.Called(ctx, namespace, localID) var r0 *core.TokenTransfer - if rf, ok := ret.Get(0).(func(context.Context, *fftypes.UUID) *core.TokenTransfer); ok { - r0 = rf(ctx, localID) + if rf, ok := ret.Get(0).(func(context.Context, string, *fftypes.UUID) *core.TokenTransfer); ok { + r0 = rf(ctx, namespace, localID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.TokenTransfer) @@ -1956,8 +1956,8 @@ func (_m *Plugin) GetTokenTransferByID(ctx context.Context, localID *fftypes.UUI } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *fftypes.UUID) error); ok { - r1 = rf(ctx, localID) + if rf, ok := ret.Get(1).(func(context.Context, string, *fftypes.UUID) error); ok { + r1 = rf(ctx, namespace, localID) } else { r1 = ret.Error(1) } @@ -1965,13 +1965,13 @@ func (_m *Plugin) GetTokenTransferByID(ctx context.Context, localID *fftypes.UUI return r0, r1 } -// GetTokenTransferByProtocolID provides a mock function with given fields: ctx, connector, protocolID -func (_m *Plugin) GetTokenTransferByProtocolID(ctx context.Context, connector string, protocolID string) (*core.TokenTransfer, error) { - ret := _m.Called(ctx, connector, protocolID) +// GetTokenTransferByProtocolID provides a mock function with given fields: ctx, namespace, connector, protocolID +func (_m *Plugin) GetTokenTransferByProtocolID(ctx context.Context, namespace string, connector string, protocolID string) (*core.TokenTransfer, error) { + ret := _m.Called(ctx, namespace, connector, protocolID) var r0 *core.TokenTransfer - if rf, ok := ret.Get(0).(func(context.Context, string, string) *core.TokenTransfer); ok { - r0 = rf(ctx, connector, protocolID) + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *core.TokenTransfer); ok { + r0 = rf(ctx, namespace, connector, protocolID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.TokenTransfer) @@ -1979,8 +1979,8 @@ func (_m *Plugin) GetTokenTransferByProtocolID(ctx context.Context, connector st } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, connector, protocolID) + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, namespace, connector, protocolID) } else { r1 = ret.Error(1) } @@ -1988,13 +1988,13 @@ func (_m *Plugin) GetTokenTransferByProtocolID(ctx context.Context, connector st return r0, r1 } -// GetTokenTransfers provides a mock function with given fields: ctx, filter -func (_m *Plugin) GetTokenTransfers(ctx context.Context, filter database.Filter) ([]*core.TokenTransfer, *database.FilterResult, error) { - ret := _m.Called(ctx, filter) +// GetTokenTransfers provides a mock function with given fields: ctx, namespace, filter +func (_m *Plugin) GetTokenTransfers(ctx context.Context, namespace string, filter database.Filter) ([]*core.TokenTransfer, *database.FilterResult, error) { + ret := _m.Called(ctx, namespace, filter) var r0 []*core.TokenTransfer - if rf, ok := ret.Get(0).(func(context.Context, database.Filter) []*core.TokenTransfer); ok { - r0 = rf(ctx, filter) + if rf, ok := ret.Get(0).(func(context.Context, string, database.Filter) []*core.TokenTransfer); ok { + r0 = rf(ctx, namespace, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.TokenTransfer) @@ -2002,8 +2002,8 @@ func (_m *Plugin) GetTokenTransfers(ctx context.Context, filter database.Filter) } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, database.Filter) *database.FilterResult); ok { - r1 = rf(ctx, filter) + if rf, ok := ret.Get(1).(func(context.Context, string, database.Filter) *database.FilterResult); ok { + r1 = rf(ctx, namespace, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -2011,8 +2011,8 @@ func (_m *Plugin) GetTokenTransfers(ctx context.Context, filter database.Filter) } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, database.Filter) error); ok { - r2 = rf(ctx, filter) + if rf, ok := ret.Get(2).(func(context.Context, string, database.Filter) error); ok { + r2 = rf(ctx, namespace, filter) } else { r2 = ret.Error(2) } diff --git a/pkg/database/plugin.go b/pkg/database/plugin.go index 06b141952..510e9b11e 100644 --- a/pkg/database/plugin.go +++ b/pkg/database/plugin.go @@ -395,13 +395,13 @@ type iTokenTransferCollection interface { UpsertTokenTransfer(ctx context.Context, transfer *core.TokenTransfer) error // GetTokenTransferByID - Get a token transfer by ID - GetTokenTransferByID(ctx context.Context, localID *fftypes.UUID) (*core.TokenTransfer, error) + GetTokenTransferByID(ctx context.Context, namespace string, localID *fftypes.UUID) (*core.TokenTransfer, error) // GetTokenTransferByProtocolID - Get a token transfer by protocol ID - GetTokenTransferByProtocolID(ctx context.Context, connector, protocolID string) (*core.TokenTransfer, error) + GetTokenTransferByProtocolID(ctx context.Context, namespace, connector, protocolID string) (*core.TokenTransfer, error) // GetTokenTransfers - Get token transfers - GetTokenTransfers(ctx context.Context, filter Filter) ([]*core.TokenTransfer, *FilterResult, error) + GetTokenTransfers(ctx context.Context, namespace string, filter Filter) ([]*core.TokenTransfer, *FilterResult, error) } type iTokenApprovalCollection interface { From 41e9f5007d16bb64a42fd3f4b2e5b9b17ca4ad31 Mon Sep 17 00:00:00 2001 From: Andrew Richardson Date: Wed, 22 Jun 2022 13:49:41 -0400 Subject: [PATCH 4/9] Add namespace to token approval database queries Signed-off-by: Andrew Richardson --- .../apiserver/route_get_token_approvals.go | 2 +- .../route_get_token_approvals_test.go | 2 +- internal/assets/manager.go | 2 +- internal/assets/token_approval.go | 4 +- internal/assets/token_approval_test.go | 4 +- .../database/sqlcommon/tokenapproval_sql.go | 12 +++-- .../sqlcommon/tokenapproval_sql_test.go | 22 ++++----- internal/events/tokens_approved.go | 4 +- internal/events/tokens_approved_test.go | 26 +++++------ internal/orchestrator/txn_status.go | 2 +- internal/orchestrator/txn_status_test.go | 6 +-- internal/syncasync/sync_async_bridge.go | 2 +- internal/syncasync/sync_async_bridge_test.go | 6 +-- internal/txcommon/event_enrich.go | 2 +- internal/txcommon/event_enrich_test.go | 4 +- mocks/assetmocks/manager.go | 18 ++++---- mocks/databasemocks/plugin.go | 46 +++++++++---------- pkg/database/plugin.go | 6 +-- 18 files changed, 86 insertions(+), 84 deletions(-) diff --git a/internal/apiserver/route_get_token_approvals.go b/internal/apiserver/route_get_token_approvals.go index 36306dccc..9cafdd571 100644 --- a/internal/apiserver/route_get_token_approvals.go +++ b/internal/apiserver/route_get_token_approvals.go @@ -38,7 +38,7 @@ var getTokenApprovals = &ffapi.Route{ FilterFactory: database.TokenApprovalQueryFactory, CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { filter := cr.filter - return filterResult(cr.or.Assets().GetTokenApprovals(cr.ctx, extractNamespace(r.PP), filter)) + return filterResult(cr.or.Assets().GetTokenApprovals(cr.ctx, filter)) }, }, } diff --git a/internal/apiserver/route_get_token_approvals_test.go b/internal/apiserver/route_get_token_approvals_test.go index c8cfa8e47..d09d40185 100644 --- a/internal/apiserver/route_get_token_approvals_test.go +++ b/internal/apiserver/route_get_token_approvals_test.go @@ -34,7 +34,7 @@ func TestGetTokenApprovals(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mam.On("GetTokenApprovals", mock.Anything, "ns1", mock.Anything). + mam.On("GetTokenApprovals", mock.Anything, mock.Anything). Return([]*core.TokenApproval{}, nil, nil) r.ServeHTTP(res, req) diff --git a/internal/assets/manager.go b/internal/assets/manager.go index 2faa54f0d..db1249cfa 100644 --- a/internal/assets/manager.go +++ b/internal/assets/manager.go @@ -62,7 +62,7 @@ type Manager interface { NewApproval(approve *core.TokenApprovalInput) sysmessaging.MessageSender TokenApproval(ctx context.Context, approval *core.TokenApprovalInput, waitConfirm bool) (*core.TokenApproval, error) - GetTokenApprovals(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenApproval, *database.FilterResult, error) + GetTokenApprovals(ctx context.Context, filter database.AndFilter) ([]*core.TokenApproval, *database.FilterResult, error) // From operations.OperationHandler PrepareOperation(ctx context.Context, op *core.Operation) (*core.PreparedOperation, error) diff --git a/internal/assets/token_approval.go b/internal/assets/token_approval.go index e0b2bc6ff..5d41a80f9 100644 --- a/internal/assets/token_approval.go +++ b/internal/assets/token_approval.go @@ -28,8 +28,8 @@ import ( "github.com/hyperledger/firefly/pkg/database" ) -func (am *assetManager) GetTokenApprovals(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenApproval, *database.FilterResult, error) { - return am.database.GetTokenApprovals(ctx, am.scopeNS(ns, filter)) +func (am *assetManager) GetTokenApprovals(ctx context.Context, filter database.AndFilter) ([]*core.TokenApproval, *database.FilterResult, error) { + return am.database.GetTokenApprovals(ctx, am.namespace, filter) } type approveSender struct { diff --git a/internal/assets/token_approval_test.go b/internal/assets/token_approval_test.go index 7974d3444..763b436ba 100644 --- a/internal/assets/token_approval_test.go +++ b/internal/assets/token_approval_test.go @@ -40,8 +40,8 @@ func TestGetTokenApprovals(t *testing.T) { mdi := am.database.(*databasemocks.Plugin) fb := database.TokenApprovalQueryFactory.NewFilter(context.Background()) f := fb.And() - mdi.On("GetTokenApprovals", context.Background(), f).Return([]*core.TokenApproval{}, nil, nil) - _, _, err := am.GetTokenApprovals(context.Background(), "ns1", f) + mdi.On("GetTokenApprovals", context.Background(), "ns1", f).Return([]*core.TokenApproval{}, nil, nil) + _, _, err := am.GetTokenApprovals(context.Background(), f) assert.NoError(t, err) } diff --git a/internal/database/sqlcommon/tokenapproval_sql.go b/internal/database/sqlcommon/tokenapproval_sql.go index da769d899..33dd09c9c 100644 --- a/internal/database/sqlcommon/tokenapproval_sql.go +++ b/internal/database/sqlcommon/tokenapproval_sql.go @@ -185,19 +185,21 @@ func (s *SQLCommon) getTokenApprovalPred(ctx context.Context, desc string, pred return approval, nil } -func (s *SQLCommon) GetTokenApprovalByID(ctx context.Context, localID *fftypes.UUID) (*core.TokenApproval, error) { - return s.getTokenApprovalPred(ctx, localID.String(), sq.Eq{"local_id": localID}) +func (s *SQLCommon) GetTokenApprovalByID(ctx context.Context, namespace string, localID *fftypes.UUID) (*core.TokenApproval, error) { + return s.getTokenApprovalPred(ctx, localID.String(), sq.Eq{"local_id": localID, "namespace": namespace}) } -func (s *SQLCommon) GetTokenApprovalByProtocolID(ctx context.Context, connector, protocolID string) (*core.TokenApproval, error) { +func (s *SQLCommon) GetTokenApprovalByProtocolID(ctx context.Context, namespace, connector, protocolID string) (*core.TokenApproval, error) { return s.getTokenApprovalPred(ctx, protocolID, sq.And{ + sq.Eq{"namespace": namespace}, sq.Eq{"connector": connector}, sq.Eq{"protocol_id": protocolID}, }) } -func (s *SQLCommon) GetTokenApprovals(ctx context.Context, filter database.Filter) (approvals []*core.TokenApproval, fr *database.FilterResult, err error) { - query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(tokenApprovalColumns...).From(tokenapprovalTable), filter, tokenApprovalFilterFieldMap, []interface{}{"seq"}) +func (s *SQLCommon) GetTokenApprovals(ctx context.Context, namespace string, filter database.Filter) (approvals []*core.TokenApproval, fr *database.FilterResult, err error) { + query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(tokenApprovalColumns...).From(tokenapprovalTable), + filter, tokenApprovalFilterFieldMap, []interface{}{"seq"}, sq.Eq{"namespace": namespace}) if err != nil { return nil, nil, err } diff --git a/internal/database/sqlcommon/tokenapproval_sql_test.go b/internal/database/sqlcommon/tokenapproval_sql_test.go index 23d31069b..ba6a2e83f 100644 --- a/internal/database/sqlcommon/tokenapproval_sql_test.go +++ b/internal/database/sqlcommon/tokenapproval_sql_test.go @@ -60,7 +60,7 @@ func TestApprovalE2EWithDB(t *testing.T) { // Initial list is empty fb := database.TokenApprovalQueryFactory.NewFilter(ctx) - approvals, _, err := s.GetTokenApprovals(ctx, fb.And()) + approvals, _, err := s.GetTokenApprovals(ctx, "ns1", fb.And()) assert.NoError(t, err) assert.NotNil(t, approvals) assert.Equal(t, 0, len(approvals)) @@ -72,14 +72,14 @@ func TestApprovalE2EWithDB(t *testing.T) { approvalJson, _ := json.Marshal(&approval) // Query back token approval by ID - approvalRead, err := s.GetTokenApprovalByID(ctx, approval.LocalID) + approvalRead, err := s.GetTokenApprovalByID(ctx, "ns1", approval.LocalID) assert.NoError(t, err) assert.NotNil(t, approvalRead) approvalReadJson, _ := json.Marshal(&approvalRead) assert.Equal(t, string(approvalJson), string(approvalReadJson)) // Query back token approval by protocol ID - approvalRead, err = s.GetTokenApprovalByProtocolID(ctx, approval.Connector, approval.ProtocolID) + approvalRead, err = s.GetTokenApprovalByProtocolID(ctx, "ns1", approval.Connector, approval.ProtocolID) assert.NoError(t, err) assert.NotNil(t, approvalRead) approvalReadJson, _ = json.Marshal(&approvalRead) @@ -93,7 +93,7 @@ func TestApprovalE2EWithDB(t *testing.T) { fb.Eq("subject", approval.Subject), fb.Eq("created", approval.Created), ) - approvals, res, err := s.GetTokenApprovals(ctx, filter.Count(true)) + approvals, res, err := s.GetTokenApprovals(ctx, "ns1", filter.Count(true)) assert.NoError(t, err) assert.Equal(t, 1, len(approvals)) assert.Equal(t, int64(1), *res.TotalCount) @@ -113,7 +113,7 @@ func TestApprovalE2EWithDB(t *testing.T) { approval.Active = false // Query back token approval by ID - approvalRead, err = s.GetTokenApprovalByID(ctx, approval.LocalID) + approvalRead, err = s.GetTokenApprovalByID(ctx, "ns1", approval.LocalID) assert.NoError(t, err) assert.NotNil(t, approvalRead) approvalJson, _ = json.Marshal(&approval) @@ -174,7 +174,7 @@ func TestUpsertApprovalFailCommit(t *testing.T) { func TestGetApprovalByIDSelectFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) - _, err := s.GetTokenApprovalByID(context.Background(), fftypes.NewUUID()) + _, err := s.GetTokenApprovalByID(context.Background(), "ns1", fftypes.NewUUID()) assert.Regexp(t, "FF10115", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -182,7 +182,7 @@ func TestGetApprovalByIDSelectFail(t *testing.T) { func TestGetApprovalByIDNotFound(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"subject"})) - a, err := s.GetTokenApprovalByID(context.Background(), fftypes.NewUUID()) + a, err := s.GetTokenApprovalByID(context.Background(), "ns1", fftypes.NewUUID()) assert.NoError(t, err) assert.Nil(t, a) assert.NoError(t, mock.ExpectationsWereMet()) @@ -191,7 +191,7 @@ func TestGetApprovalByIDNotFound(t *testing.T) { func TestGetApprovalByIDScanFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"subject"}).AddRow("1")) - _, err := s.GetTokenApprovalByID(context.Background(), fftypes.NewUUID()) + _, err := s.GetTokenApprovalByID(context.Background(), "ns1", fftypes.NewUUID()) assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -200,7 +200,7 @@ func TestGetApprovalsQueryFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) f := database.TokenApprovalQueryFactory.NewFilter(context.Background()).Eq("subject", "") - _, _, err := s.GetTokenApprovals(context.Background(), f) + _, _, err := s.GetTokenApprovals(context.Background(), "ns1", f) assert.Regexp(t, "FF10115", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -208,7 +208,7 @@ func TestGetApprovalsBuildQueryFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) f := database.TokenApprovalQueryFactory.NewFilter(context.Background()).Eq("subject", map[bool]bool{true: false}) - _, _, err := s.GetTokenApprovals(context.Background(), f) + _, _, err := s.GetTokenApprovals(context.Background(), "ns1", f) assert.Regexp(t, "FF00143.*subject", err) } @@ -216,7 +216,7 @@ func TestGetApprovalsScanFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"subject"}).AddRow("1")) f := database.TokenApprovalQueryFactory.NewFilter(context.Background()).Eq("subject", "") - _, _, err := s.GetTokenApprovals(context.Background(), f) + _, _, err := s.GetTokenApprovals(context.Background(), "ns1", f) assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } diff --git a/internal/events/tokens_approved.go b/internal/events/tokens_approved.go index 3caa0e9ac..059d91c31 100644 --- a/internal/events/tokens_approved.go +++ b/internal/events/tokens_approved.go @@ -55,7 +55,7 @@ func (em *eventManager) loadApprovalID(ctx context.Context, tx *fftypes.UUID, ap log.L(ctx).Warnf("Failed to read operation inputs for token approval '%s': %s", approval.Subject, err) } else if input != nil && input.Connector == approval.Connector && input.Pool.Equals(approval.Pool) { // Check if the LocalID has already been used - if existing, err := em.database.GetTokenApprovalByID(ctx, input.LocalID); err != nil { + if existing, err := em.database.GetTokenApprovalByID(ctx, em.namespace, input.LocalID); err != nil { return nil, err } else if existing == nil { // Everything matches - use the LocalID that was assigned up-front when the operation was submitted @@ -86,7 +86,7 @@ func (em *eventManager) persistTokenApproval(ctx context.Context, approval *toke approval.Pool = pool.ID // Check that approval has not already been recorded - if existing, err := em.database.GetTokenApprovalByProtocolID(ctx, approval.Connector, approval.ProtocolID); err != nil { + if existing, err := em.database.GetTokenApprovalByProtocolID(ctx, em.namespace, approval.Connector, approval.ProtocolID); err != nil { return false, err } else if existing != nil { log.L(ctx).Warnf("Token approval '%s' has already been recorded - ignoring", approval.ProtocolID) diff --git a/internal/events/tokens_approved_test.go b/internal/events/tokens_approved_test.go index 49d4595c6..641e5fdbf 100644 --- a/internal/events/tokens_approved_test.go +++ b/internal/events/tokens_approved_test.go @@ -75,8 +75,8 @@ func TestTokensApprovedSucceedWithRetries(t *testing.T) { mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(nil, fmt.Errorf("pop")).Once() mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil).Times(4) - mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(nil, fmt.Errorf("pop")).Once() - mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(nil, nil).Times(3) + mdi.On("GetTokenApprovalByProtocolID", em.ctx, "ns1", approval.Connector, approval.ProtocolID).Return(nil, fmt.Errorf("pop")).Once() + mdi.On("GetTokenApprovalByProtocolID", em.ctx, "ns1", approval.Connector, approval.ProtocolID).Return(nil, nil).Times(3) mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", (*fftypes.UUID)(nil), approval.Event.ProtocolID).Return(nil, nil) mth.On("InsertBlockchainEvent", em.ctx, mock.MatchedBy(func(e *core.BlockchainEvent) bool { return e.Namespace == pool.Namespace && e.Name == approval.Event.Name @@ -114,7 +114,7 @@ func TestPersistApprovalDuplicate(t *testing.T) { } mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) - mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(&core.TokenApproval{}, nil) + mdi.On("GetTokenApprovalByProtocolID", em.ctx, "ns1", approval.Connector, approval.ProtocolID).Return(&core.TokenApproval{}, nil) valid, err := em.persistTokenApproval(em.ctx, approval) assert.False(t, valid) @@ -157,7 +157,7 @@ func TestPersistApprovalOpFail(t *testing.T) { } mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) - mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(nil, nil) + mdi.On("GetTokenApprovalByProtocolID", em.ctx, "ns1", approval.Connector, approval.ProtocolID).Return(nil, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(nil, nil, fmt.Errorf("pop")) valid, err := em.persistTokenApproval(em.ctx, approval) @@ -187,7 +187,7 @@ func TestPersistApprovalBadOp(t *testing.T) { }} mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) - mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(nil, nil) + mdi.On("GetTokenApprovalByProtocolID", em.ctx, "ns1", approval.Connector, approval.ProtocolID).Return(nil, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) mth.On("PersistTransaction", mock.Anything, approval.TX.ID, core.TransactionTypeTokenApproval, "0xffffeeee").Return(false, fmt.Errorf("pop")) @@ -221,9 +221,9 @@ func TestPersistApprovalTxFail(t *testing.T) { }} mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) - mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(nil, nil) + mdi.On("GetTokenApprovalByProtocolID", em.ctx, "ns1", approval.Connector, approval.ProtocolID).Return(nil, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) - mdi.On("GetTokenApprovalByID", em.ctx, localID).Return(nil, nil) + mdi.On("GetTokenApprovalByID", em.ctx, "ns1", localID).Return(nil, nil) mth.On("PersistTransaction", mock.Anything, approval.TX.ID, core.TransactionTypeTokenApproval, "0xffffeeee").Return(false, fmt.Errorf("pop")) valid, err := em.persistTokenApproval(em.ctx, approval) @@ -256,9 +256,9 @@ func TestPersistApprovalGetApprovalFail(t *testing.T) { }} mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) - mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(nil, nil) + mdi.On("GetTokenApprovalByProtocolID", em.ctx, "ns1", approval.Connector, approval.ProtocolID).Return(nil, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) - mdi.On("GetTokenApprovalByID", em.ctx, localID).Return(nil, fmt.Errorf("pop")) + mdi.On("GetTokenApprovalByID", em.ctx, "ns1", localID).Return(nil, fmt.Errorf("pop")) valid, err := em.persistTokenApproval(em.ctx, approval) assert.False(t, valid) @@ -308,10 +308,10 @@ func TestApprovedWithTransactionRegenerateLocalID(t *testing.T) { }} mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) - mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(nil, nil) + mdi.On("GetTokenApprovalByProtocolID", em.ctx, "ns1", approval.Connector, approval.ProtocolID).Return(nil, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) mth.On("PersistTransaction", mock.Anything, approval.TX.ID, core.TransactionTypeTokenApproval, "0xffffeeee").Return(true, nil) - mdi.On("GetTokenApprovalByID", em.ctx, localID).Return(&core.TokenApproval{}, nil) + mdi.On("GetTokenApprovalByID", em.ctx, "ns1", localID).Return(&core.TokenApproval{}, nil) mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", (*fftypes.UUID)(nil), approval.Event.ProtocolID).Return(nil, nil) mth.On("InsertBlockchainEvent", em.ctx, mock.MatchedBy(func(e *core.BlockchainEvent) bool { return e.Namespace == pool.Namespace && e.Name == approval.Event.Name @@ -355,10 +355,10 @@ func TestApprovedBlockchainEventFail(t *testing.T) { }} mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "F1").Return(pool, nil) - mdi.On("GetTokenApprovalByProtocolID", em.ctx, approval.Connector, approval.ProtocolID).Return(nil, nil) + mdi.On("GetTokenApprovalByProtocolID", em.ctx, "ns1", approval.Connector, approval.ProtocolID).Return(nil, nil) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(ops, nil, nil) mth.On("PersistTransaction", mock.Anything, approval.TX.ID, core.TransactionTypeTokenApproval, "0xffffeeee").Return(true, nil) - mdi.On("GetTokenApprovalByID", em.ctx, localID).Return(&core.TokenApproval{}, nil) + mdi.On("GetTokenApprovalByID", em.ctx, "ns1", localID).Return(&core.TokenApproval{}, nil) mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", (*fftypes.UUID)(nil), approval.Event.ProtocolID).Return(nil, nil) mth.On("InsertBlockchainEvent", em.ctx, mock.MatchedBy(func(e *core.BlockchainEvent) bool { return e.Namespace == pool.Namespace && e.Name == approval.Event.Name diff --git a/internal/orchestrator/txn_status.go b/internal/orchestrator/txn_status.go index 82c9e869c..f26068196 100644 --- a/internal/orchestrator/txn_status.go +++ b/internal/orchestrator/txn_status.go @@ -172,7 +172,7 @@ func (or *orchestrator) GetTransactionStatus(ctx context.Context, id string) (*c updateStatus(result, core.OpStatusPending) } f := database.TokenApprovalQueryFactory.NewFilter(ctx) - switch approvals, _, err := or.database().GetTokenApprovals(ctx, f.Eq("tx.id", id)); { + switch approvals, _, err := or.database().GetTokenApprovals(ctx, or.namespace, f.Eq("tx.id", id)); { case err != nil: return nil, err case len(approvals) == 0: diff --git a/internal/orchestrator/txn_status_test.go b/internal/orchestrator/txn_status_test.go index 27603b65f..07fccaca8 100644 --- a/internal/orchestrator/txn_status_test.go +++ b/internal/orchestrator/txn_status_test.go @@ -529,7 +529,7 @@ func TestGetTransactionStatusTokenApprovalSuccess(t *testing.T) { or.mdi.On("GetTransactionByID", mock.Anything, "ns", txID).Return(tx, nil) or.mdi.On("GetOperations", mock.Anything, "ns", mock.Anything).Return(ops, nil, nil) or.mdi.On("GetBlockchainEvents", mock.Anything, "ns", mock.Anything).Return(events, nil, nil) - or.mdi.On("GetTokenApprovals", mock.Anything, mock.Anything).Return(approvals, nil, nil) + or.mdi.On("GetTokenApprovals", mock.Anything, "ns", mock.Anything).Return(approvals, nil, nil) status, err := or.GetTransactionStatus(context.Background(), txID.String()) assert.NoError(t, err) @@ -710,7 +710,7 @@ func TestGetTransactionStatusTokenApprovalPending(t *testing.T) { or.mdi.On("GetTransactionByID", mock.Anything, "ns", txID).Return(tx, nil) or.mdi.On("GetOperations", mock.Anything, "ns", mock.Anything).Return(ops, nil, nil) or.mdi.On("GetBlockchainEvents", mock.Anything, "ns", mock.Anything).Return(events, nil, nil) - or.mdi.On("GetTokenApprovals", mock.Anything, mock.Anything).Return(approvals, nil, nil) + or.mdi.On("GetTokenApprovals", mock.Anything, "ns", mock.Anything).Return(approvals, nil, nil) status, err := or.GetTransactionStatus(context.Background(), txID.String()) assert.NoError(t, err) @@ -914,7 +914,7 @@ func TestGetTransactionStatusApprovalError(t *testing.T) { or.mdi.On("GetTransactionByID", mock.Anything, "ns", txID).Return(tx, nil) or.mdi.On("GetOperations", mock.Anything, "ns", mock.Anything).Return(nil, nil, nil) or.mdi.On("GetBlockchainEvents", mock.Anything, "ns", mock.Anything).Return(nil, nil, nil) - or.mdi.On("GetTokenApprovals", mock.Anything, mock.Anything).Return(nil, nil, fmt.Errorf("pop")) + or.mdi.On("GetTokenApprovals", mock.Anything, "ns", mock.Anything).Return(nil, nil, fmt.Errorf("pop")) _, err := or.GetTransactionStatus(context.Background(), txID.String()) assert.EqualError(t, err, "pop") diff --git a/internal/syncasync/sync_async_bridge.go b/internal/syncasync/sync_async_bridge.go index bf2bd370e..8c5b9ecd6 100644 --- a/internal/syncasync/sync_async_bridge.go +++ b/internal/syncasync/sync_async_bridge.go @@ -226,7 +226,7 @@ func (sa *syncAsyncBridge) getTransferFromEvent(event *core.EventDelivery) (tran } func (sa *syncAsyncBridge) getApprovalFromEvent(event *core.EventDelivery) (approval *core.TokenApproval, err error) { - if approval, err = sa.database.GetTokenApprovalByID(sa.ctx, event.Reference); err != nil { + if approval, err = sa.database.GetTokenApprovalByID(sa.ctx, sa.namespace, event.Reference); err != nil { return nil, err } diff --git a/internal/syncasync/sync_async_bridge_test.go b/internal/syncasync/sync_async_bridge_test.go index 10dd8b8d5..cc6dc1954 100644 --- a/internal/syncasync/sync_async_bridge_test.go +++ b/internal/syncasync/sync_async_bridge_test.go @@ -464,7 +464,7 @@ func TestEventCallbackTokenApprovalLookupFail(t *testing.T) { } mdi := sa.database.(*databasemocks.Plugin) - mdi.On("GetTokenApprovalByID", sa.ctx, mock.Anything).Return(nil, fmt.Errorf("pop")) + mdi.On("GetTokenApprovalByID", sa.ctx, "ns1", mock.Anything).Return(nil, fmt.Errorf("pop")) err := sa.eventCallback(&core.EventDelivery{ EnrichedEvent: core.EnrichedEvent{ @@ -628,7 +628,7 @@ func TestEventCallbackTokenApprovalNotFound(t *testing.T) { } mdi := sa.database.(*databasemocks.Plugin) - mdi.On("GetTokenApprovalByID", sa.ctx, mock.Anything).Return(nil, nil) + mdi.On("GetTokenApprovalByID", sa.ctx, "ns1", mock.Anything).Return(nil, nil) err := sa.eventCallback(&core.EventDelivery{ EnrichedEvent: core.EnrichedEvent{ @@ -916,7 +916,7 @@ func TestAwaitTokenApprovalConfirmation(t *testing.T) { mse.On("AddSystemEventListener", "ns1", mock.Anything).Return(nil) mdi := sa.database.(*databasemocks.Plugin) - gmid := mdi.On("GetTokenApprovalByID", sa.ctx, mock.Anything) + gmid := mdi.On("GetTokenApprovalByID", sa.ctx, "ns1", mock.Anything) gmid.RunFn = func(a mock.Arguments) { approval := &core.TokenApproval{ LocalID: requestID, diff --git a/internal/txcommon/event_enrich.go b/internal/txcommon/event_enrich.go index f10d8cdfa..e5267d9da 100644 --- a/internal/txcommon/event_enrich.go +++ b/internal/txcommon/event_enrich.go @@ -83,7 +83,7 @@ func (t *transactionHelper) EnrichEvent(ctx context.Context, event *core.Event) } e.TokenPool = tokenPool case core.EventTypeApprovalConfirmed: - approval, err := t.database.GetTokenApprovalByID(ctx, event.Reference) + approval, err := t.database.GetTokenApprovalByID(ctx, t.namespace, event.Reference) if err != nil { return nil, err } diff --git a/internal/txcommon/event_enrich_test.go b/internal/txcommon/event_enrich_test.go index 7e62d9f34..764e3bec8 100644 --- a/internal/txcommon/event_enrich_test.go +++ b/internal/txcommon/event_enrich_test.go @@ -509,7 +509,7 @@ func TestEnrichTokenApprovalConfirmed(t *testing.T) { ev1 := fftypes.NewUUID() // Setup enrichment - mdi.On("GetTokenApprovalByID", mock.Anything, ref1).Return(&core.TokenApproval{ + mdi.On("GetTokenApprovalByID", mock.Anything, "ns1", ref1).Return(&core.TokenApproval{ LocalID: ref1, }, nil) @@ -561,7 +561,7 @@ func TestEnrichTokenApprovalConfirmedFail(t *testing.T) { ev1 := fftypes.NewUUID() // Setup enrichment - mdi.On("GetTokenApprovalByID", mock.Anything, ref1).Return(nil, fmt.Errorf("pop")) + mdi.On("GetTokenApprovalByID", mock.Anything, "ns1", ref1).Return(nil, fmt.Errorf("pop")) event := &core.Event{ ID: ev1, diff --git a/mocks/assetmocks/manager.go b/mocks/assetmocks/manager.go index ec83b65da..2adcc28d2 100644 --- a/mocks/assetmocks/manager.go +++ b/mocks/assetmocks/manager.go @@ -144,13 +144,13 @@ func (_m *Manager) GetTokenAccounts(ctx context.Context, ns string, filter datab return r0, r1, r2 } -// GetTokenApprovals provides a mock function with given fields: ctx, ns, filter -func (_m *Manager) GetTokenApprovals(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenApproval, *database.FilterResult, error) { - ret := _m.Called(ctx, ns, filter) +// GetTokenApprovals provides a mock function with given fields: ctx, filter +func (_m *Manager) GetTokenApprovals(ctx context.Context, filter database.AndFilter) ([]*core.TokenApproval, *database.FilterResult, error) { + ret := _m.Called(ctx, filter) var r0 []*core.TokenApproval - if rf, ok := ret.Get(0).(func(context.Context, string, database.AndFilter) []*core.TokenApproval); ok { - r0 = rf(ctx, ns, filter) + if rf, ok := ret.Get(0).(func(context.Context, database.AndFilter) []*core.TokenApproval); ok { + r0 = rf(ctx, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.TokenApproval) @@ -158,8 +158,8 @@ func (_m *Manager) GetTokenApprovals(ctx context.Context, ns string, filter data } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, string, database.AndFilter) *database.FilterResult); ok { - r1 = rf(ctx, ns, filter) + if rf, ok := ret.Get(1).(func(context.Context, database.AndFilter) *database.FilterResult); ok { + r1 = rf(ctx, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -167,8 +167,8 @@ func (_m *Manager) GetTokenApprovals(ctx context.Context, ns string, filter data } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, string, database.AndFilter) error); ok { - r2 = rf(ctx, ns, filter) + if rf, ok := ret.Get(2).(func(context.Context, database.AndFilter) error); ok { + r2 = rf(ctx, filter) } else { r2 = ret.Error(2) } diff --git a/mocks/databasemocks/plugin.go b/mocks/databasemocks/plugin.go index 5618f44af..d55cdc989 100644 --- a/mocks/databasemocks/plugin.go +++ b/mocks/databasemocks/plugin.go @@ -1708,13 +1708,13 @@ func (_m *Plugin) GetTokenAccounts(ctx context.Context, filter database.Filter) return r0, r1, r2 } -// GetTokenApprovalByID provides a mock function with given fields: ctx, localID -func (_m *Plugin) GetTokenApprovalByID(ctx context.Context, localID *fftypes.UUID) (*core.TokenApproval, error) { - ret := _m.Called(ctx, localID) +// GetTokenApprovalByID provides a mock function with given fields: ctx, namespace, localID +func (_m *Plugin) GetTokenApprovalByID(ctx context.Context, namespace string, localID *fftypes.UUID) (*core.TokenApproval, error) { + ret := _m.Called(ctx, namespace, localID) var r0 *core.TokenApproval - if rf, ok := ret.Get(0).(func(context.Context, *fftypes.UUID) *core.TokenApproval); ok { - r0 = rf(ctx, localID) + if rf, ok := ret.Get(0).(func(context.Context, string, *fftypes.UUID) *core.TokenApproval); ok { + r0 = rf(ctx, namespace, localID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.TokenApproval) @@ -1722,8 +1722,8 @@ func (_m *Plugin) GetTokenApprovalByID(ctx context.Context, localID *fftypes.UUI } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *fftypes.UUID) error); ok { - r1 = rf(ctx, localID) + if rf, ok := ret.Get(1).(func(context.Context, string, *fftypes.UUID) error); ok { + r1 = rf(ctx, namespace, localID) } else { r1 = ret.Error(1) } @@ -1731,13 +1731,13 @@ func (_m *Plugin) GetTokenApprovalByID(ctx context.Context, localID *fftypes.UUI return r0, r1 } -// GetTokenApprovalByProtocolID provides a mock function with given fields: ctx, connector, protocolID -func (_m *Plugin) GetTokenApprovalByProtocolID(ctx context.Context, connector string, protocolID string) (*core.TokenApproval, error) { - ret := _m.Called(ctx, connector, protocolID) +// GetTokenApprovalByProtocolID provides a mock function with given fields: ctx, namespace, connector, protocolID +func (_m *Plugin) GetTokenApprovalByProtocolID(ctx context.Context, namespace string, connector string, protocolID string) (*core.TokenApproval, error) { + ret := _m.Called(ctx, namespace, connector, protocolID) var r0 *core.TokenApproval - if rf, ok := ret.Get(0).(func(context.Context, string, string) *core.TokenApproval); ok { - r0 = rf(ctx, connector, protocolID) + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *core.TokenApproval); ok { + r0 = rf(ctx, namespace, connector, protocolID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.TokenApproval) @@ -1745,8 +1745,8 @@ func (_m *Plugin) GetTokenApprovalByProtocolID(ctx context.Context, connector st } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, connector, protocolID) + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, namespace, connector, protocolID) } else { r1 = ret.Error(1) } @@ -1754,13 +1754,13 @@ func (_m *Plugin) GetTokenApprovalByProtocolID(ctx context.Context, connector st return r0, r1 } -// GetTokenApprovals provides a mock function with given fields: ctx, filter -func (_m *Plugin) GetTokenApprovals(ctx context.Context, filter database.Filter) ([]*core.TokenApproval, *database.FilterResult, error) { - ret := _m.Called(ctx, filter) +// GetTokenApprovals provides a mock function with given fields: ctx, namespace, filter +func (_m *Plugin) GetTokenApprovals(ctx context.Context, namespace string, filter database.Filter) ([]*core.TokenApproval, *database.FilterResult, error) { + ret := _m.Called(ctx, namespace, filter) var r0 []*core.TokenApproval - if rf, ok := ret.Get(0).(func(context.Context, database.Filter) []*core.TokenApproval); ok { - r0 = rf(ctx, filter) + if rf, ok := ret.Get(0).(func(context.Context, string, database.Filter) []*core.TokenApproval); ok { + r0 = rf(ctx, namespace, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.TokenApproval) @@ -1768,8 +1768,8 @@ func (_m *Plugin) GetTokenApprovals(ctx context.Context, filter database.Filter) } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, database.Filter) *database.FilterResult); ok { - r1 = rf(ctx, filter) + if rf, ok := ret.Get(1).(func(context.Context, string, database.Filter) *database.FilterResult); ok { + r1 = rf(ctx, namespace, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -1777,8 +1777,8 @@ func (_m *Plugin) GetTokenApprovals(ctx context.Context, filter database.Filter) } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, database.Filter) error); ok { - r2 = rf(ctx, filter) + if rf, ok := ret.Get(2).(func(context.Context, string, database.Filter) error); ok { + r2 = rf(ctx, namespace, filter) } else { r2 = ret.Error(2) } diff --git a/pkg/database/plugin.go b/pkg/database/plugin.go index 510e9b11e..f652cc825 100644 --- a/pkg/database/plugin.go +++ b/pkg/database/plugin.go @@ -412,13 +412,13 @@ type iTokenApprovalCollection interface { UpdateTokenApprovals(ctx context.Context, filter Filter, update Update) (err error) // GetTokenApprovalByID - Get a token approval by ID - GetTokenApprovalByID(ctx context.Context, localID *fftypes.UUID) (*core.TokenApproval, error) + GetTokenApprovalByID(ctx context.Context, namespace string, localID *fftypes.UUID) (*core.TokenApproval, error) // GetTokenTransferByProtocolID - Get a token approval by protocol ID - GetTokenApprovalByProtocolID(ctx context.Context, connector, protocolID string) (*core.TokenApproval, error) + GetTokenApprovalByProtocolID(ctx context.Context, namespace, connector, protocolID string) (*core.TokenApproval, error) // GetTokenApprovals - Get token approvals - GetTokenApprovals(ctx context.Context, filter Filter) ([]*core.TokenApproval, *FilterResult, error) + GetTokenApprovals(ctx context.Context, namespace string, filter Filter) ([]*core.TokenApproval, *FilterResult, error) } type iFFICollection interface { From aa449c2dc79cd644af16b3ab2c533ff00adbec36 Mon Sep 17 00:00:00 2001 From: Andrew Richardson Date: Wed, 22 Jun 2022 13:55:30 -0400 Subject: [PATCH 5/9] Add namespace to token balance database queries Signed-off-by: Andrew Richardson --- .../route_get_token_account_pools.go | 2 +- .../route_get_token_account_pools_test.go | 2 +- .../apiserver/route_get_token_accounts.go | 2 +- .../route_get_token_accounts_test.go | 2 +- .../apiserver/route_get_token_balances.go | 2 +- .../route_get_token_balances_test.go | 2 +- internal/assets/manager.go | 22 +++--- internal/assets/manager_test.go | 12 ++-- .../database/sqlcommon/tokenbalance_sql.go | 19 +++--- .../sqlcommon/tokenbalance_sql_test.go | 36 +++++----- mocks/assetmocks/manager.go | 54 +++++++-------- mocks/databasemocks/plugin.go | 68 +++++++++---------- pkg/database/plugin.go | 8 +-- 13 files changed, 114 insertions(+), 117 deletions(-) diff --git a/internal/apiserver/route_get_token_account_pools.go b/internal/apiserver/route_get_token_account_pools.go index 0f10ea268..bbf9436c5 100644 --- a/internal/apiserver/route_get_token_account_pools.go +++ b/internal/apiserver/route_get_token_account_pools.go @@ -40,7 +40,7 @@ var getTokenAccountPools = &ffapi.Route{ Extensions: &coreExtensions{ FilterFactory: database.TokenAccountPoolQueryFactory, CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - return filterResult(cr.or.Assets().GetTokenAccountPools(cr.ctx, extractNamespace(r.PP), r.PP["key"], cr.filter)) + return filterResult(cr.or.Assets().GetTokenAccountPools(cr.ctx, r.PP["key"], cr.filter)) }, }, } diff --git a/internal/apiserver/route_get_token_account_pools_test.go b/internal/apiserver/route_get_token_account_pools_test.go index b1d2f7a36..c5b961eae 100644 --- a/internal/apiserver/route_get_token_account_pools_test.go +++ b/internal/apiserver/route_get_token_account_pools_test.go @@ -34,7 +34,7 @@ func TestGetTokenAccountPools(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mam.On("GetTokenAccountPools", mock.Anything, "ns1", "0x1", mock.Anything). + mam.On("GetTokenAccountPools", mock.Anything, "0x1", mock.Anything). Return([]*core.TokenAccountPool{}, nil, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_get_token_accounts.go b/internal/apiserver/route_get_token_accounts.go index 07aa31583..66c32c7a5 100644 --- a/internal/apiserver/route_get_token_accounts.go +++ b/internal/apiserver/route_get_token_accounts.go @@ -38,7 +38,7 @@ var getTokenAccounts = &ffapi.Route{ Extensions: &coreExtensions{ FilterFactory: database.TokenAccountQueryFactory, CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - return filterResult(cr.or.Assets().GetTokenAccounts(cr.ctx, extractNamespace(r.PP), cr.filter)) + return filterResult(cr.or.Assets().GetTokenAccounts(cr.ctx, cr.filter)) }, }, } diff --git a/internal/apiserver/route_get_token_accounts_test.go b/internal/apiserver/route_get_token_accounts_test.go index 3ddc85dec..f7991db90 100644 --- a/internal/apiserver/route_get_token_accounts_test.go +++ b/internal/apiserver/route_get_token_accounts_test.go @@ -34,7 +34,7 @@ func TestGetTokenAccounts(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mam.On("GetTokenAccounts", mock.Anything, "ns1", mock.Anything). + mam.On("GetTokenAccounts", mock.Anything, mock.Anything). Return([]*core.TokenAccount{}, nil, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_get_token_balances.go b/internal/apiserver/route_get_token_balances.go index 7122cd64f..ef32b8d80 100644 --- a/internal/apiserver/route_get_token_balances.go +++ b/internal/apiserver/route_get_token_balances.go @@ -38,7 +38,7 @@ var getTokenBalances = &ffapi.Route{ Extensions: &coreExtensions{ FilterFactory: database.TokenBalanceQueryFactory, CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - return filterResult(cr.or.Assets().GetTokenBalances(cr.ctx, extractNamespace(r.PP), cr.filter)) + return filterResult(cr.or.Assets().GetTokenBalances(cr.ctx, cr.filter)) }, }, } diff --git a/internal/apiserver/route_get_token_balances_test.go b/internal/apiserver/route_get_token_balances_test.go index 0cc6416d5..a8c878723 100644 --- a/internal/apiserver/route_get_token_balances_test.go +++ b/internal/apiserver/route_get_token_balances_test.go @@ -34,7 +34,7 @@ func TestGetTokenBalances(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mam.On("GetTokenBalances", mock.Anything, "ns1", mock.Anything). + mam.On("GetTokenBalances", mock.Anything, mock.Anything). Return([]*core.TokenBalance{}, nil, nil) r.ServeHTTP(res, req) diff --git a/internal/assets/manager.go b/internal/assets/manager.go index db1249cfa..f6e07ff43 100644 --- a/internal/assets/manager.go +++ b/internal/assets/manager.go @@ -46,9 +46,9 @@ type Manager interface { GetTokenPool(ctx context.Context, connector, poolName string) (*core.TokenPool, error) GetTokenPoolByNameOrID(ctx context.Context, poolNameOrID string) (*core.TokenPool, error) - GetTokenBalances(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenBalance, *database.FilterResult, error) - GetTokenAccounts(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenAccount, *database.FilterResult, error) - GetTokenAccountPools(ctx context.Context, ns, key string, filter database.AndFilter) ([]*core.TokenAccountPool, *database.FilterResult, error) + GetTokenBalances(ctx context.Context, filter database.AndFilter) ([]*core.TokenBalance, *database.FilterResult, error) + GetTokenAccounts(ctx context.Context, filter database.AndFilter) ([]*core.TokenAccount, *database.FilterResult, error) + GetTokenAccountPools(ctx context.Context, key string, filter database.AndFilter) ([]*core.TokenAccountPool, *database.FilterResult, error) GetTokenTransfers(ctx context.Context, filter database.AndFilter) ([]*core.TokenTransfer, *database.FilterResult, error) GetTokenTransferByID(ctx context.Context, id string) (*core.TokenTransfer, error) @@ -124,20 +124,16 @@ func (am *assetManager) selectTokenPlugin(ctx context.Context, name string) (tok return nil, i18n.NewError(ctx, coremsgs.MsgUnknownTokensPlugin, name) } -func (am *assetManager) scopeNS(ns string, filter database.AndFilter) database.AndFilter { - return filter.Condition(filter.Builder().Eq("namespace", ns)) +func (am *assetManager) GetTokenBalances(ctx context.Context, filter database.AndFilter) ([]*core.TokenBalance, *database.FilterResult, error) { + return am.database.GetTokenBalances(ctx, am.namespace, filter) } -func (am *assetManager) GetTokenBalances(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenBalance, *database.FilterResult, error) { - return am.database.GetTokenBalances(ctx, am.scopeNS(ns, filter)) +func (am *assetManager) GetTokenAccounts(ctx context.Context, filter database.AndFilter) ([]*core.TokenAccount, *database.FilterResult, error) { + return am.database.GetTokenAccounts(ctx, am.namespace, filter) } -func (am *assetManager) GetTokenAccounts(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenAccount, *database.FilterResult, error) { - return am.database.GetTokenAccounts(ctx, am.scopeNS(ns, filter)) -} - -func (am *assetManager) GetTokenAccountPools(ctx context.Context, ns, key string, filter database.AndFilter) ([]*core.TokenAccountPool, *database.FilterResult, error) { - return am.database.GetTokenAccountPools(ctx, key, am.scopeNS(ns, filter)) +func (am *assetManager) GetTokenAccountPools(ctx context.Context, key string, filter database.AndFilter) ([]*core.TokenAccountPool, *database.FilterResult, error) { + return am.database.GetTokenAccountPools(ctx, am.namespace, key, filter) } func (am *assetManager) GetTokenConnectors(ctx context.Context) []*core.TokenConnector { diff --git a/internal/assets/manager_test.go b/internal/assets/manager_test.go index df0d12db7..9f2ae349e 100644 --- a/internal/assets/manager_test.go +++ b/internal/assets/manager_test.go @@ -92,8 +92,8 @@ func TestGetTokenBalances(t *testing.T) { mdi := am.database.(*databasemocks.Plugin) fb := database.TokenBalanceQueryFactory.NewFilter(context.Background()) f := fb.And() - mdi.On("GetTokenBalances", context.Background(), f).Return([]*core.TokenBalance{}, nil, nil) - _, _, err := am.GetTokenBalances(context.Background(), "ns1", f) + mdi.On("GetTokenBalances", context.Background(), "ns1", f).Return([]*core.TokenBalance{}, nil, nil) + _, _, err := am.GetTokenBalances(context.Background(), f) assert.NoError(t, err) } @@ -104,8 +104,8 @@ func TestGetTokenAccounts(t *testing.T) { mdi := am.database.(*databasemocks.Plugin) fb := database.TokenBalanceQueryFactory.NewFilter(context.Background()) f := fb.And() - mdi.On("GetTokenAccounts", context.Background(), f).Return([]*core.TokenAccount{}, nil, nil) - _, _, err := am.GetTokenAccounts(context.Background(), "ns1", f) + mdi.On("GetTokenAccounts", context.Background(), "ns1", f).Return([]*core.TokenAccount{}, nil, nil) + _, _, err := am.GetTokenAccounts(context.Background(), f) assert.NoError(t, err) } @@ -116,8 +116,8 @@ func TestGetTokenAccountPools(t *testing.T) { mdi := am.database.(*databasemocks.Plugin) fb := database.TokenBalanceQueryFactory.NewFilter(context.Background()) f := fb.And() - mdi.On("GetTokenAccountPools", context.Background(), "0x1", f).Return([]*core.TokenAccountPool{}, nil, nil) - _, _, err := am.GetTokenAccountPools(context.Background(), "ns1", "0x1", f) + mdi.On("GetTokenAccountPools", context.Background(), "ns1", "0x1", f).Return([]*core.TokenAccountPool{}, nil, nil) + _, _, err := am.GetTokenAccountPools(context.Background(), "0x1", f) assert.NoError(t, err) } diff --git a/internal/database/sqlcommon/tokenbalance_sql.go b/internal/database/sqlcommon/tokenbalance_sql.go index 8dcb0dfb3..e12155e6a 100644 --- a/internal/database/sqlcommon/tokenbalance_sql.go +++ b/internal/database/sqlcommon/tokenbalance_sql.go @@ -49,7 +49,7 @@ var ( ) func (s *SQLCommon) addTokenBalance(ctx context.Context, tx *txWrapper, transfer *core.TokenTransfer, key string, negate bool) error { - account, err := s.GetTokenBalance(ctx, transfer.Pool, transfer.TokenIndex, key) + account, err := s.GetTokenBalance(ctx, transfer.Namespace, transfer.Pool, transfer.TokenIndex, key) if err != nil { return err } @@ -167,17 +167,19 @@ func (s *SQLCommon) getTokenBalancePred(ctx context.Context, desc string, pred i return account, nil } -func (s *SQLCommon) GetTokenBalance(ctx context.Context, poolID *fftypes.UUID, tokenIndex, key string) (message *core.TokenBalance, err error) { +func (s *SQLCommon) GetTokenBalance(ctx context.Context, namespace string, poolID *fftypes.UUID, tokenIndex, key string) (message *core.TokenBalance, err error) { desc := core.TokenBalanceIdentifier(poolID, tokenIndex, key) return s.getTokenBalancePred(ctx, desc, sq.And{ + sq.Eq{"namespace": namespace}, sq.Eq{"pool_id": poolID}, sq.Eq{"token_index": tokenIndex}, sq.Eq{"key": key}, }) } -func (s *SQLCommon) GetTokenBalances(ctx context.Context, filter database.Filter) ([]*core.TokenBalance, *database.FilterResult, error) { - query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(tokenBalanceColumns...).From(tokenbalanceTable), filter, tokenBalanceFilterFieldMap, []interface{}{"seq"}) +func (s *SQLCommon) GetTokenBalances(ctx context.Context, namespace string, filter database.Filter) ([]*core.TokenBalance, *database.FilterResult, error) { + query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(tokenBalanceColumns...).From(tokenbalanceTable), + filter, tokenBalanceFilterFieldMap, []interface{}{"seq"}, sq.Eq{"namespace": namespace}) if err != nil { return nil, nil, err } @@ -200,10 +202,10 @@ func (s *SQLCommon) GetTokenBalances(ctx context.Context, filter database.Filter return accounts, s.queryRes(ctx, tokenbalanceTable, tx, fop, fi), err } -func (s *SQLCommon) GetTokenAccounts(ctx context.Context, filter database.Filter) ([]*core.TokenAccount, *database.FilterResult, error) { +func (s *SQLCommon) GetTokenAccounts(ctx context.Context, namespace string, filter database.Filter) ([]*core.TokenAccount, *database.FilterResult, error) { query, fop, fi, err := s.filterSelect(ctx, "", sq.Select("key", "MAX(updated) AS updated", "MAX(seq) AS seq").From(tokenbalanceTable).GroupBy("key"), - filter, tokenBalanceFilterFieldMap, []interface{}{"seq"}) + filter, tokenBalanceFilterFieldMap, []interface{}{"seq"}, sq.Eq{"namespace": namespace}) if err != nil { return nil, nil, err } @@ -228,11 +230,10 @@ func (s *SQLCommon) GetTokenAccounts(ctx context.Context, filter database.Filter return accounts, s.queryRes(ctx, tokenbalanceTable, tx, fop, fi), err } -func (s *SQLCommon) GetTokenAccountPools(ctx context.Context, key string, filter database.Filter) ([]*core.TokenAccountPool, *database.FilterResult, error) { +func (s *SQLCommon) GetTokenAccountPools(ctx context.Context, namespace, key string, filter database.Filter) ([]*core.TokenAccountPool, *database.FilterResult, error) { query, fop, fi, err := s.filterSelect(ctx, "", sq.Select("pool_id", "MAX(updated) AS updated", "MAX(seq) AS seq").From(tokenbalanceTable).GroupBy("pool_id"), - filter, tokenBalanceFilterFieldMap, []interface{}{"seq"}, - sq.Eq{"key": key}) + filter, tokenBalanceFilterFieldMap, []interface{}{"seq"}, sq.Eq{"key": key, "namespace": namespace}) if err != nil { return nil, nil, err } diff --git a/internal/database/sqlcommon/tokenbalance_sql_test.go b/internal/database/sqlcommon/tokenbalance_sql_test.go index bf78c36aa..f0e7f0d02 100644 --- a/internal/database/sqlcommon/tokenbalance_sql_test.go +++ b/internal/database/sqlcommon/tokenbalance_sql_test.go @@ -61,7 +61,7 @@ func TestTokenBalanceE2EWithDB(t *testing.T) { assert.NoError(t, err) // Query back the token balance (by pool ID and identity) - balanceRead, err := s.GetTokenBalance(ctx, transfer.Pool, "1", "0x0") + balanceRead, err := s.GetTokenBalance(ctx, "ns1", transfer.Pool, "1", "0x0") assert.NoError(t, err) assert.NotNil(t, balanceRead) assert.Greater(t, balanceRead.Updated.UnixNano(), int64(0)) @@ -76,7 +76,7 @@ func TestTokenBalanceE2EWithDB(t *testing.T) { fb.Eq("tokenindex", balance.TokenIndex), fb.Eq("key", balance.Key), ) - balances, res, err := s.GetTokenBalances(ctx, filter.Count(true)) + balances, res, err := s.GetTokenBalances(ctx, "ns1", filter.Count(true)) assert.NoError(t, err) assert.Equal(t, 1, len(balances)) assert.Equal(t, int64(1), *res.TotalCount) @@ -93,7 +93,7 @@ func TestTokenBalanceE2EWithDB(t *testing.T) { assert.NoError(t, err) // Query back the token balance (by pool ID and identity) - balanceRead, err = s.GetTokenBalance(ctx, transfer.Pool, "1", "0x0") + balanceRead, err = s.GetTokenBalance(ctx, "ns1", transfer.Pool, "1", "0x0") assert.NoError(t, err) assert.NotNil(t, balanceRead) assert.Greater(t, balanceRead.Updated.UnixNano(), int64(0)) @@ -104,7 +104,7 @@ func TestTokenBalanceE2EWithDB(t *testing.T) { assert.Equal(t, string(balanceJson), string(balanceReadJson)) // Query back the other token balance (by pool ID and identity) - balanceRead, err = s.GetTokenBalance(ctx, transfer.Pool, "1", "0x1") + balanceRead, err = s.GetTokenBalance(ctx, "ns1", transfer.Pool, "1", "0x1") assert.NoError(t, err) assert.NotNil(t, balanceRead) assert.Greater(t, balanceRead.Updated.UnixNano(), int64(0)) @@ -116,18 +116,18 @@ func TestTokenBalanceE2EWithDB(t *testing.T) { assert.Equal(t, string(balanceJson), string(balanceReadJson)) // Query the list of unique accounts - accounts, _, err := s.GetTokenAccounts(ctx, fb.And()) + accounts, _, err := s.GetTokenAccounts(ctx, "ns1", fb.And()) assert.NoError(t, err) assert.Equal(t, 2, len(accounts)) assert.Equal(t, "0x1", accounts[0].Key) assert.Equal(t, "0x0", accounts[1].Key) // Query the pools for each account - pools, _, err := s.GetTokenAccountPools(ctx, "0x0", fb.And()) + pools, _, err := s.GetTokenAccountPools(ctx, "ns1", "0x0", fb.And()) assert.NoError(t, err) assert.Equal(t, 1, len(pools)) assert.Equal(t, *transfer.Pool, *pools[0].Pool) - pools, _, err = s.GetTokenAccountPools(ctx, "0x1", fb.And()) + pools, _, err = s.GetTokenAccountPools(ctx, "ns1", "0x1", fb.And()) assert.NoError(t, err) assert.Equal(t, 1, len(pools)) assert.Equal(t, *transfer.Pool, *pools[0].Pool) @@ -197,7 +197,7 @@ func TestUpdateTokenBalancesFailCommit(t *testing.T) { func TestGetTokenBalanceNotFound(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"id"})) - msg, err := s.GetTokenBalance(context.Background(), fftypes.NewUUID(), "1", "0x0") + msg, err := s.GetTokenBalance(context.Background(), "ns1", fftypes.NewUUID(), "1", "0x0") assert.NoError(t, err) assert.Nil(t, msg) assert.NoError(t, mock.ExpectationsWereMet()) @@ -206,7 +206,7 @@ func TestGetTokenBalanceNotFound(t *testing.T) { func TestGetTokenBalanceScanFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("only one")) - _, err := s.GetTokenBalance(context.Background(), fftypes.NewUUID(), "1", "0x0") + _, err := s.GetTokenBalance(context.Background(), "ns1", fftypes.NewUUID(), "1", "0x0") assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -215,7 +215,7 @@ func TestGetTokenBalancesQueryFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) f := database.TokenBalanceQueryFactory.NewFilter(context.Background()).Eq("pool", "") - _, _, err := s.GetTokenBalances(context.Background(), f) + _, _, err := s.GetTokenBalances(context.Background(), "ns1", f) assert.Regexp(t, "FF10115", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -223,7 +223,7 @@ func TestGetTokenBalancesQueryFail(t *testing.T) { func TestGetTokenBalancesBuildQueryFail(t *testing.T) { s, _ := newMockProvider().init() f := database.TokenBalanceQueryFactory.NewFilter(context.Background()).Eq("pool", map[bool]bool{true: false}) - _, _, err := s.GetTokenBalances(context.Background(), f) + _, _, err := s.GetTokenBalances(context.Background(), "ns1", f) assert.Regexp(t, "FF00143.*pool", err) } @@ -231,7 +231,7 @@ func TestGetTokenBalancesScanFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"pool"}).AddRow("only one")) f := database.TokenBalanceQueryFactory.NewFilter(context.Background()).Eq("pool", "") - _, _, err := s.GetTokenBalances(context.Background(), f) + _, _, err := s.GetTokenBalances(context.Background(), "ns1", f) assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -240,7 +240,7 @@ func TestGetTokenAccountsQueryFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) f := database.TokenBalanceQueryFactory.NewFilter(context.Background()).And() - _, _, err := s.GetTokenAccounts(context.Background(), f) + _, _, err := s.GetTokenAccounts(context.Background(), "ns1", f) assert.Regexp(t, "FF10115", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -248,7 +248,7 @@ func TestGetTokenAccountsQueryFail(t *testing.T) { func TestGetTokenAccountsBuildQueryFail(t *testing.T) { s, _ := newMockProvider().init() f := database.TokenBalanceQueryFactory.NewFilter(context.Background()).Eq("pool", map[bool]bool{true: false}) - _, _, err := s.GetTokenAccounts(context.Background(), f) + _, _, err := s.GetTokenAccounts(context.Background(), "ns1", f) assert.Regexp(t, "FF00143.*pool", err) } @@ -256,7 +256,7 @@ func TestGetTokenAccountsScanFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"key", "bad"}).AddRow("too many", "columns")) f := database.TokenBalanceQueryFactory.NewFilter(context.Background()).And() - _, _, err := s.GetTokenAccounts(context.Background(), f) + _, _, err := s.GetTokenAccounts(context.Background(), "ns1", f) assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -265,7 +265,7 @@ func TestGetTokenAccountPoolsQueryFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) f := database.TokenBalanceQueryFactory.NewFilter(context.Background()).And() - _, _, err := s.GetTokenAccountPools(context.Background(), "0x1", f) + _, _, err := s.GetTokenAccountPools(context.Background(), "ns1", "0x1", f) assert.Regexp(t, "FF10115", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -273,7 +273,7 @@ func TestGetTokenAccountPoolsQueryFail(t *testing.T) { func TestGetTokenAccountPoolsBuildQueryFail(t *testing.T) { s, _ := newMockProvider().init() f := database.TokenBalanceQueryFactory.NewFilter(context.Background()).Eq("pool", map[bool]bool{true: false}) - _, _, err := s.GetTokenAccountPools(context.Background(), "0x1", f) + _, _, err := s.GetTokenAccountPools(context.Background(), "ns1", "0x1", f) assert.Regexp(t, "FF00143.*pool", err) } @@ -281,7 +281,7 @@ func TestGetTokenAccountPoolsScanFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"key", "bad"}).AddRow("too many", "columns")) f := database.TokenBalanceQueryFactory.NewFilter(context.Background()).And() - _, _, err := s.GetTokenAccountPools(context.Background(), "0x1", f) + _, _, err := s.GetTokenAccountPools(context.Background(), "ns1", "0x1", f) assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } diff --git a/mocks/assetmocks/manager.go b/mocks/assetmocks/manager.go index 2adcc28d2..79dbefcb0 100644 --- a/mocks/assetmocks/manager.go +++ b/mocks/assetmocks/manager.go @@ -80,13 +80,13 @@ func (_m *Manager) CreateTokenPool(ctx context.Context, pool *core.TokenPool, wa return r0, r1 } -// GetTokenAccountPools provides a mock function with given fields: ctx, ns, key, filter -func (_m *Manager) GetTokenAccountPools(ctx context.Context, ns string, key string, filter database.AndFilter) ([]*core.TokenAccountPool, *database.FilterResult, error) { - ret := _m.Called(ctx, ns, key, filter) +// GetTokenAccountPools provides a mock function with given fields: ctx, key, filter +func (_m *Manager) GetTokenAccountPools(ctx context.Context, key string, filter database.AndFilter) ([]*core.TokenAccountPool, *database.FilterResult, error) { + ret := _m.Called(ctx, key, filter) var r0 []*core.TokenAccountPool - if rf, ok := ret.Get(0).(func(context.Context, string, string, database.AndFilter) []*core.TokenAccountPool); ok { - r0 = rf(ctx, ns, key, filter) + if rf, ok := ret.Get(0).(func(context.Context, string, database.AndFilter) []*core.TokenAccountPool); ok { + r0 = rf(ctx, key, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.TokenAccountPool) @@ -94,8 +94,8 @@ func (_m *Manager) GetTokenAccountPools(ctx context.Context, ns string, key stri } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, string, string, database.AndFilter) *database.FilterResult); ok { - r1 = rf(ctx, ns, key, filter) + if rf, ok := ret.Get(1).(func(context.Context, string, database.AndFilter) *database.FilterResult); ok { + r1 = rf(ctx, key, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -103,8 +103,8 @@ func (_m *Manager) GetTokenAccountPools(ctx context.Context, ns string, key stri } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, string, string, database.AndFilter) error); ok { - r2 = rf(ctx, ns, key, filter) + if rf, ok := ret.Get(2).(func(context.Context, string, database.AndFilter) error); ok { + r2 = rf(ctx, key, filter) } else { r2 = ret.Error(2) } @@ -112,13 +112,13 @@ func (_m *Manager) GetTokenAccountPools(ctx context.Context, ns string, key stri return r0, r1, r2 } -// GetTokenAccounts provides a mock function with given fields: ctx, ns, filter -func (_m *Manager) GetTokenAccounts(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenAccount, *database.FilterResult, error) { - ret := _m.Called(ctx, ns, filter) +// GetTokenAccounts provides a mock function with given fields: ctx, filter +func (_m *Manager) GetTokenAccounts(ctx context.Context, filter database.AndFilter) ([]*core.TokenAccount, *database.FilterResult, error) { + ret := _m.Called(ctx, filter) var r0 []*core.TokenAccount - if rf, ok := ret.Get(0).(func(context.Context, string, database.AndFilter) []*core.TokenAccount); ok { - r0 = rf(ctx, ns, filter) + if rf, ok := ret.Get(0).(func(context.Context, database.AndFilter) []*core.TokenAccount); ok { + r0 = rf(ctx, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.TokenAccount) @@ -126,8 +126,8 @@ func (_m *Manager) GetTokenAccounts(ctx context.Context, ns string, filter datab } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, string, database.AndFilter) *database.FilterResult); ok { - r1 = rf(ctx, ns, filter) + if rf, ok := ret.Get(1).(func(context.Context, database.AndFilter) *database.FilterResult); ok { + r1 = rf(ctx, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -135,8 +135,8 @@ func (_m *Manager) GetTokenAccounts(ctx context.Context, ns string, filter datab } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, string, database.AndFilter) error); ok { - r2 = rf(ctx, ns, filter) + if rf, ok := ret.Get(2).(func(context.Context, database.AndFilter) error); ok { + r2 = rf(ctx, filter) } else { r2 = ret.Error(2) } @@ -176,13 +176,13 @@ func (_m *Manager) GetTokenApprovals(ctx context.Context, filter database.AndFil return r0, r1, r2 } -// GetTokenBalances provides a mock function with given fields: ctx, ns, filter -func (_m *Manager) GetTokenBalances(ctx context.Context, ns string, filter database.AndFilter) ([]*core.TokenBalance, *database.FilterResult, error) { - ret := _m.Called(ctx, ns, filter) +// GetTokenBalances provides a mock function with given fields: ctx, filter +func (_m *Manager) GetTokenBalances(ctx context.Context, filter database.AndFilter) ([]*core.TokenBalance, *database.FilterResult, error) { + ret := _m.Called(ctx, filter) var r0 []*core.TokenBalance - if rf, ok := ret.Get(0).(func(context.Context, string, database.AndFilter) []*core.TokenBalance); ok { - r0 = rf(ctx, ns, filter) + if rf, ok := ret.Get(0).(func(context.Context, database.AndFilter) []*core.TokenBalance); ok { + r0 = rf(ctx, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.TokenBalance) @@ -190,8 +190,8 @@ func (_m *Manager) GetTokenBalances(ctx context.Context, ns string, filter datab } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, string, database.AndFilter) *database.FilterResult); ok { - r1 = rf(ctx, ns, filter) + if rf, ok := ret.Get(1).(func(context.Context, database.AndFilter) *database.FilterResult); ok { + r1 = rf(ctx, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -199,8 +199,8 @@ func (_m *Manager) GetTokenBalances(ctx context.Context, ns string, filter datab } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, string, database.AndFilter) error); ok { - r2 = rf(ctx, ns, filter) + if rf, ok := ret.Get(2).(func(context.Context, database.AndFilter) error); ok { + r2 = rf(ctx, filter) } else { r2 = ret.Error(2) } diff --git a/mocks/databasemocks/plugin.go b/mocks/databasemocks/plugin.go index d55cdc989..078accf4d 100644 --- a/mocks/databasemocks/plugin.go +++ b/mocks/databasemocks/plugin.go @@ -1644,13 +1644,13 @@ func (_m *Plugin) GetSubscriptions(ctx context.Context, namespace string, filter return r0, r1, r2 } -// GetTokenAccountPools provides a mock function with given fields: ctx, key, filter -func (_m *Plugin) GetTokenAccountPools(ctx context.Context, key string, filter database.Filter) ([]*core.TokenAccountPool, *database.FilterResult, error) { - ret := _m.Called(ctx, key, filter) +// GetTokenAccountPools provides a mock function with given fields: ctx, namespace, key, filter +func (_m *Plugin) GetTokenAccountPools(ctx context.Context, namespace string, key string, filter database.Filter) ([]*core.TokenAccountPool, *database.FilterResult, error) { + ret := _m.Called(ctx, namespace, key, filter) var r0 []*core.TokenAccountPool - if rf, ok := ret.Get(0).(func(context.Context, string, database.Filter) []*core.TokenAccountPool); ok { - r0 = rf(ctx, key, filter) + if rf, ok := ret.Get(0).(func(context.Context, string, string, database.Filter) []*core.TokenAccountPool); ok { + r0 = rf(ctx, namespace, key, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.TokenAccountPool) @@ -1658,8 +1658,8 @@ func (_m *Plugin) GetTokenAccountPools(ctx context.Context, key string, filter d } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, string, database.Filter) *database.FilterResult); ok { - r1 = rf(ctx, key, filter) + if rf, ok := ret.Get(1).(func(context.Context, string, string, database.Filter) *database.FilterResult); ok { + r1 = rf(ctx, namespace, key, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -1667,8 +1667,8 @@ func (_m *Plugin) GetTokenAccountPools(ctx context.Context, key string, filter d } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, string, database.Filter) error); ok { - r2 = rf(ctx, key, filter) + if rf, ok := ret.Get(2).(func(context.Context, string, string, database.Filter) error); ok { + r2 = rf(ctx, namespace, key, filter) } else { r2 = ret.Error(2) } @@ -1676,13 +1676,13 @@ func (_m *Plugin) GetTokenAccountPools(ctx context.Context, key string, filter d return r0, r1, r2 } -// GetTokenAccounts provides a mock function with given fields: ctx, filter -func (_m *Plugin) GetTokenAccounts(ctx context.Context, filter database.Filter) ([]*core.TokenAccount, *database.FilterResult, error) { - ret := _m.Called(ctx, filter) +// GetTokenAccounts provides a mock function with given fields: ctx, namespace, filter +func (_m *Plugin) GetTokenAccounts(ctx context.Context, namespace string, filter database.Filter) ([]*core.TokenAccount, *database.FilterResult, error) { + ret := _m.Called(ctx, namespace, filter) var r0 []*core.TokenAccount - if rf, ok := ret.Get(0).(func(context.Context, database.Filter) []*core.TokenAccount); ok { - r0 = rf(ctx, filter) + if rf, ok := ret.Get(0).(func(context.Context, string, database.Filter) []*core.TokenAccount); ok { + r0 = rf(ctx, namespace, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.TokenAccount) @@ -1690,8 +1690,8 @@ func (_m *Plugin) GetTokenAccounts(ctx context.Context, filter database.Filter) } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, database.Filter) *database.FilterResult); ok { - r1 = rf(ctx, filter) + if rf, ok := ret.Get(1).(func(context.Context, string, database.Filter) *database.FilterResult); ok { + r1 = rf(ctx, namespace, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -1699,8 +1699,8 @@ func (_m *Plugin) GetTokenAccounts(ctx context.Context, filter database.Filter) } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, database.Filter) error); ok { - r2 = rf(ctx, filter) + if rf, ok := ret.Get(2).(func(context.Context, string, database.Filter) error); ok { + r2 = rf(ctx, namespace, filter) } else { r2 = ret.Error(2) } @@ -1786,13 +1786,13 @@ func (_m *Plugin) GetTokenApprovals(ctx context.Context, namespace string, filte return r0, r1, r2 } -// GetTokenBalance provides a mock function with given fields: ctx, poolID, tokenIndex, identity -func (_m *Plugin) GetTokenBalance(ctx context.Context, poolID *fftypes.UUID, tokenIndex string, identity string) (*core.TokenBalance, error) { - ret := _m.Called(ctx, poolID, tokenIndex, identity) +// GetTokenBalance provides a mock function with given fields: ctx, namespace, poolID, tokenIndex, identity +func (_m *Plugin) GetTokenBalance(ctx context.Context, namespace string, poolID *fftypes.UUID, tokenIndex string, identity string) (*core.TokenBalance, error) { + ret := _m.Called(ctx, namespace, poolID, tokenIndex, identity) var r0 *core.TokenBalance - if rf, ok := ret.Get(0).(func(context.Context, *fftypes.UUID, string, string) *core.TokenBalance); ok { - r0 = rf(ctx, poolID, tokenIndex, identity) + if rf, ok := ret.Get(0).(func(context.Context, string, *fftypes.UUID, string, string) *core.TokenBalance); ok { + r0 = rf(ctx, namespace, poolID, tokenIndex, identity) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.TokenBalance) @@ -1800,8 +1800,8 @@ func (_m *Plugin) GetTokenBalance(ctx context.Context, poolID *fftypes.UUID, tok } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *fftypes.UUID, string, string) error); ok { - r1 = rf(ctx, poolID, tokenIndex, identity) + if rf, ok := ret.Get(1).(func(context.Context, string, *fftypes.UUID, string, string) error); ok { + r1 = rf(ctx, namespace, poolID, tokenIndex, identity) } else { r1 = ret.Error(1) } @@ -1809,13 +1809,13 @@ func (_m *Plugin) GetTokenBalance(ctx context.Context, poolID *fftypes.UUID, tok return r0, r1 } -// GetTokenBalances provides a mock function with given fields: ctx, filter -func (_m *Plugin) GetTokenBalances(ctx context.Context, filter database.Filter) ([]*core.TokenBalance, *database.FilterResult, error) { - ret := _m.Called(ctx, filter) +// GetTokenBalances provides a mock function with given fields: ctx, namespace, filter +func (_m *Plugin) GetTokenBalances(ctx context.Context, namespace string, filter database.Filter) ([]*core.TokenBalance, *database.FilterResult, error) { + ret := _m.Called(ctx, namespace, filter) var r0 []*core.TokenBalance - if rf, ok := ret.Get(0).(func(context.Context, database.Filter) []*core.TokenBalance); ok { - r0 = rf(ctx, filter) + if rf, ok := ret.Get(0).(func(context.Context, string, database.Filter) []*core.TokenBalance); ok { + r0 = rf(ctx, namespace, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.TokenBalance) @@ -1823,8 +1823,8 @@ func (_m *Plugin) GetTokenBalances(ctx context.Context, filter database.Filter) } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, database.Filter) *database.FilterResult); ok { - r1 = rf(ctx, filter) + if rf, ok := ret.Get(1).(func(context.Context, string, database.Filter) *database.FilterResult); ok { + r1 = rf(ctx, namespace, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -1832,8 +1832,8 @@ func (_m *Plugin) GetTokenBalances(ctx context.Context, filter database.Filter) } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, database.Filter) error); ok { - r2 = rf(ctx, filter) + if rf, ok := ret.Get(2).(func(context.Context, string, database.Filter) error); ok { + r2 = rf(ctx, namespace, filter) } else { r2 = ret.Error(2) } diff --git a/pkg/database/plugin.go b/pkg/database/plugin.go index f652cc825..2955435ee 100644 --- a/pkg/database/plugin.go +++ b/pkg/database/plugin.go @@ -378,16 +378,16 @@ type iTokenBalanceCollection interface { UpdateTokenBalances(ctx context.Context, transfer *core.TokenTransfer) error // GetTokenBalance - Get a token balance by pool and account identity - GetTokenBalance(ctx context.Context, poolID *fftypes.UUID, tokenIndex, identity string) (*core.TokenBalance, error) + GetTokenBalance(ctx context.Context, namespace string, poolID *fftypes.UUID, tokenIndex, identity string) (*core.TokenBalance, error) // GetTokenBalances - Get token balances - GetTokenBalances(ctx context.Context, filter Filter) ([]*core.TokenBalance, *FilterResult, error) + GetTokenBalances(ctx context.Context, namespace string, filter Filter) ([]*core.TokenBalance, *FilterResult, error) // GetTokenAccounts - Get token accounts (all distinct addresses that have a balance) - GetTokenAccounts(ctx context.Context, filter Filter) ([]*core.TokenAccount, *FilterResult, error) + GetTokenAccounts(ctx context.Context, namespace string, filter Filter) ([]*core.TokenAccount, *FilterResult, error) // GetTokenAccountPools - Get the list of pools referenced by a given account - GetTokenAccountPools(ctx context.Context, key string, filter Filter) ([]*core.TokenAccountPool, *FilterResult, error) + GetTokenAccountPools(ctx context.Context, namespace, key string, filter Filter) ([]*core.TokenAccountPool, *FilterResult, error) } type iTokenTransferCollection interface { From d598dee3647ac78eb63c48d6f85f60b510b137bf Mon Sep 17 00:00:00 2001 From: Andrew Richardson Date: Wed, 22 Jun 2022 14:11:00 -0400 Subject: [PATCH 6/9] Add namespace to FFI database queries Signed-off-by: Andrew Richardson --- ...ute_get_contract_interface_name_version.go | 4 +- ...et_contract_interface_name_version_test.go | 4 +- .../route_get_contract_interfaces.go | 2 +- .../route_get_contract_interfaces_test.go | 2 +- .../route_post_new_contract_interface.go | 2 +- .../route_post_new_contract_interface_test.go | 4 +- internal/contracts/manager.go | 47 ++++++------- internal/contracts/manager_test.go | 70 +++++++++---------- internal/database/sqlcommon/ffi_events_sql.go | 9 +-- .../database/sqlcommon/ffi_events_sql_test.go | 10 +-- .../database/sqlcommon/ffi_methods_sql.go | 7 +- .../sqlcommon/ffi_methods_sql_test.go | 10 +-- internal/database/sqlcommon/ffi_sql.go | 13 ++-- internal/database/sqlcommon/ffi_sql_test.go | 10 +-- internal/txcommon/event_enrich.go | 2 +- internal/txcommon/event_enrich_test.go | 4 +- mocks/contractmocks/manager.go | 60 ++++++++-------- mocks/databasemocks/plugin.go | 50 ++++++------- pkg/database/plugin.go | 23 +++++- 19 files changed, 176 insertions(+), 157 deletions(-) diff --git a/internal/apiserver/route_get_contract_interface_name_version.go b/internal/apiserver/route_get_contract_interface_name_version.go index b428929f1..26ad11d25 100644 --- a/internal/apiserver/route_get_contract_interface_name_version.go +++ b/internal/apiserver/route_get_contract_interface_name_version.go @@ -43,9 +43,9 @@ var getContractInterfaceNameVersion = &ffapi.Route{ Extensions: &coreExtensions{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { if strings.EqualFold(r.QP["fetchchildren"], "true") { - return cr.or.Contracts().GetFFIWithChildren(cr.ctx, extractNamespace(r.PP), r.PP["name"], r.PP["version"]) + return cr.or.Contracts().GetFFIWithChildren(cr.ctx, r.PP["name"], r.PP["version"]) } - return cr.or.Contracts().GetFFI(cr.ctx, extractNamespace(r.PP), r.PP["name"], r.PP["version"]) + return cr.or.Contracts().GetFFI(cr.ctx, r.PP["name"], r.PP["version"]) }, }, } diff --git a/internal/apiserver/route_get_contract_interface_name_version_test.go b/internal/apiserver/route_get_contract_interface_name_version_test.go index 0cb274355..fbd4de2af 100644 --- a/internal/apiserver/route_get_contract_interface_name_version_test.go +++ b/internal/apiserver/route_get_contract_interface_name_version_test.go @@ -39,7 +39,7 @@ func TestGetContractInterfaceNameVersion(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("GetFFI", mock.Anything, "ns1", "banana", "v1.0.0"). + mcm.On("GetFFI", mock.Anything, "banana", "v1.0.0"). Return(&core.FFI{}, nil) r.ServeHTTP(res, req) @@ -57,7 +57,7 @@ func TestGetContractInterfaceNameVersionWithChildren(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("GetFFIWithChildren", mock.Anything, "ns1", "banana", "v1.0.0"). + mcm.On("GetFFIWithChildren", mock.Anything, "banana", "v1.0.0"). Return(&core.FFI{}, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_get_contract_interfaces.go b/internal/apiserver/route_get_contract_interfaces.go index e47434644..f0af3da18 100644 --- a/internal/apiserver/route_get_contract_interfaces.go +++ b/internal/apiserver/route_get_contract_interfaces.go @@ -38,7 +38,7 @@ var getContractInterfaces = &ffapi.Route{ Extensions: &coreExtensions{ FilterFactory: database.FFIQueryFactory, CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - return filterResult(cr.or.Contracts().GetFFIs(cr.ctx, extractNamespace(r.PP), cr.filter)) + return filterResult(cr.or.Contracts().GetFFIs(cr.ctx, cr.filter)) }, }, } diff --git a/internal/apiserver/route_get_contract_interfaces_test.go b/internal/apiserver/route_get_contract_interfaces_test.go index 2228605d7..cb37c93cf 100644 --- a/internal/apiserver/route_get_contract_interfaces_test.go +++ b/internal/apiserver/route_get_contract_interfaces_test.go @@ -39,7 +39,7 @@ func TestGetContractInterfaces(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("GetFFIs", mock.Anything, "ns1", mock.Anything). + mcm.On("GetFFIs", mock.Anything, mock.Anything). Return([]*core.FFI{}, nil, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_post_new_contract_interface.go b/internal/apiserver/route_post_new_contract_interface.go index b16643201..fb761eef4 100644 --- a/internal/apiserver/route_post_new_contract_interface.go +++ b/internal/apiserver/route_post_new_contract_interface.go @@ -41,7 +41,7 @@ var postNewContractInterface = &ffapi.Route{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { waitConfirm := strings.EqualFold(r.QP["confirm"], "true") r.SuccessStatus = syncRetcode(waitConfirm) - return cr.or.Contracts().BroadcastFFI(cr.ctx, extractNamespace(r.PP), r.Input.(*core.FFI), waitConfirm) + return cr.or.Contracts().BroadcastFFI(cr.ctx, r.Input.(*core.FFI), waitConfirm) }, }, } diff --git a/internal/apiserver/route_post_new_contract_interface_test.go b/internal/apiserver/route_post_new_contract_interface_test.go index b982e3b76..1b3dfa48b 100644 --- a/internal/apiserver/route_post_new_contract_interface_test.go +++ b/internal/apiserver/route_post_new_contract_interface_test.go @@ -39,7 +39,7 @@ func TestPostNewContractInterface(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("BroadcastFFI", mock.Anything, "ns1", mock.AnythingOfType("*core.FFI"), false). + mcm.On("BroadcastFFI", mock.Anything, mock.AnythingOfType("*core.FFI"), false). Return(&core.FFI{}, nil) r.ServeHTTP(res, req) @@ -57,7 +57,7 @@ func TestPostNewContractInterfaceSync(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("BroadcastFFI", mock.Anything, "ns1", mock.AnythingOfType("*core.FFI"), true). + mcm.On("BroadcastFFI", mock.Anything, mock.AnythingOfType("*core.FFI"), true). Return(&core.FFI{}, nil) r.ServeHTTP(res, req) diff --git a/internal/contracts/manager.go b/internal/contracts/manager.go index 83cca3641..1f18fd5d1 100644 --- a/internal/contracts/manager.go +++ b/internal/contracts/manager.go @@ -38,12 +38,12 @@ import ( type Manager interface { core.Named - BroadcastFFI(ctx context.Context, ns string, ffi *core.FFI, waitConfirm bool) (output *core.FFI, err error) - GetFFI(ctx context.Context, ns, name, version string) (*core.FFI, error) - GetFFIWithChildren(ctx context.Context, ns, name, version string) (*core.FFI, error) + BroadcastFFI(ctx context.Context, ffi *core.FFI, waitConfirm bool) (output *core.FFI, err error) + GetFFI(ctx context.Context, name, version string) (*core.FFI, error) + GetFFIWithChildren(ctx context.Context, name, version string) (*core.FFI, error) GetFFIByID(ctx context.Context, id *fftypes.UUID) (*core.FFI, error) GetFFIByIDWithChildren(ctx context.Context, id *fftypes.UUID) (*core.FFI, error) - GetFFIs(ctx context.Context, ns string, filter database.AndFilter) ([]*core.FFI, *database.FilterResult, error) + GetFFIs(ctx context.Context, filter database.AndFilter) ([]*core.FFI, *database.FilterResult, error) InvokeContract(ctx context.Context, ns string, req *core.ContractCallRequest, waitConfirm bool) (interface{}, error) InvokeContractAPI(ctx context.Context, ns, apiName, methodPath string, req *core.ContractCallRequest, waitConfirm bool) (interface{}, error) @@ -119,11 +119,11 @@ func (cm *contractManager) newFFISchemaCompiler() *jsonschema.Compiler { return c } -func (cm *contractManager) BroadcastFFI(ctx context.Context, ns string, ffi *core.FFI, waitConfirm bool) (output *core.FFI, err error) { +func (cm *contractManager) BroadcastFFI(ctx context.Context, ffi *core.FFI, waitConfirm bool) (output *core.FFI, err error) { ffi.ID = fftypes.NewUUID() - ffi.Namespace = ns + ffi.Namespace = cm.namespace - existing, err := cm.database.GetFFI(ctx, ffi.Namespace, ffi.Name, ffi.Version) + existing, err := cm.database.GetFFI(ctx, cm.namespace, ffi.Name, ffi.Version) if existing != nil && err == nil { return nil, i18n.NewError(ctx, coremsgs.MsgContractInterfaceExists, ffi.Namespace, ffi.Name, ffi.Version) } @@ -139,7 +139,7 @@ func (cm *contractManager) BroadcastFFI(ctx context.Context, ns string, ffi *cor } output = ffi - msg, err := cm.broadcast.BroadcastDefinitionAsNode(ctx, ns, ffi, core.SystemTagDefineFFI, waitConfirm) + msg, err := cm.broadcast.BroadcastDefinitionAsNode(ctx, cm.namespace, ffi, core.SystemTagDefineFFI, waitConfirm) if err != nil { return nil, err } @@ -151,12 +151,12 @@ func (cm *contractManager) scopeNS(ns string, filter database.AndFilter) databas return filter.Condition(filter.Builder().Eq("namespace", ns)) } -func (cm *contractManager) GetFFI(ctx context.Context, ns, name, version string) (*core.FFI, error) { - return cm.database.GetFFI(ctx, ns, name, version) +func (cm *contractManager) GetFFI(ctx context.Context, name, version string) (*core.FFI, error) { + return cm.database.GetFFI(ctx, cm.namespace, name, version) } -func (cm *contractManager) GetFFIWithChildren(ctx context.Context, ns, name, version string) (*core.FFI, error) { - ffi, err := cm.GetFFI(ctx, ns, name, version) +func (cm *contractManager) GetFFIWithChildren(ctx context.Context, name, version string) (*core.FFI, error) { + ffi, err := cm.GetFFI(ctx, name, version) if err == nil { err = cm.getFFIChildren(ctx, ffi) } @@ -164,18 +164,18 @@ func (cm *contractManager) GetFFIWithChildren(ctx context.Context, ns, name, ver } func (cm *contractManager) GetFFIByID(ctx context.Context, id *fftypes.UUID) (*core.FFI, error) { - return cm.database.GetFFIByID(ctx, id) + return cm.database.GetFFIByID(ctx, cm.namespace, id) } func (cm *contractManager) getFFIChildren(ctx context.Context, ffi *core.FFI) (err error) { mfb := database.FFIMethodQueryFactory.NewFilter(ctx) - ffi.Methods, _, err = cm.database.GetFFIMethods(ctx, mfb.Eq("interface", ffi.ID)) + ffi.Methods, _, err = cm.database.GetFFIMethods(ctx, cm.namespace, mfb.Eq("interface", ffi.ID)) if err != nil { return err } efb := database.FFIEventQueryFactory.NewFilter(ctx) - ffi.Events, _, err = cm.database.GetFFIEvents(ctx, efb.Eq("interface", ffi.ID)) + ffi.Events, _, err = cm.database.GetFFIEvents(ctx, cm.namespace, efb.Eq("interface", ffi.ID)) if err != nil { return err } @@ -188,7 +188,7 @@ func (cm *contractManager) getFFIChildren(ctx context.Context, ffi *core.FFI) (e func (cm *contractManager) GetFFIByIDWithChildren(ctx context.Context, id *fftypes.UUID) (ffi *core.FFI, err error) { err = cm.database.RunAsGroup(ctx, func(ctx context.Context) (err error) { - ffi, err = cm.database.GetFFIByID(ctx, id) + ffi, err = cm.database.GetFFIByID(ctx, cm.namespace, id) if err != nil || ffi == nil { return err } @@ -197,9 +197,8 @@ func (cm *contractManager) GetFFIByIDWithChildren(ctx context.Context, id *fftyp return ffi, err } -func (cm *contractManager) GetFFIs(ctx context.Context, ns string, filter database.AndFilter) ([]*core.FFI, *database.FilterResult, error) { - filter = cm.scopeNS(ns, filter) - return cm.database.GetFFIs(ctx, ns, filter) +func (cm *contractManager) GetFFIs(ctx context.Context, filter database.AndFilter) ([]*core.FFI, *database.FilterResult, error) { + return cm.database.GetFFIs(ctx, cm.namespace, filter) } func (cm *contractManager) writeInvokeTransaction(ctx context.Context, req *core.ContractCallRequest) (*core.Operation, error) { @@ -323,13 +322,13 @@ func (cm *contractManager) GetContractAPIs(ctx context.Context, httpServerURL, n return apis, fr, err } -func (cm *contractManager) resolveFFIReference(ctx context.Context, ns string, ref *core.FFIReference) error { +func (cm *contractManager) resolveFFIReference(ctx context.Context, ref *core.FFIReference) error { switch { case ref == nil: return i18n.NewError(ctx, coremsgs.MsgContractInterfaceNotFound, "") case ref.ID != nil: - ffi, err := cm.database.GetFFIByID(ctx, ref.ID) + ffi, err := cm.database.GetFFIByID(ctx, cm.namespace, ref.ID) if err != nil { return err } else if ffi == nil { @@ -338,7 +337,7 @@ func (cm *contractManager) resolveFFIReference(ctx context.Context, ns string, r return nil case ref.Name != "" && ref.Version != "": - ffi, err := cm.database.GetFFI(ctx, ns, ref.Name, ref.Version) + ffi, err := cm.database.GetFFI(ctx, cm.namespace, ref.Name, ref.Version) if err != nil { return err } else if ffi == nil { @@ -370,7 +369,7 @@ func (cm *contractManager) BroadcastContractAPI(ctx context.Context, httpServerU } } - if err := cm.resolveFFIReference(ctx, ns, api.Interface); err != nil { + if err := cm.resolveFFIReference(ctx, api.Interface); err != nil { return err } return nil @@ -486,7 +485,7 @@ func (cm *contractManager) validateInvokeContractRequest(ctx context.Context, re } func (cm *contractManager) resolveEvent(ctx context.Context, ns string, ffi *core.FFIReference, eventPath string) (*core.FFISerializedEvent, error) { - if err := cm.resolveFFIReference(ctx, ns, ffi); err != nil { + if err := cm.resolveFFIReference(ctx, ffi); err != nil { return nil, err } event, err := cm.database.GetFFIEvent(ctx, ns, ffi.ID, eventPath) diff --git a/internal/contracts/manager_test.go b/internal/contracts/manager_test.go index aa3e21e84..b59954255 100644 --- a/internal/contracts/manager_test.go +++ b/internal/contracts/manager_test.go @@ -137,7 +137,7 @@ func TestBroadcastFFI(t *testing.T) { }, }, } - _, err := cm.BroadcastFFI(context.Background(), "ns1", ffi, false) + _, err := cm.BroadcastFFI(context.Background(), ffi, false) assert.NoError(t, err) } @@ -170,7 +170,7 @@ func TestBroadcastFFIInvalid(t *testing.T) { }, }, } - _, err := cm.BroadcastFFI(context.Background(), "ns1", ffi, false) + _, err := cm.BroadcastFFI(context.Background(), ffi, false) assert.Regexp(t, "does not validate", err) } @@ -192,7 +192,7 @@ func TestBroadcastFFIExists(t *testing.T) { Version: "1.0.0", ID: fftypes.NewUUID(), } - _, err := cm.BroadcastFFI(context.Background(), "ns1", ffi, false) + _, err := cm.BroadcastFFI(context.Background(), ffi, false) assert.Regexp(t, "FF10302", err) } @@ -216,7 +216,7 @@ func TestBroadcastFFIFail(t *testing.T) { }, }, } - _, err := cm.BroadcastFFI(context.Background(), "ns1", ffi, false) + _, err := cm.BroadcastFFI(context.Background(), ffi, false) assert.Regexp(t, "pop", err) } @@ -640,7 +640,7 @@ func TestAddContractListenerByEventPath(t *testing.T) { mbi.On("GenerateEventSignature", context.Background(), mock.Anything).Return("changed") mdi.On("GetContractListeners", context.Background(), mock.Anything).Return(nil, nil, nil) mbi.On("AddContractListener", context.Background(), sub).Return(nil) - mdi.On("GetFFIByID", context.Background(), interfaceID).Return(&core.FFI{}, nil) + mdi.On("GetFFIByID", context.Background(), "ns1", interfaceID).Return(&core.FFI{}, nil) mdi.On("GetFFIEvent", context.Background(), "ns1", interfaceID, sub.EventPath).Return(event, nil) mdi.On("InsertContractListener", context.Background(), &sub.ContractListener).Return(nil) @@ -699,7 +699,7 @@ func TestAddContractListenerFFILookupFail(t *testing.T) { } mbi.On("NormalizeContractLocation", context.Background(), sub.Location).Return(sub.Location, nil) - mdi.On("GetFFIByID", context.Background(), interfaceID).Return(nil, fmt.Errorf("pop")) + mdi.On("GetFFIByID", context.Background(), "ns1", interfaceID).Return(nil, fmt.Errorf("pop")) _, err := cm.AddContractListener(context.Background(), "ns1", sub) assert.EqualError(t, err, "pop") @@ -729,7 +729,7 @@ func TestAddContractListenerEventLookupFail(t *testing.T) { } mbi.On("NormalizeContractLocation", context.Background(), sub.Location).Return(sub.Location, nil) - mdi.On("GetFFIByID", context.Background(), interfaceID).Return(&core.FFI{}, nil) + mdi.On("GetFFIByID", context.Background(), "ns1", interfaceID).Return(&core.FFI{}, nil) mdi.On("GetFFIEvent", context.Background(), "ns1", interfaceID, sub.EventPath).Return(nil, fmt.Errorf("pop")) _, err := cm.AddContractListener(context.Background(), "ns1", sub) @@ -760,7 +760,7 @@ func TestAddContractListenerEventLookupNotFound(t *testing.T) { } mbi.On("NormalizeContractLocation", context.Background(), sub.Location).Return(sub.Location, nil) - mdi.On("GetFFIByID", context.Background(), interfaceID).Return(&core.FFI{}, nil) + mdi.On("GetFFIByID", context.Background(), "ns1", interfaceID).Return(&core.FFI{}, nil) mdi.On("GetFFIEvent", context.Background(), "ns1", interfaceID, sub.EventPath).Return(nil, nil) _, err := cm.AddContractListener(context.Background(), "ns1", sub) @@ -1061,7 +1061,7 @@ func TestAddContractAPIListener(t *testing.T) { mdi.On("GetContractAPIByName", context.Background(), "ns", "simple").Return(api, nil) mbi.On("NormalizeContractLocation", context.Background(), api.Location).Return(listener.Location, nil) - mdi.On("GetFFIByID", context.Background(), interfaceID).Return(&core.FFI{}, nil) + mdi.On("GetFFIByID", context.Background(), "ns1", interfaceID).Return(&core.FFI{}, nil) mdi.On("GetFFIEvent", context.Background(), "ns", interfaceID, "changed").Return(event, nil) mbi.On("GenerateEventSignature", context.Background(), mock.Anything).Return("changed") mdi.On("GetContractListeners", context.Background(), mock.Anything).Return(nil, nil, nil) @@ -1121,7 +1121,7 @@ func TestGetFFI(t *testing.T) { cm := newTestContractManager() mdb := cm.database.(*databasemocks.Plugin) mdb.On("GetFFI", mock.Anything, "ns1", "ffi", "v1.0.0").Return(&core.FFI{}, nil) - _, err := cm.GetFFI(context.Background(), "ns1", "ffi", "v1.0.0") + _, err := cm.GetFFI(context.Background(), "ffi", "v1.0.0") assert.NoError(t, err) } @@ -1132,17 +1132,17 @@ func TestGetFFIWithChildren(t *testing.T) { cid := fftypes.NewUUID() mdb.On("GetFFI", mock.Anything, "ns1", "ffi", "v1.0.0").Return(&core.FFI{ID: cid}, nil) - mdb.On("GetFFIMethods", mock.Anything, mock.Anything).Return([]*core.FFIMethod{ + mdb.On("GetFFIMethods", mock.Anything, "ns1", mock.Anything).Return([]*core.FFIMethod{ {ID: fftypes.NewUUID(), Name: "method1"}, }, nil, nil) - mdb.On("GetFFIEvents", mock.Anything, mock.Anything).Return([]*core.FFIEvent{ + mdb.On("GetFFIEvents", mock.Anything, "ns1", mock.Anything).Return([]*core.FFIEvent{ {ID: fftypes.NewUUID(), FFIEventDefinition: core.FFIEventDefinition{Name: "event1"}}, }, nil, nil) mbi.On("GenerateEventSignature", mock.Anything, mock.MatchedBy(func(ev *core.FFIEventDefinition) bool { return ev.Name == "event1" })).Return("event1Sig") - _, err := cm.GetFFIWithChildren(context.Background(), "ns1", "ffi", "v1.0.0") + _, err := cm.GetFFIWithChildren(context.Background(), "ffi", "v1.0.0") assert.NoError(t, err) mdb.AssertExpectations(t) @@ -1153,7 +1153,7 @@ func TestGetFFIByID(t *testing.T) { cm := newTestContractManager() mdb := cm.database.(*databasemocks.Plugin) cid := fftypes.NewUUID() - mdb.On("GetFFIByID", mock.Anything, cid).Return(&core.FFI{}, nil) + mdb.On("GetFFIByID", mock.Anything, "ns1", cid).Return(&core.FFI{}, nil) _, err := cm.GetFFIByID(context.Background(), cid) assert.NoError(t, err) } @@ -1164,13 +1164,13 @@ func TestGetFFIByIDWithChildren(t *testing.T) { mbi := cm.blockchain.(*blockchainmocks.Plugin) cid := fftypes.NewUUID() - mdb.On("GetFFIByID", mock.Anything, cid).Return(&core.FFI{ + mdb.On("GetFFIByID", mock.Anything, "ns1", cid).Return(&core.FFI{ ID: cid, }, nil) - mdb.On("GetFFIMethods", mock.Anything, mock.Anything).Return([]*core.FFIMethod{ + mdb.On("GetFFIMethods", mock.Anything, "ns1", mock.Anything).Return([]*core.FFIMethod{ {ID: fftypes.NewUUID(), Name: "method1"}, }, nil, nil) - mdb.On("GetFFIEvents", mock.Anything, mock.Anything).Return([]*core.FFIEvent{ + mdb.On("GetFFIEvents", mock.Anything, "ns1", mock.Anything).Return([]*core.FFIEvent{ {ID: fftypes.NewUUID(), FFIEventDefinition: core.FFIEventDefinition{Name: "event1"}}, }, nil, nil) mbi.On("GenerateEventSignature", mock.Anything, mock.MatchedBy(func(ev *core.FFIEventDefinition) bool { @@ -1192,13 +1192,13 @@ func TestGetFFIByIDWithChildrenEventsFail(t *testing.T) { mdb := cm.database.(*databasemocks.Plugin) cid := fftypes.NewUUID() - mdb.On("GetFFIByID", mock.Anything, cid).Return(&core.FFI{ + mdb.On("GetFFIByID", mock.Anything, "ns1", cid).Return(&core.FFI{ ID: cid, }, nil) - mdb.On("GetFFIMethods", mock.Anything, mock.Anything).Return([]*core.FFIMethod{ + mdb.On("GetFFIMethods", mock.Anything, "ns1", mock.Anything).Return([]*core.FFIMethod{ {ID: fftypes.NewUUID(), Name: "method1"}, }, nil, nil) - mdb.On("GetFFIEvents", mock.Anything, mock.Anything).Return(nil, nil, fmt.Errorf("pop")) + mdb.On("GetFFIEvents", mock.Anything, "ns1", mock.Anything).Return(nil, nil, fmt.Errorf("pop")) _, err := cm.GetFFIByIDWithChildren(context.Background(), cid) @@ -1211,10 +1211,10 @@ func TestGetFFIByIDWithChildrenMethodsFail(t *testing.T) { mdb := cm.database.(*databasemocks.Plugin) cid := fftypes.NewUUID() - mdb.On("GetFFIByID", mock.Anything, cid).Return(&core.FFI{ + mdb.On("GetFFIByID", mock.Anything, "ns1", cid).Return(&core.FFI{ ID: cid, }, nil) - mdb.On("GetFFIMethods", mock.Anything, mock.Anything).Return(nil, nil, fmt.Errorf("pop")) + mdb.On("GetFFIMethods", mock.Anything, "ns1", mock.Anything).Return(nil, nil, fmt.Errorf("pop")) _, err := cm.GetFFIByIDWithChildren(context.Background(), cid) @@ -1227,7 +1227,7 @@ func TestGetFFIByIDWithChildrenFFILookupFail(t *testing.T) { mdb := cm.database.(*databasemocks.Plugin) cid := fftypes.NewUUID() - mdb.On("GetFFIByID", mock.Anything, cid).Return(nil, fmt.Errorf("pop")) + mdb.On("GetFFIByID", mock.Anything, "ns1", cid).Return(nil, fmt.Errorf("pop")) _, err := cm.GetFFIByIDWithChildren(context.Background(), cid) @@ -1240,7 +1240,7 @@ func TestGetFFIByIDWithChildrenFFINotFound(t *testing.T) { mdb := cm.database.(*databasemocks.Plugin) cid := fftypes.NewUUID() - mdb.On("GetFFIByID", mock.Anything, cid).Return(nil, nil) + mdb.On("GetFFIByID", mock.Anything, "ns1", cid).Return(nil, nil) ffi, err := cm.GetFFIByIDWithChildren(context.Background(), cid) @@ -1254,7 +1254,7 @@ func TestGetFFIs(t *testing.T) { mdb := cm.database.(*databasemocks.Plugin) filter := database.FFIQueryFactory.NewFilter(context.Background()).And() mdb.On("GetFFIs", mock.Anything, "ns1", filter).Return([]*core.FFI{}, &database.FilterResult{}, nil) - _, _, err := cm.GetFFIs(context.Background(), "ns1", filter) + _, _, err := cm.GetFFIs(context.Background(), filter) assert.NoError(t, err) } @@ -1649,7 +1649,7 @@ func TestGetContractAPIListeners(t *testing.T) { } mdi.On("GetContractAPIByName", context.Background(), "ns", "simple").Return(api, nil) - mdi.On("GetFFIByID", context.Background(), interfaceID).Return(&core.FFI{}, nil) + mdi.On("GetFFIByID", context.Background(), "ns1", interfaceID).Return(&core.FFI{}, nil) mdi.On("GetFFIEvent", context.Background(), "ns", interfaceID, "changed").Return(event, nil) mbi.On("GenerateEventSignature", context.Background(), mock.Anything).Return("changed") mdi.On("GetContractListeners", context.Background(), mock.Anything).Return(nil, nil, nil) @@ -1703,7 +1703,7 @@ func TestGetContractAPIListenersEventNotFound(t *testing.T) { } mdi.On("GetContractAPIByName", context.Background(), "ns", "simple").Return(api, nil) - mdi.On("GetFFIByID", context.Background(), interfaceID).Return(&core.FFI{}, nil) + mdi.On("GetFFIByID", context.Background(), "ns1", interfaceID).Return(&core.FFI{}, nil) mdi.On("GetFFIEvent", context.Background(), "ns", interfaceID, "changed").Return(nil, nil) f := database.ContractListenerQueryFactory.NewFilter(context.Background()) @@ -1897,11 +1897,11 @@ func TestGetContractAPIInterface(t *testing.T) { } mdb.On("GetContractAPIByName", mock.Anything, "ns1", "banana").Return(api, nil) - mdb.On("GetFFIByID", mock.Anything, interfaceID).Return(&core.FFI{}, nil) - mdb.On("GetFFIMethods", mock.Anything, mock.Anything).Return([]*core.FFIMethod{ + mdb.On("GetFFIByID", mock.Anything, "ns1", interfaceID).Return(&core.FFI{}, nil) + mdb.On("GetFFIMethods", mock.Anything, "ns1", mock.Anything).Return([]*core.FFIMethod{ {ID: fftypes.NewUUID(), Name: "method1"}, }, nil, nil) - mdb.On("GetFFIEvents", mock.Anything, mock.Anything).Return([]*core.FFIEvent{ + mdb.On("GetFFIEvents", mock.Anything, "ns1", mock.Anything).Return([]*core.FFIEvent{ {ID: fftypes.NewUUID(), FFIEventDefinition: core.FFIEventDefinition{Name: "event1"}}, }, nil, nil) mbi.On("GenerateEventSignature", mock.Anything, mock.MatchedBy(func(ev *core.FFIEventDefinition) bool { @@ -1953,7 +1953,7 @@ func TestBroadcastContractAPI(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), api.Location).Return(api.Location, nil) mdb.On("GetContractAPIByName", mock.Anything, api.Namespace, api.Name).Return(nil, nil) - mdb.On("GetFFIByID", mock.Anything, api.Interface.ID).Return(&core.FFI{}, nil) + mdb.On("GetFFIByID", mock.Anything, "ns1", api.Interface.ID).Return(&core.FFI{}, nil) mbm.On("BroadcastDefinitionAsNode", mock.Anything, "ns1", mock.AnythingOfType("*core.ContractAPI"), core.SystemTagDefineContractAPI, false).Return(msg, nil) api, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) @@ -2024,7 +2024,7 @@ func TestBroadcastContractAPIExisting(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), api.Location).Return(api.Location, nil) mdb.On("GetContractAPIByName", mock.Anything, api.Namespace, api.Name).Return(existing, nil) - mdb.On("GetFFIByID", mock.Anything, api.Interface.ID).Return(&core.FFI{}, nil) + mdb.On("GetFFIByID", mock.Anything, "ns1", api.Interface.ID).Return(&core.FFI{}, nil) mbm.On("BroadcastDefinitionAsNode", mock.Anything, "ns1", mock.AnythingOfType("*core.ContractAPI"), core.SystemTagDefineContractAPI, false).Return(msg, nil) _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) @@ -2128,7 +2128,7 @@ func TestBroadcastContractAPIFail(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), api.Location).Return(api.Location, nil) mdb.On("GetContractAPIByName", mock.Anything, api.Namespace, api.Name).Return(nil, nil) - mdb.On("GetFFIByID", mock.Anything, api.Interface.ID).Return(&core.FFI{}, nil) + mdb.On("GetFFIByID", mock.Anything, "ns1", api.Interface.ID).Return(&core.FFI{}, nil) mbm.On("BroadcastDefinitionAsNode", mock.Anything, "ns1", mock.AnythingOfType("*core.ContractAPI"), core.SystemTagDefineContractAPI, false).Return(nil, fmt.Errorf("pop")) _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) @@ -2180,7 +2180,7 @@ func TestBroadcastContractAPIInterfaceIDFail(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), api.Location).Return(api.Location, nil) mdb.On("GetContractAPIByName", mock.Anything, api.Namespace, api.Name).Return(nil, nil) - mdb.On("GetFFIByID", mock.Anything, api.Interface.ID).Return(nil, fmt.Errorf("pop")) + mdb.On("GetFFIByID", mock.Anything, "ns1", api.Interface.ID).Return(nil, fmt.Errorf("pop")) _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) @@ -2207,7 +2207,7 @@ func TestBroadcastContractAPIInterfaceIDNotFound(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), api.Location).Return(api.Location, nil) mdb.On("GetContractAPIByName", mock.Anything, api.Namespace, api.Name).Return(nil, nil) - mdb.On("GetFFIByID", mock.Anything, api.Interface.ID).Return(nil, nil) + mdb.On("GetFFIByID", mock.Anything, "ns1", api.Interface.ID).Return(nil, nil) _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) diff --git a/internal/database/sqlcommon/ffi_events_sql.go b/internal/database/sqlcommon/ffi_events_sql.go index 207af2ee0..c9c80747f 100644 --- a/internal/database/sqlcommon/ffi_events_sql.go +++ b/internal/database/sqlcommon/ffi_events_sql.go @@ -143,8 +143,9 @@ func (s *SQLCommon) getFFIEventPred(ctx context.Context, desc string, pred inter return ci, nil } -func (s *SQLCommon) GetFFIEvents(ctx context.Context, filter database.Filter) (events []*core.FFIEvent, res *database.FilterResult, err error) { - query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(ffiEventsColumns...).From(ffieventsTable), filter, ffiEventFilterFieldMap, []interface{}{"sequence"}) +func (s *SQLCommon) GetFFIEvents(ctx context.Context, namespace string, filter database.Filter) (events []*core.FFIEvent, res *database.FilterResult, err error) { + query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(ffiEventsColumns...).From(ffieventsTable), + filter, ffiEventFilterFieldMap, []interface{}{"sequence"}, sq.Eq{"namespace": namespace}) if err != nil { return nil, nil, err } @@ -167,6 +168,6 @@ func (s *SQLCommon) GetFFIEvents(ctx context.Context, filter database.Filter) (e } -func (s *SQLCommon) GetFFIEvent(ctx context.Context, ns string, interfaceID *fftypes.UUID, pathName string) (*core.FFIEvent, error) { - return s.getFFIEventPred(ctx, ns+":"+pathName, sq.And{sq.Eq{"namespace": ns}, sq.Eq{"interface_id": interfaceID}, sq.Eq{"pathname": pathName}}) +func (s *SQLCommon) GetFFIEvent(ctx context.Context, namespace string, interfaceID *fftypes.UUID, pathName string) (*core.FFIEvent, error) { + return s.getFFIEventPred(ctx, namespace+":"+pathName, sq.Eq{"namespace": namespace, "interface_id": interfaceID, "pathname": pathName}) } diff --git a/internal/database/sqlcommon/ffi_events_sql_test.go b/internal/database/sqlcommon/ffi_events_sql_test.go index 8d8ce6959..2458bbd07 100644 --- a/internal/database/sqlcommon/ffi_events_sql_test.go +++ b/internal/database/sqlcommon/ffi_events_sql_test.go @@ -75,7 +75,7 @@ func TestFFIEventsE2EWithDB(t *testing.T) { fb.Eq("id", eventRead.ID.String()), fb.Eq("name", eventRead.Name), ) - events, res, err := s.GetFFIEvents(ctx, filter.Count(true)) + events, res, err := s.GetFFIEvents(ctx, "ns", filter.Count(true)) assert.NoError(t, err) assert.Equal(t, 1, len(events)) assert.Equal(t, int64(1), *res.TotalCount) @@ -172,7 +172,7 @@ func TestGetFFIEvents(t *testing.T) { rows := sqlmock.NewRows(ffiEventsColumns). AddRow(fftypes.NewUUID().String(), fftypes.NewUUID().String(), "ns1", "sum", "sum", "", []byte(`[]`), []byte(`{}`)) mock.ExpectQuery("SELECT .*").WillReturnRows(rows) - _, _, err := s.GetFFIEvents(context.Background(), filter) + _, _, err := s.GetFFIEvents(context.Background(), "ns1", filter) assert.NoError(t, err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -180,7 +180,7 @@ func TestGetFFIEvents(t *testing.T) { func TestGetFFIEventsFilterSelectFail(t *testing.T) { fb := database.FFIEventQueryFactory.NewFilter(context.Background()) s, _ := newMockProvider().init() - _, _, err := s.GetFFIEvents(context.Background(), fb.And(fb.Eq("id", map[bool]bool{true: false}))) + _, _, err := s.GetFFIEvents(context.Background(), "ns1", fb.And(fb.Eq("id", map[bool]bool{true: false}))) assert.Error(t, err) } @@ -191,7 +191,7 @@ func TestGetFFIEventsQueryFail(t *testing.T) { ) s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) - _, _, err := s.GetFFIEvents(context.Background(), filter) + _, _, err := s.GetFFIEvents(context.Background(), "ns1", filter) assert.Regexp(t, "pop", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -206,7 +206,7 @@ func TestGetFFIEventsQueryResultFail(t *testing.T) { AddRow("7e2c001c-e270-4fd7-9e82-9dacee843dc2", "ns1", "math", "v1.0.0"). AddRow("7e2c001c-e270-4fd7-9e82-9dacee843dc2", nil, "math", "v1.0.0") mock.ExpectQuery("SELECT .*").WillReturnRows(rows) - _, _, err := s.GetFFIEvents(context.Background(), filter) + _, _, err := s.GetFFIEvents(context.Background(), "ns1", filter) assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } diff --git a/internal/database/sqlcommon/ffi_methods_sql.go b/internal/database/sqlcommon/ffi_methods_sql.go index 7363a61a5..1851c3e44 100644 --- a/internal/database/sqlcommon/ffi_methods_sql.go +++ b/internal/database/sqlcommon/ffi_methods_sql.go @@ -147,8 +147,9 @@ func (s *SQLCommon) getFFIMethodPred(ctx context.Context, desc string, pred inte return ci, nil } -func (s *SQLCommon) GetFFIMethods(ctx context.Context, filter database.Filter) (methods []*core.FFIMethod, res *database.FilterResult, err error) { - query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(ffiMethodsColumns...).From(ffimethodsTable), filter, ffiMethodFilterFieldMap, []interface{}{"sequence"}) +func (s *SQLCommon) GetFFIMethods(ctx context.Context, namespace string, filter database.Filter) (methods []*core.FFIMethod, res *database.FilterResult, err error) { + query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(ffiMethodsColumns...).From(ffimethodsTable), + filter, ffiMethodFilterFieldMap, []interface{}{"sequence"}, sq.Eq{"namespace": namespace}) if err != nil { return nil, nil, err } @@ -172,5 +173,5 @@ func (s *SQLCommon) GetFFIMethods(ctx context.Context, filter database.Filter) ( } func (s *SQLCommon) GetFFIMethod(ctx context.Context, ns string, interfaceID *fftypes.UUID, pathName string) (*core.FFIMethod, error) { - return s.getFFIMethodPred(ctx, ns+":"+pathName, sq.And{sq.Eq{"namespace": ns}, sq.Eq{"interface_id": interfaceID}, sq.Eq{"pathname": pathName}}) + return s.getFFIMethodPred(ctx, ns+":"+pathName, sq.Eq{"namespace": ns, "interface_id": interfaceID, "pathname": pathName}) } diff --git a/internal/database/sqlcommon/ffi_methods_sql_test.go b/internal/database/sqlcommon/ffi_methods_sql_test.go index 236b3876f..e8a48b919 100644 --- a/internal/database/sqlcommon/ffi_methods_sql_test.go +++ b/internal/database/sqlcommon/ffi_methods_sql_test.go @@ -79,7 +79,7 @@ func TestFFIMethodsE2EWithDB(t *testing.T) { fb.Eq("id", methodRead.ID.String()), fb.Eq("name", methodRead.Name), ) - methods, res, err := s.GetFFIMethods(ctx, filter.Count(true)) + methods, res, err := s.GetFFIMethods(ctx, "ns", filter.Count(true)) assert.NoError(t, err) assert.Equal(t, 1, len(methods)) assert.Equal(t, int64(1), *res.TotalCount) @@ -184,7 +184,7 @@ func TestGetFFIMethods(t *testing.T) { rows := sqlmock.NewRows(ffiMethodsColumns). AddRow(fftypes.NewUUID().String(), fftypes.NewUUID().String(), "ns1", "sum", "sum", "", []byte(`[]`), []byte(`[]`), []byte(`{}`)) mock.ExpectQuery("SELECT .*").WillReturnRows(rows) - _, _, err := s.GetFFIMethods(context.Background(), filter) + _, _, err := s.GetFFIMethods(context.Background(), "ns1", filter) assert.NoError(t, err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -192,7 +192,7 @@ func TestGetFFIMethods(t *testing.T) { func TestGetFFIMethodsFilterSelectFail(t *testing.T) { fb := database.FFIMethodQueryFactory.NewFilter(context.Background()) s, _ := newMockProvider().init() - _, _, err := s.GetFFIMethods(context.Background(), fb.And(fb.Eq("id", map[bool]bool{true: false}))) + _, _, err := s.GetFFIMethods(context.Background(), "ns1", fb.And(fb.Eq("id", map[bool]bool{true: false}))) assert.Error(t, err) } @@ -203,7 +203,7 @@ func TestGetFFIMethodsQueryFail(t *testing.T) { ) s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) - _, _, err := s.GetFFIMethods(context.Background(), filter) + _, _, err := s.GetFFIMethods(context.Background(), "ns1", filter) assert.Regexp(t, "pop", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -218,7 +218,7 @@ func TestGetFFIMethodsQueryResultFail(t *testing.T) { AddRow("7e2c001c-e270-4fd7-9e82-9dacee843dc2", "ns1", "math", "v1.0.0"). AddRow("7e2c001c-e270-4fd7-9e82-9dacee843dc2", nil, "math", "v1.0.0") mock.ExpectQuery("SELECT .*").WillReturnRows(rows) - _, _, err := s.GetFFIMethods(context.Background(), filter) + _, _, err := s.GetFFIMethods(context.Background(), "ns1", filter) assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } diff --git a/internal/database/sqlcommon/ffi_sql.go b/internal/database/sqlcommon/ffi_sql.go index e96584664..8d45111ed 100644 --- a/internal/database/sqlcommon/ffi_sql.go +++ b/internal/database/sqlcommon/ffi_sql.go @@ -140,9 +140,10 @@ func (s *SQLCommon) getFFIPred(ctx context.Context, desc string, pred interface{ return ffi, nil } -func (s *SQLCommon) GetFFIs(ctx context.Context, ns string, filter database.Filter) (ffis []*core.FFI, res *database.FilterResult, err error) { +func (s *SQLCommon) GetFFIs(ctx context.Context, namespace string, filter database.Filter) (ffis []*core.FFI, res *database.FilterResult, err error) { - query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(ffiColumns...).From(ffiTable).Where(sq.Eq{"namespace": ns}), filter, ffiFilterFieldMap, []interface{}{"sequence"}) + query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(ffiColumns...).From(ffiTable), + filter, ffiFilterFieldMap, []interface{}{"sequence"}, sq.Eq{"namespace": namespace}) if err != nil { return nil, nil, err } @@ -166,10 +167,10 @@ func (s *SQLCommon) GetFFIs(ctx context.Context, ns string, filter database.Filt } -func (s *SQLCommon) GetFFIByID(ctx context.Context, id *fftypes.UUID) (*core.FFI, error) { - return s.getFFIPred(ctx, id.String(), sq.Eq{"id": id}) +func (s *SQLCommon) GetFFIByID(ctx context.Context, namespace string, id *fftypes.UUID) (*core.FFI, error) { + return s.getFFIPred(ctx, id.String(), sq.Eq{"id": id, "namespace": namespace}) } -func (s *SQLCommon) GetFFI(ctx context.Context, ns, name, version string) (*core.FFI, error) { - return s.getFFIPred(ctx, ns+":"+name+":"+version, sq.And{sq.Eq{"namespace": ns}, sq.Eq{"name": name}, sq.Eq{"version": version}}) +func (s *SQLCommon) GetFFI(ctx context.Context, namespace, name, version string) (*core.FFI, error) { + return s.getFFIPred(ctx, namespace+":"+name+":"+version, sq.Eq{"namespace": namespace, "name": name, "version": version}) } diff --git a/internal/database/sqlcommon/ffi_sql_test.go b/internal/database/sqlcommon/ffi_sql_test.go index c25d830d9..f6c4d06be 100644 --- a/internal/database/sqlcommon/ffi_sql_test.go +++ b/internal/database/sqlcommon/ffi_sql_test.go @@ -76,7 +76,7 @@ func TestFFIE2EWithDB(t *testing.T) { assert.NoError(t, err) // Check we get the correct fields back - dataRead, err := s.GetFFIByID(ctx, id) + dataRead, err := s.GetFFIByID(ctx, "ns1", id) assert.NoError(t, err) assert.NotNil(t, dataRead) assert.Equal(t, ffi.ID, dataRead.ID) @@ -91,7 +91,7 @@ func TestFFIE2EWithDB(t *testing.T) { assert.NoError(t, err) // Check we get the correct fields back - dataRead, err = s.GetFFIByID(ctx, id) + dataRead, err = s.GetFFIByID(ctx, "ns1", id) assert.NoError(t, err) assert.NotNil(t, dataRead) assert.Equal(t, ffi.ID, dataRead.ID) @@ -149,7 +149,7 @@ func TestFFIDBFailScan(t *testing.T) { s, mock := newMockProvider().init() id := fftypes.NewUUID() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("only one")) - _, err := s.GetFFIByID(context.Background(), id) + _, err := s.GetFFIByID(context.Background(), "ns1", id) assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -158,7 +158,7 @@ func TestFFIDBSelectFail(t *testing.T) { s, mock := newMockProvider().init() id := fftypes.NewUUID() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) - _, err := s.GetFFIByID(context.Background(), id) + _, err := s.GetFFIByID(context.Background(), "ns1", id) assert.Regexp(t, "pop", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -167,7 +167,7 @@ func TestFFIDBNoRows(t *testing.T) { s, mock := newMockProvider().init() id := fftypes.NewUUID() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "name", "version"})) - _, err := s.GetFFIByID(context.Background(), id) + _, err := s.GetFFIByID(context.Background(), "ns1", id) assert.NoError(t, err) assert.NoError(t, mock.ExpectationsWereMet()) } diff --git a/internal/txcommon/event_enrich.go b/internal/txcommon/event_enrich.go index e5267d9da..39ce3b351 100644 --- a/internal/txcommon/event_enrich.go +++ b/internal/txcommon/event_enrich.go @@ -53,7 +53,7 @@ func (t *transactionHelper) EnrichEvent(ctx context.Context, event *core.Event) } e.ContractAPI = contractAPI case core.EventTypeContractInterfaceConfirmed: - contractInterface, err := t.database.GetFFIByID(ctx, event.Reference) + contractInterface, err := t.database.GetFFIByID(ctx, t.namespace, event.Reference) if err != nil { return nil, err } diff --git a/internal/txcommon/event_enrich_test.go b/internal/txcommon/event_enrich_test.go index 764e3bec8..ff84389d2 100644 --- a/internal/txcommon/event_enrich_test.go +++ b/internal/txcommon/event_enrich_test.go @@ -262,7 +262,7 @@ func TestEnrichContractInterfaceSubmitted(t *testing.T) { ev1 := fftypes.NewUUID() // Setup enrichment - mdi.On("GetFFIByID", mock.Anything, ref1).Return(&core.FFI{ + mdi.On("GetFFIByID", mock.Anything, "ns1", ref1).Return(&core.FFI{ ID: ref1, }, nil) @@ -288,7 +288,7 @@ func TestEnrichContractInterfacetFail(t *testing.T) { ev1 := fftypes.NewUUID() // Setup enrichment - mdi.On("GetFFIByID", mock.Anything, ref1).Return(nil, fmt.Errorf("pop")) + mdi.On("GetFFIByID", mock.Anything, "ns1", ref1).Return(nil, fmt.Errorf("pop")) event := &core.Event{ ID: ev1, diff --git a/mocks/contractmocks/manager.go b/mocks/contractmocks/manager.go index 1a09709c1..18223121a 100644 --- a/mocks/contractmocks/manager.go +++ b/mocks/contractmocks/manager.go @@ -88,13 +88,13 @@ func (_m *Manager) BroadcastContractAPI(ctx context.Context, httpServerURL strin return r0, r1 } -// BroadcastFFI provides a mock function with given fields: ctx, ns, ffi, waitConfirm -func (_m *Manager) BroadcastFFI(ctx context.Context, ns string, ffi *core.FFI, waitConfirm bool) (*core.FFI, error) { - ret := _m.Called(ctx, ns, ffi, waitConfirm) +// BroadcastFFI provides a mock function with given fields: ctx, ffi, waitConfirm +func (_m *Manager) BroadcastFFI(ctx context.Context, ffi *core.FFI, waitConfirm bool) (*core.FFI, error) { + ret := _m.Called(ctx, ffi, waitConfirm) var r0 *core.FFI - if rf, ok := ret.Get(0).(func(context.Context, string, *core.FFI, bool) *core.FFI); ok { - r0 = rf(ctx, ns, ffi, waitConfirm) + if rf, ok := ret.Get(0).(func(context.Context, *core.FFI, bool) *core.FFI); ok { + r0 = rf(ctx, ffi, waitConfirm) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.FFI) @@ -102,8 +102,8 @@ func (_m *Manager) BroadcastFFI(ctx context.Context, ns string, ffi *core.FFI, w } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, *core.FFI, bool) error); ok { - r1 = rf(ctx, ns, ffi, waitConfirm) + if rf, ok := ret.Get(1).(func(context.Context, *core.FFI, bool) error); ok { + r1 = rf(ctx, ffi, waitConfirm) } else { r1 = ret.Error(1) } @@ -313,13 +313,13 @@ func (_m *Manager) GetContractListeners(ctx context.Context, ns string, filter d return r0, r1, r2 } -// GetFFI provides a mock function with given fields: ctx, ns, name, version -func (_m *Manager) GetFFI(ctx context.Context, ns string, name string, version string) (*core.FFI, error) { - ret := _m.Called(ctx, ns, name, version) +// GetFFI provides a mock function with given fields: ctx, name, version +func (_m *Manager) GetFFI(ctx context.Context, name string, version string) (*core.FFI, error) { + ret := _m.Called(ctx, name, version) var r0 *core.FFI - if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *core.FFI); ok { - r0 = rf(ctx, ns, name, version) + if rf, ok := ret.Get(0).(func(context.Context, string, string) *core.FFI); ok { + r0 = rf(ctx, name, version) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.FFI) @@ -327,8 +327,8 @@ func (_m *Manager) GetFFI(ctx context.Context, ns string, name string, version s } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { - r1 = rf(ctx, ns, name, version) + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, name, version) } else { r1 = ret.Error(1) } @@ -382,13 +382,13 @@ func (_m *Manager) GetFFIByIDWithChildren(ctx context.Context, id *fftypes.UUID) return r0, r1 } -// GetFFIWithChildren provides a mock function with given fields: ctx, ns, name, version -func (_m *Manager) GetFFIWithChildren(ctx context.Context, ns string, name string, version string) (*core.FFI, error) { - ret := _m.Called(ctx, ns, name, version) +// GetFFIWithChildren provides a mock function with given fields: ctx, name, version +func (_m *Manager) GetFFIWithChildren(ctx context.Context, name string, version string) (*core.FFI, error) { + ret := _m.Called(ctx, name, version) var r0 *core.FFI - if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *core.FFI); ok { - r0 = rf(ctx, ns, name, version) + if rf, ok := ret.Get(0).(func(context.Context, string, string) *core.FFI); ok { + r0 = rf(ctx, name, version) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.FFI) @@ -396,8 +396,8 @@ func (_m *Manager) GetFFIWithChildren(ctx context.Context, ns string, name strin } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { - r1 = rf(ctx, ns, name, version) + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, name, version) } else { r1 = ret.Error(1) } @@ -405,13 +405,13 @@ func (_m *Manager) GetFFIWithChildren(ctx context.Context, ns string, name strin return r0, r1 } -// GetFFIs provides a mock function with given fields: ctx, ns, filter -func (_m *Manager) GetFFIs(ctx context.Context, ns string, filter database.AndFilter) ([]*core.FFI, *database.FilterResult, error) { - ret := _m.Called(ctx, ns, filter) +// GetFFIs provides a mock function with given fields: ctx, filter +func (_m *Manager) GetFFIs(ctx context.Context, filter database.AndFilter) ([]*core.FFI, *database.FilterResult, error) { + ret := _m.Called(ctx, filter) var r0 []*core.FFI - if rf, ok := ret.Get(0).(func(context.Context, string, database.AndFilter) []*core.FFI); ok { - r0 = rf(ctx, ns, filter) + if rf, ok := ret.Get(0).(func(context.Context, database.AndFilter) []*core.FFI); ok { + r0 = rf(ctx, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.FFI) @@ -419,8 +419,8 @@ func (_m *Manager) GetFFIs(ctx context.Context, ns string, filter database.AndFi } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, string, database.AndFilter) *database.FilterResult); ok { - r1 = rf(ctx, ns, filter) + if rf, ok := ret.Get(1).(func(context.Context, database.AndFilter) *database.FilterResult); ok { + r1 = rf(ctx, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -428,8 +428,8 @@ func (_m *Manager) GetFFIs(ctx context.Context, ns string, filter database.AndFi } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, string, database.AndFilter) error); ok { - r2 = rf(ctx, ns, filter) + if rf, ok := ret.Get(2).(func(context.Context, database.AndFilter) error); ok { + r2 = rf(ctx, filter) } else { r2 = ret.Error(2) } diff --git a/mocks/databasemocks/plugin.go b/mocks/databasemocks/plugin.go index 078accf4d..062ec930a 100644 --- a/mocks/databasemocks/plugin.go +++ b/mocks/databasemocks/plugin.go @@ -814,13 +814,13 @@ func (_m *Plugin) GetFFI(ctx context.Context, namespace string, name string, ver return r0, r1 } -// GetFFIByID provides a mock function with given fields: ctx, id -func (_m *Plugin) GetFFIByID(ctx context.Context, id *fftypes.UUID) (*core.FFI, error) { - ret := _m.Called(ctx, id) +// GetFFIByID provides a mock function with given fields: ctx, namespace, id +func (_m *Plugin) GetFFIByID(ctx context.Context, namespace string, id *fftypes.UUID) (*core.FFI, error) { + ret := _m.Called(ctx, namespace, id) var r0 *core.FFI - if rf, ok := ret.Get(0).(func(context.Context, *fftypes.UUID) *core.FFI); ok { - r0 = rf(ctx, id) + if rf, ok := ret.Get(0).(func(context.Context, string, *fftypes.UUID) *core.FFI); ok { + r0 = rf(ctx, namespace, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.FFI) @@ -828,8 +828,8 @@ func (_m *Plugin) GetFFIByID(ctx context.Context, id *fftypes.UUID) (*core.FFI, } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *fftypes.UUID) error); ok { - r1 = rf(ctx, id) + if rf, ok := ret.Get(1).(func(context.Context, string, *fftypes.UUID) error); ok { + r1 = rf(ctx, namespace, id) } else { r1 = ret.Error(1) } @@ -860,13 +860,13 @@ func (_m *Plugin) GetFFIEvent(ctx context.Context, namespace string, interfaceID return r0, r1 } -// GetFFIEvents provides a mock function with given fields: ctx, filter -func (_m *Plugin) GetFFIEvents(ctx context.Context, filter database.Filter) ([]*core.FFIEvent, *database.FilterResult, error) { - ret := _m.Called(ctx, filter) +// GetFFIEvents provides a mock function with given fields: ctx, namespace, filter +func (_m *Plugin) GetFFIEvents(ctx context.Context, namespace string, filter database.Filter) ([]*core.FFIEvent, *database.FilterResult, error) { + ret := _m.Called(ctx, namespace, filter) var r0 []*core.FFIEvent - if rf, ok := ret.Get(0).(func(context.Context, database.Filter) []*core.FFIEvent); ok { - r0 = rf(ctx, filter) + if rf, ok := ret.Get(0).(func(context.Context, string, database.Filter) []*core.FFIEvent); ok { + r0 = rf(ctx, namespace, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.FFIEvent) @@ -874,8 +874,8 @@ func (_m *Plugin) GetFFIEvents(ctx context.Context, filter database.Filter) ([]* } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, database.Filter) *database.FilterResult); ok { - r1 = rf(ctx, filter) + if rf, ok := ret.Get(1).(func(context.Context, string, database.Filter) *database.FilterResult); ok { + r1 = rf(ctx, namespace, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -883,8 +883,8 @@ func (_m *Plugin) GetFFIEvents(ctx context.Context, filter database.Filter) ([]* } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, database.Filter) error); ok { - r2 = rf(ctx, filter) + if rf, ok := ret.Get(2).(func(context.Context, string, database.Filter) error); ok { + r2 = rf(ctx, namespace, filter) } else { r2 = ret.Error(2) } @@ -915,13 +915,13 @@ func (_m *Plugin) GetFFIMethod(ctx context.Context, namespace string, interfaceI return r0, r1 } -// GetFFIMethods provides a mock function with given fields: ctx, filter -func (_m *Plugin) GetFFIMethods(ctx context.Context, filter database.Filter) ([]*core.FFIMethod, *database.FilterResult, error) { - ret := _m.Called(ctx, filter) +// GetFFIMethods provides a mock function with given fields: ctx, namespace, filter +func (_m *Plugin) GetFFIMethods(ctx context.Context, namespace string, filter database.Filter) ([]*core.FFIMethod, *database.FilterResult, error) { + ret := _m.Called(ctx, namespace, filter) var r0 []*core.FFIMethod - if rf, ok := ret.Get(0).(func(context.Context, database.Filter) []*core.FFIMethod); ok { - r0 = rf(ctx, filter) + if rf, ok := ret.Get(0).(func(context.Context, string, database.Filter) []*core.FFIMethod); ok { + r0 = rf(ctx, namespace, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.FFIMethod) @@ -929,8 +929,8 @@ func (_m *Plugin) GetFFIMethods(ctx context.Context, filter database.Filter) ([] } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, database.Filter) *database.FilterResult); ok { - r1 = rf(ctx, filter) + if rf, ok := ret.Get(1).(func(context.Context, string, database.Filter) *database.FilterResult); ok { + r1 = rf(ctx, namespace, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -938,8 +938,8 @@ func (_m *Plugin) GetFFIMethods(ctx context.Context, filter database.Filter) ([] } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, database.Filter) error); ok { - r2 = rf(ctx, filter) + if rf, ok := ret.Get(2).(func(context.Context, string, database.Filter) error); ok { + r2 = rf(ctx, namespace, filter) } else { r2 = ret.Error(2) } diff --git a/pkg/database/plugin.go b/pkg/database/plugin.go index 2955435ee..6709eafad 100644 --- a/pkg/database/plugin.go +++ b/pkg/database/plugin.go @@ -422,22 +422,39 @@ type iTokenApprovalCollection interface { } type iFFICollection interface { + // UpsertFFI - Upsert an FFI UpsertFFI(ctx context.Context, cd *core.FFI) error + + // GetFFIs - Get FFIs GetFFIs(ctx context.Context, namespace string, filter Filter) ([]*core.FFI, *FilterResult, error) - GetFFIByID(ctx context.Context, id *fftypes.UUID) (*core.FFI, error) + + // GetFFIByID - Get an FFI by ID + GetFFIByID(ctx context.Context, namespace string, id *fftypes.UUID) (*core.FFI, error) + + // GetFFI - Get an FFI by name and version GetFFI(ctx context.Context, namespace, name, version string) (*core.FFI, error) } type iFFIMethodCollection interface { + // UpsertFFIMethod - Upsert an FFI method UpsertFFIMethod(ctx context.Context, method *core.FFIMethod) error + + // GetFFIMethod - Get an FFI method by path GetFFIMethod(ctx context.Context, namespace string, interfaceID *fftypes.UUID, pathName string) (*core.FFIMethod, error) - GetFFIMethods(ctx context.Context, filter Filter) (methods []*core.FFIMethod, res *FilterResult, err error) + + // GetFFIMethods - Get FFI methods + GetFFIMethods(ctx context.Context, namespace string, filter Filter) (methods []*core.FFIMethod, res *FilterResult, err error) } type iFFIEventCollection interface { + // UpsertFFIEvent - Upsert an FFI event UpsertFFIEvent(ctx context.Context, method *core.FFIEvent) error + + // GetFFIEvent - Get an FFI event by path GetFFIEvent(ctx context.Context, namespace string, interfaceID *fftypes.UUID, pathName string) (*core.FFIEvent, error) - GetFFIEvents(ctx context.Context, filter Filter) (events []*core.FFIEvent, res *FilterResult, err error) + + // GetFFIEvents - Get FFI events + GetFFIEvents(ctx context.Context, namespace string, filter Filter) (events []*core.FFIEvent, res *FilterResult, err error) } type iContractAPICollection interface { From adc2e48c77f6e50c67e6a8124794972fbb8a950d Mon Sep 17 00:00:00 2001 From: Andrew Richardson Date: Wed, 22 Jun 2022 14:21:55 -0400 Subject: [PATCH 7/9] Add namespace to contract API database queries Signed-off-by: Andrew Richardson --- .../route_get_contract_api_by_name.go | 2 +- .../route_get_contract_api_by_name_test.go | 2 +- .../route_get_contract_api_interface.go | 2 +- .../route_get_contract_api_interface_test.go | 2 +- internal/apiserver/route_get_contract_apis.go | 2 +- .../apiserver/route_get_contract_apis_test.go | 2 +- .../route_post_contract_api_invoke.go | 2 +- .../route_post_contract_api_invoke_test.go | 2 +- .../route_post_contract_api_query.go | 2 +- .../route_post_contract_api_query_test.go | 2 +- .../apiserver/route_post_contract_invoke.go | 2 +- .../route_post_contract_invoke_test.go | 2 +- .../apiserver/route_post_contract_query.go | 2 +- .../route_post_contract_query_test.go | 2 +- .../apiserver/route_post_new_contract_api.go | 2 +- .../route_post_new_contract_api_test.go | 4 +- internal/apiserver/route_put_contract_api.go | 2 +- .../apiserver/route_put_contract_api_test.go | 4 +- internal/apiserver/server.go | 2 +- internal/apiserver/server_test.go | 8 +- internal/contracts/manager.go | 45 +++++----- internal/contracts/manager_test.go | 58 ++++++------ .../database/sqlcommon/contractapis_sql.go | 13 +-- .../sqlcommon/contractapis_sql_test.go | 10 +-- internal/txcommon/event_enrich.go | 2 +- internal/txcommon/event_enrich_test.go | 4 +- mocks/contractmocks/manager.go | 88 +++++++++---------- mocks/databasemocks/plugin.go | 14 +-- pkg/database/plugin.go | 9 +- 29 files changed, 150 insertions(+), 143 deletions(-) diff --git a/internal/apiserver/route_get_contract_api_by_name.go b/internal/apiserver/route_get_contract_api_by_name.go index 760f86c44..dc87e580d 100644 --- a/internal/apiserver/route_get_contract_api_by_name.go +++ b/internal/apiserver/route_get_contract_api_by_name.go @@ -38,7 +38,7 @@ var getContractAPIByName = &ffapi.Route{ JSONOutputCodes: []int{http.StatusOK}, Extensions: &coreExtensions{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - return cr.or.Contracts().GetContractAPI(cr.ctx, cr.apiBaseURL, extractNamespace(r.PP), r.PP["apiName"]) + return cr.or.Contracts().GetContractAPI(cr.ctx, cr.apiBaseURL, r.PP["apiName"]) }, }, } diff --git a/internal/apiserver/route_get_contract_api_by_name_test.go b/internal/apiserver/route_get_contract_api_by_name_test.go index f658907bd..1dad43de8 100644 --- a/internal/apiserver/route_get_contract_api_by_name_test.go +++ b/internal/apiserver/route_get_contract_api_by_name_test.go @@ -39,7 +39,7 @@ func TestGetContractAPIByName(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("GetContractAPI", mock.Anything, "http://127.0.0.1:5000/api/v1", "ns1", "banana"). + mcm.On("GetContractAPI", mock.Anything, "http://127.0.0.1:5000/api/v1", "banana"). Return(&core.ContractAPI{}, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_get_contract_api_interface.go b/internal/apiserver/route_get_contract_api_interface.go index 4e997bc46..8ca6720ea 100644 --- a/internal/apiserver/route_get_contract_api_interface.go +++ b/internal/apiserver/route_get_contract_api_interface.go @@ -38,7 +38,7 @@ var getContractAPIInterface = &ffapi.Route{ JSONOutputCodes: []int{http.StatusOK}, Extensions: &coreExtensions{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - return cr.or.Contracts().GetContractAPIInterface(cr.ctx, extractNamespace(r.PP), r.PP["apiName"]) + return cr.or.Contracts().GetContractAPIInterface(cr.ctx, r.PP["apiName"]) }, }, } diff --git a/internal/apiserver/route_get_contract_api_interface_test.go b/internal/apiserver/route_get_contract_api_interface_test.go index 139122e23..d906df2b6 100644 --- a/internal/apiserver/route_get_contract_api_interface_test.go +++ b/internal/apiserver/route_get_contract_api_interface_test.go @@ -39,7 +39,7 @@ func TestGetContractAPIInterface(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("GetContractAPIInterface", mock.Anything, "ns1", "banana"). + mcm.On("GetContractAPIInterface", mock.Anything, "banana"). Return(&core.FFI{}, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_get_contract_apis.go b/internal/apiserver/route_get_contract_apis.go index 985896a20..6a739f4fc 100644 --- a/internal/apiserver/route_get_contract_apis.go +++ b/internal/apiserver/route_get_contract_apis.go @@ -38,7 +38,7 @@ var getContractAPIs = &ffapi.Route{ Extensions: &coreExtensions{ FilterFactory: database.ContractAPIQueryFactory, CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - return filterResult(cr.or.Contracts().GetContractAPIs(cr.ctx, cr.apiBaseURL, extractNamespace(r.PP), cr.filter)) + return filterResult(cr.or.Contracts().GetContractAPIs(cr.ctx, cr.apiBaseURL, cr.filter)) }, }, } diff --git a/internal/apiserver/route_get_contract_apis_test.go b/internal/apiserver/route_get_contract_apis_test.go index 6e5a14fe6..6405c0b06 100644 --- a/internal/apiserver/route_get_contract_apis_test.go +++ b/internal/apiserver/route_get_contract_apis_test.go @@ -39,7 +39,7 @@ func TestGetContractAPIs(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("GetContractAPIs", mock.Anything, "http://127.0.0.1:5000/api/v1", "ns1", mock.Anything). + mcm.On("GetContractAPIs", mock.Anything, "http://127.0.0.1:5000/api/v1", mock.Anything). Return([]*core.ContractAPI{}, nil, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_post_contract_api_invoke.go b/internal/apiserver/route_post_contract_api_invoke.go index 41a7307a9..dbc4e0807 100644 --- a/internal/apiserver/route_post_contract_api_invoke.go +++ b/internal/apiserver/route_post_contract_api_invoke.go @@ -46,7 +46,7 @@ var postContractAPIInvoke = &ffapi.Route{ r.SuccessStatus = syncRetcode(waitConfirm) req := r.Input.(*core.ContractCallRequest) req.Type = core.CallTypeInvoke - return cr.or.Contracts().InvokeContractAPI(cr.ctx, extractNamespace(r.PP), r.PP["apiName"], r.PP["methodPath"], req, waitConfirm) + return cr.or.Contracts().InvokeContractAPI(cr.ctx, r.PP["apiName"], r.PP["methodPath"], req, waitConfirm) }, }, } diff --git a/internal/apiserver/route_post_contract_api_invoke_test.go b/internal/apiserver/route_post_contract_api_invoke_test.go index 4f8055bee..e38d93c51 100644 --- a/internal/apiserver/route_post_contract_api_invoke_test.go +++ b/internal/apiserver/route_post_contract_api_invoke_test.go @@ -39,7 +39,7 @@ func TestPostContractAPIInvoke(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("InvokeContractAPI", mock.Anything, "ns1", "banana", "peel", mock.MatchedBy(func(req *core.ContractCallRequest) bool { + mcm.On("InvokeContractAPI", mock.Anything, "banana", "peel", mock.MatchedBy(func(req *core.ContractCallRequest) bool { return req.Type == core.CallTypeInvoke }), false).Return("banana", nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_post_contract_api_query.go b/internal/apiserver/route_post_contract_api_query.go index 938f4a4bf..004b0e9d9 100644 --- a/internal/apiserver/route_post_contract_api_query.go +++ b/internal/apiserver/route_post_contract_api_query.go @@ -41,7 +41,7 @@ var postContractAPIQuery = &ffapi.Route{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { req := r.Input.(*core.ContractCallRequest) req.Type = core.CallTypeQuery - return cr.or.Contracts().InvokeContractAPI(cr.ctx, extractNamespace(r.PP), r.PP["apiName"], r.PP["methodPath"], req, true) + return cr.or.Contracts().InvokeContractAPI(cr.ctx, r.PP["apiName"], r.PP["methodPath"], req, true) }, }, } diff --git a/internal/apiserver/route_post_contract_api_query_test.go b/internal/apiserver/route_post_contract_api_query_test.go index 22f042672..ee616da70 100644 --- a/internal/apiserver/route_post_contract_api_query_test.go +++ b/internal/apiserver/route_post_contract_api_query_test.go @@ -39,7 +39,7 @@ func TestPostContractAPIQuery(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("InvokeContractAPI", mock.Anything, "ns1", "banana", "peel", mock.MatchedBy(func(req *core.ContractCallRequest) bool { + mcm.On("InvokeContractAPI", mock.Anything, "banana", "peel", mock.MatchedBy(func(req *core.ContractCallRequest) bool { return req.Type == core.CallTypeQuery }), true).Return("banana", nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_post_contract_invoke.go b/internal/apiserver/route_post_contract_invoke.go index 38f294e08..02779e311 100644 --- a/internal/apiserver/route_post_contract_invoke.go +++ b/internal/apiserver/route_post_contract_invoke.go @@ -43,7 +43,7 @@ var postContractInvoke = &ffapi.Route{ r.SuccessStatus = syncRetcode(waitConfirm) req := r.Input.(*core.ContractCallRequest) req.Type = core.CallTypeInvoke - return cr.or.Contracts().InvokeContract(cr.ctx, extractNamespace(r.PP), req, waitConfirm) + return cr.or.Contracts().InvokeContract(cr.ctx, req, waitConfirm) }, }, } diff --git a/internal/apiserver/route_post_contract_invoke_test.go b/internal/apiserver/route_post_contract_invoke_test.go index 4972486d6..be0a5e47b 100644 --- a/internal/apiserver/route_post_contract_invoke_test.go +++ b/internal/apiserver/route_post_contract_invoke_test.go @@ -39,7 +39,7 @@ func TestPostContractInvoke(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("InvokeContract", mock.Anything, "ns1", mock.MatchedBy(func(req *core.ContractCallRequest) bool { + mcm.On("InvokeContract", mock.Anything, mock.MatchedBy(func(req *core.ContractCallRequest) bool { return req.Type == core.CallTypeInvoke }), false).Return("banana", nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_post_contract_query.go b/internal/apiserver/route_post_contract_query.go index c74dc052f..e7fd49df1 100644 --- a/internal/apiserver/route_post_contract_query.go +++ b/internal/apiserver/route_post_contract_query.go @@ -38,7 +38,7 @@ var postContractQuery = &ffapi.Route{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { req := r.Input.(*core.ContractCallRequest) req.Type = core.CallTypeQuery - return cr.or.Contracts().InvokeContract(cr.ctx, extractNamespace(r.PP), req, true) + return cr.or.Contracts().InvokeContract(cr.ctx, req, true) }, }, } diff --git a/internal/apiserver/route_post_contract_query_test.go b/internal/apiserver/route_post_contract_query_test.go index 1d57976f5..fa3787b22 100644 --- a/internal/apiserver/route_post_contract_query_test.go +++ b/internal/apiserver/route_post_contract_query_test.go @@ -39,7 +39,7 @@ func TestPostContractQuery(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("InvokeContract", mock.Anything, "ns1", mock.MatchedBy(func(req *core.ContractCallRequest) bool { + mcm.On("InvokeContract", mock.Anything, mock.MatchedBy(func(req *core.ContractCallRequest) bool { return req.Type == core.CallTypeQuery }), true).Return("banana", nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_post_new_contract_api.go b/internal/apiserver/route_post_new_contract_api.go index 5e658ca24..a2d777fa4 100644 --- a/internal/apiserver/route_post_new_contract_api.go +++ b/internal/apiserver/route_post_new_contract_api.go @@ -41,7 +41,7 @@ var postNewContractAPI = &ffapi.Route{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { waitConfirm := strings.EqualFold(r.QP["confirm"], "true") r.SuccessStatus = syncRetcode(waitConfirm) - return cr.or.Contracts().BroadcastContractAPI(cr.ctx, cr.apiBaseURL, extractNamespace(r.PP), r.Input.(*core.ContractAPI), waitConfirm) + return cr.or.Contracts().BroadcastContractAPI(cr.ctx, cr.apiBaseURL, r.Input.(*core.ContractAPI), waitConfirm) }, }, } diff --git a/internal/apiserver/route_post_new_contract_api_test.go b/internal/apiserver/route_post_new_contract_api_test.go index 257bfaa52..127716835 100644 --- a/internal/apiserver/route_post_new_contract_api_test.go +++ b/internal/apiserver/route_post_new_contract_api_test.go @@ -39,7 +39,7 @@ func TestPostNewContractAPI(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("BroadcastContractAPI", mock.Anything, mock.Anything, "ns1", mock.AnythingOfType("*core.ContractAPI"), false). + mcm.On("BroadcastContractAPI", mock.Anything, mock.Anything, mock.AnythingOfType("*core.ContractAPI"), false). Return(&core.ContractAPI{}, nil) r.ServeHTTP(res, req) @@ -57,7 +57,7 @@ func TestPostNewContractAPISync(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("BroadcastContractAPI", mock.Anything, mock.Anything, "ns1", mock.AnythingOfType("*core.ContractAPI"), true). + mcm.On("BroadcastContractAPI", mock.Anything, mock.Anything, mock.AnythingOfType("*core.ContractAPI"), true). Return(&core.ContractAPI{}, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_put_contract_api.go b/internal/apiserver/route_put_contract_api.go index 535ec0e68..90a0707ed 100644 --- a/internal/apiserver/route_put_contract_api.go +++ b/internal/apiserver/route_put_contract_api.go @@ -48,7 +48,7 @@ var putContractAPI = &ffapi.Route{ api.ID, err = fftypes.ParseUUID(cr.ctx, r.PP["id"]) var res interface{} if err == nil { - res, err = cr.or.Contracts().BroadcastContractAPI(cr.ctx, cr.apiBaseURL, extractNamespace(r.PP), api, waitConfirm) + res, err = cr.or.Contracts().BroadcastContractAPI(cr.ctx, cr.apiBaseURL, api, waitConfirm) } return res, err }, diff --git a/internal/apiserver/route_put_contract_api_test.go b/internal/apiserver/route_put_contract_api_test.go index d2e1f5491..9b5a495bb 100644 --- a/internal/apiserver/route_put_contract_api_test.go +++ b/internal/apiserver/route_put_contract_api_test.go @@ -39,7 +39,7 @@ func TestPutContractAPI(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("BroadcastContractAPI", mock.Anything, mock.Anything, "ns1", mock.AnythingOfType("*core.ContractAPI"), false). + mcm.On("BroadcastContractAPI", mock.Anything, mock.Anything, mock.AnythingOfType("*core.ContractAPI"), false). Return(&core.ContractAPI{}, nil) r.ServeHTTP(res, req) @@ -57,7 +57,7 @@ func TestPutContractAPISync(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("BroadcastContractAPI", mock.Anything, mock.Anything, "ns1", mock.AnythingOfType("*core.ContractAPI"), true). + mcm.On("BroadcastContractAPI", mock.Anything, mock.Anything, mock.AnythingOfType("*core.ContractAPI"), true). Return(&core.ContractAPI{}, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/server.go b/internal/apiserver/server.go index 787a73b29..1dc914961 100644 --- a/internal/apiserver/server.go +++ b/internal/apiserver/server.go @@ -201,7 +201,7 @@ func (as *apiServer) contractSwaggerGenerator(mgr namespace.Manager, apiBaseURL return func(req *http.Request) (*openapi3.T, error) { vars := mux.Vars(req) cm := mgr.Orchestrator(vars["ns"]).Contracts() - api, err := cm.GetContractAPI(req.Context(), apiBaseURL, vars["ns"], vars["apiName"]) + api, err := cm.GetContractAPI(req.Context(), apiBaseURL, vars["apiName"]) if err != nil { return nil, err } else if api == nil || api.Interface == nil { diff --git a/internal/apiserver/server_test.go b/internal/apiserver/server_test.go index 3821e660d..8600db82a 100644 --- a/internal/apiserver/server_test.go +++ b/internal/apiserver/server_test.go @@ -261,7 +261,7 @@ func TestContractAPISwaggerJSON(t *testing.T) { }, } - mcm.On("GetContractAPI", mock.Anything, "http://127.0.0.1:5000/api/v1", "default", "my-api").Return(api, nil) + mcm.On("GetContractAPI", mock.Anything, "http://127.0.0.1:5000/api/v1", "my-api").Return(api, nil) mcm.On("GetFFIByIDWithChildren", mock.Anything, api.Interface.ID).Return(ffi, nil) mffi.On("Generate", mock.Anything, "http://127.0.0.1:5000/api/v1/namespaces/default/apis/my-api", api, ffi).Return(&openapi3.T{}) @@ -278,7 +278,7 @@ func TestContractAPISwaggerJSONGetAPIFail(t *testing.T) { s := httptest.NewServer(r) defer s.Close() - mcm.On("GetContractAPI", mock.Anything, "http://127.0.0.1:5000/api/v1", "default", "my-api").Return(nil, fmt.Errorf("pop")) + mcm.On("GetContractAPI", mock.Anything, "http://127.0.0.1:5000/api/v1", "my-api").Return(nil, fmt.Errorf("pop")) res, err := http.Get(fmt.Sprintf("http://%s/api/v1/namespaces/default/apis/my-api/api/swagger.json", s.Listener.Addr())) assert.NoError(t, err) @@ -293,7 +293,7 @@ func TestContractAPISwaggerJSONGetAPINotFound(t *testing.T) { s := httptest.NewServer(r) defer s.Close() - mcm.On("GetContractAPI", mock.Anything, "http://127.0.0.1:5000/api/v1", "default", "my-api").Return(nil, nil) + mcm.On("GetContractAPI", mock.Anything, "http://127.0.0.1:5000/api/v1", "my-api").Return(nil, nil) res, err := http.Get(fmt.Sprintf("http://%s/api/v1/namespaces/default/apis/my-api/api/swagger.json", s.Listener.Addr())) assert.NoError(t, err) @@ -314,7 +314,7 @@ func TestContractAPISwaggerJSONGetFFIFail(t *testing.T) { }, } - mcm.On("GetContractAPI", mock.Anything, "http://127.0.0.1:5000/api/v1", "default", "my-api").Return(api, nil) + mcm.On("GetContractAPI", mock.Anything, "http://127.0.0.1:5000/api/v1", "my-api").Return(api, nil) mcm.On("GetFFIByIDWithChildren", mock.Anything, api.Interface.ID).Return(nil, fmt.Errorf("pop")) res, err := http.Get(fmt.Sprintf("http://%s/api/v1/namespaces/default/apis/my-api/api/swagger.json", s.Listener.Addr())) diff --git a/internal/contracts/manager.go b/internal/contracts/manager.go index 1f18fd5d1..f393285a3 100644 --- a/internal/contracts/manager.go +++ b/internal/contracts/manager.go @@ -45,12 +45,12 @@ type Manager interface { GetFFIByIDWithChildren(ctx context.Context, id *fftypes.UUID) (*core.FFI, error) GetFFIs(ctx context.Context, filter database.AndFilter) ([]*core.FFI, *database.FilterResult, error) - InvokeContract(ctx context.Context, ns string, req *core.ContractCallRequest, waitConfirm bool) (interface{}, error) - InvokeContractAPI(ctx context.Context, ns, apiName, methodPath string, req *core.ContractCallRequest, waitConfirm bool) (interface{}, error) - GetContractAPI(ctx context.Context, httpServerURL, ns, apiName string) (*core.ContractAPI, error) - GetContractAPIInterface(ctx context.Context, ns, apiName string) (*core.FFI, error) - GetContractAPIs(ctx context.Context, httpServerURL, ns string, filter database.AndFilter) ([]*core.ContractAPI, *database.FilterResult, error) - BroadcastContractAPI(ctx context.Context, httpServerURL, ns string, api *core.ContractAPI, waitConfirm bool) (output *core.ContractAPI, err error) + InvokeContract(ctx context.Context, req *core.ContractCallRequest, waitConfirm bool) (interface{}, error) + InvokeContractAPI(ctx context.Context, apiName, methodPath string, req *core.ContractCallRequest, waitConfirm bool) (interface{}, error) + GetContractAPI(ctx context.Context, httpServerURL, apiName string) (*core.ContractAPI, error) + GetContractAPIInterface(ctx context.Context, apiName string) (*core.FFI, error) + GetContractAPIs(ctx context.Context, httpServerURL string, filter database.AndFilter) ([]*core.ContractAPI, *database.FilterResult, error) + BroadcastContractAPI(ctx context.Context, httpServerURL string, api *core.ContractAPI, waitConfirm bool) (output *core.ContractAPI, err error) ValidateFFIAndSetPathnames(ctx context.Context, ffi *core.FFI) error @@ -218,7 +218,7 @@ func (cm *contractManager) writeInvokeTransaction(ctx context.Context, req *core return op, err } -func (cm *contractManager) InvokeContract(ctx context.Context, ns string, req *core.ContractCallRequest, waitConfirm bool) (res interface{}, err error) { +func (cm *contractManager) InvokeContract(ctx context.Context, req *core.ContractCallRequest, waitConfirm bool) (res interface{}, err error) { req.Key, err = cm.identity.NormalizeSigningKey(ctx, req.Key, identity.KeyNormalizationBlockchainPlugin) if err != nil { return nil, err @@ -226,7 +226,7 @@ func (cm *contractManager) InvokeContract(ctx context.Context, ns string, req *c var op *core.Operation err = cm.database.RunAsGroup(ctx, func(ctx context.Context) (err error) { - if err = cm.resolveInvokeContractRequest(ctx, ns, req); err != nil { + if err = cm.resolveInvokeContractRequest(ctx, req); err != nil { return err } if err := cm.validateInvokeContractRequest(ctx, req); err != nil { @@ -262,8 +262,8 @@ func (cm *contractManager) InvokeContract(ctx context.Context, ns string, req *c } } -func (cm *contractManager) InvokeContractAPI(ctx context.Context, ns, apiName, methodPath string, req *core.ContractCallRequest, waitConfirm bool) (interface{}, error) { - api, err := cm.database.GetContractAPIByName(ctx, ns, apiName) +func (cm *contractManager) InvokeContractAPI(ctx context.Context, apiName, methodPath string, req *core.ContractCallRequest, waitConfirm bool) (interface{}, error) { + api, err := cm.database.GetContractAPIByName(ctx, cm.namespace, apiName) if err != nil { return nil, err } else if api == nil || api.Interface == nil { @@ -274,15 +274,15 @@ func (cm *contractManager) InvokeContractAPI(ctx context.Context, ns, apiName, m if api.Location != nil { req.Location = api.Location } - return cm.InvokeContract(ctx, ns, req, waitConfirm) + return cm.InvokeContract(ctx, req, waitConfirm) } -func (cm *contractManager) resolveInvokeContractRequest(ctx context.Context, ns string, req *core.ContractCallRequest) (err error) { +func (cm *contractManager) resolveInvokeContractRequest(ctx context.Context, req *core.ContractCallRequest) (err error) { if req.Method == nil { if req.MethodPath == "" || req.Interface == nil { return i18n.NewError(ctx, coremsgs.MsgContractMethodNotSet) } - req.Method, err = cm.database.GetFFIMethod(ctx, ns, req.Interface, req.MethodPath) + req.Method, err = cm.database.GetFFIMethod(ctx, cm.namespace, req.Interface, req.MethodPath) if err != nil || req.Method == nil { return i18n.NewError(ctx, coremsgs.MsgContractMethodResolveError, err) } @@ -299,23 +299,22 @@ func (cm *contractManager) addContractURLs(httpServerURL string, api *core.Contr } } -func (cm *contractManager) GetContractAPI(ctx context.Context, httpServerURL, ns, apiName string) (*core.ContractAPI, error) { - api, err := cm.database.GetContractAPIByName(ctx, ns, apiName) +func (cm *contractManager) GetContractAPI(ctx context.Context, httpServerURL, apiName string) (*core.ContractAPI, error) { + api, err := cm.database.GetContractAPIByName(ctx, cm.namespace, apiName) cm.addContractURLs(httpServerURL, api) return api, err } -func (cm *contractManager) GetContractAPIInterface(ctx context.Context, ns, apiName string) (*core.FFI, error) { - api, err := cm.GetContractAPI(ctx, "", ns, apiName) +func (cm *contractManager) GetContractAPIInterface(ctx context.Context, apiName string) (*core.FFI, error) { + api, err := cm.GetContractAPI(ctx, "", apiName) if err != nil || api == nil { return nil, err } return cm.GetFFIByIDWithChildren(ctx, api.Interface.ID) } -func (cm *contractManager) GetContractAPIs(ctx context.Context, httpServerURL, ns string, filter database.AndFilter) ([]*core.ContractAPI, *database.FilterResult, error) { - filter = cm.scopeNS(ns, filter) - apis, fr, err := cm.database.GetContractAPIs(ctx, ns, filter) +func (cm *contractManager) GetContractAPIs(ctx context.Context, httpServerURL string, filter database.AndFilter) ([]*core.ContractAPI, *database.FilterResult, error) { + apis, fr, err := cm.database.GetContractAPIs(ctx, cm.namespace, filter) for _, api := range apis { cm.addContractURLs(httpServerURL, api) } @@ -351,9 +350,9 @@ func (cm *contractManager) resolveFFIReference(ctx context.Context, ref *core.FF } } -func (cm *contractManager) BroadcastContractAPI(ctx context.Context, httpServerURL, ns string, api *core.ContractAPI, waitConfirm bool) (output *core.ContractAPI, err error) { +func (cm *contractManager) BroadcastContractAPI(ctx context.Context, httpServerURL string, api *core.ContractAPI, waitConfirm bool) (output *core.ContractAPI, err error) { api.ID = fftypes.NewUUID() - api.Namespace = ns + api.Namespace = cm.namespace if api.Location != nil { if api.Location, err = cm.blockchain.NormalizeContractLocation(ctx, api.Location); err != nil { @@ -378,7 +377,7 @@ func (cm *contractManager) BroadcastContractAPI(ctx context.Context, httpServerU return nil, err } - msg, err := cm.broadcast.BroadcastDefinitionAsNode(ctx, ns, api, core.SystemTagDefineContractAPI, waitConfirm) + msg, err := cm.broadcast.BroadcastDefinitionAsNode(ctx, cm.namespace, api, core.SystemTagDefineContractAPI, waitConfirm) if err != nil { return nil, err } diff --git a/internal/contracts/manager_test.go b/internal/contracts/manager_test.go index b59954255..b98a03488 100644 --- a/internal/contracts/manager_test.go +++ b/internal/contracts/manager_test.go @@ -1287,7 +1287,7 @@ func TestInvokeContract(t *testing.T) { return op.Type == core.OpTypeBlockchainInvoke && data.Request == req })).Return(nil, nil) - _, err := cm.InvokeContract(context.Background(), "ns1", req, false) + _, err := cm.InvokeContract(context.Background(), req, false) assert.NoError(t, err) @@ -1333,7 +1333,7 @@ func TestInvokeContractConfirm(t *testing.T) { }). Return(&core.Operation{}, nil) - _, err := cm.InvokeContract(context.Background(), "ns1", req, true) + _, err := cm.InvokeContract(context.Background(), req, true) assert.NoError(t, err) @@ -1373,7 +1373,7 @@ func TestInvokeContractFail(t *testing.T) { return op.Type == core.OpTypeBlockchainInvoke && data.Request == req })).Return(nil, fmt.Errorf("pop")) - _, err := cm.InvokeContract(context.Background(), "ns1", req, false) + _, err := cm.InvokeContract(context.Background(), req, false) assert.EqualError(t, err, "pop") @@ -1395,7 +1395,7 @@ func TestInvokeContractFailNormalizeSigningKey(t *testing.T) { mim.On("NormalizeSigningKey", mock.Anything, "", identity.KeyNormalizationBlockchainPlugin).Return("", fmt.Errorf("pop")) - _, err := cm.InvokeContract(context.Background(), "ns1", req, false) + _, err := cm.InvokeContract(context.Background(), req, false) assert.Regexp(t, "pop", err) } @@ -1414,7 +1414,7 @@ func TestInvokeContractFailResolve(t *testing.T) { mim.On("NormalizeSigningKey", mock.Anything, "", identity.KeyNormalizationBlockchainPlugin).Return("key-resolved", nil) mbi.On("InvokeContract", mock.Anything, mock.AnythingOfType("*fftypes.UUID"), "key-resolved", req.Location, req.Method, req.Input).Return(nil) - _, err := cm.InvokeContract(context.Background(), "ns1", req, false) + _, err := cm.InvokeContract(context.Background(), req, false) assert.Regexp(t, "FF10313", err) } @@ -1439,7 +1439,7 @@ func TestInvokeContractTXFail(t *testing.T) { mim.On("NormalizeSigningKey", mock.Anything, "", identity.KeyNormalizationBlockchainPlugin).Return("key-resolved", nil) mth.On("SubmitNewTransaction", mock.Anything, core.TransactionTypeContractInvoke).Return(nil, fmt.Errorf("pop")) - _, err := cm.InvokeContract(context.Background(), "ns1", req, false) + _, err := cm.InvokeContract(context.Background(), req, false) assert.EqualError(t, err, "pop") } @@ -1459,7 +1459,7 @@ func TestInvokeContractMethodNotFound(t *testing.T) { mim.On("NormalizeSigningKey", mock.Anything, "", identity.KeyNormalizationBlockchainPlugin).Return("key-resolved", nil) mdb.On("GetFFIMethod", mock.Anything, "ns1", req.Interface, req.MethodPath).Return(nil, fmt.Errorf("pop")) - _, err := cm.InvokeContract(context.Background(), "ns1", req, false) + _, err := cm.InvokeContract(context.Background(), req, false) assert.Regexp(t, "FF10315", err) } @@ -1495,7 +1495,7 @@ func TestInvokeContractMethodBadInput(t *testing.T) { } mim.On("NormalizeSigningKey", mock.Anything, "", identity.KeyNormalizationBlockchainPlugin).Return("key-resolved", nil) - _, err := cm.InvokeContract(context.Background(), "ns1", req, false) + _, err := cm.InvokeContract(context.Background(), req, false) assert.Regexp(t, "FF10304", err) } @@ -1525,7 +1525,7 @@ func TestQueryContract(t *testing.T) { })).Return(nil) mbi.On("QueryContract", mock.Anything, req.Location, req.Method, req.Input, req.Options).Return(struct{}{}, nil) - _, err := cm.InvokeContract(context.Background(), "ns1", req, false) + _, err := cm.InvokeContract(context.Background(), req, false) assert.NoError(t, err) } @@ -1554,7 +1554,7 @@ func TestCallContractInvalidType(t *testing.T) { })).Return(nil) assert.PanicsWithValue(t, "unknown call type: ", func() { - cm.InvokeContract(context.Background(), "ns1", req, false) + cm.InvokeContract(context.Background(), req, false) }) } @@ -1793,7 +1793,7 @@ func TestInvokeContractAPI(t *testing.T) { return op.Type == core.OpTypeBlockchainInvoke && data.Request == req })).Return(nil, nil) - _, err := cm.InvokeContractAPI(context.Background(), "ns1", "banana", "peel", req, false) + _, err := cm.InvokeContractAPI(context.Background(), "banana", "peel", req, false) assert.NoError(t, err) @@ -1820,7 +1820,7 @@ func TestInvokeContractAPIFailContractLookup(t *testing.T) { mim.On("NormalizeSigningKey", mock.Anything, "", identity.KeyNormalizationBlockchainPlugin).Return("key-resolved", nil) mdb.On("GetContractAPIByName", mock.Anything, "ns1", "banana").Return(nil, fmt.Errorf("pop")) - _, err := cm.InvokeContractAPI(context.Background(), "ns1", "banana", "peel", req, false) + _, err := cm.InvokeContractAPI(context.Background(), "banana", "peel", req, false) assert.Regexp(t, "pop", err) } @@ -1841,7 +1841,7 @@ func TestInvokeContractAPIContractNotFound(t *testing.T) { mim.On("NormalizeSigningKey", mock.Anything, "", identity.KeyNormalizationBlockchainPlugin).Return("key-resolved", nil) mdb.On("GetContractAPIByName", mock.Anything, "ns1", "banana").Return(nil, nil) - _, err := cm.InvokeContractAPI(context.Background(), "ns1", "banana", "peel", req, false) + _, err := cm.InvokeContractAPI(context.Background(), "banana", "peel", req, false) assert.Regexp(t, "FF10109", err) } @@ -1856,7 +1856,7 @@ func TestGetContractAPI(t *testing.T) { } mdb.On("GetContractAPIByName", mock.Anything, "ns1", "banana").Return(api, nil) - result, err := cm.GetContractAPI(context.Background(), "http://localhost/api", "ns1", "banana") + result, err := cm.GetContractAPI(context.Background(), "http://localhost/api", "banana") assert.NoError(t, err) assert.Equal(t, "http://localhost/api/namespaces/ns1/apis/banana/api/swagger.json", result.URLs.OpenAPI) @@ -1876,7 +1876,7 @@ func TestGetContractAPIs(t *testing.T) { filter := database.ContractAPIQueryFactory.NewFilter(context.Background()).And() mdb.On("GetContractAPIs", mock.Anything, "ns1", filter).Return(apis, &database.FilterResult{}, nil) - results, _, err := cm.GetContractAPIs(context.Background(), "http://localhost/api", "ns1", filter) + results, _, err := cm.GetContractAPIs(context.Background(), "http://localhost/api", filter) assert.NoError(t, err) assert.Equal(t, 1, len(results)) @@ -1908,7 +1908,7 @@ func TestGetContractAPIInterface(t *testing.T) { return ev.Name == "event1" })).Return("event1Sig") - result, err := cm.GetContractAPIInterface(context.Background(), "ns1", "banana") + result, err := cm.GetContractAPIInterface(context.Background(), "banana") assert.NoError(t, err) assert.NotNil(t, result) @@ -1923,7 +1923,7 @@ func TestGetContractAPIInterfaceFail(t *testing.T) { mdb.On("GetContractAPIByName", mock.Anything, "ns1", "banana").Return(nil, fmt.Errorf("pop")) - _, err := cm.GetContractAPIInterface(context.Background(), "ns1", "banana") + _, err := cm.GetContractAPIInterface(context.Background(), "banana") assert.EqualError(t, err, "pop") @@ -1956,7 +1956,7 @@ func TestBroadcastContractAPI(t *testing.T) { mdb.On("GetFFIByID", mock.Anything, "ns1", api.Interface.ID).Return(&core.FFI{}, nil) mbm.On("BroadcastDefinitionAsNode", mock.Anything, "ns1", mock.AnythingOfType("*core.ContractAPI"), core.SystemTagDefineContractAPI, false).Return(msg, nil) - api, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) + api, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", api, false) assert.NoError(t, err) assert.NotNil(t, api) @@ -1984,7 +1984,7 @@ func TestBroadcastContractAPIBadLocation(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), api.Location).Return(nil, fmt.Errorf("pop")) - _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) + _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", api, false) assert.EqualError(t, err, "pop") @@ -2027,7 +2027,7 @@ func TestBroadcastContractAPIExisting(t *testing.T) { mdb.On("GetFFIByID", mock.Anything, "ns1", api.Interface.ID).Return(&core.FFI{}, nil) mbm.On("BroadcastDefinitionAsNode", mock.Anything, "ns1", mock.AnythingOfType("*core.ContractAPI"), core.SystemTagDefineContractAPI, false).Return(msg, nil) - _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) + _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", api, false) assert.NoError(t, err) @@ -2064,7 +2064,7 @@ func TestBroadcastContractAPICannotChangeLocation(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), api.Location).Return(api.Location, nil) mdb.On("GetContractAPIByName", mock.Anything, api.Namespace, api.Name).Return(existing, nil) - _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) + _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", api, false) assert.Regexp(t, "FF10316", err) @@ -2100,7 +2100,7 @@ func TestBroadcastContractAPIInterfaceName(t *testing.T) { mdb.On("GetFFI", mock.Anything, "ns1", "my-ffi", "1").Return(&core.FFI{ID: interfaceID}, nil) mbm.On("BroadcastDefinitionAsNode", mock.Anything, "ns1", mock.AnythingOfType("*core.ContractAPI"), core.SystemTagDefineContractAPI, false).Return(msg, nil) - _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) + _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", api, false) assert.NoError(t, err) assert.Equal(t, *interfaceID, *api.Interface.ID) @@ -2131,7 +2131,7 @@ func TestBroadcastContractAPIFail(t *testing.T) { mdb.On("GetFFIByID", mock.Anything, "ns1", api.Interface.ID).Return(&core.FFI{}, nil) mbm.On("BroadcastDefinitionAsNode", mock.Anything, "ns1", mock.AnythingOfType("*core.ContractAPI"), core.SystemTagDefineContractAPI, false).Return(nil, fmt.Errorf("pop")) - _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) + _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", api, false) assert.Regexp(t, "pop", err) @@ -2155,7 +2155,7 @@ func TestBroadcastContractAPINoInterface(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), api.Location).Return(api.Location, nil) mdb.On("GetContractAPIByName", mock.Anything, api.Namespace, api.Name).Return(nil, nil) - _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) + _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", api, false) assert.Regexp(t, "FF10303", err) @@ -2182,7 +2182,7 @@ func TestBroadcastContractAPIInterfaceIDFail(t *testing.T) { mdb.On("GetContractAPIByName", mock.Anything, api.Namespace, api.Name).Return(nil, nil) mdb.On("GetFFIByID", mock.Anything, "ns1", api.Interface.ID).Return(nil, fmt.Errorf("pop")) - _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) + _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", api, false) assert.EqualError(t, err, "pop") @@ -2209,7 +2209,7 @@ func TestBroadcastContractAPIInterfaceIDNotFound(t *testing.T) { mdb.On("GetContractAPIByName", mock.Anything, api.Namespace, api.Name).Return(nil, nil) mdb.On("GetFFIByID", mock.Anything, "ns1", api.Interface.ID).Return(nil, nil) - _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) + _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", api, false) assert.Regexp(t, "FF10303.*"+api.Interface.ID.String(), err) @@ -2237,7 +2237,7 @@ func TestBroadcastContractAPIInterfaceNameFail(t *testing.T) { mdb.On("GetContractAPIByName", mock.Anything, api.Namespace, api.Name).Return(nil, nil) mdb.On("GetFFI", mock.Anything, "ns1", "my-ffi", "1").Return(nil, fmt.Errorf("pop")) - _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) + _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", api, false) assert.EqualError(t, err, "pop") @@ -2265,7 +2265,7 @@ func TestBroadcastContractAPIInterfaceNameNotFound(t *testing.T) { mdb.On("GetContractAPIByName", mock.Anything, api.Namespace, api.Name).Return(nil, nil) mdb.On("GetFFI", mock.Anything, "ns1", "my-ffi", "1").Return(nil, nil) - _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) + _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", api, false) assert.Regexp(t, "FF10303.*my-ffi", err) @@ -2291,7 +2291,7 @@ func TestBroadcastContractAPIInterfaceNoVersion(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), api.Location).Return(api.Location, nil) mdb.On("GetContractAPIByName", mock.Anything, api.Namespace, api.Name).Return(nil, nil) - _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", "ns1", api, false) + _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", api, false) assert.Regexp(t, "FF10303.*my-ffi", err) diff --git a/internal/database/sqlcommon/contractapis_sql.go b/internal/database/sqlcommon/contractapis_sql.go index 4e5b30752..51ce704c0 100644 --- a/internal/database/sqlcommon/contractapis_sql.go +++ b/internal/database/sqlcommon/contractapis_sql.go @@ -154,9 +154,10 @@ func (s *SQLCommon) getContractAPIPred(ctx context.Context, desc string, pred in return api, nil } -func (s *SQLCommon) GetContractAPIs(ctx context.Context, ns string, filter database.AndFilter) (contractAPIs []*core.ContractAPI, res *database.FilterResult, err error) { +func (s *SQLCommon) GetContractAPIs(ctx context.Context, namespace string, filter database.AndFilter) (contractAPIs []*core.ContractAPI, res *database.FilterResult, err error) { - query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(contractAPIsColumns...).From(contractapisTable).Where(sq.Eq{"namespace": ns}), filter, contractAPIsFilterFieldMap, []interface{}{"sequence"}) + query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(contractAPIsColumns...).From(contractapisTable), + filter, contractAPIsFilterFieldMap, []interface{}{"sequence"}, sq.Eq{"namespace": namespace}) if err != nil { return nil, nil, err } @@ -180,10 +181,10 @@ func (s *SQLCommon) GetContractAPIs(ctx context.Context, ns string, filter datab } -func (s *SQLCommon) GetContractAPIByID(ctx context.Context, id *fftypes.UUID) (*core.ContractAPI, error) { - return s.getContractAPIPred(ctx, id.String(), sq.Eq{"id": id}) +func (s *SQLCommon) GetContractAPIByID(ctx context.Context, namespace string, id *fftypes.UUID) (*core.ContractAPI, error) { + return s.getContractAPIPred(ctx, id.String(), sq.Eq{"id": id, "namespace": namespace}) } -func (s *SQLCommon) GetContractAPIByName(ctx context.Context, ns, name string) (*core.ContractAPI, error) { - return s.getContractAPIPred(ctx, ns+":"+name, sq.And{sq.Eq{"namespace": ns}, sq.Eq{"name": name}}) +func (s *SQLCommon) GetContractAPIByName(ctx context.Context, namespace, name string) (*core.ContractAPI, error) { + return s.getContractAPIPred(ctx, namespace+":"+name, sq.Eq{"namespace": namespace, "name": name}) } diff --git a/internal/database/sqlcommon/contractapis_sql_test.go b/internal/database/sqlcommon/contractapis_sql_test.go index 154fb681a..e977fc596 100644 --- a/internal/database/sqlcommon/contractapis_sql_test.go +++ b/internal/database/sqlcommon/contractapis_sql_test.go @@ -61,7 +61,7 @@ func TestContractAPIE2EWithDB(t *testing.T) { assert.NoError(t, err) // Check we get the exact same ContractAPI back - dataRead, err := s.GetContractAPIByID(ctx, apiID) + dataRead, err := s.GetContractAPIByID(ctx, "ns1", apiID) assert.NoError(t, err) assert.NotNil(t, dataRead) assert.Equal(t, *apiID, *dataRead.ID) @@ -72,7 +72,7 @@ func TestContractAPIE2EWithDB(t *testing.T) { assert.NoError(t, err) // Check we get the exact same ContractAPI back - dataRead, err = s.GetContractAPIByID(ctx, apiID) + dataRead, err = s.GetContractAPIByID(ctx, "ns1", apiID) assert.NoError(t, err) assert.NotNil(t, dataRead) assert.Equal(t, *apiID, *dataRead.ID) @@ -145,7 +145,7 @@ func TestContractAPIDBFailScan(t *testing.T) { s, mock := newMockProvider().init() apiID := fftypes.NewUUID() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("only one")) - _, err := s.GetContractAPIByID(context.Background(), apiID) + _, err := s.GetContractAPIByID(context.Background(), "ns1", apiID) assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -154,7 +154,7 @@ func TestContractAPIDBSelectFail(t *testing.T) { s, mock := newMockProvider().init() apiID := fftypes.NewUUID() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) - _, err := s.GetContractAPIByID(context.Background(), apiID) + _, err := s.GetContractAPIByID(context.Background(), "ns1", apiID) assert.Regexp(t, "pop", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -163,7 +163,7 @@ func TestContractAPIDBNoRows(t *testing.T) { s, mock := newMockProvider().init() apiID := fftypes.NewUUID() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"id", "interface_id", "ledger", "location", "name", "namespace", "message_id"})) - _, err := s.GetContractAPIByID(context.Background(), apiID) + _, err := s.GetContractAPIByID(context.Background(), "ns1", apiID) assert.NoError(t, err) assert.NoError(t, mock.ExpectationsWereMet()) } diff --git a/internal/txcommon/event_enrich.go b/internal/txcommon/event_enrich.go index 39ce3b351..bec9bf46d 100644 --- a/internal/txcommon/event_enrich.go +++ b/internal/txcommon/event_enrich.go @@ -47,7 +47,7 @@ func (t *transactionHelper) EnrichEvent(ctx context.Context, event *core.Event) } e.BlockchainEvent = be case core.EventTypeContractAPIConfirmed: - contractAPI, err := t.database.GetContractAPIByID(ctx, event.Reference) + contractAPI, err := t.database.GetContractAPIByID(ctx, t.namespace, event.Reference) if err != nil { return nil, err } diff --git a/internal/txcommon/event_enrich_test.go b/internal/txcommon/event_enrich_test.go index ff84389d2..0af31defd 100644 --- a/internal/txcommon/event_enrich_test.go +++ b/internal/txcommon/event_enrich_test.go @@ -213,7 +213,7 @@ func TestEnrichContractAPISubmitted(t *testing.T) { ev1 := fftypes.NewUUID() // Setup enrichment - mdi.On("GetContractAPIByID", mock.Anything, ref1).Return(&core.ContractAPI{ + mdi.On("GetContractAPIByID", mock.Anything, "ns1", ref1).Return(&core.ContractAPI{ ID: ref1, }, nil) @@ -239,7 +239,7 @@ func TestEnrichContractAPItFail(t *testing.T) { ev1 := fftypes.NewUUID() // Setup enrichment - mdi.On("GetContractAPIByID", mock.Anything, ref1).Return(nil, fmt.Errorf("pop")) + mdi.On("GetContractAPIByID", mock.Anything, "ns1", ref1).Return(nil, fmt.Errorf("pop")) event := &core.Event{ ID: ev1, diff --git a/mocks/contractmocks/manager.go b/mocks/contractmocks/manager.go index 18223121a..d3bb9803f 100644 --- a/mocks/contractmocks/manager.go +++ b/mocks/contractmocks/manager.go @@ -65,13 +65,13 @@ func (_m *Manager) AddContractListener(ctx context.Context, ns string, listener return r0, r1 } -// BroadcastContractAPI provides a mock function with given fields: ctx, httpServerURL, ns, api, waitConfirm -func (_m *Manager) BroadcastContractAPI(ctx context.Context, httpServerURL string, ns string, api *core.ContractAPI, waitConfirm bool) (*core.ContractAPI, error) { - ret := _m.Called(ctx, httpServerURL, ns, api, waitConfirm) +// BroadcastContractAPI provides a mock function with given fields: ctx, httpServerURL, api, waitConfirm +func (_m *Manager) BroadcastContractAPI(ctx context.Context, httpServerURL string, api *core.ContractAPI, waitConfirm bool) (*core.ContractAPI, error) { + ret := _m.Called(ctx, httpServerURL, api, waitConfirm) var r0 *core.ContractAPI - if rf, ok := ret.Get(0).(func(context.Context, string, string, *core.ContractAPI, bool) *core.ContractAPI); ok { - r0 = rf(ctx, httpServerURL, ns, api, waitConfirm) + if rf, ok := ret.Get(0).(func(context.Context, string, *core.ContractAPI, bool) *core.ContractAPI); ok { + r0 = rf(ctx, httpServerURL, api, waitConfirm) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.ContractAPI) @@ -79,8 +79,8 @@ func (_m *Manager) BroadcastContractAPI(ctx context.Context, httpServerURL strin } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, string, *core.ContractAPI, bool) error); ok { - r1 = rf(ctx, httpServerURL, ns, api, waitConfirm) + if rf, ok := ret.Get(1).(func(context.Context, string, *core.ContractAPI, bool) error); ok { + r1 = rf(ctx, httpServerURL, api, waitConfirm) } else { r1 = ret.Error(1) } @@ -148,13 +148,13 @@ func (_m *Manager) GenerateFFI(ctx context.Context, ns string, generationRequest return r0, r1 } -// GetContractAPI provides a mock function with given fields: ctx, httpServerURL, ns, apiName -func (_m *Manager) GetContractAPI(ctx context.Context, httpServerURL string, ns string, apiName string) (*core.ContractAPI, error) { - ret := _m.Called(ctx, httpServerURL, ns, apiName) +// GetContractAPI provides a mock function with given fields: ctx, httpServerURL, apiName +func (_m *Manager) GetContractAPI(ctx context.Context, httpServerURL string, apiName string) (*core.ContractAPI, error) { + ret := _m.Called(ctx, httpServerURL, apiName) var r0 *core.ContractAPI - if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *core.ContractAPI); ok { - r0 = rf(ctx, httpServerURL, ns, apiName) + if rf, ok := ret.Get(0).(func(context.Context, string, string) *core.ContractAPI); ok { + r0 = rf(ctx, httpServerURL, apiName) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.ContractAPI) @@ -162,8 +162,8 @@ func (_m *Manager) GetContractAPI(ctx context.Context, httpServerURL string, ns } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { - r1 = rf(ctx, httpServerURL, ns, apiName) + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, httpServerURL, apiName) } else { r1 = ret.Error(1) } @@ -171,13 +171,13 @@ func (_m *Manager) GetContractAPI(ctx context.Context, httpServerURL string, ns return r0, r1 } -// GetContractAPIInterface provides a mock function with given fields: ctx, ns, apiName -func (_m *Manager) GetContractAPIInterface(ctx context.Context, ns string, apiName string) (*core.FFI, error) { - ret := _m.Called(ctx, ns, apiName) +// GetContractAPIInterface provides a mock function with given fields: ctx, apiName +func (_m *Manager) GetContractAPIInterface(ctx context.Context, apiName string) (*core.FFI, error) { + ret := _m.Called(ctx, apiName) var r0 *core.FFI - if rf, ok := ret.Get(0).(func(context.Context, string, string) *core.FFI); ok { - r0 = rf(ctx, ns, apiName) + if rf, ok := ret.Get(0).(func(context.Context, string) *core.FFI); ok { + r0 = rf(ctx, apiName) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.FFI) @@ -185,8 +185,8 @@ func (_m *Manager) GetContractAPIInterface(ctx context.Context, ns string, apiNa } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, ns, apiName) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, apiName) } else { r1 = ret.Error(1) } @@ -226,13 +226,13 @@ func (_m *Manager) GetContractAPIListeners(ctx context.Context, ns string, apiNa return r0, r1, r2 } -// GetContractAPIs provides a mock function with given fields: ctx, httpServerURL, ns, filter -func (_m *Manager) GetContractAPIs(ctx context.Context, httpServerURL string, ns string, filter database.AndFilter) ([]*core.ContractAPI, *database.FilterResult, error) { - ret := _m.Called(ctx, httpServerURL, ns, filter) +// GetContractAPIs provides a mock function with given fields: ctx, httpServerURL, filter +func (_m *Manager) GetContractAPIs(ctx context.Context, httpServerURL string, filter database.AndFilter) ([]*core.ContractAPI, *database.FilterResult, error) { + ret := _m.Called(ctx, httpServerURL, filter) var r0 []*core.ContractAPI - if rf, ok := ret.Get(0).(func(context.Context, string, string, database.AndFilter) []*core.ContractAPI); ok { - r0 = rf(ctx, httpServerURL, ns, filter) + if rf, ok := ret.Get(0).(func(context.Context, string, database.AndFilter) []*core.ContractAPI); ok { + r0 = rf(ctx, httpServerURL, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.ContractAPI) @@ -240,8 +240,8 @@ func (_m *Manager) GetContractAPIs(ctx context.Context, httpServerURL string, ns } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, string, string, database.AndFilter) *database.FilterResult); ok { - r1 = rf(ctx, httpServerURL, ns, filter) + if rf, ok := ret.Get(1).(func(context.Context, string, database.AndFilter) *database.FilterResult); ok { + r1 = rf(ctx, httpServerURL, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -249,8 +249,8 @@ func (_m *Manager) GetContractAPIs(ctx context.Context, httpServerURL string, ns } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, string, string, database.AndFilter) error); ok { - r2 = rf(ctx, httpServerURL, ns, filter) + if rf, ok := ret.Get(2).(func(context.Context, string, database.AndFilter) error); ok { + r2 = rf(ctx, httpServerURL, filter) } else { r2 = ret.Error(2) } @@ -437,13 +437,13 @@ func (_m *Manager) GetFFIs(ctx context.Context, filter database.AndFilter) ([]*c return r0, r1, r2 } -// InvokeContract provides a mock function with given fields: ctx, ns, req, waitConfirm -func (_m *Manager) InvokeContract(ctx context.Context, ns string, req *core.ContractCallRequest, waitConfirm bool) (interface{}, error) { - ret := _m.Called(ctx, ns, req, waitConfirm) +// InvokeContract provides a mock function with given fields: ctx, req, waitConfirm +func (_m *Manager) InvokeContract(ctx context.Context, req *core.ContractCallRequest, waitConfirm bool) (interface{}, error) { + ret := _m.Called(ctx, req, waitConfirm) var r0 interface{} - if rf, ok := ret.Get(0).(func(context.Context, string, *core.ContractCallRequest, bool) interface{}); ok { - r0 = rf(ctx, ns, req, waitConfirm) + if rf, ok := ret.Get(0).(func(context.Context, *core.ContractCallRequest, bool) interface{}); ok { + r0 = rf(ctx, req, waitConfirm) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(interface{}) @@ -451,8 +451,8 @@ func (_m *Manager) InvokeContract(ctx context.Context, ns string, req *core.Cont } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, *core.ContractCallRequest, bool) error); ok { - r1 = rf(ctx, ns, req, waitConfirm) + if rf, ok := ret.Get(1).(func(context.Context, *core.ContractCallRequest, bool) error); ok { + r1 = rf(ctx, req, waitConfirm) } else { r1 = ret.Error(1) } @@ -460,13 +460,13 @@ func (_m *Manager) InvokeContract(ctx context.Context, ns string, req *core.Cont return r0, r1 } -// InvokeContractAPI provides a mock function with given fields: ctx, ns, apiName, methodPath, req, waitConfirm -func (_m *Manager) InvokeContractAPI(ctx context.Context, ns string, apiName string, methodPath string, req *core.ContractCallRequest, waitConfirm bool) (interface{}, error) { - ret := _m.Called(ctx, ns, apiName, methodPath, req, waitConfirm) +// InvokeContractAPI provides a mock function with given fields: ctx, apiName, methodPath, req, waitConfirm +func (_m *Manager) InvokeContractAPI(ctx context.Context, apiName string, methodPath string, req *core.ContractCallRequest, waitConfirm bool) (interface{}, error) { + ret := _m.Called(ctx, apiName, methodPath, req, waitConfirm) var r0 interface{} - if rf, ok := ret.Get(0).(func(context.Context, string, string, string, *core.ContractCallRequest, bool) interface{}); ok { - r0 = rf(ctx, ns, apiName, methodPath, req, waitConfirm) + if rf, ok := ret.Get(0).(func(context.Context, string, string, *core.ContractCallRequest, bool) interface{}); ok { + r0 = rf(ctx, apiName, methodPath, req, waitConfirm) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(interface{}) @@ -474,8 +474,8 @@ func (_m *Manager) InvokeContractAPI(ctx context.Context, ns string, apiName str } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, string, string, *core.ContractCallRequest, bool) error); ok { - r1 = rf(ctx, ns, apiName, methodPath, req, waitConfirm) + if rf, ok := ret.Get(1).(func(context.Context, string, string, *core.ContractCallRequest, bool) error); ok { + r1 = rf(ctx, apiName, methodPath, req, waitConfirm) } else { r1 = ret.Error(1) } diff --git a/mocks/databasemocks/plugin.go b/mocks/databasemocks/plugin.go index 062ec930a..cbc06e45e 100644 --- a/mocks/databasemocks/plugin.go +++ b/mocks/databasemocks/plugin.go @@ -392,13 +392,13 @@ func (_m *Plugin) GetChartHistogram(ctx context.Context, namespace string, inter return r0, r1 } -// GetContractAPIByID provides a mock function with given fields: ctx, id -func (_m *Plugin) GetContractAPIByID(ctx context.Context, id *fftypes.UUID) (*core.ContractAPI, error) { - ret := _m.Called(ctx, id) +// GetContractAPIByID provides a mock function with given fields: ctx, namespace, id +func (_m *Plugin) GetContractAPIByID(ctx context.Context, namespace string, id *fftypes.UUID) (*core.ContractAPI, error) { + ret := _m.Called(ctx, namespace, id) var r0 *core.ContractAPI - if rf, ok := ret.Get(0).(func(context.Context, *fftypes.UUID) *core.ContractAPI); ok { - r0 = rf(ctx, id) + if rf, ok := ret.Get(0).(func(context.Context, string, *fftypes.UUID) *core.ContractAPI); ok { + r0 = rf(ctx, namespace, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.ContractAPI) @@ -406,8 +406,8 @@ func (_m *Plugin) GetContractAPIByID(ctx context.Context, id *fftypes.UUID) (*co } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *fftypes.UUID) error); ok { - r1 = rf(ctx, id) + if rf, ok := ret.Get(1).(func(context.Context, string, *fftypes.UUID) error); ok { + r1 = rf(ctx, namespace, id) } else { r1 = ret.Error(1) } diff --git a/pkg/database/plugin.go b/pkg/database/plugin.go index 6709eafad..789e2f63a 100644 --- a/pkg/database/plugin.go +++ b/pkg/database/plugin.go @@ -458,9 +458,16 @@ type iFFIEventCollection interface { } type iContractAPICollection interface { + // UpsertFFIEvent - Upsert a contract API UpsertContractAPI(ctx context.Context, cd *core.ContractAPI) error + + // GetContractAPIs - Get contract APIs GetContractAPIs(ctx context.Context, namespace string, filter AndFilter) ([]*core.ContractAPI, *FilterResult, error) - GetContractAPIByID(ctx context.Context, id *fftypes.UUID) (*core.ContractAPI, error) + + // GetContractAPIByID - Get a contract API by ID + GetContractAPIByID(ctx context.Context, namespace string, id *fftypes.UUID) (*core.ContractAPI, error) + + // GetContractAPIByName - Get a contract API by name GetContractAPIByName(ctx context.Context, namespace, name string) (*core.ContractAPI, error) } From 8ab22b3af6061557cca6baf219bc07e801281830 Mon Sep 17 00:00:00 2001 From: Andrew Richardson Date: Wed, 22 Jun 2022 14:43:42 -0400 Subject: [PATCH 8/9] Add namespace to contract listener database queries Signed-off-by: Andrew Richardson --- .../route_delete_contract_listener.go | 2 +- .../route_delete_contract_listener_test.go | 2 +- .../route_get_contract_api_listeners.go | 2 +- .../route_get_contract_api_listeners_test.go | 2 +- ...ute_get_contract_listener_by_name_or_id.go | 2 +- ...et_contract_listener_by_name_or_id_test.go | 2 +- .../route_get_contract_listener_test.go | 2 +- .../apiserver/route_get_contract_listeners.go | 2 +- .../route_post_contract_api_listeners.go | 2 +- .../route_post_contract_api_listeners_test.go | 2 +- .../route_post_contract_interface_generate.go | 2 +- ...e_post_contract_interface_generate_test.go | 2 +- .../route_post_new_contract_listener.go | 2 +- .../route_post_new_contract_listener_test.go | 2 +- internal/contracts/manager.go | 72 ++++----- internal/contracts/manager_test.go | 140 +++++++++--------- .../sqlcommon/contractlisteners_sql.go | 20 +-- .../sqlcommon/contractlisteners_sql_test.go | 30 ++-- internal/events/batch_pin_complete.go | 2 +- internal/events/blockchain_event.go | 41 ++--- internal/events/blockchain_event_test.go | 51 +------ internal/events/network_action.go | 2 +- internal/events/token_pool_created.go | 2 +- internal/events/tokens_approved.go | 2 +- internal/events/tokens_transferred.go | 2 +- mocks/contractmocks/manager.go | 102 ++++++------- mocks/databasemocks/plugin.go | 56 +++---- pkg/database/plugin.go | 20 +-- 28 files changed, 252 insertions(+), 318 deletions(-) diff --git a/internal/apiserver/route_delete_contract_listener.go b/internal/apiserver/route_delete_contract_listener.go index 1befaf364..38a9a4677 100644 --- a/internal/apiserver/route_delete_contract_listener.go +++ b/internal/apiserver/route_delete_contract_listener.go @@ -37,7 +37,7 @@ var deleteContractListener = &ffapi.Route{ JSONOutputCodes: []int{http.StatusNoContent}, // Sync operation, no output Extensions: &coreExtensions{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - err = cr.or.Contracts().DeleteContractListenerByNameOrID(cr.ctx, extractNamespace(r.PP), r.PP["nameOrId"]) + err = cr.or.Contracts().DeleteContractListenerByNameOrID(cr.ctx, r.PP["nameOrId"]) return nil, err }, }, diff --git a/internal/apiserver/route_delete_contract_listener_test.go b/internal/apiserver/route_delete_contract_listener_test.go index 3872049b7..4ed6362a2 100644 --- a/internal/apiserver/route_delete_contract_listener_test.go +++ b/internal/apiserver/route_delete_contract_listener_test.go @@ -35,7 +35,7 @@ func TestDeleteContractListenerByID(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("DeleteContractListenerByNameOrID", mock.Anything, "mynamespace", id.String()). + mcm.On("DeleteContractListenerByNameOrID", mock.Anything, id.String()). Return(nil, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_get_contract_api_listeners.go b/internal/apiserver/route_get_contract_api_listeners.go index c0283b090..818716dda 100644 --- a/internal/apiserver/route_get_contract_api_listeners.go +++ b/internal/apiserver/route_get_contract_api_listeners.go @@ -41,7 +41,7 @@ var getContractAPIListeners = &ffapi.Route{ Extensions: &coreExtensions{ FilterFactory: database.ContractListenerQueryFactory, CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - return filterResult(cr.or.Contracts().GetContractAPIListeners(cr.ctx, extractNamespace(r.PP), r.PP["apiName"], r.PP["eventPath"], cr.filter)) + return filterResult(cr.or.Contracts().GetContractAPIListeners(cr.ctx, r.PP["apiName"], r.PP["eventPath"], cr.filter)) }, }, } diff --git a/internal/apiserver/route_get_contract_api_listeners_test.go b/internal/apiserver/route_get_contract_api_listeners_test.go index aadf4869f..133cb0960 100644 --- a/internal/apiserver/route_get_contract_api_listeners_test.go +++ b/internal/apiserver/route_get_contract_api_listeners_test.go @@ -39,7 +39,7 @@ func TestGetContractAPIListeners(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("GetContractAPIListeners", mock.Anything, "ns1", "banana", "peeled", mock.Anything). + mcm.On("GetContractAPIListeners", mock.Anything, "banana", "peeled", mock.Anything). Return([]*core.ContractListener{}, nil, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_get_contract_listener_by_name_or_id.go b/internal/apiserver/route_get_contract_listener_by_name_or_id.go index 732ab09a7..1704e4cbb 100644 --- a/internal/apiserver/route_get_contract_listener_by_name_or_id.go +++ b/internal/apiserver/route_get_contract_listener_by_name_or_id.go @@ -38,7 +38,7 @@ var getContractListenerByNameOrID = &ffapi.Route{ JSONOutputCodes: []int{http.StatusOK}, Extensions: &coreExtensions{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - return cr.or.Contracts().GetContractListenerByNameOrID(cr.ctx, extractNamespace(r.PP), r.PP["nameOrId"]) + return cr.or.Contracts().GetContractListenerByNameOrID(cr.ctx, r.PP["nameOrId"]) }, }, } diff --git a/internal/apiserver/route_get_contract_listener_by_name_or_id_test.go b/internal/apiserver/route_get_contract_listener_by_name_or_id_test.go index d8a306fa8..360ab2bfd 100644 --- a/internal/apiserver/route_get_contract_listener_by_name_or_id_test.go +++ b/internal/apiserver/route_get_contract_listener_by_name_or_id_test.go @@ -36,7 +36,7 @@ func TestGetContractListenerByNameOrID(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("GetContractListenerByNameOrID", mock.Anything, "mynamespace", id.String()). + mcm.On("GetContractListenerByNameOrID", mock.Anything, id.String()). Return(&core.ContractListener{}, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_get_contract_listener_test.go b/internal/apiserver/route_get_contract_listener_test.go index a7c448a5c..23cf97b7b 100644 --- a/internal/apiserver/route_get_contract_listener_test.go +++ b/internal/apiserver/route_get_contract_listener_test.go @@ -34,7 +34,7 @@ func TestGetContractListener(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("GetContractListeners", mock.Anything, "mynamespace", mock.Anything). + mcm.On("GetContractListeners", mock.Anything, mock.Anything). Return([]*core.ContractListener{}, nil, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_get_contract_listeners.go b/internal/apiserver/route_get_contract_listeners.go index 03d32c576..baab05240 100644 --- a/internal/apiserver/route_get_contract_listeners.go +++ b/internal/apiserver/route_get_contract_listeners.go @@ -38,7 +38,7 @@ var getContractListeners = &ffapi.Route{ Extensions: &coreExtensions{ FilterFactory: database.ContractListenerQueryFactory, CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - return filterResult(cr.or.Contracts().GetContractListeners(cr.ctx, extractNamespace(r.PP), cr.filter)) + return filterResult(cr.or.Contracts().GetContractListeners(cr.ctx, cr.filter)) }, }, } diff --git a/internal/apiserver/route_post_contract_api_listeners.go b/internal/apiserver/route_post_contract_api_listeners.go index 212e32975..c2130fe87 100644 --- a/internal/apiserver/route_post_contract_api_listeners.go +++ b/internal/apiserver/route_post_contract_api_listeners.go @@ -39,7 +39,7 @@ var postContractAPIListeners = &ffapi.Route{ JSONOutputCodes: []int{http.StatusOK}, Extensions: &coreExtensions{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - return cr.or.Contracts().AddContractAPIListener(cr.ctx, extractNamespace(r.PP), r.PP["apiName"], r.PP["eventPath"], r.Input.(*core.ContractListener)) + return cr.or.Contracts().AddContractAPIListener(cr.ctx, r.PP["apiName"], r.PP["eventPath"], r.Input.(*core.ContractListener)) }, }, } diff --git a/internal/apiserver/route_post_contract_api_listeners_test.go b/internal/apiserver/route_post_contract_api_listeners_test.go index 723459520..ff70fcb53 100644 --- a/internal/apiserver/route_post_contract_api_listeners_test.go +++ b/internal/apiserver/route_post_contract_api_listeners_test.go @@ -39,7 +39,7 @@ func TestPostContractAPIListen(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("AddContractAPIListener", mock.Anything, "ns1", "banana", "peeled", mock.AnythingOfType("*core.ContractListener")).Return(&core.ContractListener{}, nil) + mcm.On("AddContractAPIListener", mock.Anything, "banana", "peeled", mock.AnythingOfType("*core.ContractListener")).Return(&core.ContractListener{}, nil) r.ServeHTTP(res, req) assert.Equal(t, 200, res.Result().StatusCode) diff --git a/internal/apiserver/route_post_contract_interface_generate.go b/internal/apiserver/route_post_contract_interface_generate.go index 3d3d1805a..a2577dc8f 100644 --- a/internal/apiserver/route_post_contract_interface_generate.go +++ b/internal/apiserver/route_post_contract_interface_generate.go @@ -37,7 +37,7 @@ var postContractInterfaceGenerate = &ffapi.Route{ Extensions: &coreExtensions{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { generationRequest := r.Input.(*core.FFIGenerationRequest) - return cr.or.Contracts().GenerateFFI(cr.ctx, extractNamespace(r.PP), generationRequest) + return cr.or.Contracts().GenerateFFI(cr.ctx, generationRequest) }, }, } diff --git a/internal/apiserver/route_post_contract_interface_generate_test.go b/internal/apiserver/route_post_contract_interface_generate_test.go index 64073cb60..e0329ddec 100644 --- a/internal/apiserver/route_post_contract_interface_generate_test.go +++ b/internal/apiserver/route_post_contract_interface_generate_test.go @@ -39,7 +39,7 @@ func TestPostContractInterfaceGenerate(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("GenerateFFI", mock.Anything, "ns1", mock.Anything). + mcm.On("GenerateFFI", mock.Anything, mock.Anything). Return(&core.FFI{}, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_post_new_contract_listener.go b/internal/apiserver/route_post_new_contract_listener.go index 1e1f709d0..2afcfba9c 100644 --- a/internal/apiserver/route_post_new_contract_listener.go +++ b/internal/apiserver/route_post_new_contract_listener.go @@ -36,7 +36,7 @@ var postNewContractListener = &ffapi.Route{ JSONOutputCodes: []int{http.StatusOK}, Extensions: &coreExtensions{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - return cr.or.Contracts().AddContractListener(cr.ctx, extractNamespace(r.PP), r.Input.(*core.ContractListenerInput)) + return cr.or.Contracts().AddContractListener(cr.ctx, r.Input.(*core.ContractListenerInput)) }, }, } diff --git a/internal/apiserver/route_post_new_contract_listener_test.go b/internal/apiserver/route_post_new_contract_listener_test.go index e95fbff2e..25e39e69a 100644 --- a/internal/apiserver/route_post_new_contract_listener_test.go +++ b/internal/apiserver/route_post_new_contract_listener_test.go @@ -39,7 +39,7 @@ func TestPostNewContractListener(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mcm.On("AddContractListener", mock.Anything, "mynamespace", mock.AnythingOfType("*core.ContractListenerInput")). + mcm.On("AddContractListener", mock.Anything, mock.AnythingOfType("*core.ContractListenerInput")). Return(&core.ContractListener{}, nil, nil) r.ServeHTTP(res, req) diff --git a/internal/contracts/manager.go b/internal/contracts/manager.go index f393285a3..3d6537f76 100644 --- a/internal/contracts/manager.go +++ b/internal/contracts/manager.go @@ -54,13 +54,13 @@ type Manager interface { ValidateFFIAndSetPathnames(ctx context.Context, ffi *core.FFI) error - AddContractListener(ctx context.Context, ns string, listener *core.ContractListenerInput) (output *core.ContractListener, err error) - AddContractAPIListener(ctx context.Context, ns, apiName, eventPath string, listener *core.ContractListener) (output *core.ContractListener, err error) - GetContractListenerByNameOrID(ctx context.Context, ns, nameOrID string) (*core.ContractListener, error) - GetContractListeners(ctx context.Context, ns string, filter database.AndFilter) ([]*core.ContractListener, *database.FilterResult, error) - GetContractAPIListeners(ctx context.Context, ns string, apiName, eventPath string, filter database.AndFilter) ([]*core.ContractListener, *database.FilterResult, error) - DeleteContractListenerByNameOrID(ctx context.Context, ns, nameOrID string) error - GenerateFFI(ctx context.Context, ns string, generationRequest *core.FFIGenerationRequest) (*core.FFI, error) + AddContractListener(ctx context.Context, listener *core.ContractListenerInput) (output *core.ContractListener, err error) + AddContractAPIListener(ctx context.Context, apiName, eventPath string, listener *core.ContractListener) (output *core.ContractListener, err error) + GetContractListenerByNameOrID(ctx context.Context, nameOrID string) (*core.ContractListener, error) + GetContractListeners(ctx context.Context, filter database.AndFilter) ([]*core.ContractListener, *database.FilterResult, error) + GetContractAPIListeners(ctx context.Context, apiName, eventPath string, filter database.AndFilter) ([]*core.ContractListener, *database.FilterResult, error) + DeleteContractListenerByNameOrID(ctx context.Context, nameOrID string) error + GenerateFFI(ctx context.Context, generationRequest *core.FFIGenerationRequest) (*core.FFI, error) // From operations.OperationHandler PrepareOperation(ctx context.Context, op *core.Operation) (*core.PreparedOperation, error) @@ -147,10 +147,6 @@ func (cm *contractManager) BroadcastFFI(ctx context.Context, ffi *core.FFI, wait return ffi, nil } -func (cm *contractManager) scopeNS(ns string, filter database.AndFilter) database.AndFilter { - return filter.Condition(filter.Builder().Eq("namespace", ns)) -} - func (cm *contractManager) GetFFI(ctx context.Context, name, version string) (*core.FFI, error) { return cm.database.GetFFI(ctx, cm.namespace, name, version) } @@ -483,11 +479,11 @@ func (cm *contractManager) validateInvokeContractRequest(ctx context.Context, re return nil } -func (cm *contractManager) resolveEvent(ctx context.Context, ns string, ffi *core.FFIReference, eventPath string) (*core.FFISerializedEvent, error) { +func (cm *contractManager) resolveEvent(ctx context.Context, ffi *core.FFIReference, eventPath string) (*core.FFISerializedEvent, error) { if err := cm.resolveFFIReference(ctx, ffi); err != nil { return nil, err } - event, err := cm.database.GetFFIEvent(ctx, ns, ffi.ID, eventPath) + event, err := cm.database.GetFFIEvent(ctx, cm.namespace, ffi.ID, eventPath) if err != nil { return nil, err } else if event == nil { @@ -496,13 +492,10 @@ func (cm *contractManager) resolveEvent(ctx context.Context, ns string, ffi *cor return &core.FFISerializedEvent{FFIEventDefinition: event.FFIEventDefinition}, nil } -func (cm *contractManager) AddContractListener(ctx context.Context, ns string, listener *core.ContractListenerInput) (output *core.ContractListener, err error) { +func (cm *contractManager) AddContractListener(ctx context.Context, listener *core.ContractListenerInput) (output *core.ContractListener, err error) { listener.ID = fftypes.NewUUID() - listener.Namespace = ns + listener.Namespace = cm.namespace - if err := core.ValidateFFNameField(ctx, ns, "namespace"); err != nil { - return nil, err - } if listener.Name != "" { if err := core.ValidateFFNameField(ctx, listener.Name, "name"); err != nil { return nil, err @@ -524,10 +517,10 @@ func (cm *contractManager) AddContractListener(ctx context.Context, ns string, l err = cm.database.RunAsGroup(ctx, func(ctx context.Context) (err error) { // Namespace + Name must be unique if listener.Name != "" { - if existing, err := cm.database.GetContractListener(ctx, ns, listener.Name); err != nil { + if existing, err := cm.database.GetContractListener(ctx, cm.namespace, listener.Name); err != nil { return err } else if existing != nil { - return i18n.NewError(ctx, coremsgs.MsgContractListenerNameExists, ns, listener.Name) + return i18n.NewError(ctx, coremsgs.MsgContractListenerNameExists, cm.namespace, listener.Name) } } @@ -536,7 +529,7 @@ func (cm *contractManager) AddContractListener(ctx context.Context, ns string, l return i18n.NewError(ctx, coremsgs.MsgListenerNoEvent) } // Copy the event definition into the listener - if listener.Event, err = cm.resolveEvent(ctx, ns, listener.Interface, listener.EventPath); err != nil { + if listener.Event, err = cm.resolveEvent(ctx, listener.Interface, listener.EventPath); err != nil { return err } } else { @@ -546,8 +539,7 @@ func (cm *contractManager) AddContractListener(ctx context.Context, ns string, l // Namespace + Topic + Location + Signature must be unique listener.Signature = cm.blockchain.GenerateEventSignature(ctx, &listener.Event.FFIEventDefinition) fb := database.ContractListenerQueryFactory.NewFilter(ctx) - if existing, _, err := cm.database.GetContractListeners(ctx, fb.And( - fb.Eq("namespace", listener.Namespace), + if existing, _, err := cm.database.GetContractListeners(ctx, cm.namespace, fb.And( fb.Eq("topic", listener.Topic), fb.Eq("location", listener.Location.Bytes()), fb.Eq("signature", listener.Signature), @@ -578,8 +570,8 @@ func (cm *contractManager) AddContractListener(ctx context.Context, ns string, l return &listener.ContractListener, err } -func (cm *contractManager) AddContractAPIListener(ctx context.Context, ns, apiName, eventPath string, listener *core.ContractListener) (output *core.ContractListener, err error) { - api, err := cm.database.GetContractAPIByName(ctx, ns, apiName) +func (cm *contractManager) AddContractAPIListener(ctx context.Context, apiName, eventPath string, listener *core.ContractListener) (output *core.ContractListener, err error) { + api, err := cm.database.GetContractAPIByName(ctx, cm.namespace, apiName) if err != nil { return nil, err } else if api == nil || api.Interface == nil { @@ -593,19 +585,19 @@ func (cm *contractManager) AddContractAPIListener(ctx context.Context, ns, apiNa input.Location = api.Location } - return cm.AddContractListener(ctx, ns, input) + return cm.AddContractListener(ctx, input) } -func (cm *contractManager) GetContractListenerByNameOrID(ctx context.Context, ns, nameOrID string) (listener *core.ContractListener, err error) { +func (cm *contractManager) GetContractListenerByNameOrID(ctx context.Context, nameOrID string) (listener *core.ContractListener, err error) { id, err := fftypes.ParseUUID(ctx, nameOrID) if err != nil { if err := core.ValidateFFNameField(ctx, nameOrID, "name"); err != nil { return nil, err } - if listener, err = cm.database.GetContractListener(ctx, ns, nameOrID); err != nil { + if listener, err = cm.database.GetContractListener(ctx, cm.namespace, nameOrID); err != nil { return nil, err } - } else if listener, err = cm.database.GetContractListenerByID(ctx, id); err != nil { + } else if listener, err = cm.database.GetContractListenerByID(ctx, cm.namespace, id); err != nil { return nil, err } if listener == nil { @@ -614,18 +606,18 @@ func (cm *contractManager) GetContractListenerByNameOrID(ctx context.Context, ns return listener, nil } -func (cm *contractManager) GetContractListeners(ctx context.Context, ns string, filter database.AndFilter) ([]*core.ContractListener, *database.FilterResult, error) { - return cm.database.GetContractListeners(ctx, cm.scopeNS(ns, filter)) +func (cm *contractManager) GetContractListeners(ctx context.Context, filter database.AndFilter) ([]*core.ContractListener, *database.FilterResult, error) { + return cm.database.GetContractListeners(ctx, cm.namespace, filter) } -func (cm *contractManager) GetContractAPIListeners(ctx context.Context, ns string, apiName, eventPath string, filter database.AndFilter) ([]*core.ContractListener, *database.FilterResult, error) { - api, err := cm.database.GetContractAPIByName(ctx, ns, apiName) +func (cm *contractManager) GetContractAPIListeners(ctx context.Context, apiName, eventPath string, filter database.AndFilter) ([]*core.ContractListener, *database.FilterResult, error) { + api, err := cm.database.GetContractAPIByName(ctx, cm.namespace, apiName) if err != nil { return nil, nil, err } else if api == nil || api.Interface == nil { return nil, nil, i18n.NewError(ctx, coremsgs.Msg404NotFound) } - event, err := cm.resolveEvent(ctx, ns, api.Interface, eventPath) + event, err := cm.resolveEvent(ctx, api.Interface, eventPath) if err != nil { return nil, nil, err } @@ -640,19 +632,19 @@ func (cm *contractManager) GetContractAPIListeners(ctx context.Context, ns strin if !api.Location.IsNil() { f = fb.And(f, fb.Eq("location", api.Location.Bytes())) } - return cm.database.GetContractListeners(ctx, cm.scopeNS(ns, f)) + return cm.database.GetContractListeners(ctx, cm.namespace, f) } -func (cm *contractManager) DeleteContractListenerByNameOrID(ctx context.Context, ns, nameOrID string) error { +func (cm *contractManager) DeleteContractListenerByNameOrID(ctx context.Context, nameOrID string) error { return cm.database.RunAsGroup(ctx, func(ctx context.Context) (err error) { - listener, err := cm.GetContractListenerByNameOrID(ctx, ns, nameOrID) + listener, err := cm.GetContractListenerByNameOrID(ctx, nameOrID) if err != nil { return err } if err = cm.blockchain.DeleteContractListener(ctx, listener); err != nil { return err } - return cm.database.DeleteContractListenerByID(ctx, listener.ID) + return cm.database.DeleteContractListenerByID(ctx, cm.namespace, listener.ID) }) } @@ -673,8 +665,8 @@ func (cm *contractManager) checkParamSchema(ctx context.Context, input interface return nil } -func (cm *contractManager) GenerateFFI(ctx context.Context, ns string, generationRequest *core.FFIGenerationRequest) (*core.FFI, error) { - generationRequest.Namespace = ns +func (cm *contractManager) GenerateFFI(ctx context.Context, generationRequest *core.FFIGenerationRequest) (*core.FFI, error) { + generationRequest.Namespace = cm.namespace return cm.blockchain.GenerateFFI(ctx, generationRequest) } diff --git a/internal/contracts/manager_test.go b/internal/contracts/manager_test.go index b98a03488..d23b8aefc 100644 --- a/internal/contracts/manager_test.go +++ b/internal/contracts/manager_test.go @@ -589,11 +589,11 @@ func TestAddContractListenerInline(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), sub.Location).Return(sub.Location, nil) mbi.On("GenerateEventSignature", context.Background(), mock.Anything).Return("changed") - mdi.On("GetContractListeners", context.Background(), mock.Anything).Return(nil, nil, nil) + mdi.On("GetContractListeners", context.Background(), "ns1", mock.Anything).Return(nil, nil, nil) mbi.On("AddContractListener", context.Background(), sub).Return(nil) mdi.On("InsertContractListener", context.Background(), &sub.ContractListener).Return(nil) - result, err := cm.AddContractListener(context.Background(), "ns", sub) + result, err := cm.AddContractListener(context.Background(), sub) assert.NoError(t, err) assert.NotNil(t, result.ID) assert.NotNil(t, result.Event) @@ -638,13 +638,13 @@ func TestAddContractListenerByEventPath(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), sub.Location).Return(sub.Location, nil) mbi.On("GenerateEventSignature", context.Background(), mock.Anything).Return("changed") - mdi.On("GetContractListeners", context.Background(), mock.Anything).Return(nil, nil, nil) + mdi.On("GetContractListeners", context.Background(), "ns1", mock.Anything).Return(nil, nil, nil) mbi.On("AddContractListener", context.Background(), sub).Return(nil) mdi.On("GetFFIByID", context.Background(), "ns1", interfaceID).Return(&core.FFI{}, nil) mdi.On("GetFFIEvent", context.Background(), "ns1", interfaceID, sub.EventPath).Return(event, nil) mdi.On("InsertContractListener", context.Background(), &sub.ContractListener).Return(nil) - result, err := cm.AddContractListener(context.Background(), "ns1", sub) + result, err := cm.AddContractListener(context.Background(), sub) assert.NoError(t, err) assert.NotNil(t, result.ID) assert.NotNil(t, result.Event) @@ -672,7 +672,7 @@ func TestAddContractListenerBadLocation(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), sub.Location).Return(nil, fmt.Errorf("pop")) - _, err := cm.AddContractListener(context.Background(), "ns1", sub) + _, err := cm.AddContractListener(context.Background(), sub) assert.EqualError(t, err, "pop") mbi.AssertExpectations(t) @@ -701,7 +701,7 @@ func TestAddContractListenerFFILookupFail(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), sub.Location).Return(sub.Location, nil) mdi.On("GetFFIByID", context.Background(), "ns1", interfaceID).Return(nil, fmt.Errorf("pop")) - _, err := cm.AddContractListener(context.Background(), "ns1", sub) + _, err := cm.AddContractListener(context.Background(), sub) assert.EqualError(t, err, "pop") mbi.AssertExpectations(t) @@ -732,7 +732,7 @@ func TestAddContractListenerEventLookupFail(t *testing.T) { mdi.On("GetFFIByID", context.Background(), "ns1", interfaceID).Return(&core.FFI{}, nil) mdi.On("GetFFIEvent", context.Background(), "ns1", interfaceID, sub.EventPath).Return(nil, fmt.Errorf("pop")) - _, err := cm.AddContractListener(context.Background(), "ns1", sub) + _, err := cm.AddContractListener(context.Background(), sub) assert.EqualError(t, err, "pop") mbi.AssertExpectations(t) @@ -763,7 +763,7 @@ func TestAddContractListenerEventLookupNotFound(t *testing.T) { mdi.On("GetFFIByID", context.Background(), "ns1", interfaceID).Return(&core.FFI{}, nil) mdi.On("GetFFIEvent", context.Background(), "ns1", interfaceID, sub.EventPath).Return(nil, nil) - _, err := cm.AddContractListener(context.Background(), "ns1", sub) + _, err := cm.AddContractListener(context.Background(), sub) assert.Regexp(t, "FF10370", err) mbi.AssertExpectations(t) @@ -785,20 +785,12 @@ func TestAddContractListenerMissingEventOrID(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), sub.Location).Return(sub.Location, nil) - _, err := cm.AddContractListener(context.Background(), "ns2", sub) + _, err := cm.AddContractListener(context.Background(), sub) assert.Regexp(t, "FF10317", err) mbi.AssertExpectations(t) } -func TestAddContractListenerBadNamespace(t *testing.T) { - cm := newTestContractManager() - sub := &core.ContractListenerInput{} - - _, err := cm.AddContractListener(context.Background(), "!bad", sub) - assert.Regexp(t, "FF00140.*'namespace'", err) -} - func TestAddContractListenerBadName(t *testing.T) { cm := newTestContractManager() sub := &core.ContractListenerInput{ @@ -807,7 +799,7 @@ func TestAddContractListenerBadName(t *testing.T) { }, } - _, err := cm.AddContractListener(context.Background(), "ns", sub) + _, err := cm.AddContractListener(context.Background(), sub) assert.Regexp(t, "FF00140.*'name'", err) } @@ -817,7 +809,7 @@ func TestAddContractListenerMissingTopic(t *testing.T) { ContractListener: core.ContractListener{}, } - _, err := cm.AddContractListener(context.Background(), "ns", sub) + _, err := cm.AddContractListener(context.Background(), sub) assert.Regexp(t, "FF00140.*'topic'", err) } @@ -838,9 +830,9 @@ func TestAddContractListenerNameConflict(t *testing.T) { } mbi.On("NormalizeContractLocation", context.Background(), sub.Location).Return(sub.Location, nil) - mdi.On("GetContractListener", context.Background(), "ns", "sub1").Return(&core.ContractListener{}, nil) + mdi.On("GetContractListener", context.Background(), "ns1", "sub1").Return(&core.ContractListener{}, nil) - _, err := cm.AddContractListener(context.Background(), "ns", sub) + _, err := cm.AddContractListener(context.Background(), sub) assert.Regexp(t, "FF10312", err) mbi.AssertExpectations(t) @@ -864,9 +856,9 @@ func TestAddContractListenerNameError(t *testing.T) { } mbi.On("NormalizeContractLocation", context.Background(), sub.Location).Return(sub.Location, nil) - mdi.On("GetContractListener", context.Background(), "ns", "sub1").Return(nil, fmt.Errorf("pop")) + mdi.On("GetContractListener", context.Background(), "ns1", "sub1").Return(nil, fmt.Errorf("pop")) - _, err := cm.AddContractListener(context.Background(), "ns", sub) + _, err := cm.AddContractListener(context.Background(), sub) assert.EqualError(t, err, "pop") mbi.AssertExpectations(t) @@ -890,9 +882,9 @@ func TestAddContractListenerTopicConflict(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), sub.Location).Return(sub.Location, nil) mbi.On("GenerateEventSignature", context.Background(), mock.Anything).Return("changed") - mdi.On("GetContractListeners", context.Background(), mock.Anything).Return([]*core.ContractListener{{}}, nil, nil) + mdi.On("GetContractListeners", context.Background(), "ns1", mock.Anything).Return([]*core.ContractListener{{}}, nil, nil) - _, err := cm.AddContractListener(context.Background(), "ns", sub) + _, err := cm.AddContractListener(context.Background(), sub) assert.Regexp(t, "FF10383", err) mbi.AssertExpectations(t) @@ -916,9 +908,9 @@ func TestAddContractListenerTopicError(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), sub.Location).Return(sub.Location, nil) mbi.On("GenerateEventSignature", context.Background(), mock.Anything).Return("changed") - mdi.On("GetContractListeners", context.Background(), mock.Anything).Return(nil, nil, fmt.Errorf("pop")) + mdi.On("GetContractListeners", context.Background(), "ns1", mock.Anything).Return(nil, nil, fmt.Errorf("pop")) - _, err := cm.AddContractListener(context.Background(), "ns", sub) + _, err := cm.AddContractListener(context.Background(), sub) assert.EqualError(t, err, "pop") mbi.AssertExpectations(t) @@ -952,9 +944,9 @@ func TestAddContractListenerValidateFail(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), sub.Location).Return(sub.Location, nil) mbi.On("GenerateEventSignature", context.Background(), mock.Anything).Return("changed") - mdi.On("GetContractListeners", context.Background(), mock.Anything).Return(nil, nil, nil) + mdi.On("GetContractListeners", context.Background(), "ns1", mock.Anything).Return(nil, nil, nil) - _, err := cm.AddContractListener(context.Background(), "ns", sub) + _, err := cm.AddContractListener(context.Background(), sub) assert.Regexp(t, "does not validate", err) mbi.AssertExpectations(t) @@ -988,10 +980,10 @@ func TestAddContractListenerBlockchainFail(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), sub.Location).Return(sub.Location, nil) mbi.On("GenerateEventSignature", context.Background(), mock.Anything).Return("changed") - mdi.On("GetContractListeners", context.Background(), mock.Anything).Return(nil, nil, nil) + mdi.On("GetContractListeners", context.Background(), "ns1", mock.Anything).Return(nil, nil, nil) mbi.On("AddContractListener", context.Background(), sub).Return(fmt.Errorf("pop")) - _, err := cm.AddContractListener(context.Background(), "ns", sub) + _, err := cm.AddContractListener(context.Background(), sub) assert.EqualError(t, err, "pop") mbi.AssertExpectations(t) @@ -1025,11 +1017,11 @@ func TestAddContractListenerUpsertSubFail(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), sub.Location).Return(sub.Location, nil) mbi.On("GenerateEventSignature", context.Background(), mock.Anything).Return("changed") - mdi.On("GetContractListeners", context.Background(), mock.Anything).Return(nil, nil, nil) + mdi.On("GetContractListeners", context.Background(), "ns1", mock.Anything).Return(nil, nil, nil) mbi.On("AddContractListener", context.Background(), sub).Return(nil) mdi.On("InsertContractListener", context.Background(), &sub.ContractListener).Return(fmt.Errorf("pop")) - _, err := cm.AddContractListener(context.Background(), "ns", sub) + _, err := cm.AddContractListener(context.Background(), sub) assert.EqualError(t, err, "pop") mbi.AssertExpectations(t) @@ -1059,12 +1051,12 @@ func TestAddContractAPIListener(t *testing.T) { }, } - mdi.On("GetContractAPIByName", context.Background(), "ns", "simple").Return(api, nil) + mdi.On("GetContractAPIByName", context.Background(), "ns1", "simple").Return(api, nil) mbi.On("NormalizeContractLocation", context.Background(), api.Location).Return(listener.Location, nil) mdi.On("GetFFIByID", context.Background(), "ns1", interfaceID).Return(&core.FFI{}, nil) - mdi.On("GetFFIEvent", context.Background(), "ns", interfaceID, "changed").Return(event, nil) + mdi.On("GetFFIEvent", context.Background(), "ns1", interfaceID, "changed").Return(event, nil) mbi.On("GenerateEventSignature", context.Background(), mock.Anything).Return("changed") - mdi.On("GetContractListeners", context.Background(), mock.Anything).Return(nil, nil, nil) + mdi.On("GetContractListeners", context.Background(), "ns1", mock.Anything).Return(nil, nil, nil) mbi.On("AddContractListener", context.Background(), mock.MatchedBy(func(l *core.ContractListenerInput) bool { return *l.Interface.ID == *interfaceID && l.EventPath == "changed" && l.Topic == "test-topic" })).Return(nil) @@ -1072,7 +1064,7 @@ func TestAddContractAPIListener(t *testing.T) { return *l.Interface.ID == *interfaceID && l.Event.Name == "changed" && l.Topic == "test-topic" })).Return(nil) - _, err := cm.AddContractAPIListener(context.Background(), "ns", "simple", "changed", listener) + _, err := cm.AddContractAPIListener(context.Background(), "simple", "changed", listener) assert.NoError(t, err) mbi.AssertExpectations(t) @@ -1090,9 +1082,9 @@ func TestAddContractAPIListenerNotFound(t *testing.T) { Topic: "test-topic", } - mdi.On("GetContractAPIByName", context.Background(), "ns", "simple").Return(nil, nil) + mdi.On("GetContractAPIByName", context.Background(), "ns1", "simple").Return(nil, nil) - _, err := cm.AddContractAPIListener(context.Background(), "ns", "simple", "changed", listener) + _, err := cm.AddContractAPIListener(context.Background(), "simple", "changed", listener) assert.Regexp(t, "FF10109", err) mdi.AssertExpectations(t) @@ -1109,9 +1101,9 @@ func TestAddContractAPIListenerFail(t *testing.T) { Topic: "test-topic", } - mdi.On("GetContractAPIByName", context.Background(), "ns", "simple").Return(nil, fmt.Errorf("pop")) + mdi.On("GetContractAPIByName", context.Background(), "ns1", "simple").Return(nil, fmt.Errorf("pop")) - _, err := cm.AddContractAPIListener(context.Background(), "ns", "simple", "changed", listener) + _, err := cm.AddContractAPIListener(context.Background(), "simple", "changed", listener) assert.EqualError(t, err, "pop") mdi.AssertExpectations(t) @@ -1563,9 +1555,9 @@ func TestGetContractListenerByNameOrID(t *testing.T) { mdi := cm.database.(*databasemocks.Plugin) id := fftypes.NewUUID() - mdi.On("GetContractListenerByID", context.Background(), id).Return(&core.ContractListener{}, nil) + mdi.On("GetContractListenerByID", context.Background(), "ns1", id).Return(&core.ContractListener{}, nil) - _, err := cm.GetContractListenerByNameOrID(context.Background(), "ns", id.String()) + _, err := cm.GetContractListenerByNameOrID(context.Background(), id.String()) assert.NoError(t, err) } @@ -1574,9 +1566,9 @@ func TestGetContractListenerByNameOrIDFail(t *testing.T) { mdi := cm.database.(*databasemocks.Plugin) id := fftypes.NewUUID() - mdi.On("GetContractListenerByID", context.Background(), id).Return(nil, fmt.Errorf("pop")) + mdi.On("GetContractListenerByID", context.Background(), "ns1", id).Return(nil, fmt.Errorf("pop")) - _, err := cm.GetContractListenerByNameOrID(context.Background(), "ns", id.String()) + _, err := cm.GetContractListenerByNameOrID(context.Background(), id.String()) assert.EqualError(t, err, "pop") } @@ -1584,16 +1576,16 @@ func TestGetContractListenerByName(t *testing.T) { cm := newTestContractManager() mdi := cm.database.(*databasemocks.Plugin) - mdi.On("GetContractListener", context.Background(), "ns", "sub1").Return(&core.ContractListener{}, nil) + mdi.On("GetContractListener", context.Background(), "ns1", "sub1").Return(&core.ContractListener{}, nil) - _, err := cm.GetContractListenerByNameOrID(context.Background(), "ns", "sub1") + _, err := cm.GetContractListenerByNameOrID(context.Background(), "sub1") assert.NoError(t, err) } func TestGetContractListenerBadName(t *testing.T) { cm := newTestContractManager() - _, err := cm.GetContractListenerByNameOrID(context.Background(), "ns", "!bad") + _, err := cm.GetContractListenerByNameOrID(context.Background(), "!bad") assert.Regexp(t, "FF00140", err) } @@ -1601,9 +1593,9 @@ func TestGetContractListenerByNameFail(t *testing.T) { cm := newTestContractManager() mdi := cm.database.(*databasemocks.Plugin) - mdi.On("GetContractListener", context.Background(), "ns", "sub1").Return(nil, fmt.Errorf("pop")) + mdi.On("GetContractListener", context.Background(), "ns1", "sub1").Return(nil, fmt.Errorf("pop")) - _, err := cm.GetContractListenerByNameOrID(context.Background(), "ns", "sub1") + _, err := cm.GetContractListenerByNameOrID(context.Background(), "sub1") assert.EqualError(t, err, "pop") } @@ -1611,9 +1603,9 @@ func TestGetContractListenerNotFound(t *testing.T) { cm := newTestContractManager() mdi := cm.database.(*databasemocks.Plugin) - mdi.On("GetContractListener", context.Background(), "ns", "sub1").Return(nil, nil) + mdi.On("GetContractListener", context.Background(), "ns1", "sub1").Return(nil, nil) - _, err := cm.GetContractListenerByNameOrID(context.Background(), "ns", "sub1") + _, err := cm.GetContractListenerByNameOrID(context.Background(), "sub1") assert.Regexp(t, "FF10109", err) } @@ -1621,10 +1613,10 @@ func TestGetContractListeners(t *testing.T) { cm := newTestContractManager() mdi := cm.database.(*databasemocks.Plugin) - mdi.On("GetContractListeners", context.Background(), mock.Anything).Return(nil, nil, nil) + mdi.On("GetContractListeners", context.Background(), "ns1", mock.Anything).Return(nil, nil, nil) f := database.ContractListenerQueryFactory.NewFilter(context.Background()) - _, _, err := cm.GetContractListeners(context.Background(), "ns", f.And()) + _, _, err := cm.GetContractListeners(context.Background(), f.And()) assert.NoError(t, err) } @@ -1648,14 +1640,14 @@ func TestGetContractAPIListeners(t *testing.T) { }, } - mdi.On("GetContractAPIByName", context.Background(), "ns", "simple").Return(api, nil) + mdi.On("GetContractAPIByName", context.Background(), "ns1", "simple").Return(api, nil) mdi.On("GetFFIByID", context.Background(), "ns1", interfaceID).Return(&core.FFI{}, nil) - mdi.On("GetFFIEvent", context.Background(), "ns", interfaceID, "changed").Return(event, nil) + mdi.On("GetFFIEvent", context.Background(), "ns1", interfaceID, "changed").Return(event, nil) mbi.On("GenerateEventSignature", context.Background(), mock.Anything).Return("changed") - mdi.On("GetContractListeners", context.Background(), mock.Anything).Return(nil, nil, nil) + mdi.On("GetContractListeners", context.Background(), "ns1", mock.Anything).Return(nil, nil, nil) f := database.ContractListenerQueryFactory.NewFilter(context.Background()) - _, _, err := cm.GetContractAPIListeners(context.Background(), "ns", "simple", "changed", f.And()) + _, _, err := cm.GetContractAPIListeners(context.Background(), "simple", "changed", f.And()) assert.NoError(t, err) mbi.AssertExpectations(t) @@ -1666,10 +1658,10 @@ func TestGetContractAPIListenersNotFound(t *testing.T) { cm := newTestContractManager() mdi := cm.database.(*databasemocks.Plugin) - mdi.On("GetContractAPIByName", context.Background(), "ns", "simple").Return(nil, nil) + mdi.On("GetContractAPIByName", context.Background(), "ns1", "simple").Return(nil, nil) f := database.ContractListenerQueryFactory.NewFilter(context.Background()) - _, _, err := cm.GetContractAPIListeners(context.Background(), "ns", "simple", "changed", f.And()) + _, _, err := cm.GetContractAPIListeners(context.Background(), "simple", "changed", f.And()) assert.Regexp(t, "FF10109", err) mdi.AssertExpectations(t) @@ -1679,10 +1671,10 @@ func TestGetContractAPIListenersFail(t *testing.T) { cm := newTestContractManager() mdi := cm.database.(*databasemocks.Plugin) - mdi.On("GetContractAPIByName", context.Background(), "ns", "simple").Return(nil, fmt.Errorf("pop")) + mdi.On("GetContractAPIByName", context.Background(), "ns1", "simple").Return(nil, fmt.Errorf("pop")) f := database.ContractListenerQueryFactory.NewFilter(context.Background()) - _, _, err := cm.GetContractAPIListeners(context.Background(), "ns", "simple", "changed", f.And()) + _, _, err := cm.GetContractAPIListeners(context.Background(), "simple", "changed", f.And()) assert.EqualError(t, err, "pop") mdi.AssertExpectations(t) @@ -1702,12 +1694,12 @@ func TestGetContractAPIListenersEventNotFound(t *testing.T) { }.String()), } - mdi.On("GetContractAPIByName", context.Background(), "ns", "simple").Return(api, nil) + mdi.On("GetContractAPIByName", context.Background(), "ns1", "simple").Return(api, nil) mdi.On("GetFFIByID", context.Background(), "ns1", interfaceID).Return(&core.FFI{}, nil) - mdi.On("GetFFIEvent", context.Background(), "ns", interfaceID, "changed").Return(nil, nil) + mdi.On("GetFFIEvent", context.Background(), "ns1", interfaceID, "changed").Return(nil, nil) f := database.ContractListenerQueryFactory.NewFilter(context.Background()) - _, _, err := cm.GetContractAPIListeners(context.Background(), "ns", "simple", "changed", f.And()) + _, _, err := cm.GetContractAPIListeners(context.Background(), "simple", "changed", f.And()) assert.Regexp(t, "FF10370", err) mdi.AssertExpectations(t) @@ -1722,11 +1714,11 @@ func TestDeleteContractListener(t *testing.T) { ID: fftypes.NewUUID(), } - mdi.On("GetContractListener", context.Background(), "ns", "sub1").Return(sub, nil) + mdi.On("GetContractListener", context.Background(), "ns1", "sub1").Return(sub, nil) mbi.On("DeleteContractListener", context.Background(), sub).Return(nil) - mdi.On("DeleteContractListenerByID", context.Background(), sub.ID).Return(nil) + mdi.On("DeleteContractListenerByID", context.Background(), "ns1", sub.ID).Return(nil) - err := cm.DeleteContractListenerByNameOrID(context.Background(), "ns", "sub1") + err := cm.DeleteContractListenerByNameOrID(context.Background(), "sub1") assert.NoError(t, err) } @@ -1739,11 +1731,11 @@ func TestDeleteContractListenerBlockchainFail(t *testing.T) { ID: fftypes.NewUUID(), } - mdi.On("GetContractListener", context.Background(), "ns", "sub1").Return(sub, nil) + mdi.On("GetContractListener", context.Background(), "ns1", "sub1").Return(sub, nil) mbi.On("DeleteContractListener", context.Background(), sub).Return(fmt.Errorf("pop")) - mdi.On("DeleteContractListenerByID", context.Background(), sub.ID).Return(nil) + mdi.On("DeleteContractListenerByID", context.Background(), "ns1", sub.ID).Return(nil) - err := cm.DeleteContractListenerByNameOrID(context.Background(), "ns", "sub1") + err := cm.DeleteContractListenerByNameOrID(context.Background(), "sub1") assert.EqualError(t, err, "pop") } @@ -1751,9 +1743,9 @@ func TestDeleteContractListenerNotFound(t *testing.T) { cm := newTestContractManager() mdi := cm.database.(*databasemocks.Plugin) - mdi.On("GetContractListener", context.Background(), "ns", "sub1").Return(nil, nil) + mdi.On("GetContractListener", context.Background(), "ns1", "sub1").Return(nil, nil) - err := cm.DeleteContractListenerByNameOrID(context.Background(), "ns", "sub1") + err := cm.DeleteContractListenerByNameOrID(context.Background(), "sub1") assert.Regexp(t, "FF10109", err) } @@ -2347,7 +2339,7 @@ func TestGenerateFFI(t *testing.T) { mbi.On("GenerateFFI", mock.Anything, mock.Anything).Return(&core.FFI{ Name: "generated", }, nil) - ffi, err := cm.GenerateFFI(context.Background(), "default", &core.FFIGenerationRequest{}) + ffi, err := cm.GenerateFFI(context.Background(), &core.FFIGenerationRequest{}) assert.NoError(t, err) assert.NotNil(t, ffi) assert.Equal(t, "generated", ffi.Name) diff --git a/internal/database/sqlcommon/contractlisteners_sql.go b/internal/database/sqlcommon/contractlisteners_sql.go index 47e5fd994..3acbb97e5 100644 --- a/internal/database/sqlcommon/contractlisteners_sql.go +++ b/internal/database/sqlcommon/contractlisteners_sql.go @@ -138,22 +138,22 @@ func (s *SQLCommon) getContractListenerPred(ctx context.Context, desc string, pr return sub, nil } -func (s *SQLCommon) GetContractListener(ctx context.Context, ns, name string) (sub *core.ContractListener, err error) { - return s.getContractListenerPred(ctx, fmt.Sprintf("%s:%s", ns, name), sq.Eq{"namespace": ns, "name": name}) +func (s *SQLCommon) GetContractListener(ctx context.Context, namespace, name string) (sub *core.ContractListener, err error) { + return s.getContractListenerPred(ctx, fmt.Sprintf("%s:%s", namespace, name), sq.Eq{"namespace": namespace, "name": name}) } -func (s *SQLCommon) GetContractListenerByID(ctx context.Context, id *fftypes.UUID) (sub *core.ContractListener, err error) { - return s.getContractListenerPred(ctx, id.String(), sq.Eq{"id": id}) +func (s *SQLCommon) GetContractListenerByID(ctx context.Context, namespace string, id *fftypes.UUID) (sub *core.ContractListener, err error) { + return s.getContractListenerPred(ctx, id.String(), sq.Eq{"id": id, "namespace": namespace}) } -func (s *SQLCommon) GetContractListenerByBackendID(ctx context.Context, id string) (sub *core.ContractListener, err error) { - return s.getContractListenerPred(ctx, id, sq.Eq{"backend_id": id}) +func (s *SQLCommon) GetContractListenerByBackendID(ctx context.Context, namespace, id string) (sub *core.ContractListener, err error) { + return s.getContractListenerPred(ctx, id, sq.Eq{"backend_id": id, "namespace": namespace}) } -func (s *SQLCommon) GetContractListeners(ctx context.Context, filter database.Filter) ([]*core.ContractListener, *database.FilterResult, error) { +func (s *SQLCommon) GetContractListeners(ctx context.Context, namespace string, filter database.Filter) ([]*core.ContractListener, *database.FilterResult, error) { query, fop, fi, err := s.filterSelect(ctx, "", sq.Select(contractListenerColumns...).From(contractlistenersTable), - filter, contractListenerFilterFieldMap, []interface{}{"sequence"}) + filter, contractListenerFilterFieldMap, []interface{}{"sequence"}, sq.Eq{"namespace": namespace}) if err != nil { return nil, nil, err } @@ -176,14 +176,14 @@ func (s *SQLCommon) GetContractListeners(ctx context.Context, filter database.Fi return subs, s.queryRes(ctx, contractlistenersTable, tx, fop, fi), err } -func (s *SQLCommon) DeleteContractListenerByID(ctx context.Context, id *fftypes.UUID) (err error) { +func (s *SQLCommon) DeleteContractListenerByID(ctx context.Context, namespace string, id *fftypes.UUID) (err error) { ctx, tx, autoCommit, err := s.beginOrUseTx(ctx) if err != nil { return err } defer s.rollbackTx(ctx, tx, autoCommit) - sub, err := s.GetContractListenerByID(ctx, id) + sub, err := s.GetContractListenerByID(ctx, namespace, id) if err == nil && sub != nil { err = s.deleteTx(ctx, contractlistenersTable, tx, sq.Delete(contractlistenersTable).Where(sq.Eq{"id": id}), func() { diff --git a/internal/database/sqlcommon/contractlisteners_sql_test.go b/internal/database/sqlcommon/contractlisteners_sql_test.go index fb4c61d8a..06dfddb95 100644 --- a/internal/database/sqlcommon/contractlisteners_sql_test.go +++ b/internal/database/sqlcommon/contractlisteners_sql_test.go @@ -71,7 +71,7 @@ func TestContractListenerE2EWithDB(t *testing.T) { filter := fb.And( fb.Eq("backendid", sub.BackendID), ) - subs, res, err := s.GetContractListeners(ctx, filter.Count(true)) + subs, res, err := s.GetContractListeners(ctx, "ns", filter.Count(true)) assert.NoError(t, err) assert.Equal(t, 1, len(subs)) assert.Equal(t, int64(1), *res.TotalCount) @@ -85,13 +85,13 @@ func TestContractListenerE2EWithDB(t *testing.T) { assert.Equal(t, string(subJson), string(subReadJson)) // Query back the listener (by ID) - subRead, err = s.GetContractListenerByID(ctx, sub.ID) + subRead, err = s.GetContractListenerByID(ctx, "ns", sub.ID) assert.NoError(t, err) subReadJson, _ = json.Marshal(subRead) assert.Equal(t, string(subJson), string(subReadJson)) // Query back the listener (by protocol ID) - subRead, err = s.GetContractListenerByBackendID(ctx, sub.BackendID) + subRead, err = s.GetContractListenerByBackendID(ctx, "ns", sub.BackendID) assert.NoError(t, err) subReadJson, _ = json.Marshal(subRead) assert.Equal(t, string(subJson), string(subReadJson)) @@ -100,7 +100,7 @@ func TestContractListenerE2EWithDB(t *testing.T) { filter = fb.And( fb.Eq("backendid", sub.BackendID), ) - subs, res, err = s.GetContractListeners(ctx, filter.Count(true)) + subs, res, err = s.GetContractListeners(ctx, "ns", filter.Count(true)) assert.NoError(t, err) assert.Equal(t, 1, len(subs)) assert.Equal(t, int64(1), *res.TotalCount) @@ -108,9 +108,9 @@ func TestContractListenerE2EWithDB(t *testing.T) { assert.Equal(t, string(subJson), string(subReadJson)) // Test delete, and refind no return - err = s.DeleteContractListenerByID(ctx, sub.ID) + err = s.DeleteContractListenerByID(ctx, "ns", sub.ID) assert.NoError(t, err) - subs, _, err = s.GetContractListeners(ctx, filter) + subs, _, err = s.GetContractListeners(ctx, "ns", filter) assert.NoError(t, err) assert.Equal(t, 0, len(subs)) } @@ -146,7 +146,7 @@ func TestUpsertContractListenerFailCommit(t *testing.T) { func TestGetContractListenerByIDSelectFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) - _, err := s.GetContractListenerByID(context.Background(), fftypes.NewUUID()) + _, err := s.GetContractListenerByID(context.Background(), "ns", fftypes.NewUUID()) assert.Regexp(t, "FF10115", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -154,7 +154,7 @@ func TestGetContractListenerByIDSelectFail(t *testing.T) { func TestGetContractListenerByIDNotFound(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"backendid"})) - msg, err := s.GetContractListenerByID(context.Background(), fftypes.NewUUID()) + msg, err := s.GetContractListenerByID(context.Background(), "ns", fftypes.NewUUID()) assert.NoError(t, err) assert.Nil(t, msg) assert.NoError(t, mock.ExpectationsWereMet()) @@ -163,7 +163,7 @@ func TestGetContractListenerByIDNotFound(t *testing.T) { func TestGetContractListenerByIDScanFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"backendid"}).AddRow("only one")) - _, err := s.GetContractListenerByID(context.Background(), fftypes.NewUUID()) + _, err := s.GetContractListenerByID(context.Background(), "ns", fftypes.NewUUID()) assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -172,7 +172,7 @@ func TestGetContractListenersQueryFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnError(fmt.Errorf("pop")) f := database.ContractListenerQueryFactory.NewFilter(context.Background()).Eq("backendid", "") - _, _, err := s.GetContractListeners(context.Background(), f) + _, _, err := s.GetContractListeners(context.Background(), "ns", f) assert.Regexp(t, "FF10115", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -180,7 +180,7 @@ func TestGetContractListenersQueryFail(t *testing.T) { func TestGetContractListenersBuildQueryFail(t *testing.T) { s, _ := newMockProvider().init() f := database.ContractListenerQueryFactory.NewFilter(context.Background()).Eq("backendid", map[bool]bool{true: false}) - _, _, err := s.GetContractListeners(context.Background(), f) + _, _, err := s.GetContractListeners(context.Background(), "ns", f) assert.Regexp(t, "FF00143.*id", err) } @@ -188,7 +188,7 @@ func TestGetContractListenersScanFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectQuery("SELECT .*").WillReturnRows(sqlmock.NewRows([]string{"backendid"}).AddRow("only one")) f := database.ContractListenerQueryFactory.NewFilter(context.Background()).Eq("backendid", "") - _, _, err := s.GetContractListeners(context.Background(), f) + _, _, err := s.GetContractListeners(context.Background(), "ns", f) assert.Regexp(t, "FF10121", err) assert.NoError(t, mock.ExpectationsWereMet()) } @@ -196,7 +196,7 @@ func TestGetContractListenersScanFail(t *testing.T) { func TestContractListenerDeleteBeginFail(t *testing.T) { s, mock := newMockProvider().init() mock.ExpectBegin().WillReturnError(fmt.Errorf("pop")) - err := s.DeleteContractListenerByID(context.Background(), fftypes.NewUUID()) + err := s.DeleteContractListenerByID(context.Background(), "ns", fftypes.NewUUID()) assert.Regexp(t, "FF10114", err) } @@ -207,7 +207,7 @@ func TestContractListenerDeleteFail(t *testing.T) { fftypes.NewUUID(), nil, []byte("{}"), "ns1", "sub1", "123", "{}", "sig", "topic1", nil, fftypes.Now()), ) mock.ExpectExec("DELETE .*").WillReturnError(fmt.Errorf("pop")) - err := s.DeleteContractListenerByID(context.Background(), fftypes.NewUUID()) + err := s.DeleteContractListenerByID(context.Background(), "ns", fftypes.NewUUID()) assert.Regexp(t, "FF10118", err) } @@ -231,7 +231,7 @@ func TestContractListenerOptions(t *testing.T) { err := s.InsertContractListener(ctx, l) assert.NoError(t, err) - li, err := s.GetContractListenerByID(ctx, l.ID) + li, err := s.GetContractListenerByID(ctx, "ns", l.ID) assert.NoError(t, err) assert.Equal(t, l.Options, li.Options) diff --git a/internal/events/batch_pin_complete.go b/internal/events/batch_pin_complete.go index 5c2eeb856..8ff025762 100644 --- a/internal/events/batch_pin_complete.go +++ b/internal/events/batch_pin_complete.go @@ -65,7 +65,7 @@ func (em *eventManager) BatchPinComplete(batchPin *blockchain.BatchPin, signingK ID: batchPin.TransactionID, BlockchainID: batchPin.Event.BlockchainTXID, }) - if err := em.maybePersistBlockchainEvent(ctx, chainEvent); err != nil { + if err := em.maybePersistBlockchainEvent(ctx, chainEvent, nil); err != nil { return err } em.emitBlockchainEventMetric(&batchPin.Event) diff --git a/internal/events/blockchain_event.go b/internal/events/blockchain_event.go index a944fa807..aac2914ce 100644 --- a/internal/events/blockchain_event.go +++ b/internal/events/blockchain_event.go @@ -46,13 +46,7 @@ func buildBlockchainEvent(ns string, subID *fftypes.UUID, event *blockchain.Even func (em *eventManager) getChainListenerByProtocolIDCached(ctx context.Context, protocolID string) (*core.ContractListener, error) { return em.getChainListenerCached(fmt.Sprintf("pid:%s", protocolID), func() (*core.ContractListener, error) { - return em.database.GetContractListenerByBackendID(ctx, protocolID) - }) -} - -func (em *eventManager) getChainListenerByIDCached(ctx context.Context, id *fftypes.UUID) (*core.ContractListener, error) { - return em.getChainListenerCached(fmt.Sprintf("id:%s", id), func() (*core.ContractListener, error) { - return em.database.GetContractListenerByID(ctx, id) + return em.database.GetContractListenerByBackendID(ctx, em.namespace, protocolID) }) } @@ -70,24 +64,20 @@ func (em *eventManager) getChainListenerCached(cacheKey string, getter func() (* return listener, err } -func (em *eventManager) getTopicForChainListener(ctx context.Context, listenerID *fftypes.UUID) (string, error) { - if listenerID == nil { - return core.SystemBatchPinTopic, nil - } - listener, err := em.getChainListenerByIDCached(ctx, listenerID) - if err != nil { - return "", err +func (em *eventManager) getTopicForChainListener(listener *core.ContractListener) string { + if listener == nil { + return core.SystemBatchPinTopic } var topic string if listener != nil && listener.Topic != "" { topic = listener.Topic } else { - topic = listenerID.String() + topic = listener.ID.String() } - return topic, nil + return topic } -func (em *eventManager) maybePersistBlockchainEvent(ctx context.Context, chainEvent *core.BlockchainEvent) error { +func (em *eventManager) maybePersistBlockchainEvent(ctx context.Context, chainEvent *core.BlockchainEvent, listener *core.ContractListener) error { if existing, err := em.database.GetBlockchainEventByProtocolID(ctx, chainEvent.Namespace, chainEvent.Listener, chainEvent.ProtocolID); err != nil { return err } else if existing != nil { @@ -99,10 +89,7 @@ func (em *eventManager) maybePersistBlockchainEvent(ctx context.Context, chainEv if err := em.txHelper.InsertBlockchainEvent(ctx, chainEvent); err != nil { return err } - topic, err := em.getTopicForChainListener(ctx, chainEvent.Listener) - if err != nil { - return err - } + topic := em.getTopicForChainListener(listener) ffEvent := core.NewEvent(core.EventTypeBlockchainEventReceived, chainEvent.Namespace, chainEvent.ID, chainEvent.TX.ID, topic) if err := em.database.InsertEvent(ctx, ffEvent); err != nil { return err @@ -119,23 +106,23 @@ func (em *eventManager) emitBlockchainEventMetric(event *blockchain.Event) { func (em *eventManager) BlockchainEvent(event *blockchain.EventWithSubscription) error { return em.retry.Do(em.ctx, "persist blockchain event", func(attempt int) (bool, error) { err := em.database.RunAsGroup(em.ctx, func(ctx context.Context) error { - sub, err := em.getChainListenerByProtocolIDCached(ctx, event.Subscription) + listener, err := em.getChainListenerByProtocolIDCached(ctx, event.Subscription) if err != nil { return err } - if sub == nil { + if listener == nil { log.L(ctx).Warnf("Event received from unknown subscription %s", event.Subscription) return nil // no retry } - if sub.Namespace != em.namespace { - log.L(em.ctx).Debugf("Ignoring blockchain event from different namespace '%s'", sub.Namespace) + if listener.Namespace != em.namespace { + log.L(em.ctx).Debugf("Ignoring blockchain event from different namespace '%s'", listener.Namespace) return nil } - chainEvent := buildBlockchainEvent(sub.Namespace, sub.ID, &event.Event, &core.BlockchainTransactionRef{ + chainEvent := buildBlockchainEvent(listener.Namespace, listener.ID, &event.Event, &core.BlockchainTransactionRef{ BlockchainID: event.BlockchainTXID, }) - if err := em.maybePersistBlockchainEvent(ctx, chainEvent); err != nil { + if err := em.maybePersistBlockchainEvent(ctx, chainEvent, listener); err != nil { return err } em.emitBlockchainEventMetric(&event.Event) diff --git a/internal/events/blockchain_event_test.go b/internal/events/blockchain_event_test.go index c48f13c62..a9fe04176 100644 --- a/internal/events/blockchain_event_test.go +++ b/internal/events/blockchain_event_test.go @@ -56,8 +56,8 @@ func TestContractEventWithRetries(t *testing.T) { var eventID *fftypes.UUID mdi := em.database.(*databasemocks.Plugin) - mdi.On("GetContractListenerByBackendID", mock.Anything, "sb-1").Return(nil, fmt.Errorf("pop")).Once() - mdi.On("GetContractListenerByBackendID", mock.Anything, "sb-1").Return(sub, nil).Times(1) // cached + mdi.On("GetContractListenerByBackendID", mock.Anything, "ns1", "sb-1").Return(nil, fmt.Errorf("pop")).Once() + mdi.On("GetContractListenerByBackendID", mock.Anything, "ns1", "sb-1").Return(sub, nil).Times(1) // cached mth := em.txHelper.(*txcommonmocks.Helper) mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", sub.ID, ev.ProtocolID).Return(nil, nil) mth.On("InsertBlockchainEvent", mock.Anything, mock.Anything).Return(fmt.Errorf("pop")).Once() @@ -65,7 +65,6 @@ func TestContractEventWithRetries(t *testing.T) { eventID = e.ID return *e.Listener == *sub.ID && e.Name == "Changed" && e.Namespace == "ns1" })).Return(nil).Times(2) - mdi.On("GetContractListenerByID", mock.Anything, sub.ID).Return(sub, nil) mdi.On("InsertEvent", mock.Anything, mock.Anything).Return(fmt.Errorf("pop")).Once() mdi.On("InsertEvent", mock.Anything, mock.MatchedBy(func(e *core.Event) bool { return e.Type == core.EventTypeBlockchainEventReceived && e.Reference != nil && e.Reference == eventID && e.Topic == "topic1" @@ -97,7 +96,7 @@ func TestContractEventUnknownSubscription(t *testing.T) { } mdi := em.database.(*databasemocks.Plugin) - mdi.On("GetContractListenerByBackendID", mock.Anything, "sb-1").Return(nil, nil) + mdi.On("GetContractListenerByBackendID", mock.Anything, "ns1", "sb-1").Return(nil, nil) err := em.BlockchainEvent(ev) assert.NoError(t, err) @@ -129,7 +128,7 @@ func TestContractEventWrongNS(t *testing.T) { } mdi := em.database.(*databasemocks.Plugin) - mdi.On("GetContractListenerByBackendID", mock.Anything, "sb-1").Return(sub, nil) + mdi.On("GetContractListenerByBackendID", mock.Anything, "ns1", "sb-1").Return(sub, nil) err := em.BlockchainEvent(ev) assert.NoError(t, err) @@ -157,7 +156,7 @@ func TestPersistBlockchainEventDuplicate(t *testing.T) { mdi := em.database.(*databasemocks.Plugin) mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", ev.Listener, ev.ProtocolID).Return(&core.BlockchainEvent{}, nil) - err := em.maybePersistBlockchainEvent(em.ctx, ev) + err := em.maybePersistBlockchainEvent(em.ctx, ev, nil) assert.NoError(t, err) mdi.AssertExpectations(t) @@ -183,42 +182,12 @@ func TestPersistBlockchainEventLookupFail(t *testing.T) { mdi := em.database.(*databasemocks.Plugin) mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", ev.Listener, ev.ProtocolID).Return(nil, fmt.Errorf("pop")) - err := em.maybePersistBlockchainEvent(em.ctx, ev) + err := em.maybePersistBlockchainEvent(em.ctx, ev, nil) assert.EqualError(t, err, "pop") mdi.AssertExpectations(t) } -func TestPersistBlockchainEventChainListenerLookupFail(t *testing.T) { - em, cancel := newTestEventManager(t) - defer cancel() - - ev := &core.BlockchainEvent{ - Name: "Changed", - Namespace: "ns1", - ProtocolID: "10/20/30", - Output: fftypes.JSONObject{ - "value": "1", - }, - Info: fftypes.JSONObject{ - "blockNumber": "10", - }, - Listener: fftypes.NewUUID(), - } - - mdi := em.database.(*databasemocks.Plugin) - mth := em.txHelper.(*txcommonmocks.Helper) - mdi.On("GetBlockchainEventByProtocolID", mock.Anything, "ns1", ev.Listener, ev.ProtocolID).Return(nil, nil) - mth.On("InsertBlockchainEvent", mock.Anything, mock.Anything).Return(nil) - mdi.On("GetContractListenerByID", mock.Anything, ev.Listener).Return(nil, fmt.Errorf("pop")) - - err := em.maybePersistBlockchainEvent(em.ctx, ev) - assert.Regexp(t, "pop", err) - - mdi.AssertExpectations(t) - mth.AssertExpectations(t) -} - func TestGetTopicForChainListenerFallback(t *testing.T) { em, cancel := newTestEventManager(t) defer cancel() @@ -229,14 +198,8 @@ func TestGetTopicForChainListenerFallback(t *testing.T) { Topic: "", } - mdi := em.database.(*databasemocks.Plugin) - mdi.On("GetContractListenerByID", mock.Anything, mock.Anything).Return(sub, nil) - - topic, err := em.getTopicForChainListener(em.ctx, sub.ID) - assert.NoError(t, err) + topic := em.getTopicForChainListener(sub) assert.Equal(t, sub.ID.String(), topic) - - mdi.AssertExpectations(t) } func TestBlockchainEventMetric(t *testing.T) { diff --git a/internal/events/network_action.go b/internal/events/network_action.go index e3006c5ed..8337a48c8 100644 --- a/internal/events/network_action.go +++ b/internal/events/network_action.go @@ -69,7 +69,7 @@ func (em *eventManager) BlockchainNetworkAction(action string, event *blockchain chainEvent := buildBlockchainEvent(core.LegacySystemNamespace, nil, event, &core.BlockchainTransactionRef{ BlockchainID: event.BlockchainTXID, }) - err = em.maybePersistBlockchainEvent(em.ctx, chainEvent) + err = em.maybePersistBlockchainEvent(em.ctx, chainEvent, nil) } return true, err }) diff --git a/internal/events/token_pool_created.go b/internal/events/token_pool_created.go index 911127d3d..2ba21c771 100644 --- a/internal/events/token_pool_created.go +++ b/internal/events/token_pool_created.go @@ -57,7 +57,7 @@ func (em *eventManager) confirmPool(ctx context.Context, pool *core.TokenPool, e Type: pool.TX.Type, BlockchainID: ev.BlockchainTXID, }) - if err := em.maybePersistBlockchainEvent(ctx, chainEvent); err != nil { + if err := em.maybePersistBlockchainEvent(ctx, chainEvent, nil); err != nil { return err } em.emitBlockchainEventMetric(ev) diff --git a/internal/events/tokens_approved.go b/internal/events/tokens_approved.go index 059d91c31..185eac364 100644 --- a/internal/events/tokens_approved.go +++ b/internal/events/tokens_approved.go @@ -109,7 +109,7 @@ func (em *eventManager) persistTokenApproval(ctx context.Context, approval *toke Type: approval.TX.Type, BlockchainID: approval.Event.BlockchainTXID, }) - if err := em.maybePersistBlockchainEvent(ctx, chainEvent); err != nil { + if err := em.maybePersistBlockchainEvent(ctx, chainEvent, nil); err != nil { return false, err } em.emitBlockchainEventMetric(&approval.Event) diff --git a/internal/events/tokens_transferred.go b/internal/events/tokens_transferred.go index 36bc33e8e..070d05d29 100644 --- a/internal/events/tokens_transferred.go +++ b/internal/events/tokens_transferred.go @@ -109,7 +109,7 @@ func (em *eventManager) persistTokenTransfer(ctx context.Context, transfer *toke Type: transfer.TX.Type, BlockchainID: transfer.Event.BlockchainTXID, }) - if err := em.maybePersistBlockchainEvent(ctx, chainEvent); err != nil { + if err := em.maybePersistBlockchainEvent(ctx, chainEvent, nil); err != nil { return false, err } em.emitBlockchainEventMetric(&transfer.Event) diff --git a/mocks/contractmocks/manager.go b/mocks/contractmocks/manager.go index d3bb9803f..5a65fedde 100644 --- a/mocks/contractmocks/manager.go +++ b/mocks/contractmocks/manager.go @@ -19,13 +19,13 @@ type Manager struct { mock.Mock } -// AddContractAPIListener provides a mock function with given fields: ctx, ns, apiName, eventPath, listener -func (_m *Manager) AddContractAPIListener(ctx context.Context, ns string, apiName string, eventPath string, listener *core.ContractListener) (*core.ContractListener, error) { - ret := _m.Called(ctx, ns, apiName, eventPath, listener) +// AddContractAPIListener provides a mock function with given fields: ctx, apiName, eventPath, listener +func (_m *Manager) AddContractAPIListener(ctx context.Context, apiName string, eventPath string, listener *core.ContractListener) (*core.ContractListener, error) { + ret := _m.Called(ctx, apiName, eventPath, listener) var r0 *core.ContractListener - if rf, ok := ret.Get(0).(func(context.Context, string, string, string, *core.ContractListener) *core.ContractListener); ok { - r0 = rf(ctx, ns, apiName, eventPath, listener) + if rf, ok := ret.Get(0).(func(context.Context, string, string, *core.ContractListener) *core.ContractListener); ok { + r0 = rf(ctx, apiName, eventPath, listener) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.ContractListener) @@ -33,8 +33,8 @@ func (_m *Manager) AddContractAPIListener(ctx context.Context, ns string, apiNam } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, string, string, *core.ContractListener) error); ok { - r1 = rf(ctx, ns, apiName, eventPath, listener) + if rf, ok := ret.Get(1).(func(context.Context, string, string, *core.ContractListener) error); ok { + r1 = rf(ctx, apiName, eventPath, listener) } else { r1 = ret.Error(1) } @@ -42,13 +42,13 @@ func (_m *Manager) AddContractAPIListener(ctx context.Context, ns string, apiNam return r0, r1 } -// AddContractListener provides a mock function with given fields: ctx, ns, listener -func (_m *Manager) AddContractListener(ctx context.Context, ns string, listener *core.ContractListenerInput) (*core.ContractListener, error) { - ret := _m.Called(ctx, ns, listener) +// AddContractListener provides a mock function with given fields: ctx, listener +func (_m *Manager) AddContractListener(ctx context.Context, listener *core.ContractListenerInput) (*core.ContractListener, error) { + ret := _m.Called(ctx, listener) var r0 *core.ContractListener - if rf, ok := ret.Get(0).(func(context.Context, string, *core.ContractListenerInput) *core.ContractListener); ok { - r0 = rf(ctx, ns, listener) + if rf, ok := ret.Get(0).(func(context.Context, *core.ContractListenerInput) *core.ContractListener); ok { + r0 = rf(ctx, listener) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.ContractListener) @@ -56,8 +56,8 @@ func (_m *Manager) AddContractListener(ctx context.Context, ns string, listener } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, *core.ContractListenerInput) error); ok { - r1 = rf(ctx, ns, listener) + if rf, ok := ret.Get(1).(func(context.Context, *core.ContractListenerInput) error); ok { + r1 = rf(ctx, listener) } else { r1 = ret.Error(1) } @@ -111,13 +111,13 @@ func (_m *Manager) BroadcastFFI(ctx context.Context, ffi *core.FFI, waitConfirm return r0, r1 } -// DeleteContractListenerByNameOrID provides a mock function with given fields: ctx, ns, nameOrID -func (_m *Manager) DeleteContractListenerByNameOrID(ctx context.Context, ns string, nameOrID string) error { - ret := _m.Called(ctx, ns, nameOrID) +// DeleteContractListenerByNameOrID provides a mock function with given fields: ctx, nameOrID +func (_m *Manager) DeleteContractListenerByNameOrID(ctx context.Context, nameOrID string) error { + ret := _m.Called(ctx, nameOrID) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, ns, nameOrID) + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, nameOrID) } else { r0 = ret.Error(0) } @@ -125,13 +125,13 @@ func (_m *Manager) DeleteContractListenerByNameOrID(ctx context.Context, ns stri return r0 } -// GenerateFFI provides a mock function with given fields: ctx, ns, generationRequest -func (_m *Manager) GenerateFFI(ctx context.Context, ns string, generationRequest *core.FFIGenerationRequest) (*core.FFI, error) { - ret := _m.Called(ctx, ns, generationRequest) +// GenerateFFI provides a mock function with given fields: ctx, generationRequest +func (_m *Manager) GenerateFFI(ctx context.Context, generationRequest *core.FFIGenerationRequest) (*core.FFI, error) { + ret := _m.Called(ctx, generationRequest) var r0 *core.FFI - if rf, ok := ret.Get(0).(func(context.Context, string, *core.FFIGenerationRequest) *core.FFI); ok { - r0 = rf(ctx, ns, generationRequest) + if rf, ok := ret.Get(0).(func(context.Context, *core.FFIGenerationRequest) *core.FFI); ok { + r0 = rf(ctx, generationRequest) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.FFI) @@ -139,8 +139,8 @@ func (_m *Manager) GenerateFFI(ctx context.Context, ns string, generationRequest } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, *core.FFIGenerationRequest) error); ok { - r1 = rf(ctx, ns, generationRequest) + if rf, ok := ret.Get(1).(func(context.Context, *core.FFIGenerationRequest) error); ok { + r1 = rf(ctx, generationRequest) } else { r1 = ret.Error(1) } @@ -194,13 +194,13 @@ func (_m *Manager) GetContractAPIInterface(ctx context.Context, apiName string) return r0, r1 } -// GetContractAPIListeners provides a mock function with given fields: ctx, ns, apiName, eventPath, filter -func (_m *Manager) GetContractAPIListeners(ctx context.Context, ns string, apiName string, eventPath string, filter database.AndFilter) ([]*core.ContractListener, *database.FilterResult, error) { - ret := _m.Called(ctx, ns, apiName, eventPath, filter) +// GetContractAPIListeners provides a mock function with given fields: ctx, apiName, eventPath, filter +func (_m *Manager) GetContractAPIListeners(ctx context.Context, apiName string, eventPath string, filter database.AndFilter) ([]*core.ContractListener, *database.FilterResult, error) { + ret := _m.Called(ctx, apiName, eventPath, filter) var r0 []*core.ContractListener - if rf, ok := ret.Get(0).(func(context.Context, string, string, string, database.AndFilter) []*core.ContractListener); ok { - r0 = rf(ctx, ns, apiName, eventPath, filter) + if rf, ok := ret.Get(0).(func(context.Context, string, string, database.AndFilter) []*core.ContractListener); ok { + r0 = rf(ctx, apiName, eventPath, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.ContractListener) @@ -208,8 +208,8 @@ func (_m *Manager) GetContractAPIListeners(ctx context.Context, ns string, apiNa } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, string, string, string, database.AndFilter) *database.FilterResult); ok { - r1 = rf(ctx, ns, apiName, eventPath, filter) + if rf, ok := ret.Get(1).(func(context.Context, string, string, database.AndFilter) *database.FilterResult); ok { + r1 = rf(ctx, apiName, eventPath, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -217,8 +217,8 @@ func (_m *Manager) GetContractAPIListeners(ctx context.Context, ns string, apiNa } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, string, string, string, database.AndFilter) error); ok { - r2 = rf(ctx, ns, apiName, eventPath, filter) + if rf, ok := ret.Get(2).(func(context.Context, string, string, database.AndFilter) error); ok { + r2 = rf(ctx, apiName, eventPath, filter) } else { r2 = ret.Error(2) } @@ -258,13 +258,13 @@ func (_m *Manager) GetContractAPIs(ctx context.Context, httpServerURL string, fi return r0, r1, r2 } -// GetContractListenerByNameOrID provides a mock function with given fields: ctx, ns, nameOrID -func (_m *Manager) GetContractListenerByNameOrID(ctx context.Context, ns string, nameOrID string) (*core.ContractListener, error) { - ret := _m.Called(ctx, ns, nameOrID) +// GetContractListenerByNameOrID provides a mock function with given fields: ctx, nameOrID +func (_m *Manager) GetContractListenerByNameOrID(ctx context.Context, nameOrID string) (*core.ContractListener, error) { + ret := _m.Called(ctx, nameOrID) var r0 *core.ContractListener - if rf, ok := ret.Get(0).(func(context.Context, string, string) *core.ContractListener); ok { - r0 = rf(ctx, ns, nameOrID) + if rf, ok := ret.Get(0).(func(context.Context, string) *core.ContractListener); ok { + r0 = rf(ctx, nameOrID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.ContractListener) @@ -272,8 +272,8 @@ func (_m *Manager) GetContractListenerByNameOrID(ctx context.Context, ns string, } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, ns, nameOrID) + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, nameOrID) } else { r1 = ret.Error(1) } @@ -281,13 +281,13 @@ func (_m *Manager) GetContractListenerByNameOrID(ctx context.Context, ns string, return r0, r1 } -// GetContractListeners provides a mock function with given fields: ctx, ns, filter -func (_m *Manager) GetContractListeners(ctx context.Context, ns string, filter database.AndFilter) ([]*core.ContractListener, *database.FilterResult, error) { - ret := _m.Called(ctx, ns, filter) +// GetContractListeners provides a mock function with given fields: ctx, filter +func (_m *Manager) GetContractListeners(ctx context.Context, filter database.AndFilter) ([]*core.ContractListener, *database.FilterResult, error) { + ret := _m.Called(ctx, filter) var r0 []*core.ContractListener - if rf, ok := ret.Get(0).(func(context.Context, string, database.AndFilter) []*core.ContractListener); ok { - r0 = rf(ctx, ns, filter) + if rf, ok := ret.Get(0).(func(context.Context, database.AndFilter) []*core.ContractListener); ok { + r0 = rf(ctx, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.ContractListener) @@ -295,8 +295,8 @@ func (_m *Manager) GetContractListeners(ctx context.Context, ns string, filter d } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, string, database.AndFilter) *database.FilterResult); ok { - r1 = rf(ctx, ns, filter) + if rf, ok := ret.Get(1).(func(context.Context, database.AndFilter) *database.FilterResult); ok { + r1 = rf(ctx, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -304,8 +304,8 @@ func (_m *Manager) GetContractListeners(ctx context.Context, ns string, filter d } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, string, database.AndFilter) error); ok { - r2 = rf(ctx, ns, filter) + if rf, ok := ret.Get(2).(func(context.Context, database.AndFilter) error); ok { + r2 = rf(ctx, filter) } else { r2 = ret.Error(2) } diff --git a/mocks/databasemocks/plugin.go b/mocks/databasemocks/plugin.go index cbc06e45e..1d81a0d04 100644 --- a/mocks/databasemocks/plugin.go +++ b/mocks/databasemocks/plugin.go @@ -51,13 +51,13 @@ func (_m *Plugin) DeleteBlob(ctx context.Context, sequence int64) error { return r0 } -// DeleteContractListenerByID provides a mock function with given fields: ctx, id -func (_m *Plugin) DeleteContractListenerByID(ctx context.Context, id *fftypes.UUID) error { - ret := _m.Called(ctx, id) +// DeleteContractListenerByID provides a mock function with given fields: ctx, namespace, id +func (_m *Plugin) DeleteContractListenerByID(ctx context.Context, namespace string, id *fftypes.UUID) error { + ret := _m.Called(ctx, namespace, id) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *fftypes.UUID) error); ok { - r0 = rf(ctx, id) + if rf, ok := ret.Get(0).(func(context.Context, string, *fftypes.UUID) error); ok { + r0 = rf(ctx, namespace, id) } else { r0 = ret.Error(0) } @@ -493,13 +493,13 @@ func (_m *Plugin) GetContractListener(ctx context.Context, namespace string, nam return r0, r1 } -// GetContractListenerByBackendID provides a mock function with given fields: ctx, id -func (_m *Plugin) GetContractListenerByBackendID(ctx context.Context, id string) (*core.ContractListener, error) { - ret := _m.Called(ctx, id) +// GetContractListenerByBackendID provides a mock function with given fields: ctx, namespace, id +func (_m *Plugin) GetContractListenerByBackendID(ctx context.Context, namespace string, id string) (*core.ContractListener, error) { + ret := _m.Called(ctx, namespace, id) var r0 *core.ContractListener - if rf, ok := ret.Get(0).(func(context.Context, string) *core.ContractListener); ok { - r0 = rf(ctx, id) + if rf, ok := ret.Get(0).(func(context.Context, string, string) *core.ContractListener); ok { + r0 = rf(ctx, namespace, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.ContractListener) @@ -507,8 +507,8 @@ func (_m *Plugin) GetContractListenerByBackendID(ctx context.Context, id string) } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, id) + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, namespace, id) } else { r1 = ret.Error(1) } @@ -516,13 +516,13 @@ func (_m *Plugin) GetContractListenerByBackendID(ctx context.Context, id string) return r0, r1 } -// GetContractListenerByID provides a mock function with given fields: ctx, id -func (_m *Plugin) GetContractListenerByID(ctx context.Context, id *fftypes.UUID) (*core.ContractListener, error) { - ret := _m.Called(ctx, id) +// GetContractListenerByID provides a mock function with given fields: ctx, namespace, id +func (_m *Plugin) GetContractListenerByID(ctx context.Context, namespace string, id *fftypes.UUID) (*core.ContractListener, error) { + ret := _m.Called(ctx, namespace, id) var r0 *core.ContractListener - if rf, ok := ret.Get(0).(func(context.Context, *fftypes.UUID) *core.ContractListener); ok { - r0 = rf(ctx, id) + if rf, ok := ret.Get(0).(func(context.Context, string, *fftypes.UUID) *core.ContractListener); ok { + r0 = rf(ctx, namespace, id) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.ContractListener) @@ -530,8 +530,8 @@ func (_m *Plugin) GetContractListenerByID(ctx context.Context, id *fftypes.UUID) } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *fftypes.UUID) error); ok { - r1 = rf(ctx, id) + if rf, ok := ret.Get(1).(func(context.Context, string, *fftypes.UUID) error); ok { + r1 = rf(ctx, namespace, id) } else { r1 = ret.Error(1) } @@ -539,13 +539,13 @@ func (_m *Plugin) GetContractListenerByID(ctx context.Context, id *fftypes.UUID) return r0, r1 } -// GetContractListeners provides a mock function with given fields: ctx, filter -func (_m *Plugin) GetContractListeners(ctx context.Context, filter database.Filter) ([]*core.ContractListener, *database.FilterResult, error) { - ret := _m.Called(ctx, filter) +// GetContractListeners provides a mock function with given fields: ctx, namespace, filter +func (_m *Plugin) GetContractListeners(ctx context.Context, namespace string, filter database.Filter) ([]*core.ContractListener, *database.FilterResult, error) { + ret := _m.Called(ctx, namespace, filter) var r0 []*core.ContractListener - if rf, ok := ret.Get(0).(func(context.Context, database.Filter) []*core.ContractListener); ok { - r0 = rf(ctx, filter) + if rf, ok := ret.Get(0).(func(context.Context, string, database.Filter) []*core.ContractListener); ok { + r0 = rf(ctx, namespace, filter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.ContractListener) @@ -553,8 +553,8 @@ func (_m *Plugin) GetContractListeners(ctx context.Context, filter database.Filt } var r1 *database.FilterResult - if rf, ok := ret.Get(1).(func(context.Context, database.Filter) *database.FilterResult); ok { - r1 = rf(ctx, filter) + if rf, ok := ret.Get(1).(func(context.Context, string, database.Filter) *database.FilterResult); ok { + r1 = rf(ctx, namespace, filter) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*database.FilterResult) @@ -562,8 +562,8 @@ func (_m *Plugin) GetContractListeners(ctx context.Context, filter database.Filt } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, database.Filter) error); ok { - r2 = rf(ctx, filter) + if rf, ok := ret.Get(2).(func(context.Context, string, database.Filter) error); ok { + r2 = rf(ctx, namespace, filter) } else { r2 = ret.Error(2) } diff --git a/pkg/database/plugin.go b/pkg/database/plugin.go index 789e2f63a..61e5be8b3 100644 --- a/pkg/database/plugin.go +++ b/pkg/database/plugin.go @@ -472,23 +472,23 @@ type iContractAPICollection interface { } type iContractListenerCollection interface { - // InsertContractListener - upsert a subscription to an external smart contract + // InsertContractListener - upsert a listener to an external smart contract InsertContractListener(ctx context.Context, sub *core.ContractListener) (err error) - // GetContractListener - get smart contract subscription by name + // GetContractListener - get contract listener by name GetContractListener(ctx context.Context, namespace, name string) (sub *core.ContractListener, err error) - // GetContractListenerByID - get smart contract subscription by ID - GetContractListenerByID(ctx context.Context, id *fftypes.UUID) (sub *core.ContractListener, err error) + // GetContractListenerByID - get contract listener by ID + GetContractListenerByID(ctx context.Context, namespace string, id *fftypes.UUID) (sub *core.ContractListener, err error) - // GetContractListenerByBackendID - get smart contract subscription by backend ID - GetContractListenerByBackendID(ctx context.Context, id string) (sub *core.ContractListener, err error) + // GetContractListenerByBackendID - get contract listener by backend ID + GetContractListenerByBackendID(ctx context.Context, namespace, id string) (sub *core.ContractListener, err error) - // GetContractListeners - get smart contract subscriptions - GetContractListeners(ctx context.Context, filter Filter) ([]*core.ContractListener, *FilterResult, error) + // GetContractListeners - get contract listeners + GetContractListeners(ctx context.Context, namespace string, filter Filter) ([]*core.ContractListener, *FilterResult, error) - // DeleteContractListener - delete a subscription to an external smart contract - DeleteContractListenerByID(ctx context.Context, id *fftypes.UUID) (err error) + // DeleteContractListener - delete a contract listener + DeleteContractListenerByID(ctx context.Context, namespace string, id *fftypes.UUID) (err error) } type iBlockchainEventCollection interface { From 43ed4a7d9e37ba61ae23f297493c8c055adcf879 Mon Sep 17 00:00:00 2001 From: Andrew Richardson Date: Wed, 22 Jun 2022 15:59:14 -0400 Subject: [PATCH 9/9] Remove namespace from more manager calls Signed-off-by: Andrew Richardson --- .../apiserver/route_get_chart_histogram.go | 2 +- .../route_get_chart_histogram_test.go | 2 +- internal/apiserver/route_get_data_blob.go | 2 +- .../apiserver/route_get_data_blob_test.go | 2 +- internal/apiserver/route_post_data.go | 4 +- internal/apiserver/route_post_data_test.go | 12 ++-- internal/apiserver/route_post_new_datatype.go | 2 +- .../apiserver/route_post_new_datatype_test.go | 4 +- internal/apiserver/route_post_op_retry.go | 2 +- .../apiserver/route_post_op_retry_test.go | 2 +- internal/apiserver/routes.go | 8 --- internal/apiserver/server.go | 29 ++++---- internal/broadcast/datatype.go | 8 +-- internal/broadcast/datatype_test.go | 20 +++--- internal/broadcast/definition.go | 18 ++--- internal/broadcast/definition_test.go | 11 ++- internal/broadcast/manager.go | 10 +-- internal/broadcast/tokenpool.go | 4 +- internal/broadcast/tokenpool_test.go | 8 +-- internal/contracts/manager.go | 4 +- internal/contracts/manager_test.go | 16 ++--- internal/data/blobstore.go | 21 +++--- internal/data/blobstore_test.go | 50 +++++-------- internal/data/data_manager.go | 45 ++++++------ internal/data/data_manager_test.go | 28 ++++---- .../definition_handler_datatype.go | 2 +- .../definition_handler_datatype_test.go | 12 ++-- internal/events/batch_pin_complete.go | 2 +- internal/events/batch_pin_complete_test.go | 8 +-- internal/events/event_manager.go | 2 +- internal/events/persist_batch.go | 2 +- internal/events/ss_callbacks.go | 6 +- internal/events/ss_callbacks_test.go | 9 +-- internal/events/token_pool_created.go | 2 +- internal/events/token_pool_created_test.go | 2 +- internal/namespace/manager.go | 2 +- internal/namespace/manager_test.go | 2 +- internal/networkmap/register_identity.go | 4 +- internal/networkmap/register_identity_test.go | 7 -- internal/networkmap/register_node_test.go | 1 - internal/networkmap/register_org_test.go | 1 - internal/networkmap/update_identity.go | 2 +- internal/networkmap/update_identity_test.go | 2 - internal/operations/manager.go | 24 +++---- internal/operations/manager_test.go | 16 ++--- internal/orchestrator/bound_callbacks.go | 4 +- internal/orchestrator/bound_callbacks_test.go | 4 +- internal/orchestrator/chart.go | 4 +- internal/orchestrator/chart_test.go | 14 ++-- internal/orchestrator/orchestrator.go | 2 +- internal/shareddownload/download_manager.go | 22 +++--- .../shareddownload/download_manager_test.go | 10 +-- internal/shareddownload/operations.go | 36 ++++------ internal/shareddownload/operations_test.go | 7 +- mocks/blockchainmocks/plugin.go | 10 +-- mocks/broadcastmocks/manager.go | 70 +++++++++---------- mocks/datamocks/manager.go | 56 +++++++-------- mocks/eventmocks/event_manager.go | 14 ++-- mocks/operationmocks/manager.go | 24 +++---- mocks/orchestratormocks/orchestrator.go | 14 ++-- mocks/shareddownloadmocks/callbacks.go | 14 ++-- mocks/shareddownloadmocks/manager.go | 20 +++--- 62 files changed, 348 insertions(+), 399 deletions(-) diff --git a/internal/apiserver/route_get_chart_histogram.go b/internal/apiserver/route_get_chart_histogram.go index 0dedc515b..025cb24d0 100644 --- a/internal/apiserver/route_get_chart_histogram.go +++ b/internal/apiserver/route_get_chart_histogram.go @@ -58,7 +58,7 @@ var getChartHistogram = &ffapi.Route{ if err != nil { return nil, i18n.NewError(cr.ctx, coremsgs.MsgInvalidChartNumberParam, "buckets") } - return cr.or.GetChartHistogram(cr.ctx, extractNamespace(r.PP), startTime.UnixNano(), endTime.UnixNano(), buckets, database.CollectionName(r.PP["collection"])) + return cr.or.GetChartHistogram(cr.ctx, startTime.UnixNano(), endTime.UnixNano(), buckets, database.CollectionName(r.PP["collection"])) }, }, } diff --git a/internal/apiserver/route_get_chart_histogram_test.go b/internal/apiserver/route_get_chart_histogram_test.go index 9c5de8add..d9c6575fa 100644 --- a/internal/apiserver/route_get_chart_histogram_test.go +++ b/internal/apiserver/route_get_chart_histogram_test.go @@ -69,7 +69,7 @@ func TestGetChartHistogramSuccess(t *testing.T) { startTime, _ := fftypes.ParseTimeString("1234567890") endtime, _ := fftypes.ParseTimeString("1234567891") - o.On("GetChartHistogram", mock.Anything, "mynamespace", startTime.UnixNano(), endtime.UnixNano(), int64(30), database.CollectionName("test")). + o.On("GetChartHistogram", mock.Anything, startTime.UnixNano(), endtime.UnixNano(), int64(30), database.CollectionName("test")). Return([]*core.ChartHistogram{}, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_get_data_blob.go b/internal/apiserver/route_get_data_blob.go index 43d779bcc..3cd7388ff 100644 --- a/internal/apiserver/route_get_data_blob.go +++ b/internal/apiserver/route_get_data_blob.go @@ -41,7 +41,7 @@ var getDataBlob = &ffapi.Route{ Extensions: &coreExtensions{ FilterFactory: database.MessageQueryFactory, CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - blob, reader, err := cr.or.Data().DownloadBlob(cr.ctx, extractNamespace(r.PP), r.PP["dataid"]) + blob, reader, err := cr.or.Data().DownloadBlob(cr.ctx, r.PP["dataid"]) if err == nil { r.ResponseHeaders.Set(core.HTTPHeadersBlobHashSHA256, blob.Hash.String()) if blob.Size > 0 { diff --git a/internal/apiserver/route_get_data_blob_test.go b/internal/apiserver/route_get_data_blob_test.go index c76a4ed86..9e823a435 100644 --- a/internal/apiserver/route_get_data_blob_test.go +++ b/internal/apiserver/route_get_data_blob_test.go @@ -38,7 +38,7 @@ func TestGetDataBlob(t *testing.T) { res := httptest.NewRecorder() blobHash := fftypes.NewRandB32() - mdm.On("DownloadBlob", mock.Anything, "mynamespace", "abcd1234"). + mdm.On("DownloadBlob", mock.Anything, "abcd1234"). Return(&core.Blob{ Hash: blobHash, Size: 12345, diff --git a/internal/apiserver/route_post_data.go b/internal/apiserver/route_post_data.go index da5d39caf..46e031211 100644 --- a/internal/apiserver/route_post_data.go +++ b/internal/apiserver/route_post_data.go @@ -47,7 +47,7 @@ var postData = &ffapi.Route{ JSONOutputCodes: []int{http.StatusCreated}, Extensions: &coreExtensions{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { - output, err = cr.or.Data().UploadJSON(cr.ctx, extractNamespace(r.PP), r.Input.(*core.DataRefOrValue)) + output, err = cr.or.Data().UploadJSON(cr.ctx, r.Input.(*core.DataRefOrValue)) return output, err }, CoreFormUploadHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { @@ -71,7 +71,7 @@ var postData = &ffapi.Route{ } data.Value = fftypes.JSONAnyPtr(metadata) } - output, err = cr.or.Data().UploadBlob(cr.ctx, extractNamespace(r.PP), data, r.Part, strings.EqualFold(r.FP["autometa"], "true")) + output, err = cr.or.Data().UploadBlob(cr.ctx, data, r.Part, strings.EqualFold(r.FP["autometa"], "true")) return output, err }, }, diff --git a/internal/apiserver/route_post_data_test.go b/internal/apiserver/route_post_data_test.go index cc0563a09..9068f6159 100644 --- a/internal/apiserver/route_post_data_test.go +++ b/internal/apiserver/route_post_data_test.go @@ -42,7 +42,7 @@ func TestPostDataJSON(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mdm.On("UploadJSON", mock.Anything, "ns1", mock.AnythingOfType("*core.DataRefOrValue")). + mdm.On("UploadJSON", mock.Anything, mock.AnythingOfType("*core.DataRefOrValue")). Return(&core.Data{}, nil) r.ServeHTTP(res, req) @@ -60,7 +60,7 @@ func TestPostDataJSONDefaultNS(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mdm.On("UploadJSON", mock.Anything, "default", mock.AnythingOfType("*core.DataRefOrValue")). + mdm.On("UploadJSON", mock.Anything, mock.AnythingOfType("*core.DataRefOrValue")). Return(&core.Data{}, nil) r.ServeHTTP(res, req) @@ -85,7 +85,7 @@ func TestPostDataBinary(t *testing.T) { res := httptest.NewRecorder() - mdm.On("UploadBlob", mock.Anything, "ns1", mock.AnythingOfType("*core.DataRefOrValue"), mock.AnythingOfType("*ffapi.Multipart"), false). + mdm.On("UploadBlob", mock.Anything, mock.AnythingOfType("*core.DataRefOrValue"), mock.AnythingOfType("*ffapi.Multipart"), false). Return(&core.Data{}, nil) r.ServeHTTP(res, req) @@ -125,7 +125,7 @@ func TestPostDataBinaryObjAutoMeta(t *testing.T) { res := httptest.NewRecorder() - mdm.On("UploadBlob", mock.Anything, "ns1", mock.MatchedBy(func(d *core.DataRefOrValue) bool { + mdm.On("UploadBlob", mock.Anything, mock.MatchedBy(func(d *core.DataRefOrValue) bool { assert.Equal(t, `{"filename":"anything"}`, string(*d.Value)) assert.Equal(t, core.ValidatorTypeJSON, d.Validator) assert.Equal(t, "fileinfo", d.Datatype.Name) @@ -159,7 +159,7 @@ func TestPostDataBinaryStringMetadata(t *testing.T) { res := httptest.NewRecorder() - mdm.On("UploadBlob", mock.Anything, "ns1", mock.MatchedBy(func(d *core.DataRefOrValue) bool { + mdm.On("UploadBlob", mock.Anything, mock.MatchedBy(func(d *core.DataRefOrValue) bool { assert.Equal(t, `"string metadata"`, string(*d.Value)) assert.Equal(t, "", string(d.Validator)) assert.Nil(t, d.Datatype) @@ -192,7 +192,7 @@ func TestPostDataTrailingMetadata(t *testing.T) { res := httptest.NewRecorder() - mdm.On("UploadBlob", mock.Anything, "ns1", mock.Anything, mock.AnythingOfType("*ffapi.Multipart"), false). + mdm.On("UploadBlob", mock.Anything, mock.Anything, mock.AnythingOfType("*ffapi.Multipart"), false). Return(&core.Data{}, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_post_new_datatype.go b/internal/apiserver/route_post_new_datatype.go index fd50a3a47..53e591247 100644 --- a/internal/apiserver/route_post_new_datatype.go +++ b/internal/apiserver/route_post_new_datatype.go @@ -41,7 +41,7 @@ var postNewDatatype = &ffapi.Route{ CoreJSONHandler: func(r *ffapi.APIRequest, cr *coreRequest) (output interface{}, err error) { waitConfirm := strings.EqualFold(r.QP["confirm"], "true") r.SuccessStatus = syncRetcode(waitConfirm) - _, err = cr.or.Broadcast().BroadcastDatatype(cr.ctx, extractNamespace(r.PP), r.Input.(*core.Datatype), waitConfirm) + _, err = cr.or.Broadcast().BroadcastDatatype(cr.ctx, r.Input.(*core.Datatype), waitConfirm) return r.Input, err }, }, diff --git a/internal/apiserver/route_post_new_datatype_test.go b/internal/apiserver/route_post_new_datatype_test.go index a670d4d07..c047b42bc 100644 --- a/internal/apiserver/route_post_new_datatype_test.go +++ b/internal/apiserver/route_post_new_datatype_test.go @@ -39,7 +39,7 @@ func TestPostNewDatatypes(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mbm.On("BroadcastDatatype", mock.Anything, "ns1", mock.AnythingOfType("*core.Datatype"), false). + mbm.On("BroadcastDatatype", mock.Anything, mock.AnythingOfType("*core.Datatype"), false). Return(&core.Message{}, nil) r.ServeHTTP(res, req) @@ -57,7 +57,7 @@ func TestPostNewDatatypesSync(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mbm.On("BroadcastDatatype", mock.Anything, "ns1", mock.AnythingOfType("*core.Datatype"), true). + mbm.On("BroadcastDatatype", mock.Anything, mock.AnythingOfType("*core.Datatype"), true). Return(&core.Message{}, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/route_post_op_retry.go b/internal/apiserver/route_post_op_retry.go index 6336668a2..5c8cedddc 100644 --- a/internal/apiserver/route_post_op_retry.go +++ b/internal/apiserver/route_post_op_retry.go @@ -43,7 +43,7 @@ var postOpRetry = &ffapi.Route{ if err != nil { return nil, err } - return cr.or.Operations().RetryOperation(cr.ctx, extractNamespace(r.PP), opid) + return cr.or.Operations().RetryOperation(cr.ctx, opid) }, }, } diff --git a/internal/apiserver/route_post_op_retry_test.go b/internal/apiserver/route_post_op_retry_test.go index a96336b94..308dc4d76 100644 --- a/internal/apiserver/route_post_op_retry_test.go +++ b/internal/apiserver/route_post_op_retry_test.go @@ -41,7 +41,7 @@ func TestPostOpRetry(t *testing.T) { req.Header.Set("Content-Type", "application/json; charset=utf-8") res := httptest.NewRecorder() - mom.On("RetryOperation", mock.Anything, "ns1", opID). + mom.On("RetryOperation", mock.Anything, opID). Return(&core.Operation{}, nil) r.ServeHTTP(res, req) diff --git a/internal/apiserver/routes.go b/internal/apiserver/routes.go index 780656371..094ab816e 100644 --- a/internal/apiserver/routes.go +++ b/internal/apiserver/routes.go @@ -19,7 +19,6 @@ package apiserver import ( "context" - "github.com/hyperledger/firefly-common/pkg/config" "github.com/hyperledger/firefly-common/pkg/ffapi" "github.com/hyperledger/firefly/internal/coreconfig" "github.com/hyperledger/firefly/internal/coremsgs" @@ -176,10 +175,3 @@ func namespacedRoutes(routes []*ffapi.Route) []*ffapi.Route { } return append(routes, newRoutes...) } - -func extractNamespace(pathParams map[string]string) string { - if ns, ok := pathParams["ns"]; ok { - return ns - } - return config.GetString(coreconfig.NamespacesDefault) -} diff --git a/internal/apiserver/server.go b/internal/apiserver/server.go index 1dc914961..32dcdb65e 100644 --- a/internal/apiserver/server.go +++ b/internal/apiserver/server.go @@ -218,6 +218,19 @@ func (as *apiServer) contractSwaggerGenerator(mgr namespace.Manager, apiBaseURL } } +func getOrchestrator(mgr namespace.Manager, tag string, r *ffapi.APIRequest) orchestrator.Orchestrator { + if tag == routeTagDefaultNamespace { + return mgr.Orchestrator(config.GetString(coreconfig.NamespacesDefault)) + } + if tag == routeTagNonDefaultNamespace { + vars := mux.Vars(r.Req) + if ns, ok := vars["ns"]; ok { + return mgr.Orchestrator(ns) + } + } + return nil +} + func (as *apiServer) routeHandler(hf *ffapi.HandlerFactory, mgr namespace.Manager, apiBaseURL string, route *ffapi.Route) http.HandlerFunc { // We extend the base ffapi functionality, with standardized DB filter support for all core resources. // We also pass the Orchestrator context through @@ -231,15 +244,9 @@ func (as *apiServer) routeHandler(hf *ffapi.HandlerFactory, mgr namespace.Manage } } - var or orchestrator.Orchestrator - if route.Tag == routeTagDefaultNamespace || route.Tag == routeTagNonDefaultNamespace { - vars := mux.Vars(r.Req) - or = mgr.Orchestrator(extractNamespace(vars)) - } - cr := &coreRequest{ mgr: mgr, - or: or, + or: getOrchestrator(mgr, route.Tag, r), ctx: r.Req.Context(), filter: filter, apiBaseURL: apiBaseURL, @@ -248,15 +255,9 @@ func (as *apiServer) routeHandler(hf *ffapi.HandlerFactory, mgr namespace.Manage } if ce.CoreFormUploadHandler != nil { route.FormUploadHandler = func(r *ffapi.APIRequest) (output interface{}, err error) { - var or orchestrator.Orchestrator - if route.Tag == routeTagDefaultNamespace || route.Tag == routeTagNonDefaultNamespace { - vars := mux.Vars(r.Req) - or = mgr.Orchestrator(extractNamespace(vars)) - } - cr := &coreRequest{ mgr: mgr, - or: or, + or: getOrchestrator(mgr, route.Tag, r), ctx: r.Req.Context(), apiBaseURL: apiBaseURL, } diff --git a/internal/broadcast/datatype.go b/internal/broadcast/datatype.go index 988bddf15..2255bfa1a 100644 --- a/internal/broadcast/datatype.go +++ b/internal/broadcast/datatype.go @@ -23,12 +23,12 @@ import ( "github.com/hyperledger/firefly/pkg/core" ) -func (bm *broadcastManager) BroadcastDatatype(ctx context.Context, ns string, datatype *core.Datatype, waitConfirm bool) (*core.Message, error) { +func (bm *broadcastManager) BroadcastDatatype(ctx context.Context, datatype *core.Datatype, waitConfirm bool) (*core.Message, error) { // Validate the input data definition data datatype.ID = fftypes.NewUUID() datatype.Created = fftypes.Now() - datatype.Namespace = ns + datatype.Namespace = bm.namespace if datatype.Validator == "" { datatype.Validator = core.ValidatorTypeJSON } @@ -41,10 +41,10 @@ func (bm *broadcastManager) BroadcastDatatype(ctx context.Context, ns string, da datatype.Hash = datatype.Value.Hash() // Verify the data type is now all valid, before we broadcast it - if err := bm.data.CheckDatatype(ctx, ns, datatype); err != nil { + if err := bm.data.CheckDatatype(ctx, datatype); err != nil { return nil, err } - msg, err := bm.BroadcastDefinitionAsNode(ctx, ns, datatype, core.SystemTagDefineDatatype, waitConfirm) + msg, err := bm.BroadcastDefinitionAsNode(ctx, datatype, core.SystemTagDefineDatatype, waitConfirm) if msg != nil { datatype.Message = msg.Header.ID } diff --git a/internal/broadcast/datatype_test.go b/internal/broadcast/datatype_test.go index 6878330cb..9b28e42d5 100644 --- a/internal/broadcast/datatype_test.go +++ b/internal/broadcast/datatype_test.go @@ -34,7 +34,7 @@ import ( func TestBroadcastDatatypeBadType(t *testing.T) { bm, cancel := newTestBroadcast(t) defer cancel() - _, err := bm.BroadcastDatatype(context.Background(), "ns1", &core.Datatype{ + _, err := bm.BroadcastDatatype(context.Background(), &core.Datatype{ Validator: core.ValidatorType("wrong"), }, false) assert.Regexp(t, "FF00111.*validator", err) @@ -45,7 +45,7 @@ func TestBroadcastDatatypeNSGetFail(t *testing.T) { defer cancel() mdm := bm.data.(*datamocks.Manager) mdm.On("VerifyNamespaceExists", mock.Anything, "ns1").Return(fmt.Errorf("pop")) - _, err := bm.BroadcastDatatype(context.Background(), "ns1", &core.Datatype{ + _, err := bm.BroadcastDatatype(context.Background(), &core.Datatype{ Name: "name1", Namespace: "ns1", Version: "0.0.1", @@ -59,10 +59,10 @@ func TestBroadcastDatatypeBadValue(t *testing.T) { defer cancel() mdm := bm.data.(*datamocks.Manager) mdm.On("VerifyNamespaceExists", mock.Anything, "ns1").Return(nil) - mdm.On("CheckDatatype", mock.Anything, "ns1", mock.Anything).Return(nil) + mdm.On("CheckDatatype", mock.Anything, mock.Anything).Return(nil) mim := bm.identity.(*identitymanagermocks.Manager) mim.On("ResolveInputSigningIdentity", mock.Anything, mock.Anything).Return(nil) - _, err := bm.BroadcastDatatype(context.Background(), "ns1", &core.Datatype{ + _, err := bm.BroadcastDatatype(context.Background(), &core.Datatype{ Namespace: "ns1", Name: "ent1", Version: "0.0.1", @@ -80,9 +80,9 @@ func TestBroadcastUpsertFail(t *testing.T) { mim.On("ResolveInputSigningIdentity", mock.Anything, mock.Anything).Return(nil) mdm.On("WriteNewMessage", mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("pop")) mdm.On("VerifyNamespaceExists", mock.Anything, "ns1").Return(nil) - mdm.On("CheckDatatype", mock.Anything, "ns1", mock.Anything).Return(nil) + mdm.On("CheckDatatype", mock.Anything, mock.Anything).Return(nil) - _, err := bm.BroadcastDatatype(context.Background(), "ns1", &core.Datatype{ + _, err := bm.BroadcastDatatype(context.Background(), &core.Datatype{ Namespace: "ns1", Name: "ent1", Version: "0.0.1", @@ -104,9 +104,9 @@ func TestBroadcastDatatypeInvalid(t *testing.T) { mim.On("ResolveInputIdentity", mock.Anything, mock.Anything).Return(nil) mdi.On("UpsertData", mock.Anything, mock.Anything, database.UpsertOptimizationNew).Return(nil) mdm.On("VerifyNamespaceExists", mock.Anything, "ns1").Return(nil) - mdm.On("CheckDatatype", mock.Anything, "ns1", mock.Anything).Return(fmt.Errorf("pop")) + mdm.On("CheckDatatype", mock.Anything, mock.Anything).Return(fmt.Errorf("pop")) - _, err := bm.BroadcastDatatype(context.Background(), "ns1", &core.Datatype{ + _, err := bm.BroadcastDatatype(context.Background(), &core.Datatype{ Namespace: "ns1", Name: "ent1", Version: "0.0.1", @@ -123,10 +123,10 @@ func TestBroadcastOk(t *testing.T) { mim.On("ResolveInputSigningIdentity", mock.Anything, mock.Anything).Return(nil) mdm.On("VerifyNamespaceExists", mock.Anything, "ns1").Return(nil) - mdm.On("CheckDatatype", mock.Anything, "ns1", mock.Anything).Return(nil) + mdm.On("CheckDatatype", mock.Anything, mock.Anything).Return(nil) mdm.On("WriteNewMessage", mock.Anything, mock.Anything, mock.Anything).Return(nil) - _, err := bm.BroadcastDatatype(context.Background(), "ns1", &core.Datatype{ + _, err := bm.BroadcastDatatype(context.Background(), &core.Datatype{ Namespace: "ns1", Name: "ent1", Version: "0.0.1", diff --git a/internal/broadcast/definition.go b/internal/broadcast/definition.go index ae47e6519..361b08784 100644 --- a/internal/broadcast/definition.go +++ b/internal/broadcast/definition.go @@ -28,39 +28,39 @@ import ( "github.com/hyperledger/firefly/pkg/core" ) -func (bm *broadcastManager) BroadcastDefinitionAsNode(ctx context.Context, ns string, def core.Definition, tag string, waitConfirm bool) (msg *core.Message, err error) { - return bm.BroadcastDefinition(ctx, ns, def, &core.SignerRef{ /* resolve to node default */ }, tag, waitConfirm) +func (bm *broadcastManager) BroadcastDefinitionAsNode(ctx context.Context, def core.Definition, tag string, waitConfirm bool) (msg *core.Message, err error) { + return bm.BroadcastDefinition(ctx, def, &core.SignerRef{ /* resolve to node default */ }, tag, waitConfirm) } -func (bm *broadcastManager) BroadcastDefinition(ctx context.Context, ns string, def core.Definition, signingIdentity *core.SignerRef, tag string, waitConfirm bool) (msg *core.Message, err error) { +func (bm *broadcastManager) BroadcastDefinition(ctx context.Context, def core.Definition, signingIdentity *core.SignerRef, tag string, waitConfirm bool) (msg *core.Message, err error) { err = bm.identity.ResolveInputSigningIdentity(ctx, signingIdentity) if err != nil { return nil, err } - return bm.broadcastDefinitionCommon(ctx, ns, def, signingIdentity, tag, waitConfirm) + return bm.broadcastDefinitionCommon(ctx, def, signingIdentity, tag, waitConfirm) } // BroadcastIdentityClaim is a special form of BroadcastDefinitionAsNode where the signing identity does not need to have been pre-registered // The blockchain "key" will be normalized, but the "author" will pass through unchecked -func (bm *broadcastManager) BroadcastIdentityClaim(ctx context.Context, ns string, def *core.IdentityClaim, signingIdentity *core.SignerRef, tag string, waitConfirm bool) (msg *core.Message, err error) { +func (bm *broadcastManager) BroadcastIdentityClaim(ctx context.Context, def *core.IdentityClaim, signingIdentity *core.SignerRef, tag string, waitConfirm bool) (msg *core.Message, err error) { signingIdentity.Key, err = bm.identity.NormalizeSigningKey(ctx, signingIdentity.Key, identity.KeyNormalizationBlockchainPlugin) if err != nil { return nil, err } - return bm.broadcastDefinitionCommon(ctx, ns, def, signingIdentity, tag, waitConfirm) + return bm.broadcastDefinitionCommon(ctx, def, signingIdentity, tag, waitConfirm) } -func (bm *broadcastManager) broadcastDefinitionCommon(ctx context.Context, ns string, def core.Definition, signingIdentity *core.SignerRef, tag string, waitConfirm bool) (*core.Message, error) { +func (bm *broadcastManager) broadcastDefinitionCommon(ctx context.Context, def core.Definition, signingIdentity *core.SignerRef, tag string, waitConfirm bool) (*core.Message, error) { // Serialize it into a data object, as a piece of data we can write to a message d := &core.Data{ Validator: core.ValidatorTypeSystemDefinition, ID: fftypes.NewUUID(), - Namespace: ns, + Namespace: bm.namespace, Created: fftypes.Now(), } b, err := json.Marshal(&def) @@ -77,7 +77,7 @@ func (bm *broadcastManager) broadcastDefinitionCommon(ctx context.Context, ns st Message: &core.MessageInOut{ Message: core.Message{ Header: core.MessageHeader{ - Namespace: ns, + Namespace: bm.namespace, Type: core.MessageTypeDefinition, SignerRef: *signingIdentity, Topics: core.FFStringArray{def.Topic()}, diff --git a/internal/broadcast/definition_test.go b/internal/broadcast/definition_test.go index 5eac2968e..40f4f2593 100644 --- a/internal/broadcast/definition_test.go +++ b/internal/broadcast/definition_test.go @@ -38,7 +38,7 @@ func TestBroadcastDefinitionAsNodeConfirm(t *testing.T) { mim.On("ResolveInputSigningIdentity", mock.Anything, mock.Anything).Return(nil) msa.On("WaitForMessage", bm.ctx, mock.Anything, mock.Anything).Return(nil, fmt.Errorf("pop")) - _, err := bm.BroadcastDefinitionAsNode(bm.ctx, "ns1", &core.Namespace{}, core.SystemTagDefineNamespace, true) + _, err := bm.BroadcastDefinitionAsNode(bm.ctx, &core.Namespace{}, core.SystemTagDefineNamespace, true) assert.EqualError(t, err, "pop") msa.AssertExpectations(t) @@ -55,7 +55,7 @@ func TestBroadcastIdentityClaim(t *testing.T) { mim.On("NormalizeSigningKey", mock.Anything, "0x1234", identity.KeyNormalizationBlockchainPlugin).Return("", nil) msa.On("WaitForMessage", bm.ctx, mock.Anything, mock.Anything).Return(nil, fmt.Errorf("pop")) - _, err := bm.BroadcastIdentityClaim(bm.ctx, "ns1", &core.IdentityClaim{ + _, err := bm.BroadcastIdentityClaim(bm.ctx, &core.IdentityClaim{ Identity: &core.Identity{}, }, &core.SignerRef{ Key: "0x1234", @@ -74,7 +74,7 @@ func TestBroadcastIdentityClaimFail(t *testing.T) { mim.On("NormalizeSigningKey", mock.Anything, "0x1234", identity.KeyNormalizationBlockchainPlugin).Return("", fmt.Errorf("pop")) - _, err := bm.BroadcastIdentityClaim(bm.ctx, "ns1", &core.IdentityClaim{ + _, err := bm.BroadcastIdentityClaim(bm.ctx, &core.IdentityClaim{ Identity: &core.Identity{}, }, &core.SignerRef{ Key: "0x1234", @@ -90,12 +90,11 @@ func TestBroadcastDatatypeDefinitionAsNodeConfirm(t *testing.T) { msa := bm.syncasync.(*syncasyncmocks.Bridge) mim := bm.identity.(*identitymanagermocks.Manager) - ns := "customNamespace" mim.On("ResolveInputSigningIdentity", mock.Anything, mock.Anything).Return(nil) msa.On("WaitForMessage", bm.ctx, mock.Anything, mock.Anything).Return(nil, fmt.Errorf("pop")) - _, err := bm.BroadcastDefinitionAsNode(bm.ctx, ns, &core.Datatype{}, core.SystemTagDefineNamespace, true) + _, err := bm.BroadcastDefinitionAsNode(bm.ctx, &core.Datatype{}, core.SystemTagDefineNamespace, true) assert.EqualError(t, err, "pop") msa.AssertExpectations(t) @@ -108,7 +107,7 @@ func TestBroadcastDefinitionBadIdentity(t *testing.T) { mim := bm.identity.(*identitymanagermocks.Manager) mim.On("ResolveInputSigningIdentity", mock.Anything, mock.Anything).Return(fmt.Errorf("pop")) - _, err := bm.BroadcastDefinition(bm.ctx, "ns1", &core.Namespace{}, &core.SignerRef{ + _, err := bm.BroadcastDefinition(bm.ctx, &core.Namespace{}, &core.SignerRef{ Author: "wrong", Key: "wrong", }, core.SystemTagDefineNamespace, false) diff --git a/internal/broadcast/manager.go b/internal/broadcast/manager.go index a012d0b04..580e54915 100644 --- a/internal/broadcast/manager.go +++ b/internal/broadcast/manager.go @@ -46,12 +46,12 @@ type Manager interface { core.Named NewBroadcast(in *core.MessageInOut) sysmessaging.MessageSender - BroadcastDatatype(ctx context.Context, ns string, datatype *core.Datatype, waitConfirm bool) (msg *core.Message, err error) + BroadcastDatatype(ctx context.Context, datatype *core.Datatype, waitConfirm bool) (msg *core.Message, err error) BroadcastMessage(ctx context.Context, in *core.MessageInOut, waitConfirm bool) (out *core.Message, err error) - BroadcastDefinitionAsNode(ctx context.Context, ns string, def core.Definition, tag string, waitConfirm bool) (msg *core.Message, err error) - BroadcastDefinition(ctx context.Context, ns string, def core.Definition, signingIdentity *core.SignerRef, tag string, waitConfirm bool) (msg *core.Message, err error) - BroadcastIdentityClaim(ctx context.Context, ns string, def *core.IdentityClaim, signingIdentity *core.SignerRef, tag string, waitConfirm bool) (msg *core.Message, err error) - BroadcastTokenPool(ctx context.Context, ns string, pool *core.TokenPoolAnnouncement, waitConfirm bool) (msg *core.Message, err error) + BroadcastDefinitionAsNode(ctx context.Context, def core.Definition, tag string, waitConfirm bool) (msg *core.Message, err error) + BroadcastDefinition(ctx context.Context, def core.Definition, signingIdentity *core.SignerRef, tag string, waitConfirm bool) (msg *core.Message, err error) + BroadcastIdentityClaim(ctx context.Context, def *core.IdentityClaim, signingIdentity *core.SignerRef, tag string, waitConfirm bool) (msg *core.Message, err error) + BroadcastTokenPool(ctx context.Context, pool *core.TokenPoolAnnouncement, waitConfirm bool) (msg *core.Message, err error) Start() error WaitStop() diff --git a/internal/broadcast/tokenpool.go b/internal/broadcast/tokenpool.go index b848f96e7..5a6dc400a 100644 --- a/internal/broadcast/tokenpool.go +++ b/internal/broadcast/tokenpool.go @@ -22,7 +22,7 @@ import ( "github.com/hyperledger/firefly/pkg/core" ) -func (bm *broadcastManager) BroadcastTokenPool(ctx context.Context, ns string, pool *core.TokenPoolAnnouncement, waitConfirm bool) (msg *core.Message, err error) { +func (bm *broadcastManager) BroadcastTokenPool(ctx context.Context, pool *core.TokenPoolAnnouncement, waitConfirm bool) (msg *core.Message, err error) { if err := pool.Pool.Validate(ctx); err != nil { return nil, err } @@ -30,7 +30,7 @@ func (bm *broadcastManager) BroadcastTokenPool(ctx context.Context, ns string, p return nil, err } - msg, err = bm.BroadcastDefinitionAsNode(ctx, ns, pool, core.SystemTagDefinePool, waitConfirm) + msg, err = bm.BroadcastDefinitionAsNode(ctx, pool, core.SystemTagDefinePool, waitConfirm) if msg != nil { pool.Pool.Message = msg.Header.ID } diff --git a/internal/broadcast/tokenpool_test.go b/internal/broadcast/tokenpool_test.go index afada1625..aaeaa1904 100644 --- a/internal/broadcast/tokenpool_test.go +++ b/internal/broadcast/tokenpool_test.go @@ -48,7 +48,7 @@ func TestBroadcastTokenPoolNSGetFail(t *testing.T) { mdm.On("VerifyNamespaceExists", mock.Anything, "ns1").Return(fmt.Errorf("pop")) - _, err := bm.BroadcastTokenPool(context.Background(), "ns1", pool, false) + _, err := bm.BroadcastTokenPool(context.Background(), pool, false) assert.EqualError(t, err, "pop") mdm.AssertExpectations(t) @@ -71,7 +71,7 @@ func TestBroadcastTokenPoolInvalid(t *testing.T) { }, } - _, err := bm.BroadcastTokenPool(context.Background(), "ns1", pool, false) + _, err := bm.BroadcastTokenPool(context.Background(), pool, false) assert.Regexp(t, "FF00140", err) mdi.AssertExpectations(t) @@ -99,7 +99,7 @@ func TestBroadcastTokenPoolBroadcastFail(t *testing.T) { mdm.On("VerifyNamespaceExists", mock.Anything, "ns1").Return(nil) mdm.On("WriteNewMessage", mock.Anything, mock.Anything).Return(fmt.Errorf("pop")) - _, err := bm.BroadcastTokenPool(context.Background(), "ns1", pool, false) + _, err := bm.BroadcastTokenPool(context.Background(), pool, false) assert.EqualError(t, err, "pop") mdm.AssertExpectations(t) @@ -127,7 +127,7 @@ func TestBroadcastTokenPoolOk(t *testing.T) { mdm.On("VerifyNamespaceExists", mock.Anything, "ns1").Return(nil) mdm.On("WriteNewMessage", mock.Anything, mock.Anything).Return(nil) - _, err := bm.BroadcastTokenPool(context.Background(), "ns1", pool, false) + _, err := bm.BroadcastTokenPool(context.Background(), pool, false) assert.NoError(t, err) mdm.AssertExpectations(t) diff --git a/internal/contracts/manager.go b/internal/contracts/manager.go index 3d6537f76..6cc9c3e2e 100644 --- a/internal/contracts/manager.go +++ b/internal/contracts/manager.go @@ -139,7 +139,7 @@ func (cm *contractManager) BroadcastFFI(ctx context.Context, ffi *core.FFI, wait } output = ffi - msg, err := cm.broadcast.BroadcastDefinitionAsNode(ctx, cm.namespace, ffi, core.SystemTagDefineFFI, waitConfirm) + msg, err := cm.broadcast.BroadcastDefinitionAsNode(ctx, ffi, core.SystemTagDefineFFI, waitConfirm) if err != nil { return nil, err } @@ -373,7 +373,7 @@ func (cm *contractManager) BroadcastContractAPI(ctx context.Context, httpServerU return nil, err } - msg, err := cm.broadcast.BroadcastDefinitionAsNode(ctx, cm.namespace, api, core.SystemTagDefineContractAPI, waitConfirm) + msg, err := cm.broadcast.BroadcastDefinitionAsNode(ctx, api, core.SystemTagDefineContractAPI, waitConfirm) if err != nil { return nil, err } diff --git a/internal/contracts/manager_test.go b/internal/contracts/manager_test.go index d23b8aefc..743d95268 100644 --- a/internal/contracts/manager_test.go +++ b/internal/contracts/manager_test.go @@ -119,7 +119,7 @@ func TestBroadcastFFI(t *testing.T) { ID: fftypes.NewUUID(), }, } - mbm.On("BroadcastDefinitionAsNode", mock.Anything, "ns1", mock.AnythingOfType("*core.FFI"), core.SystemTagDefineFFI, false).Return(msg, nil) + mbm.On("BroadcastDefinitionAsNode", mock.Anything, mock.AnythingOfType("*core.FFI"), core.SystemTagDefineFFI, false).Return(msg, nil) ffi := &core.FFI{ Name: "test", Version: "1.0.0", @@ -153,7 +153,7 @@ func TestBroadcastFFIInvalid(t *testing.T) { ID: fftypes.NewUUID(), }, } - mbm.On("BroadcastDefinitionAsNode", mock.Anything, "ns1", mock.AnythingOfType("*core.FFI"), core.SystemTagDefineFFI, false).Return(msg, nil) + mbm.On("BroadcastDefinitionAsNode", mock.Anything, mock.AnythingOfType("*core.FFI"), core.SystemTagDefineFFI, false).Return(msg, nil) ffi := &core.FFI{ Name: "test", Version: "1.0.0", @@ -186,7 +186,7 @@ func TestBroadcastFFIExists(t *testing.T) { ID: fftypes.NewUUID(), }, } - mbm.On("BroadcastDefinitionAsNode", mock.Anything, "ns1", mock.AnythingOfType("*core.FFI"), core.SystemTagDefineFFI, false).Return(msg, nil) + mbm.On("BroadcastDefinitionAsNode", mock.Anything, mock.AnythingOfType("*core.FFI"), core.SystemTagDefineFFI, false).Return(msg, nil) ffi := &core.FFI{ Name: "test", Version: "1.0.0", @@ -205,7 +205,7 @@ func TestBroadcastFFIFail(t *testing.T) { mdb.On("GetFFI", mock.Anything, "ns1", "test", "1.0.0").Return(nil, nil) mim.On("GetOrgKey", mock.Anything).Return("key", nil) - mbm.On("BroadcastDefinitionAsNode", mock.Anything, "ns1", mock.AnythingOfType("*core.FFI"), core.SystemTagDefineFFI, false).Return(nil, fmt.Errorf("pop")) + mbm.On("BroadcastDefinitionAsNode", mock.Anything, mock.AnythingOfType("*core.FFI"), core.SystemTagDefineFFI, false).Return(nil, fmt.Errorf("pop")) ffi := &core.FFI{ Name: "test", Version: "1.0.0", @@ -1946,7 +1946,7 @@ func TestBroadcastContractAPI(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), api.Location).Return(api.Location, nil) mdb.On("GetContractAPIByName", mock.Anything, api.Namespace, api.Name).Return(nil, nil) mdb.On("GetFFIByID", mock.Anything, "ns1", api.Interface.ID).Return(&core.FFI{}, nil) - mbm.On("BroadcastDefinitionAsNode", mock.Anything, "ns1", mock.AnythingOfType("*core.ContractAPI"), core.SystemTagDefineContractAPI, false).Return(msg, nil) + mbm.On("BroadcastDefinitionAsNode", mock.Anything, mock.AnythingOfType("*core.ContractAPI"), core.SystemTagDefineContractAPI, false).Return(msg, nil) api, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", api, false) @@ -2017,7 +2017,7 @@ func TestBroadcastContractAPIExisting(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), api.Location).Return(api.Location, nil) mdb.On("GetContractAPIByName", mock.Anything, api.Namespace, api.Name).Return(existing, nil) mdb.On("GetFFIByID", mock.Anything, "ns1", api.Interface.ID).Return(&core.FFI{}, nil) - mbm.On("BroadcastDefinitionAsNode", mock.Anything, "ns1", mock.AnythingOfType("*core.ContractAPI"), core.SystemTagDefineContractAPI, false).Return(msg, nil) + mbm.On("BroadcastDefinitionAsNode", mock.Anything, mock.AnythingOfType("*core.ContractAPI"), core.SystemTagDefineContractAPI, false).Return(msg, nil) _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", api, false) @@ -2090,7 +2090,7 @@ func TestBroadcastContractAPIInterfaceName(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), api.Location).Return(api.Location, nil) mdb.On("GetContractAPIByName", mock.Anything, api.Namespace, api.Name).Return(nil, nil) mdb.On("GetFFI", mock.Anything, "ns1", "my-ffi", "1").Return(&core.FFI{ID: interfaceID}, nil) - mbm.On("BroadcastDefinitionAsNode", mock.Anything, "ns1", mock.AnythingOfType("*core.ContractAPI"), core.SystemTagDefineContractAPI, false).Return(msg, nil) + mbm.On("BroadcastDefinitionAsNode", mock.Anything, mock.AnythingOfType("*core.ContractAPI"), core.SystemTagDefineContractAPI, false).Return(msg, nil) _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", api, false) @@ -2121,7 +2121,7 @@ func TestBroadcastContractAPIFail(t *testing.T) { mbi.On("NormalizeContractLocation", context.Background(), api.Location).Return(api.Location, nil) mdb.On("GetContractAPIByName", mock.Anything, api.Namespace, api.Name).Return(nil, nil) mdb.On("GetFFIByID", mock.Anything, "ns1", api.Interface.ID).Return(&core.FFI{}, nil) - mbm.On("BroadcastDefinitionAsNode", mock.Anything, "ns1", mock.AnythingOfType("*core.ContractAPI"), core.SystemTagDefineContractAPI, false).Return(nil, fmt.Errorf("pop")) + mbm.On("BroadcastDefinitionAsNode", mock.Anything, mock.AnythingOfType("*core.ContractAPI"), core.SystemTagDefineContractAPI, false).Return(nil, fmt.Errorf("pop")) _, err := cm.BroadcastContractAPI(context.Background(), "http://localhost/api", api, false) diff --git a/internal/data/blobstore.go b/internal/data/blobstore.go index 9cfb032c6..f541de457 100644 --- a/internal/data/blobstore.go +++ b/internal/data/blobstore.go @@ -41,7 +41,7 @@ type blobStore struct { exchange dataexchange.Plugin } -func (bs *blobStore) uploadVerifyBlob(ctx context.Context, ns string, id *fftypes.UUID, reader io.Reader) (hash *fftypes.Bytes32, written int64, payloadRef string, err error) { +func (bs *blobStore) uploadVerifyBlob(ctx context.Context, id *fftypes.UUID, reader io.Reader) (hash *fftypes.Bytes32, written int64, payloadRef string, err error) { hashCalc := sha256.New() dxReader, dx := io.Pipe() storeAndHash := io.MultiWriter(hashCalc, dx) @@ -55,7 +55,7 @@ func (bs *blobStore) uploadVerifyBlob(ctx context.Context, ns string, id *fftype copyDone <- err }() - payloadRef, uploadHash, uploadSize, dxErr := bs.exchange.UploadBlob(ctx, ns, *id, dxReader) + payloadRef, uploadHash, uploadSize, dxErr := bs.exchange.UploadBlob(ctx, bs.dm.namespace, *id, dxReader) dxReader.Close() copyErr := <-copyDone if dxErr != nil { @@ -79,11 +79,11 @@ func (bs *blobStore) uploadVerifyBlob(ctx context.Context, ns string, id *fftype } -func (bs *blobStore) UploadBlob(ctx context.Context, ns string, inData *core.DataRefOrValue, mpart *ffapi.Multipart, autoMeta bool) (*core.Data, error) { +func (bs *blobStore) UploadBlob(ctx context.Context, inData *core.DataRefOrValue, mpart *ffapi.Multipart, autoMeta bool) (*core.Data, error) { data := &core.Data{ ID: fftypes.NewUUID(), - Namespace: ns, + Namespace: bs.dm.namespace, Created: fftypes.Now(), Validator: inData.Validator, Datatype: inData.Datatype, @@ -91,10 +91,10 @@ func (bs *blobStore) UploadBlob(ctx context.Context, ns string, inData *core.Dat } data.ID = fftypes.NewUUID() - data.Namespace = ns + data.Namespace = bs.dm.namespace data.Created = fftypes.Now() - hash, blobSize, payloadRef, err := bs.uploadVerifyBlob(ctx, ns, data.ID, mpart.Data) + hash, blobSize, payloadRef, err := bs.uploadVerifyBlob(ctx, data.ID, mpart.Data) if err != nil { return nil, err } @@ -119,7 +119,7 @@ func (bs *blobStore) UploadBlob(ctx context.Context, ns string, inData *core.Dat Created: fftypes.Now(), } - err = bs.dm.checkValidation(ctx, ns, data.Validator, data.Datatype, data.Value) + err = bs.dm.checkValidation(ctx, data.Validator, data.Datatype, data.Value) if err == nil { err = data.Seal(ctx, blob) } @@ -142,11 +142,8 @@ func (bs *blobStore) UploadBlob(ctx context.Context, ns string, inData *core.Dat return data, nil } -func (bs *blobStore) DownloadBlob(ctx context.Context, ns, dataID string) (*core.Blob, io.ReadCloser, error) { +func (bs *blobStore) DownloadBlob(ctx context.Context, dataID string) (*core.Blob, io.ReadCloser, error) { - if err := core.ValidateFFNameField(ctx, ns, "namespace"); err != nil { - return nil, nil, err - } id, err := fftypes.ParseUUID(ctx, dataID) if err != nil { return nil, nil, err @@ -156,7 +153,7 @@ func (bs *blobStore) DownloadBlob(ctx context.Context, ns, dataID string) (*core if err != nil { return nil, nil, err } - if data == nil || data.Namespace != ns { + if data == nil { return nil, nil, i18n.NewError(ctx, coremsgs.Msg404NoResult) } if data.Blob == nil || data.Blob.Hash == nil { diff --git a/internal/data/blobstore_test.go b/internal/data/blobstore_test.go index ce3ee758c..c6246f27f 100644 --- a/internal/data/blobstore_test.go +++ b/internal/data/blobstore_test.go @@ -70,7 +70,7 @@ func TestUploadBlobOk(t *testing.T) { dxUpload.ReturnArguments = mock.Arguments{fmt.Sprintf("ns1/%s", uuid), &hash, int64(len(b)), err} } - data, err := dm.UploadBlob(ctx, "ns1", &core.DataRefOrValue{}, &ffapi.Multipart{Data: bytes.NewReader(b)}, false) + data, err := dm.UploadBlob(ctx, &core.DataRefOrValue{}, &ffapi.Multipart{Data: bytes.NewReader(b)}, false) assert.NoError(t, err) // Check the hashes and other details of the data @@ -111,7 +111,7 @@ func TestUploadBlobAutoMetaOk(t *testing.T) { dxUpload.ReturnArguments = mock.Arguments{fmt.Sprintf("ns1/%s", uuid), &hash, int64(len(readBytes)), err} } - data, err := dm.UploadBlob(ctx, "ns1", &core.DataRefOrValue{ + data, err := dm.UploadBlob(ctx, &core.DataRefOrValue{ Value: fftypes.JSONAnyPtr(`{"custom": "value1"}`), }, &ffapi.Multipart{ Data: bytes.NewReader([]byte(`hello`)), @@ -146,7 +146,7 @@ func TestUploadBlobBadValidator(t *testing.T) { dxUpload.ReturnArguments = mock.Arguments{fmt.Sprintf("ns1/%s", uuid), &hash, int64(len(readBytes)), err} } - _, err := dm.UploadBlob(ctx, "ns1", &core.DataRefOrValue{ + _, err := dm.UploadBlob(ctx, &core.DataRefOrValue{ Value: fftypes.JSONAnyPtr(`{"custom": "value1"}`), Validator: "wrong", }, &ffapi.Multipart{ @@ -172,7 +172,7 @@ func TestUploadBlobReadFail(t *testing.T) { assert.NoError(t, err) } - _, err := dm.UploadBlob(ctx, "ns1", &core.DataRefOrValue{}, &ffapi.Multipart{Data: iotest.ErrReader(fmt.Errorf("pop"))}, false) + _, err := dm.UploadBlob(ctx, &core.DataRefOrValue{}, &ffapi.Multipart{Data: iotest.ErrReader(fmt.Errorf("pop"))}, false) assert.Regexp(t, "FF10217.*pop", err) } @@ -185,7 +185,7 @@ func TestUploadBlobWriteFailDoesNotRead(t *testing.T) { mdx := dm.exchange.(*dataexchangemocks.Plugin) mdx.On("UploadBlob", ctx, "ns1", mock.Anything, mock.Anything).Return("", nil, int64(0), fmt.Errorf("pop")) - _, err := dm.UploadBlob(ctx, "ns1", &core.DataRefOrValue{}, &ffapi.Multipart{Data: bytes.NewReader([]byte(`any old data`))}, false) + _, err := dm.UploadBlob(ctx, &core.DataRefOrValue{}, &ffapi.Multipart{Data: bytes.NewReader([]byte(`any old data`))}, false) assert.Regexp(t, "pop", err) } @@ -203,7 +203,7 @@ func TestUploadBlobHashMismatchCalculated(t *testing.T) { assert.Nil(t, err) } - _, err := dm.UploadBlob(ctx, "ns1", &core.DataRefOrValue{}, &ffapi.Multipart{Data: bytes.NewReader([]byte(b))}, false) + _, err := dm.UploadBlob(ctx, &core.DataRefOrValue{}, &ffapi.Multipart{Data: bytes.NewReader([]byte(b))}, false) assert.Regexp(t, "FF10238", err) } @@ -222,7 +222,7 @@ func TestUploadBlobSizeMismatch(t *testing.T) { assert.Nil(t, err) } - _, err := dm.UploadBlob(ctx, "ns1", &core.DataRefOrValue{}, &ffapi.Multipart{Data: bytes.NewReader([]byte(b))}, false) + _, err := dm.UploadBlob(ctx, &core.DataRefOrValue{}, &ffapi.Multipart{Data: bytes.NewReader([]byte(b))}, false) assert.Regexp(t, "FF10323", err) } @@ -243,7 +243,7 @@ func TestUploadBlobUpsertFail(t *testing.T) { mdi := dm.database.(*databasemocks.Plugin) mdi.On("RunAsGroup", mock.Anything, mock.Anything).Return(fmt.Errorf("pop")) - _, err := dm.UploadBlob(ctx, "ns1", &core.DataRefOrValue{}, &ffapi.Multipart{Data: bytes.NewReader([]byte(b))}, false) + _, err := dm.UploadBlob(ctx, &core.DataRefOrValue{}, &ffapi.Multipart{Data: bytes.NewReader([]byte(b))}, false) assert.Regexp(t, "pop", err) } @@ -274,7 +274,7 @@ func TestDownloadBlobOk(t *testing.T) { ioutil.NopCloser(bytes.NewReader([]byte("some blob"))), nil) - blob, reader, err := dm.DownloadBlob(ctx, "ns1", dataID.String()) + blob, reader, err := dm.DownloadBlob(ctx, dataID.String()) assert.NoError(t, err) assert.Equal(t, blobHash.String(), blob.Hash.String()) b, err := ioutil.ReadAll(reader) @@ -301,7 +301,7 @@ func TestDownloadBlobNotFound(t *testing.T) { }, nil) mdi.On("GetBlobMatchingHash", ctx, blobHash).Return(nil, nil) - _, _, err := dm.DownloadBlob(ctx, "ns1", dataID.String()) + _, _, err := dm.DownloadBlob(ctx, dataID.String()) assert.Regexp(t, "FF10239", err) } @@ -324,7 +324,7 @@ func TestDownloadBlobLookupErr(t *testing.T) { }, nil) mdi.On("GetBlobMatchingHash", ctx, blobHash).Return(nil, fmt.Errorf("pop")) - _, _, err := dm.DownloadBlob(ctx, "ns1", dataID.String()) + _, _, err := dm.DownloadBlob(ctx, dataID.String()) assert.Regexp(t, "pop", err) } @@ -343,12 +343,12 @@ func TestDownloadBlobNoBlob(t *testing.T) { Blob: &core.BlobRef{}, }, nil) - _, _, err := dm.DownloadBlob(ctx, "ns1", dataID.String()) + _, _, err := dm.DownloadBlob(ctx, dataID.String()) assert.Regexp(t, "FF10241", err) } -func TestDownloadBlobNSMismatch(t *testing.T) { +func TestDownloadBlobNoData(t *testing.T) { dm, ctx, cancel := newTestDataManager(t) defer cancel() @@ -356,13 +356,9 @@ func TestDownloadBlobNSMismatch(t *testing.T) { dataID := fftypes.NewUUID() mdi := dm.database.(*databasemocks.Plugin) - mdi.On("GetDataByID", ctx, "ns1", dataID, false).Return(&core.Data{ - ID: dataID, - Namespace: "ns2", - Blob: &core.BlobRef{}, - }, nil) + mdi.On("GetDataByID", ctx, "ns1", dataID, false).Return(nil, nil) - _, _, err := dm.DownloadBlob(ctx, "ns1", dataID.String()) + _, _, err := dm.DownloadBlob(ctx, dataID.String()) assert.Regexp(t, "FF10143", err) } @@ -377,29 +373,17 @@ func TestDownloadBlobDataLookupErr(t *testing.T) { mdi := dm.database.(*databasemocks.Plugin) mdi.On("GetDataByID", ctx, "ns1", dataID, false).Return(nil, fmt.Errorf("pop")) - _, _, err := dm.DownloadBlob(ctx, "ns1", dataID.String()) + _, _, err := dm.DownloadBlob(ctx, dataID.String()) assert.Regexp(t, "pop", err) } -func TestDownloadBlobBadNS(t *testing.T) { - - dm, ctx, cancel := newTestDataManager(t) - defer cancel() - - dataID := fftypes.NewUUID() - - _, _, err := dm.DownloadBlob(ctx, "!wrong", dataID.String()) - assert.Regexp(t, "FF00140.*namespace", err) - -} - func TestDownloadBlobBadID(t *testing.T) { dm, ctx, cancel := newTestDataManager(t) defer cancel() - _, _, err := dm.DownloadBlob(ctx, "ns1", "!uuid") + _, _, err := dm.DownloadBlob(ctx, "!uuid") assert.Regexp(t, "FF00138", err) } diff --git a/internal/data/data_manager.go b/internal/data/data_manager.go index d311a2cea..d7b3b7276 100644 --- a/internal/data/data_manager.go +++ b/internal/data/data_manager.go @@ -37,7 +37,7 @@ import ( ) type Manager interface { - CheckDatatype(ctx context.Context, ns string, datatype *core.Datatype) error + CheckDatatype(ctx context.Context, datatype *core.Datatype) error ValidateAll(ctx context.Context, data core.DataArray) (valid bool, err error) GetMessageWithDataCached(ctx context.Context, msgID *fftypes.UUID, options ...CacheReadOption) (msg *core.Message, data core.DataArray, foundAllData bool, err error) GetMessageDataCached(ctx context.Context, msg *core.Message, options ...CacheReadOption) (data core.DataArray, foundAll bool, err error) @@ -49,9 +49,9 @@ type Manager interface { WriteNewMessage(ctx context.Context, newMsg *NewMessage) error VerifyNamespaceExists(ctx context.Context, ns string) error - UploadJSON(ctx context.Context, ns string, inData *core.DataRefOrValue) (*core.Data, error) - UploadBlob(ctx context.Context, ns string, inData *core.DataRefOrValue, blob *ffapi.Multipart, autoMeta bool) (*core.Data, error) - DownloadBlob(ctx context.Context, ns, dataID string) (*core.Blob, io.ReadCloser, error) + UploadJSON(ctx context.Context, inData *core.DataRefOrValue) (*core.Data, error) + UploadBlob(ctx context.Context, inData *core.DataRefOrValue, blob *ffapi.Multipart, autoMeta bool) (*core.Data, error) + DownloadBlob(ctx context.Context, dataID string) (*core.Blob, io.ReadCloser, error) HydrateBatch(ctx context.Context, persistedBatch *core.BatchPersisted) (*core.Batch, error) WaitStop() } @@ -139,8 +139,8 @@ func NewDataManager(ctx context.Context, ns string, di database.Plugin, pi share return dm, nil } -func (dm *dataManager) CheckDatatype(ctx context.Context, ns string, datatype *core.Datatype) error { - _, err := newJSONValidator(ctx, ns, datatype) +func (dm *dataManager) CheckDatatype(ctx context.Context, datatype *core.Datatype) error { + _, err := newJSONValidator(ctx, dm.namespace, datatype) return err } @@ -172,32 +172,32 @@ func (dm *dataManager) VerifyNamespaceExists(ctx context.Context, ns string) err } // getValidatorForDatatype only returns database errors - not found (of all kinds) is a nil -func (dm *dataManager) getValidatorForDatatype(ctx context.Context, ns string, validator core.ValidatorType, datatypeRef *core.DatatypeRef) (Validator, error) { +func (dm *dataManager) getValidatorForDatatype(ctx context.Context, validator core.ValidatorType, datatypeRef *core.DatatypeRef) (Validator, error) { if validator == "" { validator = core.ValidatorTypeJSON } - if ns == "" || datatypeRef == nil || datatypeRef.Name == "" || datatypeRef.Version == "" { - log.L(ctx).Warnf("Invalid datatype reference '%s:%s:%s'", validator, ns, datatypeRef) + if datatypeRef == nil || datatypeRef.Name == "" || datatypeRef.Version == "" { + log.L(ctx).Warnf("Invalid datatype reference '%s:%s:%s'", validator, dm.namespace, datatypeRef) return nil, nil } - key := fmt.Sprintf("%s:%s:%s", validator, ns, datatypeRef) + key := fmt.Sprintf("%s:%s:%s", validator, dm.namespace, datatypeRef) if cached := dm.validatorCache.Get(key); cached != nil { cached.Extend(dm.validatorCacheTTL) return cached.Value().(Validator), nil } - datatype, err := dm.database.GetDatatypeByName(ctx, ns, datatypeRef.Name, datatypeRef.Version) + datatype, err := dm.database.GetDatatypeByName(ctx, dm.namespace, datatypeRef.Name, datatypeRef.Version) if err != nil { return nil, err } if datatype == nil { return nil, nil } - v, err := newJSONValidator(ctx, ns, datatype) + v, err := newJSONValidator(ctx, dm.namespace, datatype) if err != nil { - log.L(ctx).Errorf("Invalid validator stored for '%s:%s:%s': %s", validator, ns, datatypeRef, err) + log.L(ctx).Errorf("Invalid validator stored for '%s:%s:%s': %s", validator, dm.namespace, datatypeRef, err) return nil, nil } @@ -336,7 +336,7 @@ func (dm *dataManager) getMessageData(ctx context.Context, msg *core.Message) (d func (dm *dataManager) ValidateAll(ctx context.Context, data core.DataArray) (valid bool, err error) { for _, d := range data { if d.Datatype != nil && d.Validator != core.ValidatorTypeNone { - v, err := dm.getValidatorForDatatype(ctx, d.Namespace, d.Validator, d.Datatype) + v, err := dm.getValidatorForDatatype(ctx, d.Validator, d.Datatype) if err != nil { return false, err } @@ -388,7 +388,7 @@ func (dm *dataManager) resolveBlob(ctx context.Context, blobRef *core.BlobRef) ( return nil, nil } -func (dm *dataManager) checkValidation(ctx context.Context, ns string, validator core.ValidatorType, datatype *core.DatatypeRef, value *fftypes.JSONAny) error { +func (dm *dataManager) checkValidation(ctx context.Context, validator core.ValidatorType, datatype *core.DatatypeRef, value *fftypes.JSONAny) error { if validator == "" { validator = core.ValidatorTypeJSON } @@ -401,7 +401,7 @@ func (dm *dataManager) checkValidation(ctx context.Context, ns string, validator return i18n.NewError(ctx, coremsgs.MsgDatatypeNotFound, datatype) } if validator != core.ValidatorTypeNone { - v, err := dm.getValidatorForDatatype(ctx, ns, validator, datatype) + v, err := dm.getValidatorForDatatype(ctx, validator, datatype) if err != nil { return err } @@ -417,14 +417,14 @@ func (dm *dataManager) checkValidation(ctx context.Context, ns string, validator return nil } -func (dm *dataManager) validateInputData(ctx context.Context, ns string, inData *core.DataRefOrValue) (data *core.Data, err error) { +func (dm *dataManager) validateInputData(ctx context.Context, inData *core.DataRefOrValue) (data *core.Data, err error) { validator := inData.Validator datatype := inData.Datatype value := inData.Value blobRef := inData.Blob - if err := dm.checkValidation(ctx, ns, validator, datatype, value); err != nil { + if err := dm.checkValidation(ctx, validator, datatype, value); err != nil { return nil, err } @@ -437,7 +437,7 @@ func (dm *dataManager) validateInputData(ctx context.Context, ns string, inData data = &core.Data{ Validator: validator, Datatype: datatype, - Namespace: ns, + Namespace: dm.namespace, Value: value, Blob: blobRef, } @@ -448,8 +448,8 @@ func (dm *dataManager) validateInputData(ctx context.Context, ns string, inData return data, nil } -func (dm *dataManager) UploadJSON(ctx context.Context, ns string, inData *core.DataRefOrValue) (*core.Data, error) { - data, err := dm.validateInputData(ctx, ns, inData) +func (dm *dataManager) UploadJSON(ctx context.Context, inData *core.DataRefOrValue) (*core.Data, error) { + data, err := dm.validateInputData(ctx, inData) if err != nil { return nil, err } @@ -470,7 +470,6 @@ func (dm *dataManager) ResolveInlineData(ctx context.Context, newMessage *NewMes } inData := newMessage.Message.InlineData - msg := newMessage.Message newMessage.AllData = make(core.DataArray, len(newMessage.Message.InlineData)) for i, dataOrValue := range inData { var d *core.Data @@ -489,7 +488,7 @@ func (dm *dataManager) ResolveInlineData(ctx context.Context, newMessage *NewMes } case dataOrValue.Value != nil || dataOrValue.Blob != nil: // We've got a Value, so we can validate + store it - if d, err = dm.validateInputData(ctx, msg.Header.Namespace, dataOrValue); err != nil { + if d, err = dm.validateInputData(ctx, dataOrValue); err != nil { return err } newMessage.NewData = append(newMessage.NewData, d) diff --git a/internal/data/data_manager_test.go b/internal/data/data_manager_test.go index 6fad794e6..d2376a2fa 100644 --- a/internal/data/data_manager_test.go +++ b/internal/data/data_manager_test.go @@ -105,7 +105,7 @@ func TestValidateE2E(t *testing.T) { assert.Regexp(t, "FF10198", err) assert.False(t, isValid) - v, err := dm.getValidatorForDatatype(ctx, data.Namespace, data.Validator, data.Datatype) + v, err := dm.getValidatorForDatatype(ctx, data.Validator, data.Datatype) err = v.Validate(ctx, data) assert.Regexp(t, "FF10198", err) @@ -141,7 +141,7 @@ func TestWriteNewMessageE2E(t *testing.T) { }).Return(nil) mdi.On("InsertDataArray", mock.Anything, mock.Anything).Return(nil).Once() - data1, err := dm.UploadJSON(ctx, "ns1", &core.DataRefOrValue{ + data1, err := dm.UploadJSON(ctx, &core.DataRefOrValue{ Value: fftypes.JSONAnyPtr(`"message 1 - data A"`), Validator: core.ValidatorTypeJSON, Datatype: &core.DatatypeRef{ @@ -235,11 +235,11 @@ func TestValidatorLookupCached(t *testing.T) { Namespace: "0.0.1", } mdi.On("GetDatatypeByName", mock.Anything, "ns1", "customer", "0.0.1").Return(dt, nil).Once() - lookup1, err := dm.getValidatorForDatatype(ctx, "ns1", core.ValidatorTypeJSON, ref) + lookup1, err := dm.getValidatorForDatatype(ctx, core.ValidatorTypeJSON, ref) assert.NoError(t, err) assert.Equal(t, "customer", lookup1.(*jsonValidator).datatype.Name) - lookup2, err := dm.getValidatorForDatatype(ctx, "ns1", core.ValidatorTypeJSON, ref) + lookup2, err := dm.getValidatorForDatatype(ctx, core.ValidatorTypeJSON, ref) assert.NoError(t, err) assert.Equal(t, lookup1, lookup2) @@ -378,7 +378,7 @@ func TestCheckDatatypeVerifiesTheSchema(t *testing.T) { dm, ctx, cancel := newTestDataManager(t) defer cancel() - err := dm.CheckDatatype(ctx, "ns1", &core.Datatype{}) + err := dm.CheckDatatype(ctx, &core.Datatype{}) assert.Regexp(t, "FF10196", err) } @@ -611,7 +611,7 @@ func TestUploadJSONLoadDatatypeFail(t *testing.T) { mdi := dm.database.(*databasemocks.Plugin) mdi.On("GetDatatypeByName", ctx, "ns1", "customer", "0.0.1").Return(nil, fmt.Errorf("pop")) - _, err := dm.UploadJSON(ctx, "ns1", &core.DataRefOrValue{ + _, err := dm.UploadJSON(ctx, &core.DataRefOrValue{ Datatype: &core.DatatypeRef{ Name: "customer", Version: "0.0.1", @@ -624,7 +624,7 @@ func TestUploadJSONLoadInsertDataFail(t *testing.T) { dm, ctx, cancel := newTestDataManager(t) defer cancel() dm.messageWriter.close() - _, err := dm.UploadJSON(ctx, "ns1", &core.DataRefOrValue{ + _, err := dm.UploadJSON(ctx, &core.DataRefOrValue{ Value: fftypes.JSONAnyPtr(`{}`), }) assert.Regexp(t, "FF00154", err) @@ -634,7 +634,7 @@ func TestValidateAndStoreLoadNilRef(t *testing.T) { dm, ctx, cancel := newTestDataManager(t) defer cancel() - _, err := dm.validateInputData(ctx, "ns1", &core.DataRefOrValue{ + _, err := dm.validateInputData(ctx, &core.DataRefOrValue{ Validator: core.ValidatorTypeJSON, Datatype: nil, }) @@ -647,7 +647,7 @@ func TestValidateAndStoreLoadValidatorUnknown(t *testing.T) { defer cancel() mdi := dm.database.(*databasemocks.Plugin) mdi.On("GetDatatypeByName", mock.Anything, "ns1", "customer", "0.0.1").Return(nil, nil) - _, err := dm.validateInputData(ctx, "ns1", &core.DataRefOrValue{ + _, err := dm.validateInputData(ctx, &core.DataRefOrValue{ Validator: "wrong!", Datatype: &core.DatatypeRef{ Name: "customer", @@ -664,7 +664,7 @@ func TestValidateAndStoreLoadBadRef(t *testing.T) { defer cancel() mdi := dm.database.(*databasemocks.Plugin) mdi.On("GetDatatypeByName", mock.Anything, "ns1", "customer", "0.0.1").Return(nil, nil) - _, err := dm.validateInputData(ctx, "ns1", &core.DataRefOrValue{ + _, err := dm.validateInputData(ctx, &core.DataRefOrValue{ Datatype: &core.DatatypeRef{ // Missing name }, @@ -678,7 +678,7 @@ func TestValidateAndStoreNotFound(t *testing.T) { defer cancel() mdi := dm.database.(*databasemocks.Plugin) mdi.On("GetDatatypeByName", mock.Anything, "ns1", "customer", "0.0.1").Return(nil, nil) - _, err := dm.validateInputData(ctx, "ns1", &core.DataRefOrValue{ + _, err := dm.validateInputData(ctx, &core.DataRefOrValue{ Datatype: &core.DatatypeRef{ Name: "customer", Version: "0.0.1", @@ -694,7 +694,7 @@ func TestValidateAndStoreBlobError(t *testing.T) { mdi := dm.database.(*databasemocks.Plugin) blobHash := fftypes.NewRandB32() mdi.On("GetBlobMatchingHash", mock.Anything, blobHash).Return(nil, fmt.Errorf("pop")) - _, err := dm.validateInputData(ctx, "ns1", &core.DataRefOrValue{ + _, err := dm.validateInputData(ctx, &core.DataRefOrValue{ Blob: &core.BlobRef{ Hash: blobHash, }, @@ -709,7 +709,7 @@ func TestValidateAndStoreBlobNotFound(t *testing.T) { mdi := dm.database.(*databasemocks.Plugin) blobHash := fftypes.NewRandB32() mdi.On("GetBlobMatchingHash", mock.Anything, blobHash).Return(nil, nil) - _, err := dm.validateInputData(ctx, "ns1", &core.DataRefOrValue{ + _, err := dm.validateInputData(ctx, &core.DataRefOrValue{ Blob: &core.BlobRef{ Hash: blobHash, }, @@ -742,7 +742,7 @@ func TestGetValidatorForDatatypeNilRef(t *testing.T) { dm, ctx, cancel := newTestDataManager(t) defer cancel() - v, err := dm.getValidatorForDatatype(ctx, "", "", nil) + v, err := dm.getValidatorForDatatype(ctx, "", nil) assert.Nil(t, v) assert.NoError(t, err) diff --git a/internal/definitions/definition_handler_datatype.go b/internal/definitions/definition_handler_datatype.go index 5da9820e4..0f6c79a00 100644 --- a/internal/definitions/definition_handler_datatype.go +++ b/internal/definitions/definition_handler_datatype.go @@ -34,7 +34,7 @@ func (dh *definitionHandlers) handleDatatypeBroadcast(ctx context.Context, state if err := dt.Validate(ctx, true); err != nil { return HandlerResult{Action: ActionReject}, i18n.NewError(ctx, coremsgs.MsgDefRejectedValidateFail, "datatype", dt.ID, err) } - if err := dh.data.CheckDatatype(ctx, dt.Namespace, &dt); err != nil { + if err := dh.data.CheckDatatype(ctx, &dt); err != nil { return HandlerResult{Action: ActionReject}, i18n.NewError(ctx, coremsgs.MsgDefRejectedSchemaFail, "datatype", dt.ID, err) } diff --git a/internal/definitions/definition_handler_datatype_test.go b/internal/definitions/definition_handler_datatype_test.go index cdbb4f132..a901bde15 100644 --- a/internal/definitions/definition_handler_datatype_test.go +++ b/internal/definitions/definition_handler_datatype_test.go @@ -49,7 +49,7 @@ func TestHandleDefinitionBroadcastDatatypeOk(t *testing.T) { } mdm := dh.data.(*datamocks.Manager) - mdm.On("CheckDatatype", mock.Anything, "ns1", mock.Anything).Return(nil) + mdm.On("CheckDatatype", mock.Anything, mock.Anything).Return(nil) mbi := dh.database.(*databasemocks.Plugin) mbi.On("GetDatatypeByName", mock.Anything, "ns1", "name1", "ver1").Return(nil, nil) mbi.On("UpsertDatatype", mock.Anything, mock.Anything, false).Return(nil) @@ -87,7 +87,7 @@ func TestHandleDefinitionBroadcastDatatypeEventFail(t *testing.T) { } mdm := dh.data.(*datamocks.Manager) - mdm.On("CheckDatatype", mock.Anything, "ns1", mock.Anything).Return(nil) + mdm.On("CheckDatatype", mock.Anything, mock.Anything).Return(nil) mbi := dh.database.(*databasemocks.Plugin) mbi.On("GetDatatypeByName", mock.Anything, "ns1", "name1", "ver1").Return(nil, nil) mbi.On("UpsertDatatype", mock.Anything, mock.Anything, false).Return(nil) @@ -152,7 +152,7 @@ func TestHandleDefinitionBroadcastBadSchema(t *testing.T) { } mdm := dh.data.(*datamocks.Manager) - mdm.On("CheckDatatype", mock.Anything, "ns1", mock.Anything).Return(fmt.Errorf("pop")) + mdm.On("CheckDatatype", mock.Anything, mock.Anything).Return(fmt.Errorf("pop")) action, err := dh.HandleDefinitionBroadcast(context.Background(), bs, &core.Message{ Header: core.MessageHeader{ Tag: core.SystemTagDefineDatatype, @@ -207,7 +207,7 @@ func TestHandleDefinitionBroadcastDatatypeLookupFail(t *testing.T) { } mdm := dh.data.(*datamocks.Manager) - mdm.On("CheckDatatype", mock.Anything, "ns1", mock.Anything).Return(nil) + mdm.On("CheckDatatype", mock.Anything, mock.Anything).Return(nil) mbi := dh.database.(*databasemocks.Plugin) mbi.On("GetDatatypeByName", mock.Anything, "ns1", "name1", "ver1").Return(nil, fmt.Errorf("pop")) action, err := dh.HandleDefinitionBroadcast(context.Background(), bs, &core.Message{ @@ -243,7 +243,7 @@ func TestHandleDefinitionBroadcastUpsertFail(t *testing.T) { } mdm := dh.data.(*datamocks.Manager) - mdm.On("CheckDatatype", mock.Anything, "ns1", mock.Anything).Return(nil) + mdm.On("CheckDatatype", mock.Anything, mock.Anything).Return(nil) mbi := dh.database.(*databasemocks.Plugin) mbi.On("GetDatatypeByName", mock.Anything, "ns1", "name1", "ver1").Return(nil, nil) mbi.On("UpsertDatatype", mock.Anything, mock.Anything, false).Return(fmt.Errorf("pop")) @@ -279,7 +279,7 @@ func TestHandleDefinitionBroadcastDatatypeDuplicate(t *testing.T) { } mdm := dh.data.(*datamocks.Manager) - mdm.On("CheckDatatype", mock.Anything, "ns1", mock.Anything).Return(nil) + mdm.On("CheckDatatype", mock.Anything, mock.Anything).Return(nil) mbi := dh.database.(*databasemocks.Plugin) mbi.On("GetDatatypeByName", mock.Anything, "ns1", "name1", "ver1").Return(dt, nil) action, err := dh.HandleDefinitionBroadcast(context.Background(), bs, &core.Message{ diff --git a/internal/events/batch_pin_complete.go b/internal/events/batch_pin_complete.go index 8ff025762..7f16ab7bd 100644 --- a/internal/events/batch_pin_complete.go +++ b/internal/events/batch_pin_complete.go @@ -83,7 +83,7 @@ func (em *eventManager) BatchPinComplete(batchPin *blockchain.BatchPin, signingK } // Kick off a download for broadcast batches if the batch isn't already persisted if !private && batch == nil { - if err := em.sharedDownload.InitiateDownloadBatch(ctx, batchPin.Namespace, batchPin.TransactionID, batchPin.BatchPayloadRef); err != nil { + if err := em.sharedDownload.InitiateDownloadBatch(ctx, batchPin.TransactionID, batchPin.BatchPayloadRef); err != nil { return err } } diff --git a/internal/events/batch_pin_complete_test.go b/internal/events/batch_pin_complete_test.go index b2513c429..7ded6ccb8 100644 --- a/internal/events/batch_pin_complete_test.go +++ b/internal/events/batch_pin_complete_test.go @@ -135,7 +135,7 @@ func TestBatchPinCompleteOkBroadcast(t *testing.T) { mdi.On("InsertPins", mock.Anything, mock.Anything).Return(nil).Once() msd := em.sharedDownload.(*shareddownloadmocks.Manager) mdi.On("GetBatchByID", mock.Anything, "ns1", mock.Anything).Return(nil, nil) - msd.On("InitiateDownloadBatch", mock.Anything, "ns1", batchPin.TransactionID, batchPin.BatchPayloadRef).Return(nil) + msd.On("InitiateDownloadBatch", mock.Anything, batchPin.TransactionID, batchPin.BatchPayloadRef).Return(nil) err := em.BatchPinComplete(batchPin, &core.VerifierRef{ Type: core.VerifierTypeEthAddress, @@ -358,7 +358,7 @@ func TestSequencedBroadcastInitiateDownloadFail(t *testing.T) { mdi.On("InsertPins", mock.Anything, mock.Anything).Return(nil) mdi.On("GetBatchByID", mock.Anything, "ns1", mock.Anything).Return(nil, nil) msd := em.sharedDownload.(*shareddownloadmocks.Manager) - msd.On("InitiateDownloadBatch", mock.Anything, "ns1", batchPin.TransactionID, batchPin.BatchPayloadRef).Return(fmt.Errorf("pop")) + msd.On("InitiateDownloadBatch", mock.Anything, batchPin.TransactionID, batchPin.BatchPayloadRef).Return(fmt.Errorf("pop")) err := em.BatchPinComplete(batchPin, &core.VerifierRef{ Type: core.VerifierTypeEthAddress, @@ -753,7 +753,7 @@ func TestPersistBatchDataWithPublicInitiateDownload(t *testing.T) { mdi.On("GetBlobMatchingHash", mock.Anything, blob.Hash).Return(nil, nil) msd := em.sharedDownload.(*shareddownloadmocks.Manager) - msd.On("InitiateDownloadBlob", mock.Anything, batch.Namespace, batch.Payload.TX.ID, data.ID, "ref1").Return(nil) + msd.On("InitiateDownloadBlob", mock.Anything, batch.Payload.TX.ID, data.ID, "ref1").Return(nil) valid, err := em.checkAndInitiateBlobDownloads(context.Background(), batch, 0, data) assert.Nil(t, err) @@ -780,7 +780,7 @@ func TestPersistBatchDataWithPublicInitiateDownloadFail(t *testing.T) { mdi.On("GetBlobMatchingHash", mock.Anything, blob.Hash).Return(nil, nil) msd := em.sharedDownload.(*shareddownloadmocks.Manager) - msd.On("InitiateDownloadBlob", mock.Anything, batch.Namespace, batch.Payload.TX.ID, data.ID, "ref1").Return(fmt.Errorf("pop")) + msd.On("InitiateDownloadBlob", mock.Anything, batch.Payload.TX.ID, data.ID, "ref1").Return(fmt.Errorf("pop")) valid, err := em.checkAndInitiateBlobDownloads(context.Background(), batch, 0, data) assert.Regexp(t, "pop", err) diff --git a/internal/events/event_manager.go b/internal/events/event_manager.go index 65be588df..56a7c4001 100644 --- a/internal/events/event_manager.go +++ b/internal/events/event_manager.go @@ -73,7 +73,7 @@ type EventManager interface { DXEvent(dx dataexchange.Plugin, event dataexchange.DXEvent) // Bound sharedstorage callbacks - SharedStorageBatchDownloaded(ss sharedstorage.Plugin, ns, payloadRef string, data []byte) (*fftypes.UUID, error) + SharedStorageBatchDownloaded(ss sharedstorage.Plugin, payloadRef string, data []byte) (*fftypes.UUID, error) SharedStorageBlobDownloaded(ss sharedstorage.Plugin, hash fftypes.Bytes32, size int64, payloadRef string) // Bound token callbacks diff --git a/internal/events/persist_batch.go b/internal/events/persist_batch.go index 1fa6553b6..ca6a1aa16 100644 --- a/internal/events/persist_batch.go +++ b/internal/events/persist_batch.go @@ -173,7 +173,7 @@ func (em *eventManager) checkAndInitiateBlobDownloads(ctx context.Context, batch log.L(ctx).Errorf("Invalid data entry %d id=%s in batch '%s' - missing public blob reference", i, data.ID, batch.ID) return false, nil } - if err = em.sharedDownload.InitiateDownloadBlob(ctx, data.Namespace, batch.Payload.TX.ID, data.ID, data.Blob.Public); err != nil { + if err = em.sharedDownload.InitiateDownloadBlob(ctx, batch.Payload.TX.ID, data.ID, data.Blob.Public); err != nil { return false, err } } diff --git a/internal/events/ss_callbacks.go b/internal/events/ss_callbacks.go index 5e9744814..906fdddf4 100644 --- a/internal/events/ss_callbacks.go +++ b/internal/events/ss_callbacks.go @@ -26,7 +26,7 @@ import ( "github.com/hyperledger/firefly/pkg/sharedstorage" ) -func (em *eventManager) SharedStorageBatchDownloaded(ss sharedstorage.Plugin, ns, payloadRef string, data []byte) (*fftypes.UUID, error) { +func (em *eventManager) SharedStorageBatchDownloaded(ss sharedstorage.Plugin, payloadRef string, data []byte) (*fftypes.UUID, error) { l := log.L(em.ctx) @@ -39,8 +39,8 @@ func (em *eventManager) SharedStorageBatchDownloaded(ss sharedstorage.Plugin, ns } l.Infof("Shared storage batch downloaded from %s '%s' id=%s (len=%d)", ss.Name(), payloadRef, batch.ID, len(data)) - if batch.Namespace != ns { - l.Errorf("Invalid batch '%s'. Namespace in batch '%s' does not match pin namespace '%s'", batch.ID, batch.Namespace, ns) + if batch.Namespace != em.namespace { + l.Errorf("Invalid batch '%s'. Namespace in batch '%s' does not match pin namespace '%s'", batch.ID, batch.Namespace, em.namespace) return nil, nil // This is not retryable. skip this batch } diff --git a/internal/events/ss_callbacks_test.go b/internal/events/ss_callbacks_test.go index 223c6dbb5..86bef1388 100644 --- a/internal/events/ss_callbacks_test.go +++ b/internal/events/ss_callbacks_test.go @@ -51,7 +51,7 @@ func TestSharedStorageBatchDownloadedOk(t *testing.T) { mdm := em.data.(*datamocks.Manager) mdm.On("UpdateMessageCache", mock.Anything, mock.Anything).Return() - bid, err := em.SharedStorageBatchDownloaded(mss, batch.Namespace, "payload1", b) + bid, err := em.SharedStorageBatchDownloaded(mss, "payload1", b) assert.NoError(t, err) assert.Equal(t, batch.ID, bid) @@ -78,7 +78,7 @@ func TestSharedStorageBatchDownloadedPersistFail(t *testing.T) { mdi.On("UpsertBatch", em.ctx, mock.Anything).Return(fmt.Errorf("pop")) mss.On("Name").Return("utdx").Maybe() - _, err := em.SharedStorageBatchDownloaded(mss, batch.Namespace, "payload1", b) + _, err := em.SharedStorageBatchDownloaded(mss, "payload1", b) assert.Regexp(t, "FF00154", err) mdi.AssertExpectations(t) @@ -98,7 +98,8 @@ func TestSharedStorageBatchDownloadedNSMismatch(t *testing.T) { mss := em.sharedstorage.(*sharedstoragemocks.Plugin) mss.On("Name").Return("utdx").Maybe() - _, err := em.SharedStorageBatchDownloaded(mss, "srong", "payload1", b) + em.namespace = "ns2" + _, err := em.SharedStorageBatchDownloaded(mss, "payload1", b) assert.NoError(t, err) mss.AssertExpectations(t) @@ -113,7 +114,7 @@ func TestSharedStorageBatchDownloadedBadData(t *testing.T) { mss := em.sharedstorage.(*sharedstoragemocks.Plugin) mss.On("Name").Return("utdx").Maybe() - _, err := em.SharedStorageBatchDownloaded(mss, "srong", "payload1", []byte("!json")) + _, err := em.SharedStorageBatchDownloaded(mss, "payload1", []byte("!json")) assert.NoError(t, err) mss.AssertExpectations(t) diff --git a/internal/events/token_pool_created.go b/internal/events/token_pool_created.go index 2ba21c771..3ebd5e7a6 100644 --- a/internal/events/token_pool_created.go +++ b/internal/events/token_pool_created.go @@ -195,7 +195,7 @@ func (em *eventManager) TokenPoolCreated(ti tokens.Plugin, pool *tokens.TokenPoo Pool: announcePool, } log.L(em.ctx).Infof("Announcing token pool, id=%s", announcePool.ID) - _, err = em.broadcast.BroadcastTokenPool(em.ctx, announcePool.Namespace, broadcast, false) + _, err = em.broadcast.BroadcastTokenPool(em.ctx, broadcast, false) } } diff --git a/internal/events/token_pool_created_test.go b/internal/events/token_pool_created_test.go index d46ff9369..d42d57c8a 100644 --- a/internal/events/token_pool_created_test.go +++ b/internal/events/token_pool_created_test.go @@ -505,7 +505,7 @@ func TestTokenPoolCreatedAnnounce(t *testing.T) { mdi.On("GetTokenPoolByLocator", em.ctx, "ns1", "erc1155", "123").Return(nil, nil).Times(2) mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(nil, nil, fmt.Errorf("pop")).Once() mdi.On("GetOperations", em.ctx, "ns1", mock.Anything).Return(operations, nil, nil).Once() - mbm.On("BroadcastTokenPool", em.ctx, "ns1", mock.MatchedBy(func(pool *core.TokenPoolAnnouncement) bool { + mbm.On("BroadcastTokenPool", em.ctx, mock.MatchedBy(func(pool *core.TokenPoolAnnouncement) bool { return pool.Pool.Namespace == "ns1" && pool.Pool.Name == "my-pool" && *pool.Pool.ID == *poolID }), false).Return(nil, nil) diff --git a/internal/namespace/manager.go b/internal/namespace/manager.go index 6c93271c4..9822eeaff 100644 --- a/internal/namespace/manager.go +++ b/internal/namespace/manager.go @@ -879,7 +879,7 @@ func (nm *namespaceManager) ResolveOperationByNamespacedID(ctx context.Context, if or == nil { return i18n.NewError(ctx, coremsgs.Msg404NotFound) } - return or.Operations().ResolveOperationByID(ctx, ns, u, op) + return or.Operations().ResolveOperationByID(ctx, u, op) } func (nm *namespaceManager) getEventPlugins(ctx context.Context) (plugins map[string]eventsPlugin, err error) { diff --git a/internal/namespace/manager_test.go b/internal/namespace/manager_test.go index f591ab5b0..a021aa485 100644 --- a/internal/namespace/manager_test.go +++ b/internal/namespace/manager_test.go @@ -1256,7 +1256,7 @@ func TestResolveOperationByNamespacedID(t *testing.T) { opID := fftypes.NewUUID() mo.On("Operations").Return(mom) - mom.On("ResolveOperationByID", context.Background(), "default", opID, mock.Anything).Return(nil) + mom.On("ResolveOperationByID", context.Background(), opID, mock.Anything).Return(nil) err := nm.ResolveOperationByNamespacedID(context.Background(), "default:"+opID.String(), &core.OperationUpdateDTO{}) assert.Nil(t, err) diff --git a/internal/networkmap/register_identity.go b/internal/networkmap/register_identity.go index ac7502989..3423ae7c7 100644 --- a/internal/networkmap/register_identity.go +++ b/internal/networkmap/register_identity.go @@ -113,7 +113,7 @@ func (nm *networkMap) RegisterIdentity(ctx context.Context, dto *core.IdentityCr func (nm *networkMap) sendIdentityRequest(ctx context.Context, identity *core.Identity, claimSigner *core.SignerRef, parentSigner *core.SignerRef) error { // Send the claim - we disable the check on the DID author here, as we are registering the identity so it will not exist - claimMsg, err := nm.broadcast.BroadcastIdentityClaim(ctx, identity.Namespace, &core.IdentityClaim{ + claimMsg, err := nm.broadcast.BroadcastIdentityClaim(ctx, &core.IdentityClaim{ Identity: identity, }, claimSigner, core.SystemTagIdentityClaim, false) if err != nil { @@ -123,7 +123,7 @@ func (nm *networkMap) sendIdentityRequest(ctx context.Context, identity *core.Id // Send the verification if one is required. if parentSigner != nil { - verifyMsg, err := nm.broadcast.BroadcastDefinition(ctx, identity.Namespace, &core.IdentityVerification{ + verifyMsg, err := nm.broadcast.BroadcastDefinition(ctx, &core.IdentityVerification{ Claim: core.MessageRef{ ID: claimMsg.Header.ID, Hash: claimMsg.Hash, diff --git a/internal/networkmap/register_identity_test.go b/internal/networkmap/register_identity_test.go index 90f087611..28dadc722 100644 --- a/internal/networkmap/register_identity_test.go +++ b/internal/networkmap/register_identity_test.go @@ -53,7 +53,6 @@ func TestRegisterIdentityOrgWithParentOk(t *testing.T) { mbm := nm.broadcast.(*broadcastmocks.Manager) mbm.On("BroadcastIdentityClaim", nm.ctx, - "ns1", mock.AnythingOfType("*core.IdentityClaim"), mock.MatchedBy(func(sr *core.SignerRef) bool { return sr.Key == "0x12345" @@ -61,7 +60,6 @@ func TestRegisterIdentityOrgWithParentOk(t *testing.T) { core.SystemTagIdentityClaim, false).Return(mockMsg1, nil) mbm.On("BroadcastDefinition", nm.ctx, - "ns1", mock.AnythingOfType("*core.IdentityVerification"), mock.MatchedBy(func(sr *core.SignerRef) bool { return sr.Key == "0x23456" @@ -113,7 +111,6 @@ func TestRegisterIdentityOrgWithParentWaitConfirmOk(t *testing.T) { mbm := nm.broadcast.(*broadcastmocks.Manager) mbm.On("BroadcastIdentityClaim", nm.ctx, - "ns1", mock.AnythingOfType("*core.IdentityClaim"), mock.MatchedBy(func(sr *core.SignerRef) bool { return sr.Key == "0x12345" @@ -121,7 +118,6 @@ func TestRegisterIdentityOrgWithParentWaitConfirmOk(t *testing.T) { core.SystemTagIdentityClaim, false).Return(mockMsg1, nil) mbm.On("BroadcastDefinition", nm.ctx, - "ns1", mock.AnythingOfType("*core.IdentityVerification"), mock.MatchedBy(func(sr *core.SignerRef) bool { return sr.Key == "0x23456" @@ -194,7 +190,6 @@ func TestRegisterIdentityCustomWithParentFail(t *testing.T) { mbm := nm.broadcast.(*broadcastmocks.Manager) mbm.On("BroadcastIdentityClaim", nm.ctx, - "ns1", mock.AnythingOfType("*core.IdentityClaim"), mock.MatchedBy(func(sr *core.SignerRef) bool { return sr.Key == "0x12345" @@ -202,7 +197,6 @@ func TestRegisterIdentityCustomWithParentFail(t *testing.T) { core.SystemTagIdentityClaim, false).Return(mockMsg, nil) mbm.On("BroadcastDefinition", nm.ctx, - "ns1", mock.AnythingOfType("*core.IdentityVerification"), mock.MatchedBy(func(sr *core.SignerRef) bool { return sr.Key == "0x23456" @@ -259,7 +253,6 @@ func TestRegisterIdentityRootBroadcastFail(t *testing.T) { mbm := nm.broadcast.(*broadcastmocks.Manager) mbm.On("BroadcastIdentityClaim", nm.ctx, - "ns1", mock.AnythingOfType("*core.IdentityClaim"), mock.MatchedBy(func(sr *core.SignerRef) bool { return sr.Key == "0x12345" diff --git a/internal/networkmap/register_node_test.go b/internal/networkmap/register_node_test.go index c6714fc56..484375ae7 100644 --- a/internal/networkmap/register_node_test.go +++ b/internal/networkmap/register_node_test.go @@ -61,7 +61,6 @@ func TestRegisterNodeOk(t *testing.T) { mockMsg := &core.Message{Header: core.MessageHeader{ID: fftypes.NewUUID()}} mbm := nm.broadcast.(*broadcastmocks.Manager) mbm.On("BroadcastIdentityClaim", nm.ctx, - "ns1", mock.AnythingOfType("*core.IdentityClaim"), signerRef, core.SystemTagIdentityClaim, false).Return(mockMsg, nil) diff --git a/internal/networkmap/register_org_test.go b/internal/networkmap/register_org_test.go index b0d6ec54c..f120e09c5 100644 --- a/internal/networkmap/register_org_test.go +++ b/internal/networkmap/register_org_test.go @@ -69,7 +69,6 @@ func TestRegisterNodeOrgOk(t *testing.T) { mockMsg := &core.Message{Header: core.MessageHeader{ID: fftypes.NewUUID()}} mbm := nm.broadcast.(*broadcastmocks.Manager) mbm.On("BroadcastIdentityClaim", nm.ctx, - "ns1", mock.AnythingOfType("*core.IdentityClaim"), mock.MatchedBy(func(sr *core.SignerRef) bool { return sr.Key == "0x12345" diff --git a/internal/networkmap/update_identity.go b/internal/networkmap/update_identity.go index c48f80f02..6e0d6d741 100644 --- a/internal/networkmap/update_identity.go +++ b/internal/networkmap/update_identity.go @@ -56,7 +56,7 @@ func (nm *networkMap) updateIdentityID(ctx context.Context, id *fftypes.UUID, dt } // Send the update - updateMsg, err := nm.broadcast.BroadcastDefinition(ctx, identity.Namespace, &core.IdentityUpdate{ + updateMsg, err := nm.broadcast.BroadcastDefinition(ctx, &core.IdentityUpdate{ Identity: identity.IdentityBase, Updates: dto.IdentityProfile, }, updateSigner, core.SystemTagIdentityUpdate, waitConfirm) diff --git a/internal/networkmap/update_identity_test.go b/internal/networkmap/update_identity_test.go index 5b949654a..593bb1df6 100644 --- a/internal/networkmap/update_identity_test.go +++ b/internal/networkmap/update_identity_test.go @@ -44,7 +44,6 @@ func TestUpdateIdentityProfileOk(t *testing.T) { mbm := nm.broadcast.(*broadcastmocks.Manager) mbm.On("BroadcastDefinition", nm.ctx, - "ns1", mock.AnythingOfType("*core.IdentityUpdate"), mock.MatchedBy(func(sr *core.SignerRef) bool { return sr.Key == "0x12345" @@ -78,7 +77,6 @@ func TestUpdateIdentityProfileBroadcastFail(t *testing.T) { mbm := nm.broadcast.(*broadcastmocks.Manager) mbm.On("BroadcastDefinition", nm.ctx, - "ns1", mock.AnythingOfType("*core.IdentityUpdate"), mock.MatchedBy(func(sr *core.SignerRef) bool { return sr.Key == "0x12345" diff --git a/internal/operations/manager.go b/internal/operations/manager.go index a02208609..355e4a15c 100644 --- a/internal/operations/manager.go +++ b/internal/operations/manager.go @@ -41,11 +41,11 @@ type Manager interface { RegisterHandler(ctx context.Context, handler OperationHandler, ops []core.OpType) PrepareOperation(ctx context.Context, op *core.Operation) (*core.PreparedOperation, error) RunOperation(ctx context.Context, op *core.PreparedOperation, options ...RunOperationOption) (fftypes.JSONObject, error) - RetryOperation(ctx context.Context, ns string, opID *fftypes.UUID) (*core.Operation, error) + RetryOperation(ctx context.Context, opID *fftypes.UUID) (*core.Operation, error) AddOrReuseOperation(ctx context.Context, op *core.Operation) error SubmitOperationUpdate(plugin core.Named, update *OperationUpdate) TransferResult(dx dataexchange.Plugin, event dataexchange.DXEvent) - ResolveOperationByID(ctx context.Context, ns string, opID *fftypes.UUID, op *core.OperationUpdateDTO) error + ResolveOperationByID(ctx context.Context, opID *fftypes.UUID, op *core.OperationUpdateDTO) error Start() error WaitStop() } @@ -110,10 +110,10 @@ func (om *operationsManager) RunOperation(ctx context.Context, op *core.Prepared log.L(ctx).Tracef("Operation detail: %+v", op) outputs, complete, err := handler.RunOperation(ctx, op) if err != nil { - om.writeOperationFailure(ctx, op.Namespace, op.ID, outputs, err, failState) + om.writeOperationFailure(ctx, op.ID, outputs, err, failState) return nil, err } else if complete { - om.writeOperationSuccess(ctx, op.Namespace, op.ID, outputs) + om.writeOperationSuccess(ctx, op.ID, outputs) } return outputs, nil } @@ -129,7 +129,7 @@ func (om *operationsManager) findLatestRetry(ctx context.Context, opID *fftypes. return om.findLatestRetry(ctx, op.Retry) } -func (om *operationsManager) RetryOperation(ctx context.Context, ns string, opID *fftypes.UUID) (op *core.Operation, err error) { +func (om *operationsManager) RetryOperation(ctx context.Context, opID *fftypes.UUID) (op *core.Operation, err error) { var po *core.PreparedOperation err = om.database.RunAsGroup(ctx, func(ctx context.Context) error { op, err = om.findLatestRetry(ctx, opID) @@ -150,7 +150,7 @@ func (om *operationsManager) RetryOperation(ctx context.Context, ns string, opID // Update the old operation to point to the new one update := database.OperationQueryFactory.NewUpdate(ctx).Set("retry", op.ID) - if err = om.database.UpdateOperation(ctx, ns, opID, update); err != nil { + if err = om.database.UpdateOperation(ctx, om.namespace, opID, update); err != nil { return err } @@ -196,22 +196,22 @@ func (om *operationsManager) TransferResult(dx dataexchange.Plugin, event dataex om.SubmitOperationUpdate(dx, opUpdate) } -func (om *operationsManager) writeOperationSuccess(ctx context.Context, ns string, opID *fftypes.UUID, outputs fftypes.JSONObject) { +func (om *operationsManager) writeOperationSuccess(ctx context.Context, opID *fftypes.UUID, outputs fftypes.JSONObject) { emptyString := "" - if err := om.database.ResolveOperation(ctx, ns, opID, core.OpStatusSucceeded, &emptyString, outputs); err != nil { + if err := om.database.ResolveOperation(ctx, om.namespace, opID, core.OpStatusSucceeded, &emptyString, outputs); err != nil { log.L(ctx).Errorf("Failed to update operation %s: %s", opID, err) } } -func (om *operationsManager) writeOperationFailure(ctx context.Context, ns string, opID *fftypes.UUID, outputs fftypes.JSONObject, err error, newStatus core.OpStatus) { +func (om *operationsManager) writeOperationFailure(ctx context.Context, opID *fftypes.UUID, outputs fftypes.JSONObject, err error, newStatus core.OpStatus) { errMsg := err.Error() - if err := om.database.ResolveOperation(ctx, ns, opID, newStatus, &errMsg, outputs); err != nil { + if err := om.database.ResolveOperation(ctx, om.namespace, opID, newStatus, &errMsg, outputs); err != nil { log.L(ctx).Errorf("Failed to update operation %s: %s", opID, err) } } -func (om *operationsManager) ResolveOperationByID(ctx context.Context, ns string, opID *fftypes.UUID, op *core.OperationUpdateDTO) error { - return om.database.ResolveOperation(ctx, ns, opID, op.Status, op.Error, op.Output) +func (om *operationsManager) ResolveOperationByID(ctx context.Context, opID *fftypes.UUID, op *core.OperationUpdateDTO) error { + return om.database.ResolveOperation(ctx, om.namespace, opID, op.Status, op.Error, op.Output) } func (om *operationsManager) SubmitOperationUpdate(plugin core.Named, update *OperationUpdate) { diff --git a/internal/operations/manager_test.go b/internal/operations/manager_test.go index cafe25cf8..c53bd8db9 100644 --- a/internal/operations/manager_test.go +++ b/internal/operations/manager_test.go @@ -245,7 +245,7 @@ func TestRetryOperationSuccess(t *testing.T) { })).Return(nil) om.RegisterHandler(ctx, &mockHandler{Prepared: po}, []core.OpType{core.OpTypeBlockchainPinBatch}) - newOp, err := om.RetryOperation(ctx, "ns1", op.ID) + newOp, err := om.RetryOperation(ctx, op.ID) assert.NoError(t, err) assert.NotNil(t, newOp) @@ -274,7 +274,7 @@ func TestRetryOperationGetFail(t *testing.T) { mdi.On("GetOperationByID", ctx, "ns1", opID).Return(op, fmt.Errorf("pop")) om.RegisterHandler(ctx, &mockHandler{Prepared: po}, []core.OpType{core.OpTypeBlockchainPinBatch}) - _, err := om.RetryOperation(ctx, "ns1", op.ID) + _, err := om.RetryOperation(ctx, op.ID) assert.EqualError(t, err, "pop") @@ -312,7 +312,7 @@ func TestRetryTwiceOperationInsertFail(t *testing.T) { mdi.On("InsertOperation", ctx, mock.Anything).Return(fmt.Errorf("pop")) om.RegisterHandler(ctx, &mockHandler{Prepared: po}, []core.OpType{core.OpTypeBlockchainPinBatch}) - _, err := om.RetryOperation(ctx, "ns1", op.ID) + _, err := om.RetryOperation(ctx, op.ID) assert.EqualError(t, err, "pop") @@ -341,7 +341,7 @@ func TestRetryOperationInsertFail(t *testing.T) { mdi.On("InsertOperation", ctx, mock.Anything).Return(fmt.Errorf("pop")) om.RegisterHandler(ctx, &mockHandler{Prepared: po}, []core.OpType{core.OpTypeBlockchainPinBatch}) - _, err := om.RetryOperation(ctx, "ns1", op.ID) + _, err := om.RetryOperation(ctx, op.ID) assert.EqualError(t, err, "pop") @@ -372,7 +372,7 @@ func TestRetryOperationUpdateFail(t *testing.T) { mdi.On("UpdateOperation", ctx, "ns1", op.ID, mock.Anything).Return(fmt.Errorf("pop")) om.RegisterHandler(ctx, &mockHandler{Prepared: po}, []core.OpType{core.OpTypeBlockchainPinBatch}) - _, err := om.RetryOperation(ctx, "ns1", op.ID) + _, err := om.RetryOperation(ctx, op.ID) assert.EqualError(t, err, "pop") @@ -389,7 +389,7 @@ func TestWriteOperationSuccess(t *testing.T) { mdi := om.database.(*databasemocks.Plugin) mdi.On("ResolveOperation", ctx, "ns1", opID, core.OpStatusSucceeded, mock.Anything, mock.Anything).Return(fmt.Errorf("pop")) - om.writeOperationSuccess(ctx, "ns1", opID, nil) + om.writeOperationSuccess(ctx, opID, nil) mdi.AssertExpectations(t) } @@ -405,7 +405,7 @@ func TestWriteOperationFailure(t *testing.T) { errStr := "pop" mdi.On("ResolveOperation", ctx, "ns1", opID, core.OpStatusFailed, &errStr, mock.Anything).Return(fmt.Errorf("pop")) - om.writeOperationFailure(ctx, "ns1", opID, nil, fmt.Errorf("pop"), core.OpStatusFailed) + om.writeOperationFailure(ctx, opID, nil, fmt.Errorf("pop"), core.OpStatusFailed) mdi.AssertExpectations(t) } @@ -565,7 +565,7 @@ func TestResolveOperationByNamespacedIDOk(t *testing.T) { "my": "data", }).Return(nil) - err := om.ResolveOperationByID(ctx, "ns1", opID, opUpdate) + err := om.ResolveOperationByID(ctx, opID, opUpdate) assert.NoError(t, err) diff --git a/internal/orchestrator/bound_callbacks.go b/internal/orchestrator/bound_callbacks.go index 43fe81ba1..9acaa1701 100644 --- a/internal/orchestrator/bound_callbacks.go +++ b/internal/orchestrator/bound_callbacks.go @@ -89,8 +89,8 @@ func (bc *boundCallbacks) TokensApproved(plugin tokens.Plugin, approval *tokens. return bc.ei.TokensApproved(plugin, approval) } -func (bc *boundCallbacks) SharedStorageBatchDownloaded(ns, payloadRef string, data []byte) (*fftypes.UUID, error) { - return bc.ei.SharedStorageBatchDownloaded(bc.ss, ns, payloadRef, data) +func (bc *boundCallbacks) SharedStorageBatchDownloaded(payloadRef string, data []byte) (*fftypes.UUID, error) { + return bc.ei.SharedStorageBatchDownloaded(bc.ss, payloadRef, data) } func (bc *boundCallbacks) SharedStorageBlobDownloaded(hash fftypes.Bytes32, size int64, payloadRef string) { diff --git a/internal/orchestrator/bound_callbacks_test.go b/internal/orchestrator/bound_callbacks_test.go index dfdf3fc3b..187027a67 100644 --- a/internal/orchestrator/bound_callbacks_test.go +++ b/internal/orchestrator/bound_callbacks_test.go @@ -103,8 +103,8 @@ func TestBoundCallbacks(t *testing.T) { err = bc.BlockchainEvent(&blockchain.EventWithSubscription{}) assert.EqualError(t, err, "pop") - mei.On("SharedStorageBatchDownloaded", mss, "ns1", "payload1", []byte(`{}`)).Return(nil, fmt.Errorf("pop")) - _, err = bc.SharedStorageBatchDownloaded("ns1", "payload1", []byte(`{}`)) + mei.On("SharedStorageBatchDownloaded", mss, "payload1", []byte(`{}`)).Return(nil, fmt.Errorf("pop")) + _, err = bc.SharedStorageBatchDownloaded("payload1", []byte(`{}`)) assert.EqualError(t, err, "pop") mei.On("SharedStorageBlobDownloaded", mss, *hash, int64(12345), "payload1").Return() diff --git a/internal/orchestrator/chart.go b/internal/orchestrator/chart.go index dc47cc1a7..b5b34df13 100644 --- a/internal/orchestrator/chart.go +++ b/internal/orchestrator/chart.go @@ -39,7 +39,7 @@ func (or *orchestrator) getHistogramIntervals(startTime int64, endTime int64, nu return intervals } -func (or *orchestrator) GetChartHistogram(ctx context.Context, ns string, startTime int64, endTime int64, buckets int64, collection database.CollectionName) ([]*core.ChartHistogram, error) { +func (or *orchestrator) GetChartHistogram(ctx context.Context, startTime int64, endTime int64, buckets int64, collection database.CollectionName) ([]*core.ChartHistogram, error) { if buckets > core.ChartHistogramMaxBuckets || buckets < core.ChartHistogramMinBuckets { return nil, i18n.NewError(ctx, coremsgs.MsgInvalidNumberOfIntervals, core.ChartHistogramMinBuckets, core.ChartHistogramMaxBuckets) } @@ -49,7 +49,7 @@ func (or *orchestrator) GetChartHistogram(ctx context.Context, ns string, startT intervals := or.getHistogramIntervals(startTime, endTime, buckets) - histogram, err := or.database().GetChartHistogram(ctx, ns, intervals, collection) + histogram, err := or.database().GetChartHistogram(ctx, or.namespace, intervals, collection) if err != nil { return nil, err } diff --git a/internal/orchestrator/chart_test.go b/internal/orchestrator/chart_test.go index 0dac78071..06e3887bf 100644 --- a/internal/orchestrator/chart_test.go +++ b/internal/orchestrator/chart_test.go @@ -40,27 +40,27 @@ func makeTestIntervals(start int, numIntervals int) (intervals []core.ChartHisto func TestGetHistogramBadIntervalMin(t *testing.T) { or := newTestOrchestrator() - _, err := or.GetChartHistogram(context.Background(), "ns1", 1234567890, 9876543210, core.ChartHistogramMinBuckets-1, database.CollectionName("test")) + _, err := or.GetChartHistogram(context.Background(), 1234567890, 9876543210, core.ChartHistogramMinBuckets-1, database.CollectionName("test")) assert.Regexp(t, "FF10298", err) } func TestGetHistogramBadIntervalMax(t *testing.T) { or := newTestOrchestrator() - _, err := or.GetChartHistogram(context.Background(), "ns1", 1234567890, 9876543210, core.ChartHistogramMaxBuckets+1, database.CollectionName("test")) + _, err := or.GetChartHistogram(context.Background(), 1234567890, 9876543210, core.ChartHistogramMaxBuckets+1, database.CollectionName("test")) assert.Regexp(t, "FF10298", err) } func TestGetHistogramBadStartEndTimes(t *testing.T) { or := newTestOrchestrator() - _, err := or.GetChartHistogram(context.Background(), "ns1", 9876543210, 1234567890, 10, database.CollectionName("test")) + _, err := or.GetChartHistogram(context.Background(), 9876543210, 1234567890, 10, database.CollectionName("test")) assert.Regexp(t, "FF10300", err) } func TestGetHistogramFailDB(t *testing.T) { or := newTestOrchestrator() intervals := makeTestIntervals(1000000000, 10) - or.mdi.On("GetChartHistogram", mock.Anything, "ns1", intervals, database.CollectionName("test")).Return(nil, fmt.Errorf("pop")) - _, err := or.GetChartHistogram(context.Background(), "ns1", 1000000000, 1000000010, 10, database.CollectionName("test")) + or.mdi.On("GetChartHistogram", mock.Anything, "ns", intervals, database.CollectionName("test")).Return(nil, fmt.Errorf("pop")) + _, err := or.GetChartHistogram(context.Background(), 1000000000, 1000000010, 10, database.CollectionName("test")) assert.EqualError(t, err, "pop") } @@ -69,7 +69,7 @@ func TestGetHistogramSuccess(t *testing.T) { intervals := makeTestIntervals(1000000000, 10) mockHistogram := []*core.ChartHistogram{} - or.mdi.On("GetChartHistogram", mock.Anything, "ns1", intervals, database.CollectionName("test")).Return(mockHistogram, nil) - _, err := or.GetChartHistogram(context.Background(), "ns1", 1000000000, 1000000010, 10, database.CollectionName("test")) + or.mdi.On("GetChartHistogram", mock.Anything, "ns", intervals, database.CollectionName("test")).Return(mockHistogram, nil) + _, err := or.GetChartHistogram(context.Background(), 1000000000, 1000000010, 10, database.CollectionName("test")) assert.NoError(t, err) } diff --git a/internal/orchestrator/orchestrator.go b/internal/orchestrator/orchestrator.go index 7c5d20377..220830f7c 100644 --- a/internal/orchestrator/orchestrator.go +++ b/internal/orchestrator/orchestrator.go @@ -108,7 +108,7 @@ type Orchestrator interface { GetPins(ctx context.Context, filter database.AndFilter) ([]*core.Pin, *database.FilterResult, error) // Charts - GetChartHistogram(ctx context.Context, ns string, startTime int64, endTime int64, buckets int64, tableName database.CollectionName) ([]*core.ChartHistogram, error) + GetChartHistogram(ctx context.Context, startTime int64, endTime int64, buckets int64, tableName database.CollectionName) ([]*core.ChartHistogram, error) // Message Routing RequestReply(ctx context.Context, msg *core.MessageInOut) (reply *core.MessageInOut, err error) diff --git a/internal/shareddownload/download_manager.go b/internal/shareddownload/download_manager.go index 1b5b1bfd5..1ba1e7003 100644 --- a/internal/shareddownload/download_manager.go +++ b/internal/shareddownload/download_manager.go @@ -39,8 +39,8 @@ type Manager interface { Start() error WaitStop() - InitiateDownloadBatch(ctx context.Context, ns string, tx *fftypes.UUID, payloadRef string) error - InitiateDownloadBlob(ctx context.Context, ns string, tx *fftypes.UUID, dataID *fftypes.UUID, payloadRef string) error + InitiateDownloadBatch(ctx context.Context, tx *fftypes.UUID, payloadRef string) error + InitiateDownloadBlob(ctx context.Context, tx *fftypes.UUID, dataID *fftypes.UUID, payloadRef string) error } // downloadManager operates a number of workers that can perform downloads/retries. Each download @@ -75,7 +75,7 @@ type downloadWork struct { } type Callbacks interface { - SharedStorageBatchDownloaded(ns string, payloadRef string, data []byte) (batchID *fftypes.UUID, err error) + SharedStorageBatchDownloaded(payloadRef string, data []byte) (batchID *fftypes.UUID, err error) SharedStorageBlobDownloaded(hash fftypes.Bytes32, size int64, payloadRef string) } @@ -221,16 +221,16 @@ func (dm *downloadManager) waitAndRetryDownload(work *downloadWork) { dm.dispatchWork(work) } -func (dm *downloadManager) InitiateDownloadBatch(ctx context.Context, ns string, tx *fftypes.UUID, payloadRef string) error { - op := core.NewOperation(dm.sharedstorage, ns, tx, core.OpTypeSharedStorageDownloadBatch) - addDownloadBatchInputs(op, ns, payloadRef) - return dm.createAndDispatchOp(ctx, op, opDownloadBatch(op, ns, payloadRef)) +func (dm *downloadManager) InitiateDownloadBatch(ctx context.Context, tx *fftypes.UUID, payloadRef string) error { + op := core.NewOperation(dm.sharedstorage, dm.namespace, tx, core.OpTypeSharedStorageDownloadBatch) + addDownloadBatchInputs(op, payloadRef) + return dm.createAndDispatchOp(ctx, op, opDownloadBatch(op, payloadRef)) } -func (dm *downloadManager) InitiateDownloadBlob(ctx context.Context, ns string, tx *fftypes.UUID, dataID *fftypes.UUID, payloadRef string) error { - op := core.NewOperation(dm.sharedstorage, ns, tx, core.OpTypeSharedStorageDownloadBlob) - addDownloadBlobInputs(op, ns, dataID, payloadRef) - return dm.createAndDispatchOp(ctx, op, opDownloadBlob(op, ns, dataID, payloadRef)) +func (dm *downloadManager) InitiateDownloadBlob(ctx context.Context, tx *fftypes.UUID, dataID *fftypes.UUID, payloadRef string) error { + op := core.NewOperation(dm.sharedstorage, dm.namespace, tx, core.OpTypeSharedStorageDownloadBlob) + addDownloadBlobInputs(op, dataID, payloadRef) + return dm.createAndDispatchOp(ctx, op, opDownloadBlob(op, dataID, payloadRef)) } func (dm *downloadManager) createAndDispatchOp(ctx context.Context, op *core.Operation, preparedOp *core.PreparedOperation) error { diff --git a/internal/shareddownload/download_manager_test.go b/internal/shareddownload/download_manager_test.go index c980dc645..407f6a87e 100644 --- a/internal/shareddownload/download_manager_test.go +++ b/internal/shareddownload/download_manager_test.go @@ -97,9 +97,9 @@ func TestDownloadBatchE2EOk(t *testing.T) { }).Return(nil) mci := dm.callbacks.(*shareddownloadmocks.Callbacks) - mci.On("SharedStorageBatchDownloaded", "ns1", "ref1", []byte("some batch data")).Return(batchID, nil) + mci.On("SharedStorageBatchDownloaded", "ref1", []byte("some batch data")).Return(batchID, nil) - err := dm.InitiateDownloadBatch(dm.ctx, "ns1", txID, "ref1") + err := dm.InitiateDownloadBatch(dm.ctx, txID, "ref1") assert.NoError(t, err) <-called @@ -151,7 +151,7 @@ func TestDownloadBlobWithRetryOk(t *testing.T) { mci := dm.callbacks.(*shareddownloadmocks.Callbacks) mci.On("SharedStorageBlobDownloaded", *blobHash, int64(12345), "privateRef1").Return() - err := dm.InitiateDownloadBlob(dm.ctx, "ns1", txID, dataID, "ref1") + err := dm.InitiateDownloadBlob(dm.ctx, txID, dataID, "ref1") assert.NoError(t, err) <-called @@ -177,7 +177,7 @@ func TestDownloadBlobInsertOpFail(t *testing.T) { mdi := dm.database.(*databasemocks.Plugin) mdi.On("InsertOperation", mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("pop")) - err := dm.InitiateDownloadBlob(dm.ctx, "ns1", txID, dataID, "ref1") + err := dm.InitiateDownloadBlob(dm.ctx, txID, dataID, "ref1") assert.Regexp(t, "pop", err) mdi.AssertExpectations(t) @@ -253,7 +253,7 @@ func TestDownloadManagerStartupRecoveryCombinations(t *testing.T) { }).Return(nil) mci := dm.callbacks.(*shareddownloadmocks.Callbacks) - mci.On("SharedStorageBatchDownloaded", "ns1", "ref2", []byte("some batch data")).Return(batchID, nil) + mci.On("SharedStorageBatchDownloaded", "ref2", []byte("some batch data")).Return(batchID, nil) err := dm.Start() assert.NoError(t, err) diff --git a/internal/shareddownload/operations.go b/internal/shareddownload/operations.go index 6eb8bfe88..8bf3c4654 100644 --- a/internal/shareddownload/operations.go +++ b/internal/shareddownload/operations.go @@ -31,19 +31,16 @@ import ( ) type downloadBatchData struct { - Namespace string `json:"namespace"` PayloadRef string `json:"payloadRef"` } type downloadBlobData struct { - Namespace string `json:"namespace"` DataID *fftypes.UUID `json:"dataId"` PayloadRef string `json:"payloadRef"` } -func addDownloadBatchInputs(op *core.Operation, ns, payloadRef string) { +func addDownloadBatchInputs(op *core.Operation, payloadRef string) { op.Input = fftypes.JSONObject{ - "namespace": ns, "payloadRef": payloadRef, } } @@ -54,9 +51,8 @@ func getDownloadBatchOutputs(batchID *fftypes.UUID) fftypes.JSONObject { } } -func addDownloadBlobInputs(op *core.Operation, ns string, dataID *fftypes.UUID, payloadRef string) { +func addDownloadBlobInputs(op *core.Operation, dataID *fftypes.UUID, payloadRef string) { op.Input = fftypes.JSONObject{ - "namespace": ns, "dataId": dataID.String(), "payloadRef": payloadRef, } @@ -70,16 +66,14 @@ func getDownloadBlobOutputs(hash *fftypes.Bytes32, size int64, dxPayloadRef stri } } -func retrieveDownloadBatchInputs(op *core.Operation) (string, string) { - return op.Input.GetString("namespace"), - op.Input.GetString("payloadRef") +func retrieveDownloadBatchInputs(op *core.Operation) (payloadRef string) { + return op.Input.GetString("payloadRef") } -func retrieveDownloadBlobInputs(ctx context.Context, op *core.Operation) (namespace string, dataID *fftypes.UUID, payloadRef string, err error) { - namespace = op.Input.GetString("namespace") +func retrieveDownloadBlobInputs(ctx context.Context, op *core.Operation) (dataID *fftypes.UUID, payloadRef string, err error) { dataID, err = fftypes.ParseUUID(ctx, op.Input.GetString("dataId")) if err != nil { - return "", nil, "", err + return nil, "", err } payloadRef = op.Input.GetString("payloadRef") return @@ -89,15 +83,15 @@ func (dm *downloadManager) PrepareOperation(ctx context.Context, op *core.Operat switch op.Type { case core.OpTypeSharedStorageDownloadBatch: - namespace, payloadRef := retrieveDownloadBatchInputs(op) - return opDownloadBatch(op, namespace, payloadRef), nil + payloadRef := retrieveDownloadBatchInputs(op) + return opDownloadBatch(op, payloadRef), nil case core.OpTypeSharedStorageDownloadBlob: - namespace, dataID, payloadRef, err := retrieveDownloadBlobInputs(ctx, op) + dataID, payloadRef, err := retrieveDownloadBlobInputs(ctx, op) if err != nil { return nil, err } - return opDownloadBlob(op, namespace, dataID, payloadRef), nil + return opDownloadBlob(op, dataID, payloadRef), nil default: return nil, i18n.NewError(ctx, coremsgs.MsgOperationNotSupported, op.Type) @@ -138,7 +132,7 @@ func (dm *downloadManager) downloadBatch(ctx context.Context, data downloadBatch } // Parse and store the batch - batchID, err := dm.callbacks.SharedStorageBatchDownloaded(data.Namespace, data.PayloadRef, batchBytes) + batchID, err := dm.callbacks.SharedStorageBatchDownloaded(data.PayloadRef, batchBytes) if err != nil { return nil, false, err } @@ -155,7 +149,7 @@ func (dm *downloadManager) downloadBlob(ctx context.Context, data downloadBlobDa defer reader.Close() // ... to data exchange - dxPayloadRef, hash, blobSize, err := dm.dataexchange.UploadBlob(ctx, data.Namespace, *data.DataID, reader) + dxPayloadRef, hash, blobSize, err := dm.dataexchange.UploadBlob(ctx, dm.namespace, *data.DataID, reader) if err != nil { return nil, false, i18n.WrapError(ctx, err, coremsgs.MsgDownloadSharedFailed, data.PayloadRef) } @@ -171,25 +165,23 @@ func (dm *downloadManager) OnOperationUpdate(ctx context.Context, op *core.Opera return nil } -func opDownloadBatch(op *core.Operation, ns string, payloadRef string) *core.PreparedOperation { +func opDownloadBatch(op *core.Operation, payloadRef string) *core.PreparedOperation { return &core.PreparedOperation{ ID: op.ID, Namespace: op.Namespace, Type: op.Type, Data: downloadBatchData{ - Namespace: ns, PayloadRef: payloadRef, }, } } -func opDownloadBlob(op *core.Operation, ns string, dataID *fftypes.UUID, payloadRef string) *core.PreparedOperation { +func opDownloadBlob(op *core.Operation, dataID *fftypes.UUID, payloadRef string) *core.PreparedOperation { return &core.PreparedOperation{ ID: op.ID, Namespace: op.Namespace, Type: op.Type, Data: downloadBlobData{ - Namespace: ns, DataID: dataID, PayloadRef: payloadRef, }, diff --git a/internal/shareddownload/operations_test.go b/internal/shareddownload/operations_test.go index 793a5bba4..f6c71e1d6 100644 --- a/internal/shareddownload/operations_test.go +++ b/internal/shareddownload/operations_test.go @@ -42,7 +42,6 @@ func TestDownloadBatchDownloadDataFail(t *testing.T) { mss.On("DownloadData", mock.Anything, "ref1").Return(nil, fmt.Errorf("pop")) _, _, err := dm.downloadBatch(dm.ctx, downloadBatchData{ - Namespace: "ns1", PayloadRef: "ref1", }) assert.Regexp(t, "FF10376", err) @@ -61,7 +60,6 @@ func TestDownloadBatchDownloadDataReadFail(t *testing.T) { mss.On("DownloadData", mock.Anything, "ref1").Return(reader, nil) _, _, err := dm.downloadBatch(dm.ctx, downloadBatchData{ - Namespace: "ns1", PayloadRef: "ref1", }) assert.Regexp(t, "FF10376", err) @@ -81,7 +79,6 @@ func TestDownloadBatchDownloadDataReadMaxedOut(t *testing.T) { mss.On("DownloadData", mock.Anything, "ref1").Return(reader, nil) _, _, err := dm.downloadBatch(dm.ctx, downloadBatchData{ - Namespace: "ns1", PayloadRef: "ref1", }) assert.Regexp(t, "FF10377", err) @@ -100,10 +97,9 @@ func TestDownloadBatchDownloadCallbackFailed(t *testing.T) { mss.On("DownloadData", mock.Anything, "ref1").Return(reader, nil) mci := dm.callbacks.(*shareddownloadmocks.Callbacks) - mci.On("SharedStorageBatchDownloaded", "ns1", "ref1", []byte("some batch data")).Return(nil, fmt.Errorf("pop")) + mci.On("SharedStorageBatchDownloaded", "ref1", []byte("some batch data")).Return(nil, fmt.Errorf("pop")) _, _, err := dm.downloadBatch(dm.ctx, downloadBatchData{ - Namespace: "ns1", PayloadRef: "ref1", }) assert.Regexp(t, "pop", err) @@ -126,7 +122,6 @@ func TestDownloadBlobDownloadDataReadFail(t *testing.T) { mdx.On("UploadBlob", mock.Anything, "ns1", mock.Anything, reader).Return("", nil, int64(-1), fmt.Errorf("pop")) _, _, err := dm.downloadBlob(dm.ctx, downloadBlobData{ - Namespace: "ns1", PayloadRef: "ref1", DataID: fftypes.NewUUID(), }) diff --git a/mocks/blockchainmocks/plugin.go b/mocks/blockchainmocks/plugin.go index bc6cf7a69..e90c53882 100644 --- a/mocks/blockchainmocks/plugin.go +++ b/mocks/blockchainmocks/plugin.go @@ -312,11 +312,6 @@ func (_m *Plugin) QueryContract(ctx context.Context, location *fftypes.JSONAny, return r0, r1 } -// SetHandler provides a mock function with given fields: handler -func (_m *Plugin) SetHandler(handler blockchain.Callbacks) { - _m.Called(handler) -} - // RemoveFireflySubscription provides a mock function with given fields: ctx, subID func (_m *Plugin) RemoveFireflySubscription(ctx context.Context, subID string) error { ret := _m.Called(ctx, subID) @@ -331,6 +326,11 @@ func (_m *Plugin) RemoveFireflySubscription(ctx context.Context, subID string) e return r0 } +// SetHandler provides a mock function with given fields: handler +func (_m *Plugin) SetHandler(handler blockchain.Callbacks) { + _m.Called(handler) +} + // Start provides a mock function with given fields: func (_m *Plugin) Start() error { ret := _m.Called() diff --git a/mocks/broadcastmocks/manager.go b/mocks/broadcastmocks/manager.go index 94d171d20..522802d6e 100644 --- a/mocks/broadcastmocks/manager.go +++ b/mocks/broadcastmocks/manager.go @@ -18,13 +18,13 @@ type Manager struct { mock.Mock } -// BroadcastDatatype provides a mock function with given fields: ctx, ns, datatype, waitConfirm -func (_m *Manager) BroadcastDatatype(ctx context.Context, ns string, datatype *core.Datatype, waitConfirm bool) (*core.Message, error) { - ret := _m.Called(ctx, ns, datatype, waitConfirm) +// BroadcastDatatype provides a mock function with given fields: ctx, datatype, waitConfirm +func (_m *Manager) BroadcastDatatype(ctx context.Context, datatype *core.Datatype, waitConfirm bool) (*core.Message, error) { + ret := _m.Called(ctx, datatype, waitConfirm) var r0 *core.Message - if rf, ok := ret.Get(0).(func(context.Context, string, *core.Datatype, bool) *core.Message); ok { - r0 = rf(ctx, ns, datatype, waitConfirm) + if rf, ok := ret.Get(0).(func(context.Context, *core.Datatype, bool) *core.Message); ok { + r0 = rf(ctx, datatype, waitConfirm) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.Message) @@ -32,8 +32,8 @@ func (_m *Manager) BroadcastDatatype(ctx context.Context, ns string, datatype *c } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, *core.Datatype, bool) error); ok { - r1 = rf(ctx, ns, datatype, waitConfirm) + if rf, ok := ret.Get(1).(func(context.Context, *core.Datatype, bool) error); ok { + r1 = rf(ctx, datatype, waitConfirm) } else { r1 = ret.Error(1) } @@ -41,13 +41,13 @@ func (_m *Manager) BroadcastDatatype(ctx context.Context, ns string, datatype *c return r0, r1 } -// BroadcastDefinition provides a mock function with given fields: ctx, ns, def, signingIdentity, tag, waitConfirm -func (_m *Manager) BroadcastDefinition(ctx context.Context, ns string, def core.Definition, signingIdentity *core.SignerRef, tag string, waitConfirm bool) (*core.Message, error) { - ret := _m.Called(ctx, ns, def, signingIdentity, tag, waitConfirm) +// BroadcastDefinition provides a mock function with given fields: ctx, def, signingIdentity, tag, waitConfirm +func (_m *Manager) BroadcastDefinition(ctx context.Context, def core.Definition, signingIdentity *core.SignerRef, tag string, waitConfirm bool) (*core.Message, error) { + ret := _m.Called(ctx, def, signingIdentity, tag, waitConfirm) var r0 *core.Message - if rf, ok := ret.Get(0).(func(context.Context, string, core.Definition, *core.SignerRef, string, bool) *core.Message); ok { - r0 = rf(ctx, ns, def, signingIdentity, tag, waitConfirm) + if rf, ok := ret.Get(0).(func(context.Context, core.Definition, *core.SignerRef, string, bool) *core.Message); ok { + r0 = rf(ctx, def, signingIdentity, tag, waitConfirm) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.Message) @@ -55,8 +55,8 @@ func (_m *Manager) BroadcastDefinition(ctx context.Context, ns string, def core. } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, core.Definition, *core.SignerRef, string, bool) error); ok { - r1 = rf(ctx, ns, def, signingIdentity, tag, waitConfirm) + if rf, ok := ret.Get(1).(func(context.Context, core.Definition, *core.SignerRef, string, bool) error); ok { + r1 = rf(ctx, def, signingIdentity, tag, waitConfirm) } else { r1 = ret.Error(1) } @@ -64,13 +64,13 @@ func (_m *Manager) BroadcastDefinition(ctx context.Context, ns string, def core. return r0, r1 } -// BroadcastDefinitionAsNode provides a mock function with given fields: ctx, ns, def, tag, waitConfirm -func (_m *Manager) BroadcastDefinitionAsNode(ctx context.Context, ns string, def core.Definition, tag string, waitConfirm bool) (*core.Message, error) { - ret := _m.Called(ctx, ns, def, tag, waitConfirm) +// BroadcastDefinitionAsNode provides a mock function with given fields: ctx, def, tag, waitConfirm +func (_m *Manager) BroadcastDefinitionAsNode(ctx context.Context, def core.Definition, tag string, waitConfirm bool) (*core.Message, error) { + ret := _m.Called(ctx, def, tag, waitConfirm) var r0 *core.Message - if rf, ok := ret.Get(0).(func(context.Context, string, core.Definition, string, bool) *core.Message); ok { - r0 = rf(ctx, ns, def, tag, waitConfirm) + if rf, ok := ret.Get(0).(func(context.Context, core.Definition, string, bool) *core.Message); ok { + r0 = rf(ctx, def, tag, waitConfirm) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.Message) @@ -78,8 +78,8 @@ func (_m *Manager) BroadcastDefinitionAsNode(ctx context.Context, ns string, def } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, core.Definition, string, bool) error); ok { - r1 = rf(ctx, ns, def, tag, waitConfirm) + if rf, ok := ret.Get(1).(func(context.Context, core.Definition, string, bool) error); ok { + r1 = rf(ctx, def, tag, waitConfirm) } else { r1 = ret.Error(1) } @@ -87,13 +87,13 @@ func (_m *Manager) BroadcastDefinitionAsNode(ctx context.Context, ns string, def return r0, r1 } -// BroadcastIdentityClaim provides a mock function with given fields: ctx, ns, def, signingIdentity, tag, waitConfirm -func (_m *Manager) BroadcastIdentityClaim(ctx context.Context, ns string, def *core.IdentityClaim, signingIdentity *core.SignerRef, tag string, waitConfirm bool) (*core.Message, error) { - ret := _m.Called(ctx, ns, def, signingIdentity, tag, waitConfirm) +// BroadcastIdentityClaim provides a mock function with given fields: ctx, def, signingIdentity, tag, waitConfirm +func (_m *Manager) BroadcastIdentityClaim(ctx context.Context, def *core.IdentityClaim, signingIdentity *core.SignerRef, tag string, waitConfirm bool) (*core.Message, error) { + ret := _m.Called(ctx, def, signingIdentity, tag, waitConfirm) var r0 *core.Message - if rf, ok := ret.Get(0).(func(context.Context, string, *core.IdentityClaim, *core.SignerRef, string, bool) *core.Message); ok { - r0 = rf(ctx, ns, def, signingIdentity, tag, waitConfirm) + if rf, ok := ret.Get(0).(func(context.Context, *core.IdentityClaim, *core.SignerRef, string, bool) *core.Message); ok { + r0 = rf(ctx, def, signingIdentity, tag, waitConfirm) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.Message) @@ -101,8 +101,8 @@ func (_m *Manager) BroadcastIdentityClaim(ctx context.Context, ns string, def *c } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, *core.IdentityClaim, *core.SignerRef, string, bool) error); ok { - r1 = rf(ctx, ns, def, signingIdentity, tag, waitConfirm) + if rf, ok := ret.Get(1).(func(context.Context, *core.IdentityClaim, *core.SignerRef, string, bool) error); ok { + r1 = rf(ctx, def, signingIdentity, tag, waitConfirm) } else { r1 = ret.Error(1) } @@ -133,13 +133,13 @@ func (_m *Manager) BroadcastMessage(ctx context.Context, in *core.MessageInOut, return r0, r1 } -// BroadcastTokenPool provides a mock function with given fields: ctx, ns, pool, waitConfirm -func (_m *Manager) BroadcastTokenPool(ctx context.Context, ns string, pool *core.TokenPoolAnnouncement, waitConfirm bool) (*core.Message, error) { - ret := _m.Called(ctx, ns, pool, waitConfirm) +// BroadcastTokenPool provides a mock function with given fields: ctx, pool, waitConfirm +func (_m *Manager) BroadcastTokenPool(ctx context.Context, pool *core.TokenPoolAnnouncement, waitConfirm bool) (*core.Message, error) { + ret := _m.Called(ctx, pool, waitConfirm) var r0 *core.Message - if rf, ok := ret.Get(0).(func(context.Context, string, *core.TokenPoolAnnouncement, bool) *core.Message); ok { - r0 = rf(ctx, ns, pool, waitConfirm) + if rf, ok := ret.Get(0).(func(context.Context, *core.TokenPoolAnnouncement, bool) *core.Message); ok { + r0 = rf(ctx, pool, waitConfirm) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.Message) @@ -147,8 +147,8 @@ func (_m *Manager) BroadcastTokenPool(ctx context.Context, ns string, pool *core } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, *core.TokenPoolAnnouncement, bool) error); ok { - r1 = rf(ctx, ns, pool, waitConfirm) + if rf, ok := ret.Get(1).(func(context.Context, *core.TokenPoolAnnouncement, bool) error); ok { + r1 = rf(ctx, pool, waitConfirm) } else { r1 = ret.Error(1) } diff --git a/mocks/datamocks/manager.go b/mocks/datamocks/manager.go index ea2294476..d74159d3e 100644 --- a/mocks/datamocks/manager.go +++ b/mocks/datamocks/manager.go @@ -22,13 +22,13 @@ type Manager struct { mock.Mock } -// CheckDatatype provides a mock function with given fields: ctx, ns, datatype -func (_m *Manager) CheckDatatype(ctx context.Context, ns string, datatype *core.Datatype) error { - ret := _m.Called(ctx, ns, datatype) +// CheckDatatype provides a mock function with given fields: ctx, datatype +func (_m *Manager) CheckDatatype(ctx context.Context, datatype *core.Datatype) error { + ret := _m.Called(ctx, datatype) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, *core.Datatype) error); ok { - r0 = rf(ctx, ns, datatype) + if rf, ok := ret.Get(0).(func(context.Context, *core.Datatype) error); ok { + r0 = rf(ctx, datatype) } else { r0 = ret.Error(0) } @@ -36,13 +36,13 @@ func (_m *Manager) CheckDatatype(ctx context.Context, ns string, datatype *core. return r0 } -// DownloadBlob provides a mock function with given fields: ctx, ns, dataID -func (_m *Manager) DownloadBlob(ctx context.Context, ns string, dataID string) (*core.Blob, io.ReadCloser, error) { - ret := _m.Called(ctx, ns, dataID) +// DownloadBlob provides a mock function with given fields: ctx, dataID +func (_m *Manager) DownloadBlob(ctx context.Context, dataID string) (*core.Blob, io.ReadCloser, error) { + ret := _m.Called(ctx, dataID) var r0 *core.Blob - if rf, ok := ret.Get(0).(func(context.Context, string, string) *core.Blob); ok { - r0 = rf(ctx, ns, dataID) + if rf, ok := ret.Get(0).(func(context.Context, string) *core.Blob); ok { + r0 = rf(ctx, dataID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.Blob) @@ -50,8 +50,8 @@ func (_m *Manager) DownloadBlob(ctx context.Context, ns string, dataID string) ( } var r1 io.ReadCloser - if rf, ok := ret.Get(1).(func(context.Context, string, string) io.ReadCloser); ok { - r1 = rf(ctx, ns, dataID) + if rf, ok := ret.Get(1).(func(context.Context, string) io.ReadCloser); ok { + r1 = rf(ctx, dataID) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(io.ReadCloser) @@ -59,8 +59,8 @@ func (_m *Manager) DownloadBlob(ctx context.Context, ns string, dataID string) ( } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, string, string) error); ok { - r2 = rf(ctx, ns, dataID) + if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { + r2 = rf(ctx, dataID) } else { r2 = ret.Error(2) } @@ -235,13 +235,13 @@ func (_m *Manager) UpdateMessageStateIfCached(ctx context.Context, id *fftypes.U _m.Called(ctx, id, state, confirmed) } -// UploadBlob provides a mock function with given fields: ctx, ns, inData, blob, autoMeta -func (_m *Manager) UploadBlob(ctx context.Context, ns string, inData *core.DataRefOrValue, blob *ffapi.Multipart, autoMeta bool) (*core.Data, error) { - ret := _m.Called(ctx, ns, inData, blob, autoMeta) +// UploadBlob provides a mock function with given fields: ctx, inData, blob, autoMeta +func (_m *Manager) UploadBlob(ctx context.Context, inData *core.DataRefOrValue, blob *ffapi.Multipart, autoMeta bool) (*core.Data, error) { + ret := _m.Called(ctx, inData, blob, autoMeta) var r0 *core.Data - if rf, ok := ret.Get(0).(func(context.Context, string, *core.DataRefOrValue, *ffapi.Multipart, bool) *core.Data); ok { - r0 = rf(ctx, ns, inData, blob, autoMeta) + if rf, ok := ret.Get(0).(func(context.Context, *core.DataRefOrValue, *ffapi.Multipart, bool) *core.Data); ok { + r0 = rf(ctx, inData, blob, autoMeta) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.Data) @@ -249,8 +249,8 @@ func (_m *Manager) UploadBlob(ctx context.Context, ns string, inData *core.DataR } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, *core.DataRefOrValue, *ffapi.Multipart, bool) error); ok { - r1 = rf(ctx, ns, inData, blob, autoMeta) + if rf, ok := ret.Get(1).(func(context.Context, *core.DataRefOrValue, *ffapi.Multipart, bool) error); ok { + r1 = rf(ctx, inData, blob, autoMeta) } else { r1 = ret.Error(1) } @@ -258,13 +258,13 @@ func (_m *Manager) UploadBlob(ctx context.Context, ns string, inData *core.DataR return r0, r1 } -// UploadJSON provides a mock function with given fields: ctx, ns, inData -func (_m *Manager) UploadJSON(ctx context.Context, ns string, inData *core.DataRefOrValue) (*core.Data, error) { - ret := _m.Called(ctx, ns, inData) +// UploadJSON provides a mock function with given fields: ctx, inData +func (_m *Manager) UploadJSON(ctx context.Context, inData *core.DataRefOrValue) (*core.Data, error) { + ret := _m.Called(ctx, inData) var r0 *core.Data - if rf, ok := ret.Get(0).(func(context.Context, string, *core.DataRefOrValue) *core.Data); ok { - r0 = rf(ctx, ns, inData) + if rf, ok := ret.Get(0).(func(context.Context, *core.DataRefOrValue) *core.Data); ok { + r0 = rf(ctx, inData) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.Data) @@ -272,8 +272,8 @@ func (_m *Manager) UploadJSON(ctx context.Context, ns string, inData *core.DataR } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, *core.DataRefOrValue) error); ok { - r1 = rf(ctx, ns, inData) + if rf, ok := ret.Get(1).(func(context.Context, *core.DataRefOrValue) error); ok { + r1 = rf(ctx, inData) } else { r1 = ret.Error(1) } diff --git a/mocks/eventmocks/event_manager.go b/mocks/eventmocks/event_manager.go index f309d3451..da4e2f52b 100644 --- a/mocks/eventmocks/event_manager.go +++ b/mocks/eventmocks/event_manager.go @@ -196,13 +196,13 @@ func (_m *EventManager) NewSubscriptions() chan<- *fftypes.UUID { return r0 } -// SharedStorageBatchDownloaded provides a mock function with given fields: ss, ns, payloadRef, data -func (_m *EventManager) SharedStorageBatchDownloaded(ss sharedstorage.Plugin, ns string, payloadRef string, data []byte) (*fftypes.UUID, error) { - ret := _m.Called(ss, ns, payloadRef, data) +// SharedStorageBatchDownloaded provides a mock function with given fields: ss, payloadRef, data +func (_m *EventManager) SharedStorageBatchDownloaded(ss sharedstorage.Plugin, payloadRef string, data []byte) (*fftypes.UUID, error) { + ret := _m.Called(ss, payloadRef, data) var r0 *fftypes.UUID - if rf, ok := ret.Get(0).(func(sharedstorage.Plugin, string, string, []byte) *fftypes.UUID); ok { - r0 = rf(ss, ns, payloadRef, data) + if rf, ok := ret.Get(0).(func(sharedstorage.Plugin, string, []byte) *fftypes.UUID); ok { + r0 = rf(ss, payloadRef, data) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*fftypes.UUID) @@ -210,8 +210,8 @@ func (_m *EventManager) SharedStorageBatchDownloaded(ss sharedstorage.Plugin, ns } var r1 error - if rf, ok := ret.Get(1).(func(sharedstorage.Plugin, string, string, []byte) error); ok { - r1 = rf(ss, ns, payloadRef, data) + if rf, ok := ret.Get(1).(func(sharedstorage.Plugin, string, []byte) error); ok { + r1 = rf(ss, payloadRef, data) } else { r1 = ret.Error(1) } diff --git a/mocks/operationmocks/manager.go b/mocks/operationmocks/manager.go index 635d59e68..f0642aab4 100644 --- a/mocks/operationmocks/manager.go +++ b/mocks/operationmocks/manager.go @@ -62,13 +62,13 @@ func (_m *Manager) RegisterHandler(ctx context.Context, handler operations.Opera _m.Called(ctx, handler, ops) } -// ResolveOperationByID provides a mock function with given fields: ctx, ns, opID, op -func (_m *Manager) ResolveOperationByID(ctx context.Context, ns string, opID *fftypes.UUID, op *core.OperationUpdateDTO) error { - ret := _m.Called(ctx, ns, opID, op) +// ResolveOperationByID provides a mock function with given fields: ctx, opID, op +func (_m *Manager) ResolveOperationByID(ctx context.Context, opID *fftypes.UUID, op *core.OperationUpdateDTO) error { + ret := _m.Called(ctx, opID, op) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, *fftypes.UUID, *core.OperationUpdateDTO) error); ok { - r0 = rf(ctx, ns, opID, op) + if rf, ok := ret.Get(0).(func(context.Context, *fftypes.UUID, *core.OperationUpdateDTO) error); ok { + r0 = rf(ctx, opID, op) } else { r0 = ret.Error(0) } @@ -76,13 +76,13 @@ func (_m *Manager) ResolveOperationByID(ctx context.Context, ns string, opID *ff return r0 } -// RetryOperation provides a mock function with given fields: ctx, ns, opID -func (_m *Manager) RetryOperation(ctx context.Context, ns string, opID *fftypes.UUID) (*core.Operation, error) { - ret := _m.Called(ctx, ns, opID) +// RetryOperation provides a mock function with given fields: ctx, opID +func (_m *Manager) RetryOperation(ctx context.Context, opID *fftypes.UUID) (*core.Operation, error) { + ret := _m.Called(ctx, opID) var r0 *core.Operation - if rf, ok := ret.Get(0).(func(context.Context, string, *fftypes.UUID) *core.Operation); ok { - r0 = rf(ctx, ns, opID) + if rf, ok := ret.Get(0).(func(context.Context, *fftypes.UUID) *core.Operation); ok { + r0 = rf(ctx, opID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*core.Operation) @@ -90,8 +90,8 @@ func (_m *Manager) RetryOperation(ctx context.Context, ns string, opID *fftypes. } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, *fftypes.UUID) error); ok { - r1 = rf(ctx, ns, opID) + if rf, ok := ret.Get(1).(func(context.Context, *fftypes.UUID) error); ok { + r1 = rf(ctx, opID) } else { r1 = ret.Error(1) } diff --git a/mocks/orchestratormocks/orchestrator.go b/mocks/orchestratormocks/orchestrator.go index fce2e6f54..c31c137a2 100644 --- a/mocks/orchestratormocks/orchestrator.go +++ b/mocks/orchestratormocks/orchestrator.go @@ -302,13 +302,13 @@ func (_m *Orchestrator) GetBlockchainEvents(ctx context.Context, filter database return r0, r1, r2 } -// GetChartHistogram provides a mock function with given fields: ctx, ns, startTime, endTime, buckets, tableName -func (_m *Orchestrator) GetChartHistogram(ctx context.Context, ns string, startTime int64, endTime int64, buckets int64, tableName database.CollectionName) ([]*core.ChartHistogram, error) { - ret := _m.Called(ctx, ns, startTime, endTime, buckets, tableName) +// GetChartHistogram provides a mock function with given fields: ctx, startTime, endTime, buckets, tableName +func (_m *Orchestrator) GetChartHistogram(ctx context.Context, startTime int64, endTime int64, buckets int64, tableName database.CollectionName) ([]*core.ChartHistogram, error) { + ret := _m.Called(ctx, startTime, endTime, buckets, tableName) var r0 []*core.ChartHistogram - if rf, ok := ret.Get(0).(func(context.Context, string, int64, int64, int64, database.CollectionName) []*core.ChartHistogram); ok { - r0 = rf(ctx, ns, startTime, endTime, buckets, tableName) + if rf, ok := ret.Get(0).(func(context.Context, int64, int64, int64, database.CollectionName) []*core.ChartHistogram); ok { + r0 = rf(ctx, startTime, endTime, buckets, tableName) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*core.ChartHistogram) @@ -316,8 +316,8 @@ func (_m *Orchestrator) GetChartHistogram(ctx context.Context, ns string, startT } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, int64, int64, int64, database.CollectionName) error); ok { - r1 = rf(ctx, ns, startTime, endTime, buckets, tableName) + if rf, ok := ret.Get(1).(func(context.Context, int64, int64, int64, database.CollectionName) error); ok { + r1 = rf(ctx, startTime, endTime, buckets, tableName) } else { r1 = ret.Error(1) } diff --git a/mocks/shareddownloadmocks/callbacks.go b/mocks/shareddownloadmocks/callbacks.go index d29a924a4..8bb219a38 100644 --- a/mocks/shareddownloadmocks/callbacks.go +++ b/mocks/shareddownloadmocks/callbacks.go @@ -12,13 +12,13 @@ type Callbacks struct { mock.Mock } -// SharedStorageBatchDownloaded provides a mock function with given fields: ns, payloadRef, data -func (_m *Callbacks) SharedStorageBatchDownloaded(ns string, payloadRef string, data []byte) (*fftypes.UUID, error) { - ret := _m.Called(ns, payloadRef, data) +// SharedStorageBatchDownloaded provides a mock function with given fields: payloadRef, data +func (_m *Callbacks) SharedStorageBatchDownloaded(payloadRef string, data []byte) (*fftypes.UUID, error) { + ret := _m.Called(payloadRef, data) var r0 *fftypes.UUID - if rf, ok := ret.Get(0).(func(string, string, []byte) *fftypes.UUID); ok { - r0 = rf(ns, payloadRef, data) + if rf, ok := ret.Get(0).(func(string, []byte) *fftypes.UUID); ok { + r0 = rf(payloadRef, data) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*fftypes.UUID) @@ -26,8 +26,8 @@ func (_m *Callbacks) SharedStorageBatchDownloaded(ns string, payloadRef string, } var r1 error - if rf, ok := ret.Get(1).(func(string, string, []byte) error); ok { - r1 = rf(ns, payloadRef, data) + if rf, ok := ret.Get(1).(func(string, []byte) error); ok { + r1 = rf(payloadRef, data) } else { r1 = ret.Error(1) } diff --git a/mocks/shareddownloadmocks/manager.go b/mocks/shareddownloadmocks/manager.go index 582168806..dc90f998b 100644 --- a/mocks/shareddownloadmocks/manager.go +++ b/mocks/shareddownloadmocks/manager.go @@ -14,13 +14,13 @@ type Manager struct { mock.Mock } -// InitiateDownloadBatch provides a mock function with given fields: ctx, ns, tx, payloadRef -func (_m *Manager) InitiateDownloadBatch(ctx context.Context, ns string, tx *fftypes.UUID, payloadRef string) error { - ret := _m.Called(ctx, ns, tx, payloadRef) +// InitiateDownloadBatch provides a mock function with given fields: ctx, tx, payloadRef +func (_m *Manager) InitiateDownloadBatch(ctx context.Context, tx *fftypes.UUID, payloadRef string) error { + ret := _m.Called(ctx, tx, payloadRef) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, *fftypes.UUID, string) error); ok { - r0 = rf(ctx, ns, tx, payloadRef) + if rf, ok := ret.Get(0).(func(context.Context, *fftypes.UUID, string) error); ok { + r0 = rf(ctx, tx, payloadRef) } else { r0 = ret.Error(0) } @@ -28,13 +28,13 @@ func (_m *Manager) InitiateDownloadBatch(ctx context.Context, ns string, tx *fft return r0 } -// InitiateDownloadBlob provides a mock function with given fields: ctx, ns, tx, dataID, payloadRef -func (_m *Manager) InitiateDownloadBlob(ctx context.Context, ns string, tx *fftypes.UUID, dataID *fftypes.UUID, payloadRef string) error { - ret := _m.Called(ctx, ns, tx, dataID, payloadRef) +// InitiateDownloadBlob provides a mock function with given fields: ctx, tx, dataID, payloadRef +func (_m *Manager) InitiateDownloadBlob(ctx context.Context, tx *fftypes.UUID, dataID *fftypes.UUID, payloadRef string) error { + ret := _m.Called(ctx, tx, dataID, payloadRef) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, *fftypes.UUID, *fftypes.UUID, string) error); ok { - r0 = rf(ctx, ns, tx, dataID, payloadRef) + if rf, ok := ret.Get(0).(func(context.Context, *fftypes.UUID, *fftypes.UUID, string) error); ok { + r0 = rf(ctx, tx, dataID, payloadRef) } else { r0 = ret.Error(0) }