Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rows: 同库事务语句合并执行,提前读取所有数据 #219

Merged
merged 1 commit into from
Oct 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
- [eorm: 分库分表:datasource-简单的分布式事务方案支持](https://github.com/ecodeclub/eorm/pull/204)
- [merger: 使用 sqlx.Scanner 来读取数据](https://github.com/ecodeclub/eorm/pull/216)
- [rows, merger: 使用 sqlx.Rows 作为接口,并重构 merger 包 ](https://github.com/ecodeclub/eorm/pull/217)

- [rows: 同库事务语句合并执行,提前读取所有数据](https://github.com/ecodeclub/eorm/pull/219)
## v0.0.1:
- [Init Project](https://github.com/ecodeclub/eorm/pull/1)
- [Selector Definition](https://github.com/ecodeclub/eorm/pull/2)
Expand Down
34 changes: 14 additions & 20 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type DBOption func(db *DB)

// DB represents a database
type DB struct {
core
baseSession
ds datasource.DataSource
}

Expand All @@ -62,14 +62,6 @@ func UseReflection() DBOption {
}
}

func (db *DB) queryContext(ctx context.Context, q datasource.Query) (*sql.Rows, error) {
return db.ds.Query(ctx, q)
}

func (db *DB) execContext(ctx context.Context, q datasource.Query) (sql.Result, error) {
return db.ds.Exec(ctx, q)
}

// Open 创建一个 ORM 实例
// 注意该实例是一个无状态的对象,你应该尽可能复用它
func Open(driver string, dsn string, opts ...DBOption) (*DB, error) {
Expand All @@ -86,12 +78,15 @@ func OpenDS(driver string, ds datasource.DataSource, opts ...DBOption) (*DB, err
return nil, err
}
orm := &DB{
core: core{
metaRegistry: model.NewMetaRegistry(),
dialect: dl,
// 可以设为默认,因为原本这里也有默认
valCreator: valuer.PrimitiveCreator{
Creator: valuer.NewUnsafeValue,
baseSession: baseSession{
executor: ds,
core: core{
metaRegistry: model.NewMetaRegistry(),
dialect: dl,
// 可以设为默认,因为原本这里也有默认
valCreator: valuer.PrimitiveCreator{
Creator: valuer.NewUnsafeValue,
},
},
},
ds: ds,
Expand All @@ -111,13 +106,12 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
if err != nil {
return nil, err
}
return &Tx{tx: tx, core: db.getCore()}, nil
return &Tx{tx: tx, baseSession: baseSession{
executor: tx,
core: db.core,
}}, nil
}

func (db *DB) Close() error {
return db.ds.Close()
}

func (db *DB) getCore() core {
return db.core
}
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ go 1.20

require (
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/ecodeclub/ekit v0.0.4-0.20230904153403-e76aae064994
github.com/ecodeclub/ekit v0.0.8-0.20231001021557-856d32ae850b
github.com/go-sql-driver/mysql v1.6.0
github.com/gotomicro/ekit v0.0.0-20230224040531-869798da3c4d
github.com/mattn/go-sqlite3 v1.14.15
github.com/stretchr/testify v1.8.1
github.com/valyala/bytebufferpool v1.0.0
Expand Down
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/ecodeclub/ekit v0.0.4-0.20230904153403-e76aae064994 h1:4Rp8WrJhISj8GDtnueoD22ygPuppajnCVZuEfRjg6w8=
github.com/ecodeclub/ekit v0.0.4-0.20230904153403-e76aae064994/go.mod h1:OqTojKeKFTxeeAAUwNIPKu339SRkX6KAuoK/8A5BCEs=
github.com/ecodeclub/ekit v0.0.8-0.20231001021557-856d32ae850b h1:T1OvEeJJEOhkrhkg55//A5kzX7lgdeX9gDJuVDahSpw=
github.com/ecodeclub/ekit v0.0.8-0.20231001021557-856d32ae850b/go.mod h1:OqTojKeKFTxeeAAUwNIPKu339SRkX6KAuoK/8A5BCEs=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/gotomicro/ekit v0.0.0-20230224040531-869798da3c4d h1:kmDgYRZ06UifBqAfew+cj02juQQ3Ko349NzsDIZ0QPw=
github.com/gotomicro/ekit v0.0.0-20230224040531-869798da3c4d/go.mod h1:ISYxgxcx3SOYGm/Hg9+M+pHVhN5G6W7p91/Pn7x6Hz8=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
Expand Down
19 changes: 7 additions & 12 deletions internal/datasource/transaction/delay_transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ func (s *TestDelayTxTestSuite) TestExecute_Commit_Or_Rollback() {
mockOrder: func(mock1, mock2 sqlmock.Sqlmock) {},
afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) {},
txFunc: func() (*eorm.Tx, error) {
s.DataSource = shardingsource.NewShardingDataSource(map[string]datasource.DataSource{
ds := shardingsource.NewShardingDataSource(map[string]datasource.DataSource{
"1.db.cluster.company.com:3306": s.clusterDB,
})
r := model.NewMetaRegistry()
_, err := r.Register(&test.OrderDetail{},
model.WithTableShardingAlgorithm(s.algorithm))
require.NoError(t, err)
db, err := eorm.OpenDS("mysql", s.DataSource, eorm.DBWithMetaRegistry(r))
db, err := eorm.OpenDS("mysql", ds, eorm.DBWithMetaRegistry(r))
require.NoError(t, err)
return db.BeginTx(transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{})
},
Expand All @@ -98,15 +98,15 @@ func (s *TestDelayTxTestSuite) TestExecute_Commit_Or_Rollback() {
mockOrder: func(mock1, mock2 sqlmock.Sqlmock) {},
afterFunc: func(t *testing.T, tx *eorm.Tx, values []*test.OrderDetail) {},
txFunc: func() (*eorm.Tx, error) {
s.DataSource = shardingsource.NewShardingDataSource(map[string]datasource.DataSource{
ds := shardingsource.NewShardingDataSource(map[string]datasource.DataSource{
"0.db.cluster.company.com:3306": masterslave.NewMasterSlavesDB(s.mockMaster1DB, masterslave.MasterSlavesWithSlaves(
newSlaves(t, s.mockSlave1DB, s.mockSlave2DB, s.mockSlave3DB))),
})
r := model.NewMetaRegistry()
_, err := r.Register(&test.OrderDetail{},
model.WithTableShardingAlgorithm(s.algorithm))
require.NoError(t, err)
db, err := eorm.OpenDS("mysql", s.DataSource, eorm.DBWithMetaRegistry(r))
db, err := eorm.OpenDS("mysql", ds, eorm.DBWithMetaRegistry(r))
require.NoError(t, err)
return db.BeginTx(transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{})
},
Expand All @@ -123,14 +123,14 @@ func (s *TestDelayTxTestSuite) TestExecute_Commit_Or_Rollback() {
"order_detail_db_0": masterslave.NewMasterSlavesDB(s.mockMaster1DB, masterslave.MasterSlavesWithSlaves(
newSlaves(t, s.mockSlave1DB, s.mockSlave2DB, s.mockSlave3DB))),
})
s.DataSource = shardingsource.NewShardingDataSource(map[string]datasource.DataSource{
ds := shardingsource.NewShardingDataSource(map[string]datasource.DataSource{
"0.db.cluster.company.com:3306": clusterDB,
})
r := model.NewMetaRegistry()
_, err := r.Register(&test.OrderDetail{},
model.WithTableShardingAlgorithm(s.algorithm))
require.NoError(t, err)
db, err := eorm.OpenDS("mysql", s.DataSource, eorm.DBWithMetaRegistry(r))
db, err := eorm.OpenDS("mysql", ds, eorm.DBWithMetaRegistry(r))
require.NoError(t, err)
return db.BeginTx(transaction.UsingTxType(context.Background(), transaction.Delay), &sql.TxOptions{})
},
Expand Down Expand Up @@ -483,10 +483,6 @@ func (s *TestDelayTxTestSuite) TestExecute_Commit_Or_Rollback() {
rows := s.mockMaster2.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"})
s.mockMaster2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_2` WHERE (`order_id`=?) OR (`order_id`=?);SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_1` WHERE (`order_id`=?) OR (`order_id`=?);")).
WithArgs(199, 299, 199, 299).WillReturnRows(rows)

queryVal := s.findTgt(t, values)
var wantOds []*test.OrderDetail
assert.ElementsMatch(t, wantOds, queryVal)
},
},
}
Expand All @@ -496,10 +492,9 @@ func (s *TestDelayTxTestSuite) TestExecute_Commit_Or_Rollback() {
tx, err := tc.txFunc()
require.NoError(t, err)

// TODO GetMultiV2 待将 table 维度改成 db 维度
querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx).
Where(eorm.C("OrderId").NEQ(123)).
GetMultiV2(masterslave.UseMaster(context.Background()))
GetMulti(masterslave.UseMaster(context.Background()))
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
Expand Down
3 changes: 1 addition & 2 deletions internal/datasource/transaction/transaction_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,8 @@ func (s *ShardingTransactionSuite) findTgt(t *testing.T, values []*test.OrderDet
od = values[i]
pre = pre.Or(eorm.C(s.shardingKey).EQ(od.OrderId))
}
// TODO GetMultiV2 待将 table 维度改成 db 维度
querySet, err := eorm.NewShardingSelector[test.OrderDetail](s.shardingDB).
Where(pre).GetMultiV2(masterslave.UseMaster(context.Background()))
Where(pre).GetMulti(masterslave.UseMaster(context.Background()))
require.NoError(t, err)
return querySet
}
Expand Down
6 changes: 3 additions & 3 deletions internal/integration/sharding_delay_transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ func (s *ShardingDelayTxTestSuite) TestDoubleShardingSelect() {
defer tx.Commit()
querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx).
Where(eorm.C("OrderId").NEQ(123)).
GetMultiV2(masterslave.UseMaster(context.Background()))
GetMulti(masterslave.UseMaster(context.Background()))
require.NoError(t, err)
assert.ElementsMatch(t, tc.querySet, querySet)

querySet, err = eorm.NewShardingSelector[test.OrderDetail](tx).
Where(eorm.C("OrderId").NEQ(123)).
GetMultiV2(masterslave.UseMaster(context.Background()))
GetMulti(masterslave.UseMaster(context.Background()))
require.NoError(t, err)
assert.ElementsMatch(t, tc.querySet, querySet)
})
Expand Down Expand Up @@ -228,7 +228,7 @@ func (s *ShardingDelayTxTestSuite) TestShardingSelectUpdateInsert_Commit_Or_Roll
tx := tc.txFunc(t)
querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx).
Where(eorm.C("OrderId").NEQ(123)).
GetMultiV2(masterslave.UseMaster(context.Background()))
GetMulti(masterslave.UseMaster(context.Background()))
require.NoError(t, err)
assert.ElementsMatch(t, tc.querySet, querySet)

Expand Down
8 changes: 4 additions & 4 deletions internal/integration/sharding_single_transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ func (s *ShardingSingleTxTestSuite) TestDoubleShardingSelect() {
defer tx.Commit()
querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx).
Where(eorm.C("OrderId").EQ(123)).
GetMultiV2(masterslave.UseMaster(context.Background()))
GetMulti(masterslave.UseMaster(context.Background()))
require.NoError(t, err)
assert.ElementsMatch(t, tc.querySet, querySet)

querySet, err = eorm.NewShardingSelector[test.OrderDetail](tx).
Where(eorm.C("OrderId").EQ(123)).
GetMultiV2(masterslave.UseMaster(context.Background()))
GetMulti(masterslave.UseMaster(context.Background()))
require.NoError(t, err)
assert.ElementsMatch(t, tc.querySet, querySet)
})
Expand Down Expand Up @@ -137,7 +137,7 @@ func (s *ShardingSingleTxTestSuite) TestShardingSelectInsert_Commit_Or_Rollback(
tx := tc.txFunc(t)
querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx).
Where(eorm.C("OrderId").EQ(123)).
GetMultiV2(masterslave.UseMaster(context.Background()))
GetMulti(masterslave.UseMaster(context.Background()))
require.NoError(t, err)
assert.ElementsMatch(t, tc.querySet, querySet)
res := eorm.NewShardingInsert[test.OrderDetail](tx).
Expand Down Expand Up @@ -220,7 +220,7 @@ func (s *ShardingSingleTxTestSuite) TestShardingSelectUpdate_Commit_Or_Rollback(
tx := tc.txFunc(t)
querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx).
Where(eorm.C("OrderId").EQ(123)).
GetMultiV2(masterslave.UseMaster(context.Background()))
GetMulti(masterslave.UseMaster(context.Background()))
require.NoError(t, err)
assert.ElementsMatch(t, tc.querySet, querySet)
res := eorm.NewShardingUpdater[test.OrderDetail](tx).Update(tc.target).
Expand Down
4 changes: 2 additions & 2 deletions internal/merger/groupby_merger/aggregator_merger.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ import (

"go.uber.org/multierr"

"github.com/ecodeclub/ekit/mapx"
"github.com/ecodeclub/eorm/internal/merger"
"github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator"
"github.com/ecodeclub/eorm/internal/merger/internal/errs"
"github.com/gotomicro/ekit/mapx"
)

type AggregatorMerger struct {
Expand Down Expand Up @@ -109,7 +109,7 @@ func (a *AggregatorMerger) getCols(rowsList []rows.Rows) (*mapx.TreeMap[Key, [][
val, ok := treeMap.Get(key)
if ok {
val = append(val, colData)
err = treeMap.Set(key, val)
err = treeMap.Put(key, val)
if err != nil {
return nil, nil, err
}
Expand Down
1 change: 0 additions & 1 deletion internal/merger/internal/errs/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ var (
ErrMergerAggregateHasEmptyRows = errors.New("merger: 聚合函数计算时rowsList有一个或多个为空")
ErrMergerInvalidAggregateColumnIndex = errors.New("merger: ColumnInfo的index不合法")
ErrMergerAggregateFuncNotFound = errors.New("merger: 聚合函数方法未找到")
ErrMergerNullable = errors.New("merger: 接收数据的类型需要为sql.Nullable")
)

func NewRepeatSortColumn(column string) error {
Expand Down
10 changes: 10 additions & 0 deletions internal/rows/convert_assign.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package rows

import (
"database/sql"
"database/sql/driver"
_ "unsafe"
)
Expand All @@ -31,5 +32,14 @@
return err
}
}
// 预处理一下 sqlConvertAssign 不支持的转换,遇到一个加一个
switch sv := src.(type) {
case sql.RawBytes:
switch dv := dest.(type) {
case *string:
*dv = string(sv)
return nil

Check warning on line 41 in internal/rows/convert_assign.go

View check run for this annotation

Codecov / codecov/patch

internal/rows/convert_assign.go#L37-L41

Added lines #L37 - L41 were not covered by tests
}
}
return sqlConvertAssign(dest, src)
}
88 changes: 88 additions & 0 deletions internal/rows/data_rows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright 2021 ecodeclub
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package rows

import (
"database/sql"

"github.com/ecodeclub/eorm/internal/errs"
)

var _ Rows = (*DataRows)(nil)

// DataRows 直接传入数据,伪装成了一个 Rows
// 非线程安全实现
type DataRows struct {
data [][]any
len int
columns []string
columnTypes []*sql.ColumnType
// 第几行
idx int
}

func (*DataRows) NextResultSet() bool {
return false
}

func (d *DataRows) ColumnTypes() ([]*sql.ColumnType, error) {
return d.columnTypes, nil

Check warning on line 41 in internal/rows/data_rows.go

View check run for this annotation

Codecov / codecov/patch

internal/rows/data_rows.go#L40-L41

Added lines #L40 - L41 were not covered by tests
}

func NewDataRows(data [][]any, columns []string, columnTypes []*sql.ColumnType) *DataRows {
// 这里并没有什么必要检查 data 和 columns 的输入
// 因为只有在很故意的情况下,data 和 columns 才可能会有问题
return &DataRows{
data: data,
len: len(data),
columns: columns,
idx: -1,
columnTypes: columnTypes,
}
}

func (d *DataRows) Next() bool {
if d.idx >= d.len-1 {
return false
}
d.idx++
return true
}

func (d *DataRows) Scan(dest ...any) error {
// 不需要检测,作为内部代码我们可以预期用户会主动控制
data := d.data[d.idx]
if len(data) != len(dest) {
return errs.NewErrScanWrongDestinationArguments(len(data), len(dest))
}
for idx, dst := range dest {
if err := ConvertAssign(dst, data[idx]); err != nil {
return err
}
}
return nil
}

func (*DataRows) Close() error {
return nil
}

func (d *DataRows) Columns() ([]string, error) {
return d.columns, nil
}

func (*DataRows) Err() error {
return nil
}
Loading
Loading