Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: Save operation support for contrib/drivers/dm #3404

Merged
merged 1 commit into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 95 additions & 153 deletions contrib/drivers/dm/dm_do_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ import (
"context"
"database/sql"
"fmt"

"strings"

"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
Expand All @@ -32,176 +32,118 @@ func (d *Driver) DoInsert(
)

case gdb.InsertOptionSave:
// This syntax currently only supports design tables whose primary key is ID.
listLength := len(list)
if listLength == 0 {
return nil, gerror.NewCode(
gcode.CodeInvalidRequest, `Save operation list is empty by dm driver`,
)
}
var (
keysSort []string
charL, charR = d.GetChars()
return d.doSave(ctx, link, table, list, option)
}
return d.Core.DoInsert(ctx, link, table, list, option)
}

// doSave support upsert for dm
func (d *Driver) doSave(ctx context.Context,
link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption,
) (result sql.Result, err error) {
if len(option.OnConflict) == 0 {
return nil, gerror.NewCode(
gcode.CodeMissingParameter, `Please specify conflict columns`,
)
// Column names need to be aligned in the syntax
for k := range list[0] {
keysSort = append(keysSort, k)
}
var char = struct {
charL string
charR string
valueCharL string
valueCharR string
duplicateKey string
keys []string
}{
charL: charL,
charR: charR,
valueCharL: "'",
valueCharR: "'",
// TODO:: Need to dynamically set the primary key of the table
duplicateKey: "ID",
keys: keysSort,
}
}

// insertKeys: Handle valid keys that need to be inserted and updated
// insertValues: Handle values that need to be inserted
// updateValues: Handle values that need to be updated
// queryValues: Handle only one insert with column name
insertKeys, insertValues, updateValues, queryValues := parseValue(list[0], char)
// unionValues: Handling values that need to be inserted and updated
unionValues := parseUnion(list[1:], char)

batchResult := new(gdb.SqlResult)
// parseSql():
// MERGE INTO {{table}} T1
// USING ( SELECT {{queryValues}} FROM DUAL
// {{unionValues}} ) T2
// ON (T1.{{duplicateKey}} = T2.{{duplicateKey}})
// WHEN NOT MATCHED THEN
// INSERT {{insertKeys}} VALUES {{insertValues}}
// WHEN MATCHED THEN
// UPDATE SET {{updateValues}}
sqlStr := parseSql(
insertKeys, insertValues, updateValues, queryValues, unionValues, table, char.duplicateKey,
if len(list) == 0 {
return nil, gerror.NewCode(
gcode.CodeInvalidRequest, `Save operation list is empty by oracle driver`,
)
r, err := d.DoExec(ctx, link, sqlStr)
if err != nil {
return r, err
}
if n, err := r.RowsAffected(); err != nil {
return r, err
} else {
batchResult.Result = r
batchResult.Affected += n
}
return batchResult, nil
}
return d.Core.DoInsert(ctx, link, table, list, option)
}

func parseValue(listOne gdb.Map, char struct {
charL string
charR string
valueCharL string
valueCharR string
duplicateKey string
keys []string
}) (insertKeys []string, insertValues []string, updateValues []string, queryValues []string) {
for _, column := range char.keys {
if listOne[column] == nil {
// remove unassigned struct object
continue
}
insertKeys = append(insertKeys, char.charL+column+char.charR)
insertValues = append(insertValues, "T2."+char.charL+column+char.charR)
if column != char.duplicateKey {
updateValues = append(
updateValues,
fmt.Sprintf(`T1.%s = T2.%s`, char.charL+column+char.charR, char.charL+column+char.charR),
)
}
var (
one = list[0]
charL, charR = d.GetChars()
valueCharL, valueCharR = "'", "'"

conflictKeys = option.OnConflict
conflictKeySet = gset.New(false)

// insertKeys: Handle valid keys that need to be inserted
// insertValues: Handle values that need to be inserted
// updateValues: Handle values that need to be updated
// queryValues: Handle data that need to be upsert
queryValues, insertKeys, insertValues, updateValues []string
)

// conflictKeys slice type conv to set type
for _, conflictKey := range conflictKeys {
conflictKeySet.Add(gstr.ToUpper(conflictKey))
}

saveValue := gconv.String(listOne[column])
for key, value := range one {
saveValue := gconv.String(value)
queryValues = append(
queryValues,
fmt.Sprintf(
char.valueCharL+"%s"+char.valueCharR+" AS "+char.charL+"%s"+char.charR,
saveValue, column,
valueCharL+"%s"+valueCharR+" AS "+charL+"%s"+charR,
saveValue, key,
),
)
}
return
}

func parseUnion(list gdb.List, char struct {
charL string
charR string
valueCharL string
valueCharR string
duplicateKey string
keys []string
}) (unionValues []string) {
for _, mapper := range list {
var saveValue []string
for _, column := range char.keys {
if mapper[column] == nil {
continue
}
// va := reflect.ValueOf(mapper[column])
// ty := reflect.TypeOf(mapper[column])
// switch ty.Kind() {
// case reflect.String:
// saveValue = append(saveValue, char.valueCharL+va.String()+char.valueCharR)

// case reflect.Int:
// saveValue = append(saveValue, strconv.FormatInt(va.Int(), 10))

// case reflect.Int64:
// saveValue = append(saveValue, strconv.FormatInt(va.Int(), 10))

// default:
// // The fish has no chance getting here.
// // Nothing to do.
// }
saveValue = append(saveValue,
fmt.Sprintf(
char.valueCharL+"%s"+char.valueCharR,
gconv.String(mapper[column]),
))
insertKeys = append(insertKeys, charL+key+charR)
insertValues = append(insertValues, "T2."+charL+key+charR)

// filter conflict keys in updateValues
if !conflictKeySet.Contains(key) {
updateValues = append(
updateValues,
fmt.Sprintf(`T1.%s = T2.%s`, charL+key+charR, charL+key+charR),
)
}
unionValues = append(
unionValues,
fmt.Sprintf(`UNION ALL SELECT %s FROM DUAL`, strings.Join(saveValue, ",")),
)
}
return

batchResult := new(gdb.SqlResult)
sqlStr := parseSqlForUpsert(table, queryValues, insertKeys, insertValues, updateValues, conflictKeys)
r, err := d.DoExec(ctx, link, sqlStr)
if err != nil {
return r, err
}
if n, err := r.RowsAffected(); err != nil {
return r, err
} else {
batchResult.Result = r
batchResult.Affected += n
}
return batchResult, nil
}

func parseSql(
insertKeys, insertValues, updateValues, queryValues, unionValues []string, table, duplicateKey string,
// parseSqlForUpsert
// MERGE INTO {{table}} T1
// USING ( SELECT {{queryValues}} FROM DUAL T2
// ON (T1.{{duplicateKey}} = T2.{{duplicateKey}} AND ...)
// WHEN NOT MATCHED THEN
// INSERT {{insertKeys}} VALUES {{insertValues}}
// WHEN MATCHED THEN
// UPDATE SET {{updateValues}}
func parseSqlForUpsert(table string,
queryValues, insertKeys, insertValues, updateValues, duplicateKey []string,
) (sqlStr string) {
var (
queryValueStr = strings.Join(queryValues, ",")
unionValueStr = strings.Join(unionValues, " ")
insertKeyStr = strings.Join(insertKeys, ",")
insertValueStr = strings.Join(insertValues, ",")
updateValueStr = strings.Join(updateValues, ",")
pattern = gstr.Trim(`
MERGE INTO %s T1 USING (SELECT %s FROM DUAL %s) T2 ON %s
WHEN NOT MATCHED
THEN
INSERT(%s) VALUES (%s)
WHEN MATCHED
THEN
UPDATE SET %s;
COMMIT;
`)
queryValueStr = strings.Join(queryValues, ",")
insertKeyStr = strings.Join(insertKeys, ",")
insertValueStr = strings.Join(insertValues, ",")
updateValueStr = strings.Join(updateValues, ",")
duplicateKeyStr string
pattern = gstr.Trim(`MERGE INTO %s T1 USING (SELECT %s FROM DUAL) T2 ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s;`)
)
return fmt.Sprintf(
pattern,
table, queryValueStr, unionValueStr,
fmt.Sprintf("(T1.%s = T2.%s)", duplicateKey, duplicateKey),
insertKeyStr, insertValueStr, updateValueStr,

for index, keys := range duplicateKey {
if index != 0 {
duplicateKeyStr += " AND "
}
duplicateTmp := fmt.Sprintf("T1.%s = T2.%s", keys, keys)
duplicateKeyStr += duplicateTmp
}

return fmt.Sprintf(pattern,
table,
queryValueStr,
duplicateKeyStr,
insertKeyStr,
insertValueStr,
updateValueStr,
)
}
82 changes: 42 additions & 40 deletions contrib/drivers/dm/dm_z_unit_basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package dm_test

import (
"database/sql"
"fmt"
"strings"
"testing"
Expand Down Expand Up @@ -138,52 +139,53 @@ func Test_DB_Query(t *testing.T) {
}

func TestModelSave(t *testing.T) {
table := "A_tables"
createInitTable(table)
table := createTable("test")
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
data := []User{
{
ID: 100,
AccountName: "user_100",
AttrIndex: 100,
CreatedTime: time.Now(),
},
type User struct {
Id int
AccountName string
AttrIndex int
}
_, err := db.Model(table).Data(data).Save()
gtest.Assert(err, nil)
var (
user User
count int
result sql.Result
err error
)
db.SetDebug(true)

data2 := []User{
{
ID: 101,
AccountName: "user_101",
},
}
_, err = db.Model(table).Data(&data2).Save()
gtest.Assert(err, nil)
result, err = db.Model(table).Data(g.Map{
"id": 1,
"accountName": "ac1",
"attrIndex": 100,
}).OnConflict("id").Save()

data3 := []User{
{
ID: 10,
AccountName: "user_10",
PwdReset: 10,
},
}
_, err = db.Model(table).Save(data3)
gtest.Assert(err, nil)
t.AssertNil(err)
n, _ := result.RowsAffected()
t.Assert(n, 1)

data4 := []User{
{
ID: 9,
AccountName: "user_9",
CreatedTime: time.Now(),
},
}
_, err = db.Model(table).Save(&data4)
gtest.Assert(err, nil)
err = db.Model(table).Scan(&user)
t.AssertNil(err)
t.Assert(user.Id, 1)
t.Assert(user.AccountName, "ac1")
t.Assert(user.AttrIndex, 100)

_, err = db.Model(table).Data(g.Map{
"id": 1,
"accountName": "ac2",
"attrIndex": 200,
}).OnConflict("id").Save()
t.AssertNil(err)

// TODO:: Should be Supported 'Replace' Operation
// _, err = db.Schema(TestDBName).Replace(ctx, "DoInsert", data, 10)
// gtest.Assert(err, nil)
err = db.Model(table).Scan(&user)
t.AssertNil(err)
t.Assert(user.AccountName, "ac2")
t.Assert(user.AttrIndex, 200)

count, err = db.Model(table).Count()
t.AssertNil(err)
t.Assert(count, 1)
})
}

Expand Down
Loading