Skip to content

Commit

Permalink
Convert slice arguments in mappers into flat values
Browse files Browse the repository at this point in the history
As discussed in [1], sometimes the caller wants to use a slice in the mapper
arguments and the driver does not support it. This patch will convert slices
from the mapper into simple values, modifying the SQL to add more placeholders
and the mapper, to put each slice item as an unique entry.

This behaviour is optional and by default will be disabled.

[1] go-gorp#5
  • Loading branch information
rafaeljusto committed Jun 18, 2018
1 parent 6032c66 commit 7f61e35
Showing 1 changed file with 201 additions and 0 deletions.
201 changes: 201 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,57 @@ type DbMap struct {

TypeConverter TypeConverter

// ExpandSlices when enabled will convert slice arguments in mappers into flat
// values. It will modify the query, adding more placeholders, and the mapper,
// adding each item of the slice as a new unique entry in the mapper. For
// example, given the scenario bellow:
//
// dbmap.Select(&output, "SELECT 1 FROM example WHERE id IN (:IDs)", map[string]interface{}{
// "IDs": []int64{1, 2, 3},
// })
//
// The executed query would be:
//
// SELECT 1 FROM example WHERE id IN (:IDs0,:IDs1,:IDs2)
//
// With the mapper:
//
// map[string]interface{}{
// "IDs": []int64{1, 2, 3},
// "IDs0": int64(1),
// "IDs1": int64(2),
// "IDs2": int64(3),
// }
//
// It is also flexible for custom slice types. The value just need to
// implement stringer or numberer interfaces.
//
// type CustomValue string
//
// const (
// CustomValueHey CustomValue = "hey"
// CustomValueOh CustomValue = "oh"
// )
//
// type CustomValues []CustomValue
//
// func (c CustomValues) ToStringSlice() []string {
// values := make([]string, len(c))
// for i := range c {
// values[i] = string(c[i])
// }
// return values
// }
//
// func query() {
// // ...
// result, err := dbmap.Select(&output, "SELECT 1 FROM example WHERE value IN (:Values)", map[string]interface{}{
// "Values": CustomValues([]CustomValue{CustomValueHey}),
// })
// // ...
// }
ExpandSliceArgs bool

tables []*TableMap
tablesDynamic map[string]*TableMap // tables that use same go-struct and different db table names
logger GorpLogger
Expand Down Expand Up @@ -605,12 +656,20 @@ func (m *DbMap) Get(i interface{}, keys ...interface{}) (interface{}, error) {
//
// i does NOT need to be registered with AddTable()
func (m *DbMap) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
if m.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}

return hookedselect(m, m, i, query, args...)
}

// Exec runs an arbitrary SQL statement. args represent the bind parameters.
// This is equivalent to running: Exec() using database/sql
func (m *DbMap) Exec(query string, args ...interface{}) (sql.Result, error) {
if m.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}

if m.logger != nil {
now := time.Now()
defer m.trace(now, query, args...)
Expand All @@ -620,36 +679,64 @@ func (m *DbMap) Exec(query string, args ...interface{}) (sql.Result, error) {

// SelectInt is a convenience wrapper around the gorp.SelectInt function
func (m *DbMap) SelectInt(query string, args ...interface{}) (int64, error) {
if m.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}

return SelectInt(m, query, args...)
}

// SelectNullInt is a convenience wrapper around the gorp.SelectNullInt function
func (m *DbMap) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) {
if m.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}

return SelectNullInt(m, query, args...)
}

// SelectFloat is a convenience wrapper around the gorp.SelectFloat function
func (m *DbMap) SelectFloat(query string, args ...interface{}) (float64, error) {
if m.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}

return SelectFloat(m, query, args...)
}

// SelectNullFloat is a convenience wrapper around the gorp.SelectNullFloat function
func (m *DbMap) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) {
if m.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}

return SelectNullFloat(m, query, args...)
}

// SelectStr is a convenience wrapper around the gorp.SelectStr function
func (m *DbMap) SelectStr(query string, args ...interface{}) (string, error) {
if m.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}

return SelectStr(m, query, args...)
}

// SelectNullStr is a convenience wrapper around the gorp.SelectNullStr function
func (m *DbMap) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) {
if m.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}

return SelectNullStr(m, query, args...)
}

// SelectOne is a convenience wrapper around the gorp.SelectOne function
func (m *DbMap) SelectOne(holder interface{}, query string, args ...interface{}) error {
if m.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}

return SelectOne(m, m, holder, query, args...)
}

Expand Down Expand Up @@ -764,6 +851,10 @@ func (m *DbMap) tableForPointer(ptr interface{}, checkPK bool) (*TableMap, refle
}

func (m *DbMap) QueryRow(query string, args ...interface{}) *sql.Row {
if m.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}

if m.logger != nil {
now := time.Now()
defer m.trace(now, query, args...)
Expand All @@ -772,6 +863,10 @@ func (m *DbMap) QueryRow(query string, args ...interface{}) *sql.Row {
}

func (m *DbMap) Query(q string, args ...interface{}) (*sql.Rows, error) {
if m.ExpandSliceArgs {
expandSliceArgs(&q, args...)
}

if m.logger != nil {
now := time.Now()
defer m.trace(now, q, args...)
Expand All @@ -780,8 +875,114 @@ func (m *DbMap) Query(q string, args ...interface{}) (*sql.Rows, error) {
}

func (m *DbMap) trace(started time.Time, query string, args ...interface{}) {
if m.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}

if m.logger != nil {
var margs = argsString(args...)
m.logger.Printf("%s%s [%s] (%v)", m.logPrefix, query, margs, (time.Now().Sub(started)))
}
}

type stringer interface {
ToStringSlice() []string
}

type numberer interface {
ToInt64Slice() []int64
}

func expandSliceArgs(query *string, args ...interface{}) {
for _, arg := range args {
mapper, ok := arg.(map[string]interface{})
if !ok {
continue
}

for key, value := range mapper {
var replacements []string

// add flexibility for any custom type to be convert to one of the
// acceptable formats.
if v, ok := value.(stringer); ok {
value = v.ToStringSlice()
}
if v, ok := value.(numberer); ok {
value = v.ToInt64Slice()
}

switch v := value.(type) {
case []string:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []uint:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []uint8:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []uint16:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []uint32:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []uint64:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []int:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []int8:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []int16:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []int32:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []int64:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []float32:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
case []float64:
for id, replace := range v {
mapper[fmt.Sprintf("%s%d", key, id)] = replace
replacements = append(replacements, fmt.Sprintf(":%s%d", key, id))
}
default:
continue
}

*query = strings.Replace(*query, fmt.Sprintf(":%s", key), strings.Join(replacements, ","), -1)
}
}
}

0 comments on commit 7f61e35

Please sign in to comment.