Skip to content

Commit

Permalink
Change to check if struct is Anonymous when recursing through an embe…
Browse files Browse the repository at this point in the history
…dded struct.
  • Loading branch information
andymoon committed Aug 28, 2015
1 parent 96ece5b commit aff28bb
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 26 deletions.
15 changes: 1 addition & 14 deletions dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,17 +288,4 @@ func (me *Dataset) expressionSql(buf *SqlBuilder, expression Expression) error {
return me.adapter.ExpressionOrMapSql(buf, e)
}
return NewGoquError("Unsupported expression type %T", expression)
}

func (me *Dataset) isSpecialType(value reflect.Value) bool {
i := value.Interface()
if _, ok := i.(time.Time); ok {
return true
} else if _, ok := i.(*time.Time); ok {
return true
} else if _, ok := i.(driver.Valuer); ok {
return true
}

return false
}
}
6 changes: 2 additions & 4 deletions dataset_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,8 @@ func (me *Dataset) getFieldsValues(value reflect.Value) (rowCols []interface{},
if value.IsValid() {
for i := 0; i < value.NumField(); i++ {
v := value.Field(i)

kind := v.Kind()
if me.isSpecialType(v) || ((kind != reflect.Struct) && (kind != reflect.Ptr)) {
t := value.Type().Field(i)
t := value.Type().Field(i)
if !t.Anonymous {
if me.canInsertField(t) {
rowCols = append(rowCols, t.Tag.Get("db"))
rowVals = append(rowVals, v.Interface())
Expand Down
23 changes: 23 additions & 0 deletions dataset_insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,29 @@ func (me *datasetTest) TestInsertSqlWithValuer() {
assert.Equal(t, sqlString, `INSERT INTO "items" ("address", "name", "valuer") VALUES ('111 Test Addr', 'Test1', 10), ('211 Test Addr', 'Test2', 10), ('311 Test Addr', 'Test3', 10), ('411 Test Addr', 'Test4', 10)`)
}

func (me *datasetTest) TestInsertSqlWithValuerNull() {
t := me.T()
ds1 := From("items")

type item struct {
Address string `db:"address"`
Name string `db:"name"`
Valuer sql.NullInt64 `db:"valuer"`
}
sqlString, _, err := ds1.ToInsertSql(item{Name: "Test", Address: "111 Test Addr"})
assert.NoError(t, err)
assert.Equal(t, sqlString, `INSERT INTO "items" ("address", "name", "valuer") VALUES ('111 Test Addr', 'Test', NULL)`)

sqlString, _, err = ds1.ToInsertSql(
item{Address: "111 Test Addr", Name: "Test1"},
item{Address: "211 Test Addr", Name: "Test2"},
item{Address: "311 Test Addr", Name: "Test3"},
item{Address: "411 Test Addr", Name: "Test4"},
)
assert.NoError(t, err)
assert.Equal(t, sqlString, `INSERT INTO "items" ("address", "name", "valuer") VALUES ('111 Test Addr', 'Test1', NULL), ('211 Test Addr', 'Test2', NULL), ('311 Test Addr', 'Test3', NULL), ('411 Test Addr', 'Test4', NULL)`)
}

func (me *datasetTest) TestInsertSqlWithMaps() {
t := me.T()
ds1 := From("items")
Expand Down
5 changes: 2 additions & 3 deletions dataset_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ func (me *Dataset) ToUpdateSql(update interface{}) (string, []interface{}, error
func (me *Dataset) getUpdateExpression(value reflect.Value) (updates []UpdateExpression) {
for i := 0; i < value.NumField(); i++ {
v := value.Field(i)
kind := v.Kind()
if me.isSpecialType(v) || ((kind != reflect.Struct) && (kind != reflect.Ptr)) {
t := value.Type().Field(i)
t := value.Type().Field(i)
if !t.Anonymous {
if me.canUpdateField(t) {
updates = append(updates, I(t.Tag.Get("db")).Set(v.Interface()))
}
Expand Down
49 changes: 44 additions & 5 deletions dataset_update_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package goqu

import (
"database/sql"
"database/sql/driver"
"fmt"
"time"
Expand Down Expand Up @@ -97,7 +98,7 @@ func (j valuerType) Value() (driver.Value, error) {
return []byte(fmt.Sprintf("%s World", string(j))), nil
}

func (me *datasetTest) TestUpdateSqlWithValuer() {
func (me *datasetTest) TestUpdateSqlWithCustomValuer() {
t := me.T()
ds1 := From("items")
type item struct {
Expand All @@ -109,6 +110,31 @@ func (me *datasetTest) TestUpdateSqlWithValuer() {
assert.Equal(t, sql, `UPDATE "items" SET "name"='Test',"data"='Hello World' RETURNING "items".*`)
}

func (me *datasetTest) TestUpdateSqlWithValuer() {
t := me.T()
ds1 := From("items")
type item struct {
Name string `db:"name"`
Data sql.NullString `db:"data"`
}

sql, _, err := ds1.Returning(I("items").All()).ToUpdateSql(item{Name: "Test", Data: sql.NullString{String: "Hello World", Valid: true}})
assert.NoError(t, err)
assert.Equal(t, sql, `UPDATE "items" SET "name"='Test',"data"='Hello World' RETURNING "items".*`)
}

func (me *datasetTest) TestUpdateSqlWithValuerNull() {
t := me.T()
ds1 := From("items")
type item struct {
Name string `db:"name"`
Data sql.NullString `db:"data"`
}
sql, _, err := ds1.Returning(I("items").All()).ToUpdateSql(item{Name: "Test"})
assert.NoError(t, err)
assert.Equal(t, sql, `UPDATE "items" SET "name"='Test',"data"=NULL RETURNING "items".*`)
}

func (me *datasetTest) TestUpdateSqlWithUnsupportedType() {
t := me.T()
ds1 := From("items")
Expand Down Expand Up @@ -196,7 +222,7 @@ func (me *datasetTest) TestPreparedUpdateSqlWithByteSlice() {
assert.Equal(t, sql, `UPDATE "items" SET "name"=?,"data"=? RETURNING "items".*`)
}

func (me *datasetTest) TestPreparedUpdateSqlWithValuer() {
func (me *datasetTest) TestPreparedUpdateSqlWithCustomValuer() {
t := me.T()
ds1 := From("items")
type item struct {
Expand All @@ -209,6 +235,19 @@ func (me *datasetTest) TestPreparedUpdateSqlWithValuer() {
assert.Equal(t, sql, `UPDATE "items" SET "name"=?,"data"=? RETURNING "items".*`)
}

func (me *datasetTest) TestPreparedUpdateSqlWithValuer() {
t := me.T()
ds1 := From("items")
type item struct {
Name string `db:"name"`
Data sql.NullString `db:"data"`
}
sql, args, err := ds1.Returning(I("items").All()).Prepared(true).ToUpdateSql(item{Name: "Test", Data: sql.NullString{String: "Hello World", Valid: true}})
assert.NoError(t, err)
assert.Equal(t, args, []interface{}{"Test", "Hello World"})
assert.Equal(t, sql, `UPDATE "items" SET "name"=?,"data"=? RETURNING "items".*`)
}

func (me *datasetTest) TestPreparedUpdateSqlWithSkipupdateTag() {
t := me.T()
ds1 := From("items")
Expand All @@ -232,9 +271,9 @@ func (me *datasetTest) TestPreparedUpdateSqlWithEmbeddedStruct() {
}
type item struct {
phone
Address string `db:"address" goqu:"skipupdate"`
Name string `db:"name"`
Created time.Time `db:"created"`
Address string `db:"address" goqu:"skipupdate"`
Name string `db:"name"`
Created time.Time `db:"created"`
NilPointer interface{} `db:"nil_pointer"`
}
created, _ := time.Parse("2006-01-02", "2015-01-01")
Expand Down

0 comments on commit aff28bb

Please sign in to comment.