From da105f533d71bb162d925a7a785edf395e1d97ce Mon Sep 17 00:00:00 2001 From: zwl <1633720889@qq.com> Date: Thu, 23 Mar 2023 08:52:25 +0800 Subject: [PATCH 1/4] =?UTF-8?q?merger:=20=E5=88=86=E9=A1=B5=E6=9F=A5?= =?UTF-8?q?=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/merger/limitmerger/merger.go | 130 ++++++ internal/merger/limitmerger/merger_test.go | 502 +++++++++++++++++++++ 2 files changed, 632 insertions(+) create mode 100644 internal/merger/limitmerger/merger.go create mode 100644 internal/merger/limitmerger/merger_test.go diff --git a/internal/merger/limitmerger/merger.go b/internal/merger/limitmerger/merger.go new file mode 100644 index 0000000..a146c1c --- /dev/null +++ b/internal/merger/limitmerger/merger.go @@ -0,0 +1,130 @@ +package limitmerger + +import ( + "context" + "database/sql" + "sync" + + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" +) + +type Merger struct { + m merger.Merger + limit int + offset int +} + +func NewMerger(m merger.Merger, offset int, limit int) *Merger { + return &Merger{ + m: m, + limit: limit, + offset: offset, + } +} + +func (m *Merger) Merge(ctx context.Context, results []*sql.Rows) (merger.Rows, error) { + rows, err := m.m.Merge(ctx, results) + if err != nil { + return nil, err + } + err = m.nextOffset(ctx, rows) + if err != nil { + return nil, err + } + return &Rows{ + rows: rows, + mu: &sync.RWMutex{}, + limit: m.limit, + }, nil +} + +func (m *Merger) nextOffset(ctx context.Context, rows merger.Rows) error { + offset := m.offset + for i := 0; i < offset; i++ { + if ctx.Err() != nil { + return ctx.Err() + } + // 如果偏移量超过rows结果集返回的行数,不会报错。用户最终查到0行 + if !rows.Next() { + if rows.Err() != nil { + return rows.Err() + } + break + } + } + return nil +} + +type Rows struct { + rows merger.Rows + limit int + cnt int + lastErr error + closed bool + mu *sync.RWMutex +} + +func (r *Rows) Next() bool { + r.mu.Lock() + if r.closed { + r.mu.Unlock() + return false + } + if r.cnt >= r.limit || r.lastErr != nil { + r.mu.Unlock() + _ = r.Close() + return false + } + canNext, err := r.next() + if err != nil { + r.lastErr = err + r.mu.Unlock() + _ = r.Close() + return false + } + if !canNext { + r.mu.Unlock() + _ = r.Close() + return canNext + } + r.cnt++ + r.mu.Unlock() + return canNext +} +func (r *Rows) next() (bool, error) { + if r.rows.Next() { + return true, nil + } + if r.rows.Err() != nil { + return false, r.rows.Err() + } + return false, nil +} + +func (r *Rows) Scan(dest ...any) error { + r.mu.RLock() + defer r.mu.RUnlock() + if r.lastErr != nil { + return r.lastErr + } + if r.closed { + return errs.ErrMergerRowsClosed + } + return r.rows.Scan(dest...) +} + +func (r *Rows) Close() error { + r.mu.Lock() + defer r.mu.Unlock() + r.closed = true + return r.rows.Close() +} + +func (r *Rows) Columns() ([]string, error) { + return r.rows.Columns() +} + +func (r *Rows) Err() error { + return r.lastErr +} diff --git a/internal/merger/limitmerger/merger_test.go b/internal/merger/limitmerger/merger_test.go new file mode 100644 index 0000000..1ca8cc6 --- /dev/null +++ b/internal/merger/limitmerger/merger_test.go @@ -0,0 +1,502 @@ +package limitmerger + +import ( + "context" + "database/sql" + "errors" + "fmt" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" + "github.com/ecodeclub/eorm/internal/merger/sortmerger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "go.uber.org/multierr" +) + +var ( + offsetMockErr error = errors.New("rows: MockOffsetErr") + limitMockErr error = errors.New("rows: MockLimitErr") +) + +func newCloseMockErr(dbName string) error { + return fmt.Errorf("rows: %s MockCloseErr", dbName) +} + +type MergerSuite struct { + suite.Suite + mockDB01 *sql.DB + mock01 sqlmock.Sqlmock + mockDB02 *sql.DB + mock02 sqlmock.Sqlmock + mockDB03 *sql.DB + mock03 sqlmock.Sqlmock + mockDB04 *sql.DB + mock04 sqlmock.Sqlmock +} + +func (ms *MergerSuite) SetupTest() { + t := ms.T() + ms.initMock(t) +} + +func (ms *MergerSuite) TearDownTest() { + _ = ms.mockDB01.Close() + _ = ms.mockDB02.Close() + _ = ms.mockDB03.Close() + _ = ms.mockDB04.Close() +} + +func (ms *MergerSuite) initMock(t *testing.T) { + var err error + ms.mockDB01, ms.mock01, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } + ms.mockDB02, ms.mock02, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } + ms.mockDB03, ms.mock03, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } + ms.mockDB04, ms.mock04, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } +} + +func (ms *MergerSuite) TestMerger_Merge() { + testcases := []struct { + name string + getMerger func() (merger.Merger, error) + GetRowsList func() []*sql.Rows + wantErr error + ctx func() (context.Context, context.CancelFunc) + limit int + offset int + }{ + { + name: "limitMerger里的Merger的Merge出错", + getMerger: func() (merger.Merger, error) { + return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) + }, + GetRowsList: func() []*sql.Rows { + return []*sql.Rows{} + }, + wantErr: errs.ErrMergerEmptyRows, + ctx: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + limit: 0, + offset: 0, + }, + { + name: "Next offset个值时遇到报错", + getMerger: func() (merger.Merger, error) { + return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) + }, + GetRowsList: func() []*sql.Rows { + cols := []string{"id", "name", "address"} + query := "SELECT * FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "bruce", "cn").RowError(1, offsetMockErr)) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "a", "cn").AddRow(7, "b", "cn")) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }, + wantErr: offsetMockErr, + ctx: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + limit: 10, + offset: 5, + }, + { + name: "offset的值超过返回的数据行数", + getMerger: func() (merger.Merger, error) { + return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) + }, + GetRowsList: func() []*sql.Rows { + cols := []string{"id", "name", "address"} + query := "SELECT * FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "bruce", "cn")) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "a", "cn").AddRow(7, "b", "cn")) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }, + ctx: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + limit: 10, + offset: 10, + }, + { + name: "超时", + getMerger: func() (merger.Merger, error) { + return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) + }, + GetRowsList: func() []*sql.Rows { + cols := []string{"id", "name", "address"} + query := "SELECT * FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(2, "bruce", "cn")) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(5, "a", "cn").AddRow(7, "b", "cn")) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }, + ctx: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 0) + }, + wantErr: context.DeadlineExceeded, + limit: 5, + offset: 0, + }, + } + for _, tc := range testcases { + ms.T().Run(tc.name, func(t *testing.T) { + merger, err := tc.getMerger() + limitMerger := NewMerger(merger, tc.offset, tc.limit) + require.NoError(t, err) + ctx, cancel := tc.ctx() + rows, err := limitMerger.Merge(ctx, tc.GetRowsList()) + cancel() + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + require.NotNil(t, rows) + + }) + } +} + +func (ms *MergerSuite) TestMerger_NextAndScan() { + testcases := []struct { + name string + getMerger func() (merger.Merger, error) + GetRowsList func() []*sql.Rows + wantVal []TestModel + limit int + offset int + }{ + { + name: "limit的行数超过了返回的总行数,", + getMerger: func() (merger.Merger, error) { + return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) + }, + GetRowsList: func() []*sql.Rows { + cols := []string{"id", "name", "address"} + query := "SELECT * FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }, + wantVal: []TestModel{ + { + Id: 2, + Name: "a", + Address: "cn", + }, + { + Id: 3, + Name: "alex", + Address: "cn", + }, + { + Id: 4, + Name: "x", + Address: "cn", + }, + { + Id: 5, + Name: "bruce", + Address: "cn", + }, + { + Id: 7, + Name: "b", + Address: "cn", + }, + }, + limit: 100, + offset: 1, + }, + { + name: "limit 行数小于返回的总行数", + getMerger: func() (merger.Merger, error) { + return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) + }, + GetRowsList: func() []*sql.Rows { + cols := []string{"id", "name", "address"} + query := "SELECT * FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }, + wantVal: []TestModel{ + { + Id: 2, + Name: "a", + Address: "cn", + }, + { + Id: 3, + Name: "alex", + Address: "cn", + }, + }, + limit: 2, + offset: 1, + }, + { + name: "offset超过sqlRows列表返回的总行数", + getMerger: func() (merger.Merger, error) { + return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) + }, + GetRowsList: func() []*sql.Rows { + cols := []string{"id", "name", "address"} + query := "SELECT * FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }, + wantVal: []TestModel{}, + limit: 2, + offset: 100, + }, + } + for _, tc := range testcases { + ms.T().Run(tc.name, func(t *testing.T) { + merger, err := tc.getMerger() + require.NoError(t, err) + limitMerger := NewMerger(merger, tc.offset, tc.limit) + rows, err := limitMerger.Merge(context.Background(), tc.GetRowsList()) + require.NoError(t, err) + res := make([]TestModel, 0, len(tc.wantVal)) + for rows.Next() { + var model TestModel + err = rows.Scan(&model.Id, &model.Name, &model.Address) + require.NoError(t, err) + res = append(res, model) + } + require.True(t, rows.(*Rows).closed) + require.NoError(t, rows.Err()) + assert.Equal(t, tc.wantVal, res) + }) + } +} + +func (ms *MergerSuite) TestRows_NextAndErr() { + testcases := []struct { + name string + getMerger func() (merger.Merger, error) + GetRowsList func() []*sql.Rows + wantErr error + limit int + offset int + }{ + { + name: "有sql.Rows返回错误", + getMerger: func() (merger.Merger, error) { + return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) + }, + GetRowsList: func() []*sql.Rows { + cols := []string{"id", "name", "address"} + query := "SELECT * FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn").RowError(1, limitMockErr)) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }, + limit: 10, + offset: 1, + wantErr: limitMockErr, + }, + } + for _, tc := range testcases { + ms.T().Run(tc.name, func(t *testing.T) { + merger, err := tc.getMerger() + require.NoError(t, err) + limitMerger := NewMerger(merger, tc.offset, tc.limit) + rows, err := limitMerger.Merge(context.Background(), tc.GetRowsList()) + require.NoError(t, err) + for rows.Next() { + } + require.True(t, rows.(*Rows).closed) + assert.Equal(t, tc.wantErr, rows.Err()) + }) + } +} + +func (ms *MergerSuite) TestRows_ScanAndErr() { + ms.T().Run("未调用Next,直接Scan,返回错", func(t *testing.T) { + cols := []string{"id"} + query := "SELECT * FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1).AddRow(5)) + r, err := ms.mockDB01.QueryContext(context.Background(), query) + require.NoError(t, err) + rowsList := []*sql.Rows{r} + merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) + require.NoError(t, err) + limitMerger := NewMerger(merger, 0, 1) + rows, err := limitMerger.Merge(context.Background(), rowsList) + require.NoError(t, err) + id := 0 + err = rows.Scan(&id) + require.Error(t, err) + }) + ms.T().Run("迭代过程中发现错误,调用Scan,返回迭代中发现的错误", func(t *testing.T) { + cols := []string{"id"} + query := "SELECT * FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1).AddRow(2).RowError(1, limitMockErr)) + r, err := ms.mockDB01.QueryContext(context.Background(), query) + require.NoError(t, err) + rowsList := []*sql.Rows{r} + merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) + require.NoError(t, err) + limitMerger := NewMerger(merger, 0, 1) + rows, err := limitMerger.Merge(context.Background(), rowsList) + require.NoError(t, err) + for rows.Next() { + } + id := 0 + err = rows.Scan(&id) + assert.Equal(t, limitMockErr, err) + }) +} + +func (ms *MergerSuite) TestRows_Close() { + cols := []string{"id"} + query := "SELECT * FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2").AddRow("5").CloseError(newCloseMockErr("db02"))) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").CloseError(newCloseMockErr("db03"))) + merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) + require.NoError(ms.T(), err) + limitMerger := NewMerger(merger, 1, 6) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + rows, err := limitMerger.Merge(context.Background(), rowsList) + require.NoError(ms.T(), err) + // 判断当前是可以正常读取的 + require.True(ms.T(), rows.Next()) + var id int + err = rows.Scan(&id) + require.NoError(ms.T(), err) + err = rows.Close() + ms.T().Run("close返回error", func(t *testing.T) { + assert.Equal(ms.T(), multierr.Combine(newCloseMockErr("db02"), newCloseMockErr("db03")), err) + }) + ms.T().Run("close之后Next返回false", func(t *testing.T) { + for i := 0; i < len(rowsList); i++ { + require.False(ms.T(), rowsList[i].Next()) + } + require.False(ms.T(), rows.Next()) + }) + ms.T().Run("close之后Scan返回迭代过程中的错误", func(t *testing.T) { + var id int + err := rows.Scan(&id) + assert.Equal(t, errs.ErrMergerRowsClosed, err) + }) + ms.T().Run("close之后调用Columns方法返回错误", func(t *testing.T) { + _, err := rows.Columns() + require.Error(t, err) + }) + ms.T().Run("close多次是等效的", func(t *testing.T) { + for i := 0; i < 4; i++ { + err = rows.Close() + require.NoError(t, err) + } + }) +} + +func (ms *MergerSuite) TestRows_Columns() { + cols := []string{"id"} + query := "SELECT * FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1")) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) + merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) + limitMerger := NewMerger(merger, 0, 10) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + rows, err := limitMerger.Merge(context.Background(), rowsList) + columns, err := rows.Columns() + require.NoError(ms.T(), err) + assert.Equal(ms.T(), cols, columns) +} + +func TestMerger(t *testing.T) { + suite.Run(t, &MergerSuite{}) +} + +type TestModel struct { + Id int + Name string + Address string +} From 43d258d616c38d08e4d4741396b89f20fcd878c6 Mon Sep 17 00:00:00 2001 From: zwl <1633720889@qq.com> Date: Sat, 25 Mar 2023 09:57:26 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .CHANGELOG.md | 1 + internal/merger/limitmerger/merger.go | 18 ++++++++++++++++-- internal/merger/limitmerger/merger_test.go | 16 ++++++++++++++++ 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 01ff2f2..35db516 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -16,6 +16,7 @@ - [eorm: 分库分表: Merger抽象与批量查询实现](https://github.com/ecodeclub/eorm/pull/160) - [eorm: 增强的 ShardingAlgorithm 设计与实现](https://github.com/ecodeclub/eorm/pull/161) - [eorm: 分库分表: Merger排序实现](https://github.com/ecodeclub/eorm/pull/166) +- [eorm: 分库分表: Merger分页实现](https://github.com/ecodeclub/eorm/pull/175) - [eorm: BasicTypeValue重命名](https://github.com/ecodeclub/eorm/pull/177) ## v0.0.1: diff --git a/internal/merger/limitmerger/merger.go b/internal/merger/limitmerger/merger.go index a146c1c..8517205 100644 --- a/internal/merger/limitmerger/merger.go +++ b/internal/merger/limitmerger/merger.go @@ -1,3 +1,17 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package limitmerger import ( @@ -76,7 +90,7 @@ func (r *Rows) Next() bool { _ = r.Close() return false } - canNext, err := r.next() + canNext, err := r.nextVal() if err != nil { r.lastErr = err r.mu.Unlock() @@ -92,7 +106,7 @@ func (r *Rows) Next() bool { r.mu.Unlock() return canNext } -func (r *Rows) next() (bool, error) { +func (r *Rows) nextVal() (bool, error) { if r.rows.Next() { return true, nil } diff --git a/internal/merger/limitmerger/merger_test.go b/internal/merger/limitmerger/merger_test.go index 1ca8cc6..365bf76 100644 --- a/internal/merger/limitmerger/merger_test.go +++ b/internal/merger/limitmerger/merger_test.go @@ -1,3 +1,17 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package limitmerger import ( @@ -477,6 +491,7 @@ func (ms *MergerSuite) TestRows_Columns() { ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) + require.NoError(ms.T(), err) limitMerger := NewMerger(merger, 0, 10) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]*sql.Rows, 0, len(dbs)) @@ -486,6 +501,7 @@ func (ms *MergerSuite) TestRows_Columns() { rowsList = append(rowsList, row) } rows, err := limitMerger.Merge(context.Background(), rowsList) + require.NoError(ms.T(), err) columns, err := rows.Columns() require.NoError(ms.T(), err) assert.Equal(ms.T(), cols, columns) From a762f13e0be5cb043c884c1989ecfe7c7a681a05 Mon Sep 17 00:00:00 2001 From: zwl <1633720889@qq.com> Date: Fri, 24 Mar 2023 12:35:58 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/merger/internal/errs/error.go | 13 +- internal/merger/limitmerger/merger.go | 16 ++- internal/merger/limitmerger/merger_test.go | 141 +++++++++++++++++++-- 3 files changed, 150 insertions(+), 20 deletions(-) diff --git a/internal/merger/internal/errs/error.go b/internal/merger/internal/errs/error.go index 7d6b475..862df7c 100644 --- a/internal/merger/internal/errs/error.go +++ b/internal/merger/internal/errs/error.go @@ -20,12 +20,13 @@ import ( ) var ( - ErrEmptySortColumns = errors.New("merger: 排序列为空") - ErrMergerEmptyRows = errors.New("merger: sql.Rows列表为空") - ErrMergerRowsIsNull = errors.New("merger: sql.Rows列表中有元素为nil") - ErrMergerScanNotNext = errors.New("merger: Scan之前没有调用Next方法") - ErrMergerRowsClosed = errors.New("merger: Rows已经关闭") - ErrMergerRowsDiff = errors.New("merger: sql.Rows列表中的字段不同") + ErrEmptySortColumns = errors.New("merger: 排序列为空") + ErrMergerEmptyRows = errors.New("merger: sql.Rows列表为空") + ErrMergerRowsIsNull = errors.New("merger: sql.Rows列表中有元素为nil") + ErrMergerScanNotNext = errors.New("merger: Scan之前没有调用Next方法") + ErrMergerRowsClosed = errors.New("merger: Rows已经关闭") + ErrMergerRowsDiff = errors.New("merger: sql.Rows列表中的字段不同") + ErrMergerInvalidLimitOrOffset = errors.New("merger: offset或limit小于0") ) func NewRepeatSortColumn(column string) error { diff --git a/internal/merger/limitmerger/merger.go b/internal/merger/limitmerger/merger.go index 8517205..346261c 100644 --- a/internal/merger/limitmerger/merger.go +++ b/internal/merger/limitmerger/merger.go @@ -29,12 +29,16 @@ type Merger struct { offset int } -func NewMerger(m merger.Merger, offset int, limit int) *Merger { +func NewMerger(m merger.Merger, offset int, limit int) (*Merger, error) { + if offset < 0 || limit < 0 { + return nil, errs.ErrMergerInvalidLimitOrOffset + } + return &Merger{ m: m, limit: limit, offset: offset, - } + }, nil } func (m *Merger) Merge(ctx context.Context, results []*sql.Rows) (merger.Rows, error) { @@ -53,6 +57,7 @@ func (m *Merger) Merge(ctx context.Context, results []*sql.Rows) (merger.Rows, e }, nil } +// nextOffset 会把游标挪到 offset 所指定的位置。 func (m *Merger) nextOffset(ctx context.Context, rows merger.Rows) error { offset := m.offset for i := 0; i < offset; i++ { @@ -61,10 +66,7 @@ func (m *Merger) nextOffset(ctx context.Context, rows merger.Rows) error { } // 如果偏移量超过rows结果集返回的行数,不会报错。用户最终查到0行 if !rows.Next() { - if rows.Err() != nil { - return rows.Err() - } - break + return rows.Err() } } return nil @@ -140,5 +142,7 @@ func (r *Rows) Columns() ([]string, error) { } func (r *Rows) Err() error { + r.mu.RLock() + defer r.mu.RUnlock() return r.lastErr } diff --git a/internal/merger/limitmerger/merger_test.go b/internal/merger/limitmerger/merger_test.go index 365bf76..9068cc9 100644 --- a/internal/merger/limitmerger/merger_test.go +++ b/internal/merger/limitmerger/merger_test.go @@ -83,6 +83,44 @@ func (ms *MergerSuite) initMock(t *testing.T) { t.Fatal(err) } } +func (ms *MergerSuite) TestMerger_New() { + testcases := []struct { + name string + limit int + offset int + wantErr error + }{ + { + name: "limit 小于0", + limit: -1, + offset: 10, + wantErr: errs.ErrMergerInvalidLimitOrOffset, + }, + { + name: "offset 小于0", + limit: 0, + offset: -1, + wantErr: errs.ErrMergerInvalidLimitOrOffset, + }, + { + name: "limit 大于等于0,offset大于等于0", + limit: 10, + offset: 10, + }, + } + for _, tc := range testcases { + ms.T().Run(tc.name, func(t *testing.T) { + m, err := sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) + require.NoError(t, err) + limitMerger, err := NewMerger(m, tc.offset, tc.limit) + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + require.NotNil(t, limitMerger) + }) + } +} func (ms *MergerSuite) TestMerger_Merge() { testcases := []struct { @@ -110,7 +148,7 @@ func (ms *MergerSuite) TestMerger_Merge() { offset: 0, }, { - name: "Next offset个值时遇到报错", + name: "初始化游标出错", getMerger: func() (merger.Merger, error) { return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) }, @@ -193,7 +231,9 @@ func (ms *MergerSuite) TestMerger_Merge() { for _, tc := range testcases { ms.T().Run(tc.name, func(t *testing.T) { merger, err := tc.getMerger() - limitMerger := NewMerger(merger, tc.offset, tc.limit) + require.NoError(t, err) + limitMerger, err := NewMerger(merger, tc.offset, tc.limit) + require.NoError(t, err) require.NoError(t, err) ctx, cancel := tc.ctx() rows, err := limitMerger.Merge(ctx, tc.GetRowsList()) @@ -326,12 +366,92 @@ func (ms *MergerSuite) TestMerger_NextAndScan() { limit: 2, offset: 100, }, + { + name: "limit 的值为0", + getMerger: func() (merger.Merger, error) { + return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) + }, + GetRowsList: func() []*sql.Rows { + cols := []string{"id", "name", "address"} + query := "SELECT * FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }, + wantVal: []TestModel{}, + limit: 0, + offset: 0, + }, + { + name: "offset 的值为0", + getMerger: func() (merger.Merger, error) { + return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) + }, + GetRowsList: func() []*sql.Rows { + cols := []string{"id", "name", "address"} + query := "SELECT * FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }, + wantVal: []TestModel{ + { + Id: 1, + Name: "abex", + Address: "cn", + }, + { + Id: 2, + Name: "a", + Address: "cn", + }, + { + Id: 3, + Name: "alex", + Address: "cn", + }, + { + Id: 4, + Name: "x", + Address: "cn", + }, + { + Id: 5, + Name: "bruce", + Address: "cn", + }, + { + Id: 7, + Name: "b", + Address: "cn", + }, + }, + limit: 10, + offset: 0, + }, } for _, tc := range testcases { ms.T().Run(tc.name, func(t *testing.T) { merger, err := tc.getMerger() require.NoError(t, err) - limitMerger := NewMerger(merger, tc.offset, tc.limit) + limitMerger, err := NewMerger(merger, tc.offset, tc.limit) + require.NoError(t, err) rows, err := limitMerger.Merge(context.Background(), tc.GetRowsList()) require.NoError(t, err) res := make([]TestModel, 0, len(tc.wantVal)) @@ -386,7 +506,8 @@ func (ms *MergerSuite) TestRows_NextAndErr() { ms.T().Run(tc.name, func(t *testing.T) { merger, err := tc.getMerger() require.NoError(t, err) - limitMerger := NewMerger(merger, tc.offset, tc.limit) + limitMerger, err := NewMerger(merger, tc.offset, tc.limit) + require.NoError(t, err) rows, err := limitMerger.Merge(context.Background(), tc.GetRowsList()) require.NoError(t, err) for rows.Next() { @@ -407,7 +528,8 @@ func (ms *MergerSuite) TestRows_ScanAndErr() { rowsList := []*sql.Rows{r} merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) require.NoError(t, err) - limitMerger := NewMerger(merger, 0, 1) + limitMerger, err := NewMerger(merger, 0, 1) + require.NoError(t, err) rows, err := limitMerger.Merge(context.Background(), rowsList) require.NoError(t, err) id := 0 @@ -423,7 +545,8 @@ func (ms *MergerSuite) TestRows_ScanAndErr() { rowsList := []*sql.Rows{r} merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) require.NoError(t, err) - limitMerger := NewMerger(merger, 0, 1) + limitMerger, err := NewMerger(merger, 0, 1) + require.NoError(t, err) rows, err := limitMerger.Merge(context.Background(), rowsList) require.NoError(t, err) for rows.Next() { @@ -442,7 +565,8 @@ func (ms *MergerSuite) TestRows_Close() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").CloseError(newCloseMockErr("db03"))) merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) require.NoError(ms.T(), err) - limitMerger := NewMerger(merger, 1, 6) + limitMerger, err := NewMerger(merger, 1, 6) + require.NoError(ms.T(), err) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]*sql.Rows, 0, len(dbs)) for _, db := range dbs { @@ -492,7 +616,8 @@ func (ms *MergerSuite) TestRows_Columns() { ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4")) merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) require.NoError(ms.T(), err) - limitMerger := NewMerger(merger, 0, 10) + limitMerger, err := NewMerger(merger, 0, 10) + require.NoError(ms.T(), err) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]*sql.Rows, 0, len(dbs)) for _, db := range dbs { From c99fa8c369137783f26c20cd6e972f9cd6303389 Mon Sep 17 00:00:00 2001 From: zwl <1633720889@qq.com> Date: Sun, 26 Mar 2023 20:35:06 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E4=BF=AE=E6=94=B9limit=E9=9C=80=E8=A6=81?= =?UTF-8?q?=E5=A4=A7=E4=BA=8E0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../{limitmerger => pagedmerger}/merger.go | 4 +-- .../merger_test.go | 34 +++++-------------- 2 files changed, 10 insertions(+), 28 deletions(-) rename internal/merger/{limitmerger => pagedmerger}/merger.go (98%) rename internal/merger/{limitmerger => pagedmerger}/merger_test.go (95%) diff --git a/internal/merger/limitmerger/merger.go b/internal/merger/pagedmerger/merger.go similarity index 98% rename from internal/merger/limitmerger/merger.go rename to internal/merger/pagedmerger/merger.go index 346261c..390bd66 100644 --- a/internal/merger/limitmerger/merger.go +++ b/internal/merger/pagedmerger/merger.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package limitmerger +package pagedmerger import ( "context" @@ -30,7 +30,7 @@ type Merger struct { } func NewMerger(m merger.Merger, offset int, limit int) (*Merger, error) { - if offset < 0 || limit < 0 { + if offset < 0 || limit <= 0 { return nil, errs.ErrMergerInvalidLimitOrOffset } diff --git a/internal/merger/limitmerger/merger_test.go b/internal/merger/pagedmerger/merger_test.go similarity index 95% rename from internal/merger/limitmerger/merger_test.go rename to internal/merger/pagedmerger/merger_test.go index 9068cc9..6055605 100644 --- a/internal/merger/limitmerger/merger_test.go +++ b/internal/merger/pagedmerger/merger_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package limitmerger +package pagedmerger import ( "context" @@ -96,6 +96,12 @@ func (ms *MergerSuite) TestMerger_New() { offset: 10, wantErr: errs.ErrMergerInvalidLimitOrOffset, }, + { + name: "limit 等于0", + limit: 0, + offset: 10, + wantErr: errs.ErrMergerInvalidLimitOrOffset, + }, { name: "offset 小于0", limit: 0, @@ -144,7 +150,7 @@ func (ms *MergerSuite) TestMerger_Merge() { ctx: func() (context.Context, context.CancelFunc) { return context.WithCancel(context.Background()) }, - limit: 0, + limit: 1, offset: 0, }, { @@ -366,30 +372,6 @@ func (ms *MergerSuite) TestMerger_NextAndScan() { limit: 2, offset: 100, }, - { - name: "limit 的值为0", - getMerger: func() (merger.Merger, error) { - return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC)) - }, - GetRowsList: func() []*sql.Rows { - cols := []string{"id", "name", "address"} - query := "SELECT * FROM `t1`" - ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, "abex", "cn").AddRow(5, "bruce", "cn")) - ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, "alex", "cn").AddRow(4, "x", "cn")) - ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, "a", "cn").AddRow(7, "b", "cn")) - dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} - rowsList := make([]*sql.Rows, 0, len(dbs)) - for _, db := range dbs { - row, err := db.QueryContext(context.Background(), query) - require.NoError(ms.T(), err) - rowsList = append(rowsList, row) - } - return rowsList - }, - wantVal: []TestModel{}, - limit: 0, - offset: 0, - }, { name: "offset 的值为0", getMerger: func() (merger.Merger, error) {