diff --git a/pkg/filter/dt/exec/query_insert.go b/pkg/filter/dt/exec/query_insert.go index 4eb24cc..79d6dbb 100644 --- a/pkg/filter/dt/exec/query_insert.go +++ b/pkg/filter/dt/exec/query_insert.go @@ -18,12 +18,14 @@ package exec import ( "context" + "fmt" "strings" "github.com/cectc/dbpack/pkg/driver" "github.com/cectc/dbpack/pkg/dt/schema" "github.com/cectc/dbpack/pkg/log" "github.com/cectc/dbpack/pkg/meta" + "github.com/cectc/dbpack/pkg/misc" "github.com/cectc/dbpack/pkg/proto" "github.com/cectc/dbpack/pkg/resource" "github.com/cectc/dbpack/third_party/parser/ast" @@ -52,8 +54,54 @@ func (executor *queryInsertExecutor) BeforeImage(ctx context.Context) (*schema.T } func (executor *queryInsertExecutor) AfterImage(ctx context.Context) (*schema.TableRecords, error) { - // todo - return nil, nil + var afterImage *schema.TableRecords + var err error + pkValues, err := executor.getPKValuesByColumn(ctx) + if err != nil { + return nil, err + } + if executor.getPKIndex(ctx) >= 0 { + afterImage, err = executor.buildTableRecords(ctx, pkValues) + } else { + pk, _ := executor.result.LastInsertId() + afterImage, err = executor.buildTableRecords(ctx, []interface{}{pk}) + } + if err != nil { + return nil, err + } + return afterImage, nil +} + +func (executor *queryInsertExecutor) buildTableRecords(ctx context.Context, pkValues []interface{}) (*schema.TableRecords, error) { + tableMeta, err := executor.GetTableMeta(ctx) + if err != nil { + return nil, err + } + + afterImageSql := executor.buildAfterImageSql(tableMeta, pkValues) + result, _, err := executor.conn.ExecuteWithWarningCount(afterImageSql, true) + if err != nil { + return nil, err + } + return schema.BuildTextRecords(tableMeta, result), nil +} + +func (executor *queryInsertExecutor) buildAfterImageSql(tableMeta schema.TableMeta, pkValues []interface{}) string { + var b strings.Builder + b.WriteString("SELECT ") + columnCount := len(tableMeta.Columns) + for i, column := range tableMeta.Columns { + b.WriteString(misc.CheckAndReplace(column)) + if i < columnCount-1 { + b.WriteByte(',') + } else { + b.WriteByte(' ') + } + } + b.WriteString(fmt.Sprintf("FROM %s ", executor.GetTableName())) + b.WriteString(fmt.Sprintf(" WHERE `%s` IN ", tableMeta.GetPKName())) + b.WriteString(misc.MysqlAppendInParamWithValue(pkValues)) + return b.String() } func (executor *queryInsertExecutor) GetTableMeta(ctx context.Context) (schema.TableMeta, error) { @@ -69,3 +117,61 @@ func (executor *queryInsertExecutor) GetTableName() string { } return sb.String() } + +func (executor *queryInsertExecutor) getPKValuesByColumn(ctx context.Context) ([]interface{}, error) { + pkValues := make([]interface{}, 0) + pkIndex := executor.getPKIndex(ctx) + for j := range executor.stmt.Lists { + for i, value := range executor.stmt.Lists[j] { + if i == pkIndex { + var sb strings.Builder + if err := value.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)); err != nil { + log.Panic(err) + } + pkValues = append(pkValues, sb.String()) + break + } + } + } + + return pkValues, nil +} + +func (executor *queryInsertExecutor) getPKIndex(ctx context.Context) int { + insertColumns := executor.GetInsertColumns() + tableMeta, _ := executor.GetTableMeta(ctx) + + if len(insertColumns) > 0 { + for i, columnName := range insertColumns { + if strings.EqualFold(tableMeta.GetPKName(), columnName) { + return i + } + } + } else { + allColumns := tableMeta.Columns + for i, column := range allColumns { + if strings.EqualFold(tableMeta.GetPKName(), column) { + return i + } + } + } + return -1 +} + +func (executor *queryInsertExecutor) getColumnLen(ctx context.Context) int { + insertColumns := executor.GetInsertColumns() + if insertColumns != nil { + return len(insertColumns) + } + tableMeta, _ := executor.GetTableMeta(ctx) + + return len(tableMeta.Columns) +} + +func (executor *queryInsertExecutor) GetInsertColumns() []string { + result := make([]string, 0) + for _, col := range executor.stmt.Columns { + result = append(result, col.Name.String()) + } + return result +} diff --git a/pkg/misc/sql.go b/pkg/misc/sql.go index e4522b7..33daa62 100644 --- a/pkg/misc/sql.go +++ b/pkg/misc/sql.go @@ -46,3 +46,16 @@ func PgsqlAppendInParam(size int) string { fmt.Fprintf(&sb, ")") return sb.String() } + +func MysqlAppendInParamWithValue(values []interface{}) string { + var sb strings.Builder + fmt.Fprintf(&sb, "(") + for i, value := range values { + fmt.Fprintf(&sb, "'%v'", value) + if i < len(values)-1 { + fmt.Fprint(&sb, ",") + } + } + fmt.Fprintf(&sb, ")") + return sb.String() +} diff --git a/test/sdb/distributed_transaction_test.go b/test/sdb/distributed_transaction_test.go index 1a2a429..2b2f2db 100644 --- a/test/sdb/distributed_transaction_test.go +++ b/test/sdb/distributed_transaction_test.go @@ -161,17 +161,17 @@ func (suite *_DistributedTransactionSuite) TestDistributedTransactionQueryReques if suite.NoErrorf(err, "begin global transaction error: %v", err) { tx, err := suite.db2.Begin() if suite.NoErrorf(err, "begin tx error: %v", err) { - //insertSql := fmt.Sprintf(`INSERT /*+ XID('%s') */ INTO dept_manager ( id, emp_no, dept_no, from_date, to_date ) VALUES (?, ?, ?, ?, ?)`, xid) - //result, err := tx.Exec(insertSql, 1, 100002, 1002, "2022-01-01", "2024-01-01") - //if suite.NoErrorf(err, "insert row error: %v", err) { - // affected, err := result.RowsAffected() - // if suite.NoErrorf(err, "insert row error: %v", err) { - // suite.Equal(int64(1), affected) - // } - //} - // + insertSql := fmt.Sprintf(`INSERT /*+ XID('%s') */ INTO dept_manager ( id, emp_no, dept_no, from_date, to_date ) VALUES (?, ?, ?, ?, ?)`, xid) + result, err := tx.Exec(insertSql, 2, 100002, 1002, "2022-01-01", "2024-01-01") + if suite.NoErrorf(err, "insert row error: %v", err) { + affected, err := result.RowsAffected() + if suite.NoErrorf(err, "insert row error: %v", err) { + suite.Equal(int64(1), affected) + } + } + deleteSql := fmt.Sprintf(`DELETE /*+ XID('%s') */ FROM dept_emp WHERE emp_no = ? and dept_no = ?`, xid) - result, err := tx.Exec(deleteSql, 100002, 1002) + result, err = tx.Exec(deleteSql, 100002, 1002) if suite.NoErrorf(err, "delete row error: %v", err) { affected, err := result.RowsAffected() if suite.NoErrorf(err, "delete row error: %v", err) {