Skip to content

Commit

Permalink
feat: support com_query insert sql request (#129)
Browse files Browse the repository at this point in the history
* feat: support com_query insert sql request
  • Loading branch information
wybrobin authored Jun 7, 2022
1 parent f3bc97f commit 1aa20fc
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 12 deletions.
110 changes: 108 additions & 2 deletions pkg/filter/dt/exec/query_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ package exec

import (
"context"
"fmt"
"strings"

"github.com/cectc/dbpack/pkg/driver"
"github.com/cectc/dbpack/pkg/dt/schema"
"github.com/cectc/dbpack/pkg/log"
"github.com/cectc/dbpack/pkg/meta"
"github.com/cectc/dbpack/pkg/misc"
"github.com/cectc/dbpack/pkg/proto"
"github.com/cectc/dbpack/pkg/resource"
"github.com/cectc/dbpack/third_party/parser/ast"
Expand Down Expand Up @@ -52,8 +54,54 @@ func (executor *queryInsertExecutor) BeforeImage(ctx context.Context) (*schema.T
}

func (executor *queryInsertExecutor) AfterImage(ctx context.Context) (*schema.TableRecords, error) {
// todo
return nil, nil
var afterImage *schema.TableRecords
var err error
pkValues, err := executor.getPKValuesByColumn(ctx)
if err != nil {
return nil, err
}
if executor.getPKIndex(ctx) >= 0 {
afterImage, err = executor.buildTableRecords(ctx, pkValues)
} else {
pk, _ := executor.result.LastInsertId()
afterImage, err = executor.buildTableRecords(ctx, []interface{}{pk})
}
if err != nil {
return nil, err
}
return afterImage, nil
}

func (executor *queryInsertExecutor) buildTableRecords(ctx context.Context, pkValues []interface{}) (*schema.TableRecords, error) {
tableMeta, err := executor.GetTableMeta(ctx)
if err != nil {
return nil, err
}

afterImageSql := executor.buildAfterImageSql(tableMeta, pkValues)
result, _, err := executor.conn.ExecuteWithWarningCount(afterImageSql, true)
if err != nil {
return nil, err
}
return schema.BuildTextRecords(tableMeta, result), nil
}

func (executor *queryInsertExecutor) buildAfterImageSql(tableMeta schema.TableMeta, pkValues []interface{}) string {
var b strings.Builder
b.WriteString("SELECT ")
columnCount := len(tableMeta.Columns)
for i, column := range tableMeta.Columns {
b.WriteString(misc.CheckAndReplace(column))
if i < columnCount-1 {
b.WriteByte(',')
} else {
b.WriteByte(' ')
}
}
b.WriteString(fmt.Sprintf("FROM %s ", executor.GetTableName()))
b.WriteString(fmt.Sprintf(" WHERE `%s` IN ", tableMeta.GetPKName()))
b.WriteString(misc.MysqlAppendInParamWithValue(pkValues))
return b.String()
}

func (executor *queryInsertExecutor) GetTableMeta(ctx context.Context) (schema.TableMeta, error) {
Expand All @@ -69,3 +117,61 @@ func (executor *queryInsertExecutor) GetTableName() string {
}
return sb.String()
}

func (executor *queryInsertExecutor) getPKValuesByColumn(ctx context.Context) ([]interface{}, error) {
pkValues := make([]interface{}, 0)
pkIndex := executor.getPKIndex(ctx)
for j := range executor.stmt.Lists {
for i, value := range executor.stmt.Lists[j] {
if i == pkIndex {
var sb strings.Builder
if err := value.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)); err != nil {
log.Panic(err)
}
pkValues = append(pkValues, sb.String())
break
}
}
}

return pkValues, nil
}

func (executor *queryInsertExecutor) getPKIndex(ctx context.Context) int {
insertColumns := executor.GetInsertColumns()
tableMeta, _ := executor.GetTableMeta(ctx)

if len(insertColumns) > 0 {
for i, columnName := range insertColumns {
if strings.EqualFold(tableMeta.GetPKName(), columnName) {
return i
}
}
} else {
allColumns := tableMeta.Columns
for i, column := range allColumns {
if strings.EqualFold(tableMeta.GetPKName(), column) {
return i
}
}
}
return -1
}

func (executor *queryInsertExecutor) getColumnLen(ctx context.Context) int {
insertColumns := executor.GetInsertColumns()
if insertColumns != nil {
return len(insertColumns)
}
tableMeta, _ := executor.GetTableMeta(ctx)

return len(tableMeta.Columns)
}

func (executor *queryInsertExecutor) GetInsertColumns() []string {
result := make([]string, 0)
for _, col := range executor.stmt.Columns {
result = append(result, col.Name.String())
}
return result
}
13 changes: 13 additions & 0 deletions pkg/misc/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,16 @@ func PgsqlAppendInParam(size int) string {
fmt.Fprintf(&sb, ")")
return sb.String()
}

func MysqlAppendInParamWithValue(values []interface{}) string {
var sb strings.Builder
fmt.Fprintf(&sb, "(")
for i, value := range values {
fmt.Fprintf(&sb, "'%v'", value)
if i < len(values)-1 {
fmt.Fprint(&sb, ",")
}
}
fmt.Fprintf(&sb, ")")
return sb.String()
}
20 changes: 10 additions & 10 deletions test/sdb/distributed_transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,17 @@ func (suite *_DistributedTransactionSuite) TestDistributedTransactionQueryReques
if suite.NoErrorf(err, "begin global transaction error: %v", err) {
tx, err := suite.db2.Begin()
if suite.NoErrorf(err, "begin tx error: %v", err) {
//insertSql := fmt.Sprintf(`INSERT /*+ XID('%s') */ INTO dept_manager ( id, emp_no, dept_no, from_date, to_date ) VALUES (?, ?, ?, ?, ?)`, xid)
//result, err := tx.Exec(insertSql, 1, 100002, 1002, "2022-01-01", "2024-01-01")
//if suite.NoErrorf(err, "insert row error: %v", err) {
// affected, err := result.RowsAffected()
// if suite.NoErrorf(err, "insert row error: %v", err) {
// suite.Equal(int64(1), affected)
// }
//}
//
insertSql := fmt.Sprintf(`INSERT /*+ XID('%s') */ INTO dept_manager ( id, emp_no, dept_no, from_date, to_date ) VALUES (?, ?, ?, ?, ?)`, xid)
result, err := tx.Exec(insertSql, 2, 100002, 1002, "2022-01-01", "2024-01-01")
if suite.NoErrorf(err, "insert row error: %v", err) {
affected, err := result.RowsAffected()
if suite.NoErrorf(err, "insert row error: %v", err) {
suite.Equal(int64(1), affected)
}
}

deleteSql := fmt.Sprintf(`DELETE /*+ XID('%s') */ FROM dept_emp WHERE emp_no = ? and dept_no = ?`, xid)
result, err := tx.Exec(deleteSql, 100002, 1002)
result, err = tx.Exec(deleteSql, 100002, 1002)
if suite.NoErrorf(err, "delete row error: %v", err) {
affected, err := result.RowsAffected()
if suite.NoErrorf(err, "delete row error: %v", err) {
Expand Down

0 comments on commit 1aa20fc

Please sign in to comment.