diff --git a/.CHANGELOG.md b/.CHANGELOG.md index a6aad4d..9116949 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -25,6 +25,7 @@ - [eorm: 修复单条查询时连接泄露问题](https://github.com/ecodeclub/eorm/pull/188) - [eorm: 分库分表: 结果集处理--聚合函数(含GroupBy子句)](https://github.com/ecodeclub/eorm/pull/193) - [eorm: 分库分表: NOT 支持](https://github.com/ecodeclub/eorm/pull/191) +- [eorm: 分库分表: Merger NullAble类型数据的支持(sortMerger)](https://github.com/ecodeclub/eorm/pull/195) ## v0.0.1: - [Init Project](https://github.com/ecodeclub/eorm/pull/1) diff --git a/internal/merger/internal/errs/error.go b/internal/merger/internal/errs/error.go index 9dc7e5f..9982b03 100644 --- a/internal/merger/internal/errs/error.go +++ b/internal/merger/internal/errs/error.go @@ -30,6 +30,7 @@ var ( ErrMergerAggregateHasEmptyRows = errors.New("merger: 聚合函数计算时rowsList有一个或多个为空") ErrMergerInvalidAggregateColumnIndex = errors.New("merger: ColumnInfo的index不合法") ErrMergerAggregateFuncNotFound = errors.New("merger: 聚合函数方法未找到") + ErrMergerNullable = errors.New("merger: 接收数据的类型需要为sql.Nullable") ) func NewRepeatSortColumn(column string) error { diff --git a/internal/merger/sortmerger/heap.go b/internal/merger/sortmerger/heap.go index c073a6a..fc3a095 100644 --- a/internal/merger/sortmerger/heap.go +++ b/internal/merger/sortmerger/heap.go @@ -14,7 +14,11 @@ package sortmerger -import "reflect" +import ( + "database/sql/driver" + "reflect" + "time" +) var compareFuncMapping = map[reflect.Kind]func(any, any, Order) int{ reflect.Int: compare[int], @@ -45,8 +49,14 @@ func (h *Heap) Less(i, j int) bool { for k := 0; k < h.sortColumns.Len(); k++ { valueI := h.h[i].sortCols[k] valueJ := h.h[j].sortCols[k] - kind := reflect.TypeOf(valueI).Kind() - cp := compareFuncMapping[kind] + _, ok := valueJ.(driver.Valuer) + var cp func(any, any, Order) int + if ok { + cp = compareNullable + } else { + kind := reflect.TypeOf(valueI).Kind() + cp = compareFuncMapping[kind] + } res := cp(valueI, valueJ, h.sortColumns.Get(k).order) if res == 0 { continue @@ -84,11 +94,35 @@ type node struct { func compare[T Ordered](ii any, jj any, order Order) int { i, j := ii.(T), jj.(T) - if i < j && order || i > j && !order { + if i < j && order == ASC || i > j && order == DESC { return -1 - } else if i > j && order || i < j && !order { + } else if i > j && order == ASC || i < j && order == DESC { return 1 } else { return 0 } } + +func compareNullable(ii, jj any, order Order) int { + i := ii.(driver.Valuer) + j := jj.(driver.Valuer) + iVal, _ := i.Value() + jVal, _ := j.Value() + // 如果i,j都为空返回0 + // 如果val返回为空永远是最小值 + if iVal == nil && jVal == nil { + return 0 + } else if iVal == nil && order == ASC || jVal == nil && order == DESC { + return -1 + } else if iVal == nil && order == DESC || jVal == nil && order == ASC { + return 1 + } + + vali, ok := iVal.(time.Time) + if ok { + valj := jVal.(time.Time) + return compare[int64](vali.UnixMilli(), valj.UnixMilli(), order) + } + kind := reflect.TypeOf(iVal).Kind() + return compareFuncMapping[kind](iVal, jVal, order) +} diff --git a/internal/merger/sortmerger/heap_test.go b/internal/merger/sortmerger/heap_test.go index 0c49313..2ae8e71 100644 --- a/internal/merger/sortmerger/heap_test.go +++ b/internal/merger/sortmerger/heap_test.go @@ -16,8 +16,10 @@ package sortmerger import ( "container/heap" + "database/sql" "reflect" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -393,6 +395,471 @@ func TestHeap(t *testing.T) { } +func TestHeap_Nullable(t *testing.T) { + testcases := []struct { + name string + nodes func() []*node + wantNodes func() []*node + sortCols func() sortColumns + }{ + { + name: "sql.NullInt64 asc", + nodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullInt64{Int64: 5, Valid: true}}, + {sql.NullInt64{Int64: 1, Valid: true}}, + {sql.NullInt64{Int64: 3, Valid: true}}, + {sql.NullInt64{Int64: 2, Valid: true}}, + {sql.NullInt64{Int64: 10, Valid: false}}, + }) + }, + wantNodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullInt64{Int64: 10, Valid: false}}, + {sql.NullInt64{Int64: 1, Valid: true}}, + {sql.NullInt64{Int64: 2, Valid: true}}, + {sql.NullInt64{Int64: 3, Valid: true}}, + {sql.NullInt64{Int64: 5, Valid: true}}, + }) + }, + sortCols: func() sortColumns { + sortCols, err := newSortColumns(NewSortColumn("id", ASC)) + require.NoError(t, err) + return sortCols + }, + }, + { + name: "sql.NullInt64 desc", + nodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullInt64{Int64: 5, Valid: true}}, + {sql.NullInt64{Int64: 1, Valid: true}}, + {sql.NullInt64{Int64: 3, Valid: true}}, + {sql.NullInt64{Int64: 2, Valid: true}}, + {sql.NullInt64{Int64: 10, Valid: false}}, + }) + }, + wantNodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullInt64{Int64: 5, Valid: true}}, + {sql.NullInt64{Int64: 3, Valid: true}}, + {sql.NullInt64{Int64: 2, Valid: true}}, + {sql.NullInt64{Int64: 1, Valid: true}}, + {sql.NullInt64{Int64: 10, Valid: false}}, + }) + }, + sortCols: func() sortColumns { + sortCols, err := newSortColumns(NewSortColumn("id", DESC)) + require.NoError(t, err) + return sortCols + }, + }, + { + name: "sql.NullString asc", + nodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullString{String: "ab", Valid: true}}, + {sql.NullString{String: "cd", Valid: true}}, + {sql.NullString{String: "bc", Valid: true}}, + {sql.NullString{String: "ba", Valid: true}}, + {sql.NullString{String: "z", Valid: false}}, + }) + }, + wantNodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullString{String: "z", Valid: false}}, + {sql.NullString{String: "ab", Valid: true}}, + {sql.NullString{String: "ba", Valid: true}}, + {sql.NullString{String: "bc", Valid: true}}, + {sql.NullString{String: "cd", Valid: true}}, + }) + }, + sortCols: func() sortColumns { + sortCols, err := newSortColumns(NewSortColumn("name", ASC)) + require.NoError(t, err) + return sortCols + }, + }, + { + name: "sql.NullString desc", + nodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullString{String: "ab", Valid: true}}, + {sql.NullString{String: "cd", Valid: true}}, + {sql.NullString{String: "bc", Valid: true}}, + {sql.NullString{String: "z", Valid: false}}, + {sql.NullString{String: "ba", Valid: true}}, + }) + }, + wantNodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullString{String: "cd", Valid: true}}, + {sql.NullString{String: "bc", Valid: true}}, + {sql.NullString{String: "ba", Valid: true}}, + {sql.NullString{String: "ab", Valid: true}}, + {sql.NullString{String: "z", Valid: false}}, + }) + }, + sortCols: func() sortColumns { + sortCols, err := newSortColumns(NewSortColumn("name", DESC)) + require.NoError(t, err) + return sortCols + }, + }, + { + name: "sql.NullInt16 asc", + nodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullInt16{Int16: 5, Valid: true}}, + {sql.NullInt16{Int16: 1, Valid: true}}, + {sql.NullInt16{Int16: 3, Valid: true}}, + {sql.NullInt16{Int16: 2, Valid: true}}, + {sql.NullInt16{Int16: 10, Valid: false}}, + }) + }, + wantNodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullInt16{Int16: 10, Valid: false}}, + {sql.NullInt16{Int16: 1, Valid: true}}, + {sql.NullInt16{Int16: 2, Valid: true}}, + {sql.NullInt16{Int16: 3, Valid: true}}, + {sql.NullInt16{Int16: 5, Valid: true}}, + }) + }, + sortCols: func() sortColumns { + sortCols, err := newSortColumns(NewSortColumn("id", ASC)) + require.NoError(t, err) + return sortCols + }, + }, + { + name: "sql.NullInt16 desc", + nodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullInt16{Int16: 5, Valid: true}}, + {sql.NullInt16{Int16: 1, Valid: true}}, + {sql.NullInt16{Int16: 3, Valid: true}}, + {sql.NullInt16{Int16: 2, Valid: true}}, + {sql.NullInt16{Int16: 10, Valid: false}}, + }) + }, + wantNodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullInt16{Int16: 5, Valid: true}}, + {sql.NullInt16{Int16: 3, Valid: true}}, + {sql.NullInt16{Int16: 2, Valid: true}}, + {sql.NullInt16{Int16: 1, Valid: true}}, + {sql.NullInt16{Int16: 10, Valid: false}}, + }) + }, + sortCols: func() sortColumns { + sortCols, err := newSortColumns(NewSortColumn("id", DESC)) + require.NoError(t, err) + return sortCols + }, + }, + { + name: "sql.NullInt32 asc", + nodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullInt32{Int32: 5, Valid: true}}, + {sql.NullInt32{Int32: 1, Valid: true}}, + {sql.NullInt32{Int32: 3, Valid: true}}, + {sql.NullInt32{Int32: 2, Valid: true}}, + {sql.NullInt32{Int32: 10, Valid: false}}, + }) + }, + wantNodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullInt32{Int32: 10, Valid: false}}, + {sql.NullInt32{Int32: 1, Valid: true}}, + {sql.NullInt32{Int32: 2, Valid: true}}, + {sql.NullInt32{Int32: 3, Valid: true}}, + {sql.NullInt32{Int32: 5, Valid: true}}, + }) + }, + sortCols: func() sortColumns { + sortCols, err := newSortColumns(NewSortColumn("id", ASC)) + require.NoError(t, err) + return sortCols + }, + }, + { + name: "sql.NullInt32 desc", + nodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullInt32{Int32: 5, Valid: true}}, + {sql.NullInt32{Int32: 1, Valid: true}}, + {sql.NullInt32{Int32: 3, Valid: true}}, + {sql.NullInt32{Int32: 2, Valid: true}}, + {sql.NullInt32{Int32: 10, Valid: false}}, + }) + }, + wantNodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullInt32{Int32: 5, Valid: true}}, + {sql.NullInt32{Int32: 3, Valid: true}}, + {sql.NullInt32{Int32: 2, Valid: true}}, + {sql.NullInt32{Int32: 1, Valid: true}}, + {sql.NullInt32{Int32: 10, Valid: false}}, + }) + }, + sortCols: func() sortColumns { + sortCols, err := newSortColumns(NewSortColumn("id", DESC)) + require.NoError(t, err) + return sortCols + }, + }, + { + name: "sql.NullFloat64 asc", + nodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullFloat64{Float64: 5.0, Valid: true}}, + {sql.NullFloat64{Float64: 1.0, Valid: true}}, + {sql.NullFloat64{Float64: 3.0, Valid: true}}, + {sql.NullFloat64{Float64: 2.0, Valid: true}}, + {sql.NullFloat64{Float64: 10.0, Valid: false}}, + }) + }, + wantNodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullFloat64{Float64: 10.0, Valid: false}}, + {sql.NullFloat64{Float64: 1.0, Valid: true}}, + {sql.NullFloat64{Float64: 2.0, Valid: true}}, + {sql.NullFloat64{Float64: 3.0, Valid: true}}, + {sql.NullFloat64{Float64: 5.0, Valid: true}}, + }) + }, + sortCols: func() sortColumns { + sortCols, err := newSortColumns(NewSortColumn("id", ASC)) + require.NoError(t, err) + return sortCols + }, + }, + { + name: "sql.NullFloat64 desc", + nodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullFloat64{Float64: 5.0, Valid: true}}, + {sql.NullFloat64{Float64: 1.0, Valid: true}}, + {sql.NullFloat64{Float64: 3.0, Valid: true}}, + {sql.NullFloat64{Float64: 2.0, Valid: true}}, + {sql.NullFloat64{Float64: 10.0, Valid: false}}, + }) + }, + wantNodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullFloat64{Float64: 5.0, Valid: true}}, + {sql.NullFloat64{Float64: 3.0, Valid: true}}, + {sql.NullFloat64{Float64: 2.0, Valid: true}}, + {sql.NullFloat64{Float64: 1.0, Valid: true}}, + {sql.NullFloat64{Float64: 10.0, Valid: false}}, + }) + }, + sortCols: func() sortColumns { + sortCols, err := newSortColumns(NewSortColumn("id", DESC)) + require.NoError(t, err) + return sortCols + }, + }, + + { + name: "sql.NullTime asc", + nodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 12:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: true}}, + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-02 12:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: true}}, + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-09 12:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: true}}, + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 11:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: true}}, + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-20 12:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: false}}, + }) + }, + wantNodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-20 12:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: false}}, + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 11:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: true}}, + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 12:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: true}}, + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-02 12:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: true}}, + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-09 12:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: true}}, + }) + }, + sortCols: func() sortColumns { + sortCols, err := newSortColumns(NewSortColumn("time", ASC)) + require.NoError(t, err) + return sortCols + }, + }, + { + name: "sql.NullTime desc", + nodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 12:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: true}}, + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-02 12:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: true}}, + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-09 12:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: true}}, + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 11:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: true}}, + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-20 12:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: false}}, + }) + }, + wantNodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-09 12:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: true}}, + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-02 12:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: true}}, + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 12:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: true}}, + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 11:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: true}}, + {sql.NullTime{Time: func() time.Time { + time, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-20 12:00:00", time.Local) + require.NoError(t, err) + return time + }(), Valid: false}}, + }) + }, + sortCols: func() sortColumns { + sortCols, err := newSortColumns(NewSortColumn("time", DESC)) + require.NoError(t, err) + return sortCols + }, + }, + { + name: "sql.NullByte asc", + nodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullByte{Byte: 'a', Valid: true}}, + {sql.NullByte{Byte: 'c', Valid: true}}, + {sql.NullByte{Byte: 'b', Valid: true}}, + {sql.NullByte{Byte: 'k', Valid: true}}, + {sql.NullByte{Byte: 'z', Valid: false}}, + }) + }, + wantNodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullByte{Byte: 'z', Valid: false}}, + {sql.NullByte{Byte: 'a', Valid: true}}, + {sql.NullByte{Byte: 'b', Valid: true}}, + {sql.NullByte{Byte: 'c', Valid: true}}, + {sql.NullByte{Byte: 'k', Valid: true}}, + }) + }, + sortCols: func() sortColumns { + sortCols, err := newSortColumns(NewSortColumn("byte", ASC)) + require.NoError(t, err) + return sortCols + }, + }, + { + name: "sql.NullByte desc", + nodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullByte{Byte: 'a', Valid: true}}, + {sql.NullByte{Byte: 'c', Valid: true}}, + {sql.NullByte{Byte: 'b', Valid: true}}, + {sql.NullByte{Byte: 'k', Valid: true}}, + {sql.NullByte{Byte: 'z', Valid: false}}, + }) + }, + wantNodes: func() []*node { + return newTestNodes([][]any{ + {sql.NullByte{Byte: 'k', Valid: true}}, + {sql.NullByte{Byte: 'c', Valid: true}}, + {sql.NullByte{Byte: 'b', Valid: true}}, + {sql.NullByte{Byte: 'a', Valid: true}}, + {sql.NullByte{Byte: 'z', Valid: false}}, + }) + }, + sortCols: func() sortColumns { + sortCols, err := newSortColumns(NewSortColumn("byte", DESC)) + require.NoError(t, err) + return sortCols + }, + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + h := newTestHp(tc.nodes(), tc.sortCols()) + res := make([]*node, 0, h.Len()) + for h.Len() > 0 { + res = append(res, heap.Pop(h).(*node)) + } + assert.Equal(t, tc.wantNodes(), res) + }) + } +} + func (ms *MergerSuite) TestCompare() { testcases := []struct { name string diff --git a/internal/merger/sortmerger/merger.go b/internal/merger/sortmerger/merger.go index 96b2295..536ec41 100644 --- a/internal/merger/sortmerger/merger.go +++ b/internal/merger/sortmerger/merger.go @@ -191,18 +191,16 @@ func newNode(row *sql.Rows, sortCols sortColumns, index int) (*node, error) { sortColumns := make([]any, sortCols.Len()) for _, colInfo := range colsInfo { colName := colInfo.Name() + colType := colInfo.ScanType() + for colType.Kind() == reflect.Ptr { + colType = colType.Elem() + } + column := reflect.New(colType).Interface() if sortCols.Has(colName) { sortIndex := sortCols.Find(colName) - colType := colInfo.ScanType() - for colType.Kind() == reflect.Ptr { - colType = colType.Elem() - } - sortColumn := reflect.New(colType).Interface() - sortColumns[sortIndex] = sortColumn - columns = append(columns, sortColumn) - } else { - columns = append(columns, &[]byte{}) + sortColumns[sortIndex] = column } + columns = append(columns, column) } err = row.Scan(columns...) if err != nil { @@ -281,9 +279,9 @@ func (r *Rows) Scan(dest ...any) error { if r.cur == nil { return errs.ErrMergerScanNotNext } - + var err error for i := 0; i < len(dest); i++ { - err := utils.ConvertAssign(dest[i], r.cur.columns[i]) + err = utils.ConvertAssign(dest[i], r.cur.columns[i]) if err != nil { return err } diff --git a/internal/merger/sortmerger/merger_test.go b/internal/merger/sortmerger/merger_test.go index 0dffb48..399f103 100644 --- a/internal/merger/sortmerger/merger_test.go +++ b/internal/merger/sortmerger/merger_test.go @@ -22,12 +22,12 @@ import ( "testing" "github.com/DATA-DOG/go-sqlmock" - "go.uber.org/multierr" - "github.com/ecodeclub/eorm/internal/merger/internal/errs" + _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "go.uber.org/multierr" ) var ( @@ -1216,4 +1216,188 @@ type TestModel struct { func TestMerger(t *testing.T) { suite.Run(t, &MergerSuite{}) + suite.Run(t, &NullableMergerSuite{}) +} + +type NullableMergerSuite struct { + suite.Suite + db01 *sql.DB + db02 *sql.DB + db03 *sql.DB +} + +func (ms *NullableMergerSuite) SetupSuite() { + t := ms.T() + query := "CREATE TABLE t1 (\n id int primary key,\n `age` int,\n \t`name` varchar(20)\n );\n" + db01, err := sql.Open("sqlite3", "file:test01.db?cache=shared&mode=memory") + if err != nil { + t.Fatal(err) + } + ms.db01 = db01 + _, err = db01.ExecContext(context.Background(), query) + if err != nil { + t.Fatal(err) + } + db02, err := sql.Open("sqlite3", "file:test02.db?cache=shared&mode=memory") + if err != nil { + t.Fatal(err) + } + ms.db02 = db02 + _, err = db02.ExecContext(context.Background(), query) + if err != nil { + t.Fatal(err) + } + db03, err := sql.Open("sqlite3", "file:test03.db?cache=shared&mode=memory") + if err != nil { + t.Fatal(err) + } + ms.db03 = db03 + _, err = db03.ExecContext(context.Background(), query) + if err != nil { + t.Fatal(err) + } +} + +func (ms *NullableMergerSuite) TearDownSuite() { + _ = ms.db01.Close() + _ = ms.db02.Close() + _ = ms.db03.Close() +} + +func (ms *NullableMergerSuite) TestRows_Nullable() { + testcases := []struct { + name string + rowsList func() []*sql.Rows + sortColumns []SortColumn + wantErr error + afterFunc func() + wantVal []Nullable + }{ + { + name: "多个nullable类型排序 age asc,name desc", + rowsList: func() []*sql.Rows { + db1InsertSql := []string{ + "insert into t1 (id, name) values (1, 'zwl')", + "insert into t1 (id, age, name) values (2, 10, 'zwl')", + "insert into t1 (id, age, name) values (3, 20, 'zwl')", + "insert into t1 (id, age) values (4, 20)", + } + for _, sql := range db1InsertSql { + _, err := ms.db01.ExecContext(context.Background(), sql) + require.NoError(ms.T(), err) + } + db2InsertSql := []string{ + "insert into t1 (id, age, name) values (5, 5, 'zwl')", + "insert into t1 (id, age, name) values (6, 20, 'dm')", + } + for _, sql := range db2InsertSql { + _, err := ms.db02.ExecContext(context.Background(), sql) + require.NoError(ms.T(), err) + } + db3InsertSql := []string{ + "insert into t1 (id, name) values (7, 'xq')", + "insert into t1 (id, age) values (8, 5)", + "insert into t1 (id, age,name) values (9, 10,'xx')", + } + for _, sql := range db3InsertSql { + _, err := ms.db03.ExecContext(context.Background(), sql) + require.NoError(ms.T(), err) + } + dbs := []*sql.DB{ms.db01, ms.db02, ms.db03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + query := "SELECT `id`, `age`,`name` FROM `t1` order by age asc,name desc" + for _, db := range dbs { + rows, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, rows) + } + return rowsList + }, + sortColumns: []SortColumn{ + NewSortColumn("age", ASC), + NewSortColumn("name", DESC), + }, + afterFunc: func() { + dbs := []*sql.DB{ms.db01, ms.db02, ms.db03} + for _, db := range dbs { + _, err := db.Exec("DELETE FROM t1;") + require.NoError(ms.T(), err) + } + }, + wantVal: func() []Nullable { + return []Nullable{ + { + Id: sql.NullInt64{Valid: true, Int64: 1}, + Age: sql.NullInt64{Valid: false, Int64: 0}, + Name: sql.NullString{Valid: true, String: "zwl"}, + }, + { + Id: sql.NullInt64{Valid: true, Int64: 7}, + Age: sql.NullInt64{Valid: false, Int64: 0}, + Name: sql.NullString{Valid: true, String: "xq"}, + }, + { + Id: sql.NullInt64{Valid: true, Int64: 5}, + Age: sql.NullInt64{Valid: true, Int64: 5}, + Name: sql.NullString{Valid: true, String: "zwl"}, + }, + { + Id: sql.NullInt64{Valid: true, Int64: 8}, + Age: sql.NullInt64{Valid: true, Int64: 5}, + Name: sql.NullString{Valid: false, String: ""}, + }, + { + Id: sql.NullInt64{Valid: true, Int64: 2}, + Age: sql.NullInt64{Valid: true, Int64: 10}, + Name: sql.NullString{Valid: true, String: "zwl"}, + }, + { + Id: sql.NullInt64{Valid: true, Int64: 9}, + Age: sql.NullInt64{Valid: true, Int64: 10}, + Name: sql.NullString{Valid: true, String: "xx"}, + }, + { + Id: sql.NullInt64{Valid: true, Int64: 3}, + Age: sql.NullInt64{Valid: true, Int64: 20}, + Name: sql.NullString{Valid: true, String: "zwl"}, + }, + { + Id: sql.NullInt64{Valid: true, Int64: 6}, + Age: sql.NullInt64{Valid: true, Int64: 20}, + Name: sql.NullString{Valid: true, String: "dm"}, + }, + { + Id: sql.NullInt64{Valid: true, Int64: 4}, + Age: sql.NullInt64{Valid: true, Int64: 20}, + Name: sql.NullString{Valid: false, String: ""}, + }, + } + }(), + }, + } + for _, tc := range testcases { + ms.T().Run(tc.name, func(t *testing.T) { + merger, err := NewMerger(tc.sortColumns...) + require.NoError(t, err) + rows, err := merger.Merge(context.Background(), tc.rowsList()) + require.NoError(t, err) + res := make([]Nullable, 0, len(tc.wantVal)) + for rows.Next() { + nullT := Nullable{} + err := rows.Scan(&nullT.Id, &nullT.Age, &nullT.Name) + require.NoError(ms.T(), err) + res = append(res, nullT) + } + require.True(t, rows.(*Rows).closed) + assert.NoError(t, rows.Err()) + assert.Equal(t, tc.wantVal, res) + tc.afterFunc() + }) + } +} + +type Nullable struct { + Id sql.NullInt64 + Age sql.NullInt64 + Name sql.NullString } diff --git a/internal/merger/utils/convert_Assign.go b/internal/merger/utils/convert_assign.go similarity index 65% rename from internal/merger/utils/convert_Assign.go rename to internal/merger/utils/convert_assign.go index e4abfc3..9bb2828 100644 --- a/internal/merger/utils/convert_Assign.go +++ b/internal/merger/utils/convert_assign.go @@ -15,9 +15,21 @@ package utils import ( - _ "database/sql" + "database/sql/driver" _ "unsafe" ) -//go:linkname ConvertAssign database/sql.convertAssign -func ConvertAssign(dest, src any) error +//go:linkname sqlConvertAssign database/sql.convertAssign +func sqlConvertAssign(dest, src any) error + +func ConvertAssign(dest, src any) error { + srcVal, ok := src.(driver.Valuer) + if ok { + var err error + src, err = srcVal.Value() + if err != nil { + return err + } + } + return sqlConvertAssign(dest, src) +} diff --git a/internal/merger/utils/convert_assign_test.go b/internal/merger/utils/convert_assign_test.go new file mode 100644 index 0000000..1059967 --- /dev/null +++ b/internal/merger/utils/convert_assign_test.go @@ -0,0 +1,237 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "database/sql" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConvertNullable(t *testing.T) { + testcases := []struct { + name string + src any + dest any + wantVal any + hasErr bool + }{ + { + name: "sql.NUllbool", + src: sql.NullBool{Valid: true, Bool: true}, + dest: &sql.NullBool{Valid: false, Bool: false}, + wantVal: &sql.NullBool{Valid: true, Bool: true}, + }, + { + name: "sql.NUllbool的valid为false", + src: sql.NullBool{Valid: false, Bool: true}, + dest: &sql.NullBool{Valid: false, Bool: false}, + wantVal: &sql.NullBool{Valid: false, Bool: false}, + }, + { + name: "sql.NUllString", + src: sql.NullString{Valid: true, String: "xx"}, + dest: &sql.NullString{Valid: false, String: ""}, + wantVal: &sql.NullString{Valid: true, String: "xx"}, + }, + { + name: "sql.NUllString的valid为false", + src: sql.NullString{Valid: false, String: "xx"}, + dest: &sql.NullString{Valid: false, String: ""}, + wantVal: &sql.NullString{Valid: false, String: ""}, + }, + { + name: "sql.NUllByte", + src: sql.NullByte{Valid: true, Byte: 'a'}, + dest: &sql.NullByte{Valid: false, Byte: ' '}, + wantVal: &sql.NullByte{Valid: true, Byte: 'a'}, + }, + { + name: "sql.NUllByte的valid的false", + src: sql.NullByte{Valid: false, Byte: 'a'}, + dest: &sql.NullByte{Valid: false, Byte: 0}, + wantVal: &sql.NullByte{Valid: false, Byte: 0}, + }, + { + name: "sql.NUllInt32", + src: sql.NullInt32{Valid: true, Int32: 5}, + dest: &sql.NullInt32{Valid: false, Int32: 0}, + wantVal: &sql.NullInt32{Valid: true, Int32: 5}, + }, + { + name: "sql.NUllInt32的valid的false", + src: sql.NullInt32{Valid: false, Int32: 0}, + dest: &sql.NullInt32{Valid: false, Int32: 0}, + wantVal: &sql.NullInt32{Valid: false, Int32: 0}, + }, + { + name: "sql.NUllInt64", + src: sql.NullInt64{Valid: true, Int64: 5}, + dest: &sql.NullInt64{Valid: false, Int64: 0}, + wantVal: &sql.NullInt64{Valid: true, Int64: 5}, + }, + { + name: "sql.NUllInt64的valid的false", + src: sql.NullInt64{Valid: false, Int64: 0}, + dest: &sql.NullInt64{Valid: false, Int64: 0}, + wantVal: &sql.NullInt64{Valid: false, Int64: 0}, + }, + { + name: "sql.NUllInt16", + src: sql.NullInt16{Valid: true, Int16: 5}, + dest: &sql.NullInt16{Valid: false, Int16: 0}, + wantVal: &sql.NullInt16{Valid: true, Int16: 5}, + }, + { + name: "sql.NUllInt16的valid的false", + src: sql.NullInt16{Valid: false, Int16: 0}, + dest: &sql.NullInt16{Valid: false, Int16: 0}, + wantVal: &sql.NullInt16{Valid: false, Int16: 0}, + }, + { + name: "sql.NUllFloat64", + src: sql.NullFloat64{Valid: true, Float64: 5}, + dest: &sql.NullFloat64{Valid: false, Float64: 0}, + wantVal: &sql.NullFloat64{Valid: true, Float64: 5}, + }, + { + name: "sql.NUllfloat64的valid的false", + src: sql.NullFloat64{Valid: false, Float64: 0}, + dest: &sql.NullFloat64{Valid: false, Float64: 0}, + wantVal: &sql.NullFloat64{Valid: false, Float64: 0}, + }, + { + name: "sql.NUllTime", + src: sql.NullTime{Valid: true, Time: func() time.Time { + val, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 12:00:00", time.Local) + require.NoError(t, err) + return val + }()}, + dest: &sql.NullTime{Valid: false, Time: time.Time{}}, + wantVal: &sql.NullTime{Valid: true, Time: func() time.Time { + val, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 12:00:00", time.Local) + require.NoError(t, err) + return val + }()}, + }, + { + name: "sql.NUllTime的valid的false", + src: sql.NullTime{Valid: false, Time: time.Time{}}, + dest: &sql.NullTime{Valid: false, Time: time.Time{}}, + wantVal: &sql.NullTime{Valid: false, Time: time.Time{}}, + }, + { + name: "使用sql.NullInt32接收sql.NullInt64", + src: sql.NullInt64{Valid: true, Int64: 5}, + dest: &sql.NullInt32{Valid: false, Int32: 0}, + wantVal: &sql.NullInt32{Valid: true, Int32: 5}, + }, + { + name: "使用sql.NullInt16接收sql.NullInt64", + src: sql.NullInt64{Valid: true, Int64: 5}, + dest: &sql.NullInt16{Valid: false, Int16: 0}, + wantVal: &sql.NullInt16{Valid: true, Int16: 5}, + }, + { + name: "使用sql.NullInt32接收sql.NullInt64,Valid为false", + src: sql.NullInt64{Valid: false, Int64: 0}, + dest: &sql.NullInt32{Valid: false, Int32: 0}, + wantVal: &sql.NullInt32{Valid: false, Int32: 0}, + }, + { + name: "使用sql.NullInt16接收sql.NullInt64,Valid为false", + src: sql.NullInt64{Valid: false, Int64: 0}, + dest: &sql.NullInt16{Valid: false, Int16: 0}, + wantVal: &sql.NullInt16{Valid: false, Int16: 0}, + }, + { + name: "使用int32接收sql.NullInt64", + src: sql.NullInt64{Valid: true, Int64: 5}, + dest: func() *int32 { + var val int32 + return &val + }(), + wantVal: func() *int32 { + val := int32(5) + return &val + }(), + }, + { + name: "使用int16接收sql.NullInt64", + src: sql.NullInt64{Valid: true, Int64: 5}, + dest: func() *int16 { + var val int16 + return &val + }(), + wantVal: func() *int16 { + val := int16(5) + return &val + }(), + }, + { + name: "使用float32接收sql.Nullfloat64", + src: sql.NullFloat64{Valid: true, Float64: 5}, + dest: func() *float32 { + var val float32 + return &val + }(), + wantVal: func() *float32 { + val := float32(5) + return &val + }(), + }, + { + name: "使用int32接收sql.NullInt64,Valid为false", + src: sql.NullInt64{Valid: false, Int64: 0}, + dest: func() *int32 { + var val int32 + return &val + }(), + hasErr: true, + }, + { + name: "使用int16接收sql.NullInt64,valid为false", + src: sql.NullInt64{Valid: false, Int64: 0}, + dest: func() *int16 { + var val int16 + return &val + }(), + hasErr: true, + }, + { + name: "使用float32接收sql.Nullfloat64", + src: sql.NullFloat64{Valid: false, Float64: 0}, + dest: func() *float32 { + var val float32 + return &val + }(), + hasErr: true, + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := ConvertAssign(tc.dest, tc.src) + if tc.hasErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tc.dest, tc.wantVal) + }) + } +}