diff --git a/hrpc/call.go b/hrpc/call.go index 4cd8d89c..8e84d56e 100644 --- a/hrpc/call.go +++ b/hrpc/call.go @@ -43,6 +43,7 @@ type RegionClient interface { Close() Addr() string QueueRPC(Call) + QueueBatch(context.Context, []Call) String() string } diff --git a/rpc.go b/rpc.go index 6f5d1216..ac6f426f 100644 --- a/rpc.go +++ b/rpc.go @@ -250,14 +250,9 @@ func (c *client) SendBatch(ctx context.Context, batch []hrpc.Call) ( ) wg.Add(len(rpcByClient)) for client, rpcs := range rpcByClient { - // TODO: Move this to the RegionClient interface so we don't - // need to type assert here - qb := client.(interface { - QueueBatch(ctx context.Context, rpcs []hrpc.Call) - }) go func(client hrpc.RegionClient, rpcs []hrpc.Call) { defer wg.Done() - qb.QueueBatch(ctx, rpcs) + client.QueueBatch(ctx, rpcs) ctx, sp := observability.StartSpan(ctx, "waitForResult") defer sp.End() ok := c.waitForCompletion(ctx, client, rpcs, res, rpcToRes) diff --git a/test/mock/region/client.go b/test/mock/region/client.go index 501d0a99..33bb4ce7 100644 --- a/test/mock/region/client.go +++ b/test/mock/region/client.go @@ -75,6 +75,18 @@ func (mr *MockRegionClientMockRecorder) Dial(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dial", reflect.TypeOf((*MockRegionClient)(nil).Dial), arg0) } +// QueueBatch mocks base method. +func (m *MockRegionClient) QueueBatch(arg0 context.Context, arg1 []hrpc.Call) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "QueueBatch", arg0, arg1) +} + +// QueueBatch indicates an expected call of QueueBatch. +func (mr *MockRegionClientMockRecorder) QueueBatch(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueBatch", reflect.TypeOf((*MockRegionClient)(nil).QueueBatch), arg0, arg1) +} + // QueueRPC mocks base method. func (m *MockRegionClient) QueueRPC(arg0 hrpc.Call) { m.ctrl.T.Helper()