diff --git a/contrib/drivers/dm/dm_do_insert.go b/contrib/drivers/dm/dm_do_insert.go index edad1a95a34..003a9923c54 100644 --- a/contrib/drivers/dm/dm_do_insert.go +++ b/contrib/drivers/dm/dm_do_insert.go @@ -10,9 +10,9 @@ import ( "context" "database/sql" "fmt" - "strings" + "github.com/gogf/gf/v2/container/gset" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" @@ -32,176 +32,118 @@ func (d *Driver) DoInsert( ) case gdb.InsertOptionSave: - // This syntax currently only supports design tables whose primary key is ID. - listLength := len(list) - if listLength == 0 { - return nil, gerror.NewCode( - gcode.CodeInvalidRequest, `Save operation list is empty by dm driver`, - ) - } - var ( - keysSort []string - charL, charR = d.GetChars() + return d.doSave(ctx, link, table, list, option) + } + return d.Core.DoInsert(ctx, link, table, list, option) +} + +// doSave support upsert for dm +func (d *Driver) doSave(ctx context.Context, + link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, +) (result sql.Result, err error) { + if len(option.OnConflict) == 0 { + return nil, gerror.NewCode( + gcode.CodeMissingParameter, `Please specify conflict columns`, ) - // Column names need to be aligned in the syntax - for k := range list[0] { - keysSort = append(keysSort, k) - } - var char = struct { - charL string - charR string - valueCharL string - valueCharR string - duplicateKey string - keys []string - }{ - charL: charL, - charR: charR, - valueCharL: "'", - valueCharR: "'", - // TODO:: Need to dynamically set the primary key of the table - duplicateKey: "ID", - keys: keysSort, - } + } - // insertKeys: Handle valid keys that need to be inserted and updated - // insertValues: Handle values that need to be inserted - // updateValues: Handle values that need to be updated - // queryValues: Handle only one insert with column name - insertKeys, insertValues, updateValues, queryValues := parseValue(list[0], char) - // unionValues: Handling values that need to be inserted and updated - unionValues := parseUnion(list[1:], char) - - batchResult := new(gdb.SqlResult) - // parseSql(): - // MERGE INTO {{table}} T1 - // USING ( SELECT {{queryValues}} FROM DUAL - // {{unionValues}} ) T2 - // ON (T1.{{duplicateKey}} = T2.{{duplicateKey}}) - // WHEN NOT MATCHED THEN - // INSERT {{insertKeys}} VALUES {{insertValues}} - // WHEN MATCHED THEN - // UPDATE SET {{updateValues}} - sqlStr := parseSql( - insertKeys, insertValues, updateValues, queryValues, unionValues, table, char.duplicateKey, + if len(list) == 0 { + return nil, gerror.NewCode( + gcode.CodeInvalidRequest, `Save operation list is empty by oracle driver`, ) - r, err := d.DoExec(ctx, link, sqlStr) - if err != nil { - return r, err - } - if n, err := r.RowsAffected(); err != nil { - return r, err - } else { - batchResult.Result = r - batchResult.Affected += n - } - return batchResult, nil } - return d.Core.DoInsert(ctx, link, table, list, option) -} -func parseValue(listOne gdb.Map, char struct { - charL string - charR string - valueCharL string - valueCharR string - duplicateKey string - keys []string -}) (insertKeys []string, insertValues []string, updateValues []string, queryValues []string) { - for _, column := range char.keys { - if listOne[column] == nil { - // remove unassigned struct object - continue - } - insertKeys = append(insertKeys, char.charL+column+char.charR) - insertValues = append(insertValues, "T2."+char.charL+column+char.charR) - if column != char.duplicateKey { - updateValues = append( - updateValues, - fmt.Sprintf(`T1.%s = T2.%s`, char.charL+column+char.charR, char.charL+column+char.charR), - ) - } + var ( + one = list[0] + charL, charR = d.GetChars() + valueCharL, valueCharR = "'", "'" + + conflictKeys = option.OnConflict + conflictKeySet = gset.New(false) + + // insertKeys: Handle valid keys that need to be inserted + // insertValues: Handle values that need to be inserted + // updateValues: Handle values that need to be updated + // queryValues: Handle data that need to be upsert + queryValues, insertKeys, insertValues, updateValues []string + ) + + // conflictKeys slice type conv to set type + for _, conflictKey := range conflictKeys { + conflictKeySet.Add(gstr.ToUpper(conflictKey)) + } - saveValue := gconv.String(listOne[column]) + for key, value := range one { + saveValue := gconv.String(value) queryValues = append( queryValues, fmt.Sprintf( - char.valueCharL+"%s"+char.valueCharR+" AS "+char.charL+"%s"+char.charR, - saveValue, column, + valueCharL+"%s"+valueCharR+" AS "+charL+"%s"+charR, + saveValue, key, ), ) - } - return -} -func parseUnion(list gdb.List, char struct { - charL string - charR string - valueCharL string - valueCharR string - duplicateKey string - keys []string -}) (unionValues []string) { - for _, mapper := range list { - var saveValue []string - for _, column := range char.keys { - if mapper[column] == nil { - continue - } - // va := reflect.ValueOf(mapper[column]) - // ty := reflect.TypeOf(mapper[column]) - // switch ty.Kind() { - // case reflect.String: - // saveValue = append(saveValue, char.valueCharL+va.String()+char.valueCharR) - - // case reflect.Int: - // saveValue = append(saveValue, strconv.FormatInt(va.Int(), 10)) - - // case reflect.Int64: - // saveValue = append(saveValue, strconv.FormatInt(va.Int(), 10)) - - // default: - // // The fish has no chance getting here. - // // Nothing to do. - // } - saveValue = append(saveValue, - fmt.Sprintf( - char.valueCharL+"%s"+char.valueCharR, - gconv.String(mapper[column]), - )) + insertKeys = append(insertKeys, charL+key+charR) + insertValues = append(insertValues, "T2."+charL+key+charR) + + // filter conflict keys in updateValues + if !conflictKeySet.Contains(key) { + updateValues = append( + updateValues, + fmt.Sprintf(`T1.%s = T2.%s`, charL+key+charR, charL+key+charR), + ) } - unionValues = append( - unionValues, - fmt.Sprintf(`UNION ALL SELECT %s FROM DUAL`, strings.Join(saveValue, ",")), - ) } - return + + batchResult := new(gdb.SqlResult) + sqlStr := parseSqlForUpsert(table, queryValues, insertKeys, insertValues, updateValues, conflictKeys) + r, err := d.DoExec(ctx, link, sqlStr) + if err != nil { + return r, err + } + if n, err := r.RowsAffected(); err != nil { + return r, err + } else { + batchResult.Result = r + batchResult.Affected += n + } + return batchResult, nil } -func parseSql( - insertKeys, insertValues, updateValues, queryValues, unionValues []string, table, duplicateKey string, +// parseSqlForUpsert +// MERGE INTO {{table}} T1 +// USING ( SELECT {{queryValues}} FROM DUAL T2 +// ON (T1.{{duplicateKey}} = T2.{{duplicateKey}} AND ...) +// WHEN NOT MATCHED THEN +// INSERT {{insertKeys}} VALUES {{insertValues}} +// WHEN MATCHED THEN +// UPDATE SET {{updateValues}} +func parseSqlForUpsert(table string, + queryValues, insertKeys, insertValues, updateValues, duplicateKey []string, ) (sqlStr string) { var ( - queryValueStr = strings.Join(queryValues, ",") - unionValueStr = strings.Join(unionValues, " ") - insertKeyStr = strings.Join(insertKeys, ",") - insertValueStr = strings.Join(insertValues, ",") - updateValueStr = strings.Join(updateValues, ",") - pattern = gstr.Trim(` -MERGE INTO %s T1 USING (SELECT %s FROM DUAL %s) T2 ON %s -WHEN NOT MATCHED -THEN -INSERT(%s) VALUES (%s) -WHEN MATCHED -THEN -UPDATE SET %s; -COMMIT; -`) + queryValueStr = strings.Join(queryValues, ",") + insertKeyStr = strings.Join(insertKeys, ",") + insertValueStr = strings.Join(insertValues, ",") + updateValueStr = strings.Join(updateValues, ",") + duplicateKeyStr string + pattern = gstr.Trim(`MERGE INTO %s T1 USING (SELECT %s FROM DUAL) T2 ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s;`) ) - return fmt.Sprintf( - pattern, - table, queryValueStr, unionValueStr, - fmt.Sprintf("(T1.%s = T2.%s)", duplicateKey, duplicateKey), - insertKeyStr, insertValueStr, updateValueStr, + + for index, keys := range duplicateKey { + if index != 0 { + duplicateKeyStr += " AND " + } + duplicateTmp := fmt.Sprintf("T1.%s = T2.%s", keys, keys) + duplicateKeyStr += duplicateTmp + } + + return fmt.Sprintf(pattern, + table, + queryValueStr, + duplicateKeyStr, + insertKeyStr, + insertValueStr, + updateValueStr, ) } diff --git a/contrib/drivers/dm/dm_z_unit_basic_test.go b/contrib/drivers/dm/dm_z_unit_basic_test.go index 71504d7c265..3ad8ccfc710 100644 --- a/contrib/drivers/dm/dm_z_unit_basic_test.go +++ b/contrib/drivers/dm/dm_z_unit_basic_test.go @@ -7,6 +7,7 @@ package dm_test import ( + "database/sql" "fmt" "strings" "testing" @@ -138,52 +139,53 @@ func Test_DB_Query(t *testing.T) { } func TestModelSave(t *testing.T) { - table := "A_tables" - createInitTable(table) + table := createTable("test") + defer dropTable(table) gtest.C(t, func(t *gtest.T) { - data := []User{ - { - ID: 100, - AccountName: "user_100", - AttrIndex: 100, - CreatedTime: time.Now(), - }, + type User struct { + Id int + AccountName string + AttrIndex int } - _, err := db.Model(table).Data(data).Save() - gtest.Assert(err, nil) + var ( + user User + count int + result sql.Result + err error + ) + db.SetDebug(true) - data2 := []User{ - { - ID: 101, - AccountName: "user_101", - }, - } - _, err = db.Model(table).Data(&data2).Save() - gtest.Assert(err, nil) + result, err = db.Model(table).Data(g.Map{ + "id": 1, + "accountName": "ac1", + "attrIndex": 100, + }).OnConflict("id").Save() - data3 := []User{ - { - ID: 10, - AccountName: "user_10", - PwdReset: 10, - }, - } - _, err = db.Model(table).Save(data3) - gtest.Assert(err, nil) + t.AssertNil(err) + n, _ := result.RowsAffected() + t.Assert(n, 1) - data4 := []User{ - { - ID: 9, - AccountName: "user_9", - CreatedTime: time.Now(), - }, - } - _, err = db.Model(table).Save(&data4) - gtest.Assert(err, nil) + err = db.Model(table).Scan(&user) + t.AssertNil(err) + t.Assert(user.Id, 1) + t.Assert(user.AccountName, "ac1") + t.Assert(user.AttrIndex, 100) + + _, err = db.Model(table).Data(g.Map{ + "id": 1, + "accountName": "ac2", + "attrIndex": 200, + }).OnConflict("id").Save() + t.AssertNil(err) - // TODO:: Should be Supported 'Replace' Operation - // _, err = db.Schema(TestDBName).Replace(ctx, "DoInsert", data, 10) - // gtest.Assert(err, nil) + err = db.Model(table).Scan(&user) + t.AssertNil(err) + t.Assert(user.AccountName, "ac2") + t.Assert(user.AttrIndex, 200) + + count, err = db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, 1) }) }