From e49e182afee9e4094fb2ed41715b33f76ab87b4d Mon Sep 17 00:00:00 2001 From: zhangyingqi Date: Sun, 17 Oct 2021 18:20:31 +0800 Subject: [PATCH] Feat: add insert function and test for it --- .CHANGELOG.md | 1 + db.go | 12 +++--- insert.go | 85 +++++++++++++++++++++++++++++++++++++-- insert_test.go | 105 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 194 insertions(+), 9 deletions(-) create mode 100644 insert_test.go diff --git a/.CHANGELOG.md b/.CHANGELOG.md index fceb71f..3173957 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -10,3 +10,4 @@ - [Support Aggregate Functions](https://github.com/gotomicro/eql/pull/37) - [Updater implementation, excluding WHERE clause](https://github.com/gotomicro/eql/pull/36) - [Force test and lint in pre-push](https://github.com/gotomicro/eql/pull/35) +- [Insert implementation ](https://github.com/gotomicro/eql/pull/38) \ No newline at end of file diff --git a/db.go b/db.go index 082d76a..6644cbe 100644 --- a/db.go +++ b/db.go @@ -24,7 +24,7 @@ type DBOption func(db *DB) // DB represents a database type DB struct { - metaRegistry MetaRegistry + metaRegistry MetaRegistry dialect Dialect nullAssertFunc NullAssertFunc } @@ -65,14 +65,16 @@ func (db *DB) Update(table interface{}) *Updater { // Insert generate Inserter to builder insert query func (db *DB) Insert() *Inserter { - return &Inserter{} + return &Inserter{ + builder: db.builder(), + } } func (db *DB) builder() builder { return builder{ registry: db.metaRegistry, - dialect: db.dialect, - buffer: &strings.Builder{}, + dialect: db.dialect, + buffer: &strings.Builder{}, } } @@ -97,7 +99,7 @@ func NilAsNullFunc(val interface{}) bool { // ZeroAsNullFunc means "zero value = null" func ZeroAsNullFunc(val interface{}) bool { - if val == nil{ + if val == nil { return true } switch v := val.(type) { diff --git a/insert.go b/insert.go index 4dbdf5a..fdda623 100644 --- a/insert.go +++ b/insert.go @@ -14,34 +14,111 @@ package eql +import ( + "errors" + "fmt" + "reflect" +) + // Inserter is used to construct an insert query type Inserter struct { + builder + columns []string + values []interface{} } func (i *Inserter) Build() (*Query, error) { - panic("implement me") + var err error + if len(i.values) == 0 { + return &Query{}, errors.New("no values") + } + i.buffer.WriteString("INSERT INTO ") + i.meta, err = i.registry.Get(i.values[0]) + if err != nil { + return &Query{}, err + } + i.quote(i.meta.tableName) + i.buffer.WriteString("(") + fields, err := i.buildColumns() + if err != nil { + return &Query{}, err + } + i.buffer.WriteString(")") + i.buffer.WriteString(" VALUES") + for index, value := range i.values { + i.buffer.WriteString("(") + refVal := reflect.ValueOf(value).Elem() + for j, v := range fields { + field := refVal.FieldByName(v.fieldName) + if !field.IsValid() { + return &Query{}, fmt.Errorf("invalid column %s", v.fieldName) + } + val := field.Interface() + i.parameter(val) + if j != len(fields)-1 { + i.comma() + } + } + i.buffer.WriteString(")") + if index != len(i.values)-1 { + i.comma() + } + } + i.end() + return &Query{SQL: i.buffer.String(), Args: i.args}, nil } // Columns specifies the columns that need to be inserted // if cs is empty, all columns will be inserted except auto increment columns -func (db *DB) Columns(cs ...string) *Inserter { - panic("implements me") +func (i *Inserter) Columns(cs ...string) *Inserter { + i.columns = cs + return i } // Values specify the rows // all the elements must be the same structure func (i *Inserter) Values(values ...interface{}) *Inserter { - panic("implement me") + i.values = values + return i } // OnDuplicateKey generate MysqlUpserter // if the dialect is not MySQL, it will panic func (i *Inserter) OnDuplicateKey() *MysqlUpserter { + panic("implement me") } // OnConflict generate PgSQLUpserter // if the dialect is not PgSQL, it will panic func (i *Inserter) OnConflict(cs ...string) *PgSQLUpserter { + panic("implement me") } + +func (i *Inserter) buildColumns() ([]*ColumnMeta, error) { + cs := i.meta.columns + if len(i.columns) != 0 { + cs = make([]*ColumnMeta, 0, len(i.columns)) + for index, value := range i.columns { + v, isOk := i.meta.fieldMap[value] + if !isOk { + return cs, fmt.Errorf("invalid column %s", value) + } + i.quote(v.columnName) + if index != len(i.columns)-1 { + i.comma() + } + cs = append(cs, v) + } + } else { + for index, value := range i.meta.columns { + i.quote(value.columnName) + if index != len(cs)-1 { + i.comma() + } + } + } + return cs, nil + +} diff --git a/insert_test.go b/insert_test.go new file mode 100644 index 0000000..3f53107 --- /dev/null +++ b/insert_test.go @@ -0,0 +1,105 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package eql + +import ( + "errors" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestInserter_Values(t *testing.T) { + type User struct { + Id int64 + FirstName string + Ctime time.Time + } + type Order struct { + Id int64 + Name string + Price int64 + } + + n := time.Now() + u := &User{ + Id: 12, + FirstName: "Tom", + Ctime: n, + } + u1 := &User{ + Id: 13, + FirstName: "Jerry", + Ctime: n, + } + o1 := &Order{ + Id: 14, + Name: "Hellen", + Price: 200, + } + testCases := []CommonTestCase{ + { + name: "no examples of values", + builder: New().Insert().Values(), + wantErr: errors.New("no values"), + }, + { + name: "single example of values", + builder: New().Insert().Values(u), + wantSql: "INSERT INTO `user`(`id`,`first_name`,`ctime`) VALUES(?,?,?);", + wantArgs: []interface{}{int64(12), "Tom", n}, + }, + + { + name: "multiple values of same type", + builder: New().Insert().Values(u, u1), + wantSql: "INSERT INTO `user`(`id`,`first_name`,`ctime`) VALUES(?,?,?),(?,?,?);", + wantArgs: []interface{}{int64(12), "Tom", n, int64(13), "Jerry", n}, + }, + + { + name: "no example of a whole columns", + builder: New().Insert().Columns("Id", "FirstName").Values(u), + wantSql: "INSERT INTO `user`(`id`,`first_name`) VALUES(?,?);", + wantArgs: []interface{}{int64(12), "Tom"}, + }, + { + name: "an example with invalid columns", + builder: New().Insert().Columns("id", "FirstName").Values(u), + wantErr: errors.New("invalid column id"), + }, + { + name: "no whole columns and multiple values of same type", + builder: New().Insert().Columns("Id", "FirstName").Values(u, u1), + wantSql: "INSERT INTO `user`(`id`,`first_name`) VALUES(?,?),(?,?);", + wantArgs: []interface{}{int64(12), "Tom", int64(13), "Jerry"}, + }, + { + name: "multiple values of invalid column", + builder: New().Insert().Values(u, o1), + wantErr: errors.New("invalid column FirstName"), + }, + } + + for _, tc := range testCases { + + c := tc + t.Run(tc.name, func(t *testing.T) { + q, err := c.builder.Build() + assert.Equal(t, c.wantErr, err) + assert.Equal(t, c.wantSql, q.SQL) + assert.Equal(t, c.wantArgs, q.Args) + }) + } +}