From aff28bb27e3da547ca665e2a6a33d354aa319ffd Mon Sep 17 00:00:00 2001 From: Andrew Moon Date: Fri, 28 Aug 2015 11:50:06 -0500 Subject: [PATCH] Change to check if struct is Anonymous when recursing through an embedded struct. --- dataset.go | 15 +------------ dataset_insert.go | 6 ++---- dataset_insert_test.go | 23 ++++++++++++++++++++ dataset_update.go | 5 ++--- dataset_update_test.go | 49 +++++++++++++++++++++++++++++++++++++----- 5 files changed, 72 insertions(+), 26 deletions(-) diff --git a/dataset.go b/dataset.go index f49e65b7..e0c90b8e 100644 --- a/dataset.go +++ b/dataset.go @@ -288,17 +288,4 @@ func (me *Dataset) expressionSql(buf *SqlBuilder, expression Expression) error { return me.adapter.ExpressionOrMapSql(buf, e) } return NewGoquError("Unsupported expression type %T", expression) -} - -func (me *Dataset) isSpecialType(value reflect.Value) bool { - i := value.Interface() - if _, ok := i.(time.Time); ok { - return true - } else if _, ok := i.(*time.Time); ok { - return true - } else if _, ok := i.(driver.Valuer); ok { - return true - } - - return false -} +} \ No newline at end of file diff --git a/dataset_insert.go b/dataset_insert.go index c8beb812..5a7dccb0 100644 --- a/dataset_insert.go +++ b/dataset_insert.go @@ -110,10 +110,8 @@ func (me *Dataset) getFieldsValues(value reflect.Value) (rowCols []interface{}, if value.IsValid() { for i := 0; i < value.NumField(); i++ { v := value.Field(i) - - kind := v.Kind() - if me.isSpecialType(v) || ((kind != reflect.Struct) && (kind != reflect.Ptr)) { - t := value.Type().Field(i) + t := value.Type().Field(i) + if !t.Anonymous { if me.canInsertField(t) { rowCols = append(rowCols, t.Tag.Get("db")) rowVals = append(rowVals, v.Interface()) diff --git a/dataset_insert_test.go b/dataset_insert_test.go index 11064730..f9f80d85 100644 --- a/dataset_insert_test.go +++ b/dataset_insert_test.go @@ -134,6 +134,29 @@ func (me *datasetTest) TestInsertSqlWithValuer() { assert.Equal(t, sqlString, `INSERT INTO "items" ("address", "name", "valuer") VALUES ('111 Test Addr', 'Test1', 10), ('211 Test Addr', 'Test2', 10), ('311 Test Addr', 'Test3', 10), ('411 Test Addr', 'Test4', 10)`) } +func (me *datasetTest) TestInsertSqlWithValuerNull() { + t := me.T() + ds1 := From("items") + + type item struct { + Address string `db:"address"` + Name string `db:"name"` + Valuer sql.NullInt64 `db:"valuer"` + } + sqlString, _, err := ds1.ToInsertSql(item{Name: "Test", Address: "111 Test Addr"}) + assert.NoError(t, err) + assert.Equal(t, sqlString, `INSERT INTO "items" ("address", "name", "valuer") VALUES ('111 Test Addr', 'Test', NULL)`) + + sqlString, _, err = ds1.ToInsertSql( + item{Address: "111 Test Addr", Name: "Test1"}, + item{Address: "211 Test Addr", Name: "Test2"}, + item{Address: "311 Test Addr", Name: "Test3"}, + item{Address: "411 Test Addr", Name: "Test4"}, + ) + assert.NoError(t, err) + assert.Equal(t, sqlString, `INSERT INTO "items" ("address", "name", "valuer") VALUES ('111 Test Addr', 'Test1', NULL), ('211 Test Addr', 'Test2', NULL), ('311 Test Addr', 'Test3', NULL), ('411 Test Addr', 'Test4', NULL)`) +} + func (me *datasetTest) TestInsertSqlWithMaps() { t := me.T() ds1 := From("items") diff --git a/dataset_update.go b/dataset_update.go index 626fe3d8..e2e113ce 100644 --- a/dataset_update.go +++ b/dataset_update.go @@ -79,9 +79,8 @@ func (me *Dataset) ToUpdateSql(update interface{}) (string, []interface{}, error func (me *Dataset) getUpdateExpression(value reflect.Value) (updates []UpdateExpression) { for i := 0; i < value.NumField(); i++ { v := value.Field(i) - kind := v.Kind() - if me.isSpecialType(v) || ((kind != reflect.Struct) && (kind != reflect.Ptr)) { - t := value.Type().Field(i) + t := value.Type().Field(i) + if !t.Anonymous { if me.canUpdateField(t) { updates = append(updates, I(t.Tag.Get("db")).Set(v.Interface())) } diff --git a/dataset_update_test.go b/dataset_update_test.go index c26fcd3c..95d12cc1 100644 --- a/dataset_update_test.go +++ b/dataset_update_test.go @@ -1,6 +1,7 @@ package goqu import ( + "database/sql" "database/sql/driver" "fmt" "time" @@ -97,7 +98,7 @@ func (j valuerType) Value() (driver.Value, error) { return []byte(fmt.Sprintf("%s World", string(j))), nil } -func (me *datasetTest) TestUpdateSqlWithValuer() { +func (me *datasetTest) TestUpdateSqlWithCustomValuer() { t := me.T() ds1 := From("items") type item struct { @@ -109,6 +110,31 @@ func (me *datasetTest) TestUpdateSqlWithValuer() { assert.Equal(t, sql, `UPDATE "items" SET "name"='Test',"data"='Hello World' RETURNING "items".*`) } +func (me *datasetTest) TestUpdateSqlWithValuer() { + t := me.T() + ds1 := From("items") + type item struct { + Name string `db:"name"` + Data sql.NullString `db:"data"` + } + + sql, _, err := ds1.Returning(I("items").All()).ToUpdateSql(item{Name: "Test", Data: sql.NullString{String: "Hello World", Valid: true}}) + assert.NoError(t, err) + assert.Equal(t, sql, `UPDATE "items" SET "name"='Test',"data"='Hello World' RETURNING "items".*`) +} + +func (me *datasetTest) TestUpdateSqlWithValuerNull() { + t := me.T() + ds1 := From("items") + type item struct { + Name string `db:"name"` + Data sql.NullString `db:"data"` + } + sql, _, err := ds1.Returning(I("items").All()).ToUpdateSql(item{Name: "Test"}) + assert.NoError(t, err) + assert.Equal(t, sql, `UPDATE "items" SET "name"='Test',"data"=NULL RETURNING "items".*`) +} + func (me *datasetTest) TestUpdateSqlWithUnsupportedType() { t := me.T() ds1 := From("items") @@ -196,7 +222,7 @@ func (me *datasetTest) TestPreparedUpdateSqlWithByteSlice() { assert.Equal(t, sql, `UPDATE "items" SET "name"=?,"data"=? RETURNING "items".*`) } -func (me *datasetTest) TestPreparedUpdateSqlWithValuer() { +func (me *datasetTest) TestPreparedUpdateSqlWithCustomValuer() { t := me.T() ds1 := From("items") type item struct { @@ -209,6 +235,19 @@ func (me *datasetTest) TestPreparedUpdateSqlWithValuer() { assert.Equal(t, sql, `UPDATE "items" SET "name"=?,"data"=? RETURNING "items".*`) } +func (me *datasetTest) TestPreparedUpdateSqlWithValuer() { + t := me.T() + ds1 := From("items") + type item struct { + Name string `db:"name"` + Data sql.NullString `db:"data"` + } + sql, args, err := ds1.Returning(I("items").All()).Prepared(true).ToUpdateSql(item{Name: "Test", Data: sql.NullString{String: "Hello World", Valid: true}}) + assert.NoError(t, err) + assert.Equal(t, args, []interface{}{"Test", "Hello World"}) + assert.Equal(t, sql, `UPDATE "items" SET "name"=?,"data"=? RETURNING "items".*`) +} + func (me *datasetTest) TestPreparedUpdateSqlWithSkipupdateTag() { t := me.T() ds1 := From("items") @@ -232,9 +271,9 @@ func (me *datasetTest) TestPreparedUpdateSqlWithEmbeddedStruct() { } type item struct { phone - Address string `db:"address" goqu:"skipupdate"` - Name string `db:"name"` - Created time.Time `db:"created"` + Address string `db:"address" goqu:"skipupdate"` + Name string `db:"name"` + Created time.Time `db:"created"` NilPointer interface{} `db:"nil_pointer"` } created, _ := time.Parse("2006-01-02", "2015-01-01")