Skip to content

Commit

Permalink
Feat: add insert function and test for it
Browse files Browse the repository at this point in the history
  • Loading branch information
CarolineZhang666 committed Oct 18, 2021
1 parent 36bee31 commit e49e182
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 9 deletions.
1 change: 1 addition & 0 deletions .CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
- [Support Aggregate Functions](https://github.com/gotomicro/eql/pull/37)
- [Updater implementation, excluding WHERE clause](https://github.com/gotomicro/eql/pull/36)
- [Force test and lint in pre-push](https://github.com/gotomicro/eql/pull/35)
- [Insert implementation ](https://github.com/gotomicro/eql/pull/38)
12 changes: 7 additions & 5 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type DBOption func(db *DB)

// DB represents a database
type DB struct {
metaRegistry MetaRegistry
metaRegistry MetaRegistry
dialect Dialect
nullAssertFunc NullAssertFunc
}
Expand Down Expand Up @@ -65,14 +65,16 @@ func (db *DB) Update(table interface{}) *Updater {

// Insert generate Inserter to builder insert query
func (db *DB) Insert() *Inserter {
return &Inserter{}
return &Inserter{
builder: db.builder(),
}
}

func (db *DB) builder() builder {
return builder{
registry: db.metaRegistry,
dialect: db.dialect,
buffer: &strings.Builder{},
dialect: db.dialect,
buffer: &strings.Builder{},
}
}

Expand All @@ -97,7 +99,7 @@ func NilAsNullFunc(val interface{}) bool {

// ZeroAsNullFunc means "zero value = null"
func ZeroAsNullFunc(val interface{}) bool {
if val == nil{
if val == nil {
return true
}
switch v := val.(type) {
Expand Down
85 changes: 81 additions & 4 deletions insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,111 @@

package eql

import (
"errors"
"fmt"
"reflect"
)

// Inserter is used to construct an insert query
type Inserter struct {
builder
columns []string
values []interface{}
}

func (i *Inserter) Build() (*Query, error) {
panic("implement me")
var err error
if len(i.values) == 0 {
return &Query{}, errors.New("no values")
}
i.buffer.WriteString("INSERT INTO ")
i.meta, err = i.registry.Get(i.values[0])
if err != nil {
return &Query{}, err
}
i.quote(i.meta.tableName)
i.buffer.WriteString("(")
fields, err := i.buildColumns()
if err != nil {
return &Query{}, err
}
i.buffer.WriteString(")")
i.buffer.WriteString(" VALUES")
for index, value := range i.values {
i.buffer.WriteString("(")
refVal := reflect.ValueOf(value).Elem()
for j, v := range fields {
field := refVal.FieldByName(v.fieldName)
if !field.IsValid() {
return &Query{}, fmt.Errorf("invalid column %s", v.fieldName)
}
val := field.Interface()
i.parameter(val)
if j != len(fields)-1 {
i.comma()
}
}
i.buffer.WriteString(")")
if index != len(i.values)-1 {
i.comma()
}
}
i.end()
return &Query{SQL: i.buffer.String(), Args: i.args}, nil
}

// Columns specifies the columns that need to be inserted
// if cs is empty, all columns will be inserted except auto increment columns
func (db *DB) Columns(cs ...string) *Inserter {
panic("implements me")
func (i *Inserter) Columns(cs ...string) *Inserter {
i.columns = cs
return i
}

// Values specify the rows
// all the elements must be the same structure
func (i *Inserter) Values(values ...interface{}) *Inserter {
panic("implement me")
i.values = values
return i
}

// OnDuplicateKey generate MysqlUpserter
// if the dialect is not MySQL, it will panic
func (i *Inserter) OnDuplicateKey() *MysqlUpserter {

panic("implement me")
}

// OnConflict generate PgSQLUpserter
// if the dialect is not PgSQL, it will panic
func (i *Inserter) OnConflict(cs ...string) *PgSQLUpserter {

panic("implement me")
}

func (i *Inserter) buildColumns() ([]*ColumnMeta, error) {
cs := i.meta.columns
if len(i.columns) != 0 {
cs = make([]*ColumnMeta, 0, len(i.columns))
for index, value := range i.columns {
v, isOk := i.meta.fieldMap[value]
if !isOk {
return cs, fmt.Errorf("invalid column %s", value)
}
i.quote(v.columnName)
if index != len(i.columns)-1 {
i.comma()
}
cs = append(cs, v)
}
} else {
for index, value := range i.meta.columns {
i.quote(value.columnName)
if index != len(cs)-1 {
i.comma()
}
}
}
return cs, nil

}
105 changes: 105 additions & 0 deletions insert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright 2021 gotomicro
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package eql

import (
"errors"
"github.com/stretchr/testify/assert"
"testing"
"time"
)

func TestInserter_Values(t *testing.T) {
type User struct {
Id int64
FirstName string
Ctime time.Time
}
type Order struct {
Id int64
Name string
Price int64
}

n := time.Now()
u := &User{
Id: 12,
FirstName: "Tom",
Ctime: n,
}
u1 := &User{
Id: 13,
FirstName: "Jerry",
Ctime: n,
}
o1 := &Order{
Id: 14,
Name: "Hellen",
Price: 200,
}
testCases := []CommonTestCase{
{
name: "no examples of values",
builder: New().Insert().Values(),
wantErr: errors.New("no values"),
},
{
name: "single example of values",
builder: New().Insert().Values(u),
wantSql: "INSERT INTO `user`(`id`,`first_name`,`ctime`) VALUES(?,?,?);",
wantArgs: []interface{}{int64(12), "Tom", n},
},

{
name: "multiple values of same type",
builder: New().Insert().Values(u, u1),
wantSql: "INSERT INTO `user`(`id`,`first_name`,`ctime`) VALUES(?,?,?),(?,?,?);",
wantArgs: []interface{}{int64(12), "Tom", n, int64(13), "Jerry", n},
},

{
name: "no example of a whole columns",
builder: New().Insert().Columns("Id", "FirstName").Values(u),
wantSql: "INSERT INTO `user`(`id`,`first_name`) VALUES(?,?);",
wantArgs: []interface{}{int64(12), "Tom"},
},
{
name: "an example with invalid columns",
builder: New().Insert().Columns("id", "FirstName").Values(u),
wantErr: errors.New("invalid column id"),
},
{
name: "no whole columns and multiple values of same type",
builder: New().Insert().Columns("Id", "FirstName").Values(u, u1),
wantSql: "INSERT INTO `user`(`id`,`first_name`) VALUES(?,?),(?,?);",
wantArgs: []interface{}{int64(12), "Tom", int64(13), "Jerry"},
},
{
name: "multiple values of invalid column",
builder: New().Insert().Values(u, o1),
wantErr: errors.New("invalid column FirstName"),
},
}

for _, tc := range testCases {

c := tc
t.Run(tc.name, func(t *testing.T) {
q, err := c.builder.Build()
assert.Equal(t, c.wantErr, err)
assert.Equal(t, c.wantSql, q.SQL)
assert.Equal(t, c.wantArgs, q.Args)
})
}
}

0 comments on commit e49e182

Please sign in to comment.