diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 4a26ca4..db6ae94 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -3,6 +3,7 @@ - [Selector Definition](https://github.com/gotomicro/eql/pull/2) - [Deleter Definition](https://github.com/gotomicro/eql/pull/4) - [Updater Definition](https://github.com/gotomicro/eql/pull/8) +- [Rft: remove NilAsNullFunc](https://github.com/gotomicro/eql/pull/48) - [Metadata API](https://github.com/gotomicro/eql/pull/16) - [tagMetaRegistry: default implementation of MetaRegistry](https://github.com/gotomicro/eql/pull/25) - [Rft: remove defaultRegistry](https://github.com/gotomicro/eql/pull/46) diff --git a/README.md b/README.md index 30c6800..09c3cf7 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ We are not English native speaker, so we use Chinese to write the design documen Here is a good one: https://www.deepl.com/en/translator [设计思路](./docs/design.md) + [B站视频](https://space.bilibili.com/324486985) ## Contribution @@ -21,3 +22,10 @@ You must follow these rules: - You must add license header to every new files [style guide](https://github.com/uber-go/guide/blob/master/style.md) + +### Setup Develop Environment + +#### install golangci-lint +Please refer [Install golangci-lint](https://golangci-lint.run/usage/install/) +#### setup pre-push github hook +Please move the `.github/pre-push` to your `.git` directory \ No newline at end of file diff --git a/db.go b/db.go index 06512e7..32223f6 100644 --- a/db.go +++ b/db.go @@ -15,8 +15,6 @@ package eql import ( - "reflect" - "github.com/valyala/bytebufferpool" ) @@ -27,7 +25,6 @@ type DBOption func(db *DB) type DB struct { metaRegistry MetaRegistry dialect Dialect - nullAssertFunc NullAssertFunc } // New returns DB. It's the entry of EQL @@ -35,7 +32,6 @@ func New(opts ...DBOption) *DB { db := &DB{ metaRegistry: &tagMetaRegistry{}, dialect: mysql, - nullAssertFunc: NilAsNullFunc, } for _, o := range opts { o(db) @@ -62,7 +58,6 @@ func (db *DB) Update(table interface{}) *Updater { return &Updater{ builder: db.builder(), table: table, - nullAssertFunc: db.nullAssertFunc, } } @@ -80,62 +75,3 @@ func (db *DB) builder() builder { buffer: bytebufferpool.Get(), } } - -func WithNullAssertFunc(nullable NullAssertFunc) DBOption { - return func(db *DB) { - db.nullAssertFunc = nullable - } -} - -// NullAssertFunc determined if the value is NULL. -// As we know, there is a gap between NULL and nil -// There are two kinds of nullAssertFunc -// 1. nil = NULL, see NilAsNullFunc -// 2. zero value = NULL, see ZeroAsNullFunc -type NullAssertFunc func(val interface{}) bool - -// NilAsNullFunc use the strict definition of "nullAssertFunc" -// if and only if the val is nil, indicates value is null -func NilAsNullFunc(val interface{}) bool { - return val == nil -} - -// ZeroAsNullFunc means "zero value = null" -func ZeroAsNullFunc(val interface{}) bool { - if val == nil { - return true - } - switch v := val.(type) { - case int: - return v == 0 - case int8: - return v == 0 - case int16: - return v == 0 - case int32: - return v == 0 - case int64: - return v == 0 - case uint: - return v == 0 - case uint8: - return v == 0 - case uint16: - return v == 0 - case uint32: - return v == 0 - case uint64: - return v == 0 - case float32: - return v == 0 - case float64: - return v == 0 - case bool: - return v - case string: - return v == "" - default: - valRef := reflect.ValueOf(val) - return valRef.IsZero() - } -} diff --git a/db_test.go b/db_test.go deleted file mode 100644 index 604319b..0000000 --- a/db_test.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2021 gotomicro -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package eql - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestStrictNullableFunc(t *testing.T) { - str := "Hello" - assert.False(t, NilAsNullFunc(str)) - str = "" - assert.False(t, NilAsNullFunc(str)) - var err error - assert.True(t, NilAsNullFunc(err)) - - var i int - assert.False(t, NilAsNullFunc(i)) -} - -func TestZeroAsNullableFunc(t *testing.T) { - assert.True(t, ZeroAsNullFunc(0)) - assert.True(t, ZeroAsNullFunc(int8(0))) - assert.True(t, ZeroAsNullFunc(int16(0))) - assert.True(t, ZeroAsNullFunc(int32(0))) - assert.True(t, ZeroAsNullFunc(int64(0))) - assert.True(t, ZeroAsNullFunc(uint(0))) - assert.True(t, ZeroAsNullFunc(uint8(0))) - assert.True(t, ZeroAsNullFunc(uint16(0))) - assert.True(t, ZeroAsNullFunc(uint32(0))) - assert.True(t, ZeroAsNullFunc(uint64(0))) - assert.True(t, ZeroAsNullFunc(float32(0))) - assert.True(t, ZeroAsNullFunc(float64(0))) - assert.True(t, ZeroAsNullFunc("")) - var err error - assert.True(t, ZeroAsNullFunc(err)) -} - diff --git a/internal/error.go b/internal/error.go index 6dd85ed..7cea769 100644 --- a/internal/error.go +++ b/internal/error.go @@ -14,7 +14,12 @@ package internal -import "fmt" +import ( + "errors" + "fmt" +) + +var errValueNotSet = errors.New("value unset") // NewInvalidColumnError returns an error represents invalid field name // TODO(do we need errors pkg?) @@ -22,3 +27,7 @@ func NewInvalidColumnError(field string) error { return fmt.Errorf("eql: invalid column name %s, " + "it must be a valid field name of structure", field) } + +func NewValueNotSetError() error { + return errValueNotSet +} diff --git a/model_test.go b/model_test.go index de7ea86..f2e8eb5 100644 --- a/model_test.go +++ b/model_test.go @@ -48,7 +48,7 @@ func TestTagMetaRegistry(t *testing.T) { idMetaLastName := meta.fieldMap["LastName"] assert.Equal(t, "last_name", idMetaLastName.columnName) assert.Equal(t, "LastName", idMetaLastName.fieldName) - assert.Equal(t, reflect.TypeOf(string("")), idMetaLastName.typ) + assert.Equal(t, reflect.TypeOf((*string)(nil)), idMetaLastName.typ) idMetaLastAge := meta.fieldMap["Age"] assert.Equal(t, "age", idMetaLastAge.columnName) diff --git a/predicate_test.go b/predicate_test.go index 994adae..0e45f41 100644 --- a/predicate_test.go +++ b/predicate_test.go @@ -103,7 +103,7 @@ type TestModel struct { Id int64 `eql:"auto_increment,primary_key"` FirstName string Age int8 - LastName string + LastName *string } type CommonTestCase struct { @@ -112,4 +112,8 @@ type CommonTestCase struct { wantArgs []interface{} wantSql string wantErr error +} + +func stringPtr(val string) *string { + return &val } \ No newline at end of file diff --git a/update.go b/update.go index 15e49ac..08ee60b 100644 --- a/update.go +++ b/update.go @@ -15,7 +15,6 @@ package eql import ( - "errors" "fmt" "reflect" @@ -30,7 +29,8 @@ type Updater struct { tableEle reflect.Value where []Predicate assigns []Assignable - nullAssertFunc NullAssertFunc + withNil bool + withZero bool } // Build returns UPDATE query @@ -45,9 +45,9 @@ func (u *Updater) Build() (*Query, error) { u.tableEle = reflect.ValueOf(u.table).Elem() u.args = make([]interface{}, 0, len(u.meta.columns)) - u.buffer.WriteString("UPDATE ") + _, _ = u.buffer.WriteString("UPDATE ") u.quote(u.meta.tableName) - u.buffer.WriteString(" SET ") + _, _ = u.buffer.WriteString(" SET ") if len(u.assigns) == 0 { err = u.buildDefaultColumns() } else { @@ -57,7 +57,13 @@ func (u *Updater) Build() (*Query, error) { return nil, err } - // TODO WHERE + if len(u.where) > 0 { + _, _ = u.buffer.WriteString(" WHERE ") + err = u.buildPredicates(u.where) + if err != nil { + return nil, err + } + } u.end() return &Query{ @@ -74,21 +80,35 @@ func (u *Updater) buildAssigns() error { } switch a := assign.(type) { case Column: - set, err := u.buildColumn(a.name) - if err != nil { - return err + c, ok := u.meta.fieldMap[a.name] + if !ok { + return internal.NewInvalidColumnError(a.name) } - has = has || set + val, ok := u.getValue(a.name) + if !ok { + continue + } + u.quote(c.columnName) + _ = u.buffer.WriteByte('=') + u.parameter(val) + has = true case columns: - for _, c := range a.cs { + for _, name := range a.cs { + c, ok := u.meta.fieldMap[name] + if !ok { + return internal.NewInvalidColumnError(name) + } + val, ok := u.getValue(name) + if !ok { + continue + } if has { u.comma() } - set, err := u.buildColumn(c) - if err != nil { - return err - } - has = has || set + u.quote(c.columnName) + _ = u.buffer.WriteByte('=') + u.parameter(val) + has = true } case Assignment: if err := u.buildExpr(binaryExpr(a)); err != nil { @@ -100,52 +120,45 @@ func (u *Updater) buildAssigns() error { } } if !has { - return errors.New("eql: value unset") + return internal.NewValueNotSetError() } return nil } -func (u *Updater) buildColumn(field string) (bool, error) { - c, ok := u.meta.fieldMap[field] - if !ok { - return false, internal.NewInvalidColumnError(field) - } - return u.setColumn(c), nil -} - -func (u *Updater) setColumn(c *ColumnMeta) bool { - val := u.tableEle.FieldByName(c.fieldName).Interface() - isNull := u.nullAssertFunc(val) - if !isNull { - u.quote(c.columnName) - u.buffer.WriteByte('=') - u.parameter(val) - return true - } - return false -} - func (u *Updater) buildDefaultColumns() error { has := false for _, c := range u.meta.columns { - if has { - u.buffer.WriteByte(',') + val, ok := u.getValue(c.fieldName) + if !ok { + continue } - val := u.tableEle.FieldByName(c.fieldName).Interface() - isNull := u.nullAssertFunc(val) - if !isNull { - u.quote(c.columnName) - u.buffer.WriteByte('=') - u.parameter(val) - has = true + if has { + _ = u.buffer.WriteByte(',') } + u.quote(c.columnName) + _ = u.buffer.WriteByte('=') + u.parameter(val) + has = true } if !has { - return errors.New("value unset") + return internal.NewValueNotSetError() } return nil } +func (u *Updater) getValue(fieldName string) (interface{}, bool) { + val := u.tableEle.FieldByName(fieldName) + res := val.Interface() + + if !u.withNil && val.Kind() == reflect.Ptr && val.IsNil() { + return nil, false + } + if !u.withZero && val.Kind() != reflect.Ptr && val.IsZero() { + return nil, false + } + return res, true +} + // Set represents SET clause func (u *Updater) Set(assigns ...Assignable) *Updater { u.assigns = assigns @@ -157,3 +170,17 @@ func (u *Updater) Where(predicates ...Predicate) *Updater { u.where = predicates return u } + +// WithNil use nil to update database +func (u *Updater) WithNil() *Updater { + u.withNil = true + return u +} + +// WithZero specific use zero value to update databases. +// but "zero value" here is different from reflect.IsZero, it doesn't contain nil value +// for example if the int value is 0, it will be used to update database, but if the pointer is nil, it won't +func (u *Updater) WithZero() *Updater { + u.withZero = true + return u +} diff --git a/update_test.go b/update_test.go index f6bda8b..2510167 100644 --- a/update_test.go +++ b/update_test.go @@ -25,14 +25,14 @@ func TestUpdater_Set(t *testing.T) { Id: 12, FirstName: "Tom", Age: 18, - LastName: "Jerry", + LastName: stringPtr("Jerry"), } testCases := []CommonTestCase { { name: "no set", builder: New().Update(tm), wantSql: "UPDATE `test_model` SET `id`=?,`first_name`=?,`age`=?,`last_name`=?;", - wantArgs: []interface{}{int64(12), "Tom", int8(18), "Jerry"}, + wantArgs: []interface{}{int64(12), "Tom", int8(18), stringPtr("Jerry")}, }, { name: "set columns", @@ -93,6 +93,47 @@ func TestUpdater_Set(t *testing.T) { wantSql: "UPDATE `test_model` SET `first_name`=?,`age`=((`id`+(`age`*?))*?);", wantArgs: []interface{}{"Tom", 100, 110}, }, + { + name: "without zero no value", + builder: New().Update(&TestModel{Id: 13}).Set(C("FirstName")), + wantErr: internal.NewValueNotSetError(), + }, + { + name: "without zero", + builder: New().Update(&TestModel{Id: 13, FirstName: "Tom"}).Set(C("FirstName")), + wantSql: "UPDATE `test_model` SET `first_name`=?;", + wantArgs: []interface{}{"Tom"}, + }, + { + name: "with zero", + builder: New().Update(&TestModel{Id: 13, FirstName: "Tom"}).WithZero(), + wantSql: "UPDATE `test_model` SET `id`=?,`first_name`=?,`age`=?;", + wantArgs: []interface{}{int64(13), "Tom", int8(0)}, + }, + { + name: "with nil", + builder: New().Update(&TestModel{Id: 13, FirstName: "Tom"}).WithNil(), + wantSql: "UPDATE `test_model` SET `id`=?,`first_name`=?,`last_name`=?;", + wantArgs: []interface{}{int64(13), "Tom", (*string)(nil)}, + }, + { + name: "with nil, with zero", + builder: New().Update(&TestModel{Id: 13, FirstName: "Tom"}).WithNil().WithZero(), + wantSql: "UPDATE `test_model` SET `id`=?,`first_name`=?,`age`=?,`last_name`=?;", + wantArgs: []interface{}{int64(13), "Tom", int8(0), (*string)(nil)}, + }, + { + name: "no where", + builder: New().Update(&TestModel{Id: 13, FirstName: "Tom"}).WithNil().WithZero().Where(), + wantSql: "UPDATE `test_model` SET `id`=?,`first_name`=?,`age`=?,`last_name`=?;", + wantArgs: []interface{}{int64(13), "Tom", int8(0), (*string)(nil)}, + }, + { + name: "where", + builder: New().Update(&TestModel{Id: 13, FirstName: "Tom"}).WithNil().WithZero().Where(C("Id").EQ(14)), + wantSql: "UPDATE `test_model` SET `id`=?,`first_name`=?,`age`=?,`last_name`=? WHERE `id`=?;", + wantArgs: []interface{}{int64(13), "Tom", int8(0), (*string)(nil), 14}, + }, } for _, tc := range testCases {