From c090ba236b6f8bf727ea246a08c7024c13895f87 Mon Sep 17 00:00:00 2001 From: oldme Date: Sun, 25 Feb 2024 22:13:59 +0800 Subject: [PATCH 1/7] up --- contrib/drivers/pgsql/pgsql_do_insert.go | 63 +++++++++- database/gdb/gdb.go | 3 + database/gdb/gdb_core.go | 54 ++------- database/gdb/gdb_core_underlying.go | 45 +++++++ database/gdb/gdb_model.go | 5 +- database/gdb/gdb_model_insert.go | 143 +++++++++++++++-------- 6 files changed, 212 insertions(+), 101 deletions(-) diff --git a/contrib/drivers/pgsql/pgsql_do_insert.go b/contrib/drivers/pgsql/pgsql_do_insert.go index b7586e3d79b..e17b4dcceba 100644 --- a/contrib/drivers/pgsql/pgsql_do_insert.go +++ b/contrib/drivers/pgsql/pgsql_do_insert.go @@ -9,20 +9,23 @@ package pgsql import ( "context" "database/sql" + "fmt" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" ) // DoInsert inserts or updates data forF given table. func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption) (result sql.Result, err error) { switch option.InsertOption { case gdb.InsertOptionSave: - return nil, gerror.NewCode( - gcode.CodeNotSupported, - `Save operation is not supported by pgsql driver`, - ) + //return nil, gerror.NewCode( + // gcode.CodeNotSupported, + // `Save operation is not supported by pgsql driver`, + //) case gdb.InsertOptionReplace: return nil, gerror.NewCode( @@ -50,3 +53,55 @@ func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list } return d.Core.DoInsert(ctx, link, table, list, option) } + +// DoFormatUpsert returns SQL clause of type upsert for PgSQL. +// For example: ON CONFLICT (id) DO UPDATE SET ... +func (d *Driver) DoFormatUpsert(columns []string, option gdb.DoInsertOption) (string, error) { + if len(option.OnConflict) == 0 { + return "", gerror.New("Please specify conflict columns") + } + + var onDuplicateStr string + if option.OnDuplicateStr != "" { + onDuplicateStr = option.OnDuplicateStr + } else if len(option.OnDuplicateMap) > 0 { + for k, v := range option.OnDuplicateMap { + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + switch v.(type) { + case gdb.Raw, *gdb.Raw: + onDuplicateStr += fmt.Sprintf( + "%s=%s", + d.Core.QuoteWord(k), + v, + ) + default: + onDuplicateStr += fmt.Sprintf( + "%s=EXCLUDED.%s", + d.Core.QuoteWord(k), + d.Core.QuoteWord(gconv.String(v)), + ) + } + } + } else { + for _, column := range columns { + // If it's SAVE operation, do not automatically update the creating time. + if d.Core.IsSoftCreatedFieldName(column) { + continue + } + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + onDuplicateStr += fmt.Sprintf( + "%s=EXCLUDED.%s", + d.Core.QuoteWord(column), + d.Core.QuoteWord(column), + ) + } + } + + conflictKeys := gstr.Join(option.OnConflict, ",") + + return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET ", conflictKeys) + onDuplicateStr, nil +} diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 89f02b7995e..9ef6327cf9c 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -110,6 +110,8 @@ type DB interface { DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, error) // See Core.DoPrepare. + DoFormatUpsert(columns []string, option DoInsertOption) (string, error) // See Core.DoFormatUpsert + // =========================================================================== // Query APIs for convenience purpose. // =========================================================================== @@ -320,6 +322,7 @@ type Sql struct { type DoInsertOption struct { OnDuplicateStr string // Custom string for `on duplicated` statement. OnDuplicateMap map[string]interface{} // Custom key-value map from `OnDuplicateEx` function for `on duplicated` statement. + OnConflict []string // Custom conflict key of upsert clause, if the database needs it. InsertOption InsertOption // Insert operation in constant value. BatchCount int // Batch count for batch inserting. } diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index c0292d042a0..8d06cfd58a0 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -487,9 +487,12 @@ func (c *Core) DoInsert(ctx context.Context, link Link, table string, list List, keysStr = charL + strings.Join(keys, charR+","+charL) + charR operation = GetInsertOperationByOption(option.InsertOption) ) - // `ON DUPLICATED...` statement only takes effect on Save operation. + // Upsert clause only takes effect on Save operation. if option.InsertOption == InsertOptionSave { - onDuplicateStr = c.formatOnDuplicate(keys, option) + onDuplicateStr, err = c.db.DoFormatUpsert(keys, option) + if err != nil { + return nil, err + } } var ( listLength = len(list) @@ -537,49 +540,6 @@ func (c *Core) DoInsert(ctx context.Context, link Link, table string, list List, return batchResult, nil } -func (c *Core) formatOnDuplicate(columns []string, option DoInsertOption) string { - var onDuplicateStr string - if option.OnDuplicateStr != "" { - onDuplicateStr = option.OnDuplicateStr - } else if len(option.OnDuplicateMap) > 0 { - for k, v := range option.OnDuplicateMap { - if len(onDuplicateStr) > 0 { - onDuplicateStr += "," - } - switch v.(type) { - case Raw, *Raw: - onDuplicateStr += fmt.Sprintf( - "%s=%s", - c.QuoteWord(k), - v, - ) - default: - onDuplicateStr += fmt.Sprintf( - "%s=VALUES(%s)", - c.QuoteWord(k), - c.QuoteWord(gconv.String(v)), - ) - } - } - } else { - for _, column := range columns { - // If it's SAVE operation, do not automatically update the creating time. - if c.isSoftCreatedFieldName(column) { - continue - } - if len(onDuplicateStr) > 0 { - onDuplicateStr += "," - } - onDuplicateStr += fmt.Sprintf( - "%s=VALUES(%s)", - c.QuoteWord(column), - c.QuoteWord(column), - ) - } - } - return InsertOnDuplicateKeyUpdate + " " + onDuplicateStr -} - // Update does "UPDATE ... " statement for the table. // // The parameter `data` can be type of string/map/gmap/struct/*struct, etc. @@ -798,8 +758,8 @@ func (c *Core) GetTablesWithCache() ([]string, error) { return result.Strings(), nil } -// isSoftCreatedFieldName checks and returns whether given field name is an automatic-filled created time. -func (c *Core) isSoftCreatedFieldName(fieldName string) bool { +// IsSoftCreatedFieldName checks and returns whether given field name is an automatic-filled created time. +func (c *Core) IsSoftCreatedFieldName(fieldName string) bool { if fieldName == "" { return false } diff --git a/database/gdb/gdb_core_underlying.go b/database/gdb/gdb_core_underlying.go index 85b0aaa84ec..0338b96dc4c 100644 --- a/database/gdb/gdb_core_underlying.go +++ b/database/gdb/gdb_core_underlying.go @@ -10,6 +10,7 @@ package gdb import ( "context" "database/sql" + "fmt" "reflect" "go.opentelemetry.io/otel" @@ -352,6 +353,50 @@ func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (stmt *Stmt return out.Stmt, err } +// DoFormatUpsert returns SQL clause for type upsert. +func (c *Core) DoFormatUpsert(columns []string, option DoInsertOption) (string, error) { + var onDuplicateStr string + if option.OnDuplicateStr != "" { + onDuplicateStr = option.OnDuplicateStr + } else if len(option.OnDuplicateMap) > 0 { + for k, v := range option.OnDuplicateMap { + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + switch v.(type) { + case Raw, *Raw: + onDuplicateStr += fmt.Sprintf( + "%s=%s", + c.QuoteWord(k), + v, + ) + default: + onDuplicateStr += fmt.Sprintf( + "%s=VALUES(%s)", + c.QuoteWord(k), + c.QuoteWord(gconv.String(v)), + ) + } + } + } else { + for _, column := range columns { + // If it's SAVE operation, do not automatically update the creating time. + if c.IsSoftCreatedFieldName(column) { + continue + } + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + onDuplicateStr += fmt.Sprintf( + "%s=VALUES(%s)", + c.QuoteWord(column), + c.QuoteWord(column), + ) + } + } + return InsertOnDuplicateKeyUpdate + " " + onDuplicateStr, nil +} + // RowsToResult converts underlying data record type sql.Rows to Result type. func (c *Core) RowsToResult(ctx context.Context, rows *sql.Rows) (Result, error) { if rows == nil { diff --git a/database/gdb/gdb_model.go b/database/gdb/gdb_model.go index f854063640f..ab9cf4317a1 100644 --- a/database/gdb/gdb_model.go +++ b/database/gdb/gdb_model.go @@ -48,8 +48,9 @@ type Model struct { hookHandler HookHandler // Hook functions for model hook feature. unscoped bool // Disables soft deleting features when select/delete operations. safe bool // If true, it clones and returns a new model object whenever operation done; or else it changes the attribute of current model. - onDuplicate interface{} // onDuplicate is used for ON "DUPLICATE KEY UPDATE" statement. - onDuplicateEx interface{} // onDuplicateEx is used for excluding some columns ON "DUPLICATE KEY UPDATE" statement. + onDuplicate interface{} // onDuplicate is used for on Upsert clause. + onDuplicateEx interface{} // onDuplicateEx is used for excluding some columns on Upsert clause. + onConflict interface{} // onConflict is used for conflict keys on Upsert clause. tableAliasMap map[string]string // Table alias to true table name, usually used in join statements. softTimeOption SoftTimeOption // SoftTimeOption is the option to customize soft time feature for Model. } diff --git a/database/gdb/gdb_model_insert.go b/database/gdb/gdb_model_insert.go index a470d3d1c34..3a76583b459 100644 --- a/database/gdb/gdb_model_insert.go +++ b/database/gdb/gdb_model_insert.go @@ -118,8 +118,24 @@ func (m *Model) Data(data ...interface{}) *Model { return model } +// OnConflict sets the primary key or index when columns conflicts occurs. +// In MySQL, it's not needed. +func (m *Model) OnConflict(onConflict ...interface{}) *Model { + if len(onConflict) == 0 { + return m + } + model := m.getModel() + if len(onConflict) > 1 { + model.onConflict = onConflict + } else if len(onConflict) == 1 { + model.onConflict = onConflict[0] + } + return model +} + // OnDuplicate sets the operations when columns conflicts occurs. // In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement. +// In PgSQL, this is used for "ON CONFLICT (id) DO UPDATE SET" statement. // The parameter `onDuplicate` can be type of string/Raw/*Raw/map/slice. // Example: // @@ -148,6 +164,7 @@ func (m *Model) OnDuplicate(onDuplicate ...interface{}) *Model { // OnDuplicateEx sets the excluding columns for operations when columns conflict occurs. // In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement. +// In PgSQL, this is used for "ON CONFLICT (id) DO UPDATE SET" statement. // The parameter `onDuplicateEx` can be type of string/map/slice. // Example: // @@ -320,63 +337,71 @@ func (m *Model) formatDoInsertOption(insertOption InsertOption, columnNames []st InsertOption: insertOption, BatchCount: m.getBatch(), } - if insertOption == InsertOptionSave { - onDuplicateExKeys, err := m.formatOnDuplicateExKeys(m.onDuplicateEx) - if err != nil { - return option, err - } - onDuplicateExKeySet := gset.NewStrSetFrom(onDuplicateExKeys) - if m.onDuplicate != nil { - switch m.onDuplicate.(type) { - case Raw, *Raw: - option.OnDuplicateStr = gconv.String(m.onDuplicate) + if insertOption != InsertOptionSave { + return + } - default: - reflectInfo := reflection.OriginValueAndKind(m.onDuplicate) - switch reflectInfo.OriginKind { - case reflect.String: - option.OnDuplicateMap = make(map[string]interface{}) - for _, v := range gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ",") { - if onDuplicateExKeySet.Contains(v) { - continue - } - option.OnDuplicateMap[v] = v - } + onConflictKeys, err := m.formatOnConflictKeys(m.onConflict) + if err != nil { + return option, err + } + option.OnConflict = onConflictKeys + + onDuplicateExKeys, err := m.formatOnDuplicateExKeys(m.onDuplicateEx) + if err != nil { + return option, err + } + onDuplicateExKeySet := gset.NewStrSetFrom(onDuplicateExKeys) + if m.onDuplicate != nil { + switch m.onDuplicate.(type) { + case Raw, *Raw: + option.OnDuplicateStr = gconv.String(m.onDuplicate) - case reflect.Map: - option.OnDuplicateMap = make(map[string]interface{}) - for k, v := range gconv.Map(m.onDuplicate) { - if onDuplicateExKeySet.Contains(k) { - continue - } - option.OnDuplicateMap[k] = v + default: + reflectInfo := reflection.OriginValueAndKind(m.onDuplicate) + switch reflectInfo.OriginKind { + case reflect.String: + option.OnDuplicateMap = make(map[string]interface{}) + for _, v := range gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ",") { + if onDuplicateExKeySet.Contains(v) { + continue } + option.OnDuplicateMap[v] = v + } - case reflect.Slice, reflect.Array: - option.OnDuplicateMap = make(map[string]interface{}) - for _, v := range gconv.Strings(m.onDuplicate) { - if onDuplicateExKeySet.Contains(v) { - continue - } - option.OnDuplicateMap[v] = v + case reflect.Map: + option.OnDuplicateMap = make(map[string]interface{}) + for k, v := range gconv.Map(m.onDuplicate) { + if onDuplicateExKeySet.Contains(k) { + continue } + option.OnDuplicateMap[k] = v + } - default: - return option, gerror.NewCodef( - gcode.CodeInvalidParameter, - `unsupported OnDuplicate parameter type "%s"`, - reflect.TypeOf(m.onDuplicate), - ) + case reflect.Slice, reflect.Array: + option.OnDuplicateMap = make(map[string]interface{}) + for _, v := range gconv.Strings(m.onDuplicate) { + if onDuplicateExKeySet.Contains(v) { + continue + } + option.OnDuplicateMap[v] = v } + + default: + return option, gerror.NewCodef( + gcode.CodeInvalidParameter, + `unsupported OnDuplicate parameter type "%s"`, + reflect.TypeOf(m.onDuplicate), + ) } - } else if onDuplicateExKeySet.Size() > 0 { - option.OnDuplicateMap = make(map[string]interface{}) - for _, v := range columnNames { - if onDuplicateExKeySet.Contains(v) { - continue - } - option.OnDuplicateMap[v] = v + } + } else if onDuplicateExKeySet.Size() > 0 { + option.OnDuplicateMap = make(map[string]interface{}) + for _, v := range columnNames { + if onDuplicateExKeySet.Contains(v) { + continue } + option.OnDuplicateMap[v] = v } } return @@ -407,6 +432,28 @@ func (m *Model) formatOnDuplicateExKeys(onDuplicateEx interface{}) ([]string, er } } +func (m *Model) formatOnConflictKeys(onConflict interface{}) ([]string, error) { + if onConflict == nil { + return nil, nil + } + + reflectInfo := reflection.OriginValueAndKind(onConflict) + switch reflectInfo.OriginKind { + case reflect.String: + return gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ","), nil + + case reflect.Slice, reflect.Array: + return gconv.Strings(onConflict), nil + + default: + return nil, gerror.NewCodef( + gcode.CodeInvalidParameter, + `unsupported onConflict parameter type "%s"`, + reflect.TypeOf(onConflict), + ) + } +} + func (m *Model) getBatch() int { return m.batch } From 8bae0430890a432a3c2bbd82554a2bac15009838 Mon Sep 17 00:00:00 2001 From: oldme Date: Sun, 25 Feb 2024 22:40:54 +0800 Subject: [PATCH 2/7] unit test --- contrib/drivers/pgsql/pgsql_do_insert.go | 6 - .../drivers/pgsql/pgsql_z_unit_init_test.go | 12 +- .../drivers/pgsql/pgsql_z_unit_model_test.go | 266 +++++++++++++++++- 3 files changed, 269 insertions(+), 15 deletions(-) diff --git a/contrib/drivers/pgsql/pgsql_do_insert.go b/contrib/drivers/pgsql/pgsql_do_insert.go index e17b4dcceba..a0f245079ae 100644 --- a/contrib/drivers/pgsql/pgsql_do_insert.go +++ b/contrib/drivers/pgsql/pgsql_do_insert.go @@ -21,12 +21,6 @@ import ( // DoInsert inserts or updates data forF given table. func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption) (result sql.Result, err error) { switch option.InsertOption { - case gdb.InsertOptionSave: - //return nil, gerror.NewCode( - // gcode.CodeNotSupported, - // `Save operation is not supported by pgsql driver`, - //) - case gdb.InsertOptionReplace: return nil, gerror.NewCode( gcode.CodeNotSupported, diff --git a/contrib/drivers/pgsql/pgsql_z_unit_init_test.go b/contrib/drivers/pgsql/pgsql_z_unit_init_test.go index 8e67af502f9..c2033e30125 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_init_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_init_test.go @@ -78,12 +78,12 @@ func createTableWithDb(db gdb.DB, table ...string) (name string) { if _, err := db.Exec(ctx, fmt.Sprintf(` CREATE TABLE %s ( - id bigserial NOT NULL, - passport varchar(45) NOT NULL, - password varchar(32) NOT NULL, - nickname varchar(45) NOT NULL, - create_time timestamp NOT NULL, - PRIMARY KEY (id) + id bigserial NOT NULL, + passport varchar(45) NOT NULL, + password varchar(32) NOT NULL, + nickname varchar(45) NOT NULL, + create_time timestamp NOT NULL, + PRIMARY KEY (id) ) ;`, name, )); err != nil { gtest.Fatal(err) diff --git a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go index 26ce0797a8d..69c4447d9c3 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go @@ -7,8 +7,10 @@ package pgsql_test import ( + "fmt" "testing" + "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/test/gtest" @@ -258,14 +260,16 @@ func Test_Model_Save(t *testing.T) { table := createTable() defer dropTable(table) gtest.C(t, func(t *gtest.T) { - _, err := db.Model(table).Data(g.Map{ + result, err := db.Model(table).Data(g.Map{ "id": 1, "passport": "t111", "password": "25d55ad283aa400af464c76d713c07ad", "nickname": "T111", "create_time": "2018-10-24 10:00:00", - }).Save() - t.Assert(err, "Save operation is not supported by pgsql driver") + }).OnConflict("id").Save() + t.AssertNil(err) + n, _ := result.RowsAffected() + t.Assert(n, 1) }) } @@ -284,3 +288,259 @@ func Test_Model_Replace(t *testing.T) { t.Assert(err, "Replace operation is not supported by pgsql driver") }) } + +func Test_Model_OnConflict(t *testing.T) { + var ( + table = fmt.Sprintf(`%s_%d`, TablePrefix+"test", gtime.TimestampNano()) + uniqueName = fmt.Sprintf(`%s_%d`, TablePrefix+"test_unique", gtime.TimestampNano()) + ) + if _, err := db.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id bigserial NOT NULL, + passport varchar(45) NOT NULL, + password varchar(32) NOT NULL, + nickname varchar(45) NOT NULL, + create_time timestamp NOT NULL, + PRIMARY KEY (id), + CONSTRAINT %s UNIQUE ("passport", "password") + ) ;`, table, uniqueName, + )); err != nil { + gtest.Fatal(err) + } + defer dropTable(table) + + // string type 1. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("passport,password").Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "n1") + }) + + // string type 2. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("passport", "password").Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "n1") + }) + + // slice. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict(g.Slice{"passport", "password"}).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "n1") + }) +} + +func Test_Model_OnDuplicate(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + // string type 1. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate("passport,password").Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) + + // string type 2. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate("passport", "password").Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) + + // slice. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate(g.Slice{"passport", "password"}).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) + + // map. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{ + "passport": "nickname", + "password": "nickname", + }).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["nickname"]) + t.Assert(one["password"], data["nickname"]) + t.Assert(one["nickname"], "name_1") + }) + + // map+raw. + gtest.C(t, func(t *gtest.T) { + data := g.MapStrStr{ + "id": "1", + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{ + "passport": gdb.Raw("CONCAT(EXCLUDED.passport, '1')"), + "password": gdb.Raw("CONCAT(EXCLUDED.password, '2')"), + }).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]+"1") + t.Assert(one["password"], data["password"]+"2") + t.Assert(one["nickname"], "name_1") + }) +} + +func Test_Model_OnDuplicateEx(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + // string type 1. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicateEx("nickname,create_time").Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) + + // string type 2. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicateEx("nickname", "create_time").Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) + + // slice. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicateEx(g.Slice{"nickname", "create_time"}).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) + + // map. + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicateEx(g.Map{ + "nickname": "nickname", + "create_time": "nickname", + }).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["passport"], data["passport"]) + t.Assert(one["password"], data["password"]) + t.Assert(one["nickname"], "name_1") + }) +} From 8e4b5d89bbe9fcc53682495a81b6a8e70ec319b0 Mon Sep 17 00:00:00 2001 From: oldme Date: Sun, 25 Feb 2024 23:12:47 +0800 Subject: [PATCH 3/7] up --- contrib/drivers/pgsql/pgsql_z_unit_model_test.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go index 69c4447d9c3..5e59652dd66 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go @@ -8,12 +8,11 @@ package pgsql_test import ( "fmt" - "testing" - "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/test/gtest" + "testing" ) func Test_Model_Insert(t *testing.T) { @@ -320,7 +319,7 @@ func Test_Model_OnConflict(t *testing.T) { } _, err := db.Model(table).OnConflict("passport,password").Data(data).Save() t.AssertNil(err) - one, err := db.Model(table).WherePri(1).One() + one, err := db.Model(table).Where("id", 1).One() t.AssertNil(err) t.Assert(one["passport"], data["passport"]) t.Assert(one["password"], data["password"]) @@ -338,7 +337,7 @@ func Test_Model_OnConflict(t *testing.T) { } _, err := db.Model(table).OnConflict("passport", "password").Data(data).Save() t.AssertNil(err) - one, err := db.Model(table).WherePri(1).One() + one, err := db.Model(table).Where("id", 1).One() t.AssertNil(err) t.Assert(one["passport"], data["passport"]) t.Assert(one["password"], data["password"]) @@ -356,7 +355,7 @@ func Test_Model_OnConflict(t *testing.T) { } _, err := db.Model(table).OnConflict(g.Slice{"passport", "password"}).Data(data).Save() t.AssertNil(err) - one, err := db.Model(table).WherePri(1).One() + one, err := db.Model(table).Where("id", 1).One() t.AssertNil(err) t.Assert(one["passport"], data["passport"]) t.Assert(one["password"], data["password"]) From 761cba2a7d0cb40eace574d7b2cd9c4d240d04fb Mon Sep 17 00:00:00 2001 From: oldme Date: Mon, 26 Feb 2024 10:54:37 +0800 Subject: [PATCH 4/7] up --- contrib/drivers/pgsql/pgsql_z_unit_model_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go index 5e59652dd66..81411ec49cb 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go @@ -8,11 +8,12 @@ package pgsql_test import ( "fmt" + "testing" + "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/test/gtest" - "testing" ) func Test_Model_Insert(t *testing.T) { From 00601bd95a022943a6392ba4b56a0e6877ef16bd Mon Sep 17 00:00:00 2001 From: oldme Date: Wed, 28 Feb 2024 17:48:33 +0800 Subject: [PATCH 5/7] up --- database/gdb/gdb_model_insert.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/database/gdb/gdb_model_insert.go b/database/gdb/gdb_model_insert.go index 3a76583b459..3f99f23fb19 100644 --- a/database/gdb/gdb_model_insert.go +++ b/database/gdb/gdb_model_insert.go @@ -119,7 +119,7 @@ func (m *Model) Data(data ...interface{}) *Model { } // OnConflict sets the primary key or index when columns conflicts occurs. -// In MySQL, it's not needed. +// It's not necessary for MySQL driver. func (m *Model) OnConflict(onConflict ...interface{}) *Model { if len(onConflict) == 0 { return m From 9b42895253e1f0fe84eb5c550305617c5a1661f1 Mon Sep 17 00:00:00 2001 From: oldme Date: Fri, 1 Mar 2024 14:01:47 +0800 Subject: [PATCH 6/7] up --- contrib/drivers/pgsql/pgsql_do_insert.go | 4 ++-- database/gdb/gdb.go | 3 +-- database/gdb/gdb_core.go | 2 +- database/gdb/gdb_core_underlying.go | 6 ++++-- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/contrib/drivers/pgsql/pgsql_do_insert.go b/contrib/drivers/pgsql/pgsql_do_insert.go index a0f245079ae..be82ad0c410 100644 --- a/contrib/drivers/pgsql/pgsql_do_insert.go +++ b/contrib/drivers/pgsql/pgsql_do_insert.go @@ -48,9 +48,9 @@ func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list return d.Core.DoInsert(ctx, link, table, list, option) } -// DoFormatUpsert returns SQL clause of type upsert for PgSQL. +// FormatUpsert returns SQL clause of type upsert for PgSQL. // For example: ON CONFLICT (id) DO UPDATE SET ... -func (d *Driver) DoFormatUpsert(columns []string, option gdb.DoInsertOption) (string, error) { +func (d *Driver) FormatUpsert(columns []string, option gdb.DoInsertOption) (string, error) { if len(option.OnConflict) == 0 { return "", gerror.New("Please specify conflict columns") } diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 9ef6327cf9c..fd6533a26f5 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -110,8 +110,6 @@ type DB interface { DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, error) // See Core.DoPrepare. - DoFormatUpsert(columns []string, option DoInsertOption) (string, error) // See Core.DoFormatUpsert - // =========================================================================== // Query APIs for convenience purpose. // =========================================================================== @@ -177,6 +175,7 @@ type DB interface { ConvertValueForField(ctx context.Context, fieldType string, fieldValue interface{}) (interface{}, error) // See Core.ConvertValueForField ConvertValueForLocal(ctx context.Context, fieldType string, fieldValue interface{}) (interface{}, error) // See Core.ConvertValueForLocal CheckLocalTypeForField(ctx context.Context, fieldType string, fieldValue interface{}) (LocalType, error) // See Core.CheckLocalTypeForField + FormatUpsert(columns []string, option DoInsertOption) (string, error) // See Core.DoFormatUpsert } // TX defines the interfaces for ORM transaction operations. diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 8d06cfd58a0..4295708328f 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -489,7 +489,7 @@ func (c *Core) DoInsert(ctx context.Context, link Link, table string, list List, ) // Upsert clause only takes effect on Save operation. if option.InsertOption == InsertOptionSave { - onDuplicateStr, err = c.db.DoFormatUpsert(keys, option) + onDuplicateStr, err = c.db.FormatUpsert(keys, option) if err != nil { return nil, err } diff --git a/database/gdb/gdb_core_underlying.go b/database/gdb/gdb_core_underlying.go index 0338b96dc4c..05fe7cb8fcb 100644 --- a/database/gdb/gdb_core_underlying.go +++ b/database/gdb/gdb_core_underlying.go @@ -353,8 +353,10 @@ func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (stmt *Stmt return out.Stmt, err } -// DoFormatUpsert returns SQL clause for type upsert. -func (c *Core) DoFormatUpsert(columns []string, option DoInsertOption) (string, error) { +// FormatUpsert formats and returns SQL clause part for upsert statement. +// In default implements, this function performs upsert statement for MySQL like: +// `INSERT INTO ... ON DUPLICATE KEY UPDATE x=VALUES(z),m=VALUES(y)...` +func (c *Core) FormatUpsert(columns []string, option DoInsertOption) (string, error) { var onDuplicateStr string if option.OnDuplicateStr != "" { onDuplicateStr = option.OnDuplicateStr From 8095157ea39f8924cd4f4e20cb3da799e731011c Mon Sep 17 00:00:00 2001 From: oldme Date: Fri, 1 Mar 2024 14:08:53 +0800 Subject: [PATCH 7/7] up --- contrib/drivers/pgsql/pgsql_do_insert.go | 2 +- database/gdb/gdb.go | 2 +- database/gdb/gdb_core.go | 2 +- database/gdb/gdb_core_underlying.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/contrib/drivers/pgsql/pgsql_do_insert.go b/contrib/drivers/pgsql/pgsql_do_insert.go index be82ad0c410..84995ff3710 100644 --- a/contrib/drivers/pgsql/pgsql_do_insert.go +++ b/contrib/drivers/pgsql/pgsql_do_insert.go @@ -50,7 +50,7 @@ func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list // FormatUpsert returns SQL clause of type upsert for PgSQL. // For example: ON CONFLICT (id) DO UPDATE SET ... -func (d *Driver) FormatUpsert(columns []string, option gdb.DoInsertOption) (string, error) { +func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInsertOption) (string, error) { if len(option.OnConflict) == 0 { return "", gerror.New("Please specify conflict columns") } diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index fd6533a26f5..d63c12ffe51 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -175,7 +175,7 @@ type DB interface { ConvertValueForField(ctx context.Context, fieldType string, fieldValue interface{}) (interface{}, error) // See Core.ConvertValueForField ConvertValueForLocal(ctx context.Context, fieldType string, fieldValue interface{}) (interface{}, error) // See Core.ConvertValueForLocal CheckLocalTypeForField(ctx context.Context, fieldType string, fieldValue interface{}) (LocalType, error) // See Core.CheckLocalTypeForField - FormatUpsert(columns []string, option DoInsertOption) (string, error) // See Core.DoFormatUpsert + FormatUpsert(columns []string, list List, option DoInsertOption) (string, error) // See Core.DoFormatUpsert } // TX defines the interfaces for ORM transaction operations. diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 4295708328f..c48f45dd556 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -489,7 +489,7 @@ func (c *Core) DoInsert(ctx context.Context, link Link, table string, list List, ) // Upsert clause only takes effect on Save operation. if option.InsertOption == InsertOptionSave { - onDuplicateStr, err = c.db.FormatUpsert(keys, option) + onDuplicateStr, err = c.db.FormatUpsert(keys, list, option) if err != nil { return nil, err } diff --git a/database/gdb/gdb_core_underlying.go b/database/gdb/gdb_core_underlying.go index 05fe7cb8fcb..d3b5b5b88aa 100644 --- a/database/gdb/gdb_core_underlying.go +++ b/database/gdb/gdb_core_underlying.go @@ -356,7 +356,7 @@ func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (stmt *Stmt // FormatUpsert formats and returns SQL clause part for upsert statement. // In default implements, this function performs upsert statement for MySQL like: // `INSERT INTO ... ON DUPLICATE KEY UPDATE x=VALUES(z),m=VALUES(y)...` -func (c *Core) FormatUpsert(columns []string, option DoInsertOption) (string, error) { +func (c *Core) FormatUpsert(columns []string, list List, option DoInsertOption) (string, error) { var onDuplicateStr string if option.OnDuplicateStr != "" { onDuplicateStr = option.OnDuplicateStr