Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ScanStructs, ScanVals to Scanner interface #273

Merged
merged 2 commits into from
May 20, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 16 additions & 45 deletions exec/query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,15 @@ func (q QueryExecutor) ScanStructs(i interface{}) error {
//
// i: A pointer to a slice of structs.
func (q QueryExecutor) ScanStructsContext(ctx context.Context, i interface{}) error {
val := reflect.ValueOf(i)
if !util.IsPointer(val.Kind()) {
return errUnsupportedScanStructsType
if _, err := checkScanStructsTarget(i); err != nil {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to do this check both here and in the Scanner#ScanStructs? Same for ScanVals?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should clarify this, I think we should just do the check in the single place for each to keep it DRY.

Copy link
Contributor Author

@vlanse vlanse May 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to call checkScanStructsTarget (checkScanValsTarget) only in corresponding Scanner methods because errors that cause checks to fail seem likely to not occur at runtime (they should be detected early by simple tests in fact, along with errors like invalid SQL) , so it does not matter that check will be done before actual query execution.
Will update PR appropriately

return err
}
val = reflect.Indirect(val)
if !util.IsSlice(val.Kind()) {
return errUnsupportedScanStructsType
scanner, err := q.ScannerContext(ctx)
if err != nil {
return err
}

return q.scanIntoSlice(ctx, val, func(sc Scanner, r interface{}) error {
return sc.ScanStruct(r)
})
defer func() { _ = scanner.Close() }()
return scanner.ScanStructs(i)
}

// This will execute the SQL and fill out the struct with the fields returned.
Expand Down Expand Up @@ -136,7 +133,7 @@ func (q QueryExecutor) ScanStructContext(ctx context.Context, i interface{}) (bo
return false, err
}

defer scanner.Close()
defer func() { _ = scanner.Close() }()

if scanner.Next() {
err = scanner.ScanStruct(i)
Expand Down Expand Up @@ -169,18 +166,15 @@ func (q QueryExecutor) ScanVals(i interface{}) error {
//
// i: Takes a pointer to a slice of primitive values.
func (q QueryExecutor) ScanValsContext(ctx context.Context, i interface{}) error {
val := reflect.ValueOf(i)
if !util.IsPointer(val.Kind()) {
return errUnsupportedScanValsType
if _, err := checkScanValsTarget(i); err != nil {
return err
}
val = reflect.Indirect(val)
if !util.IsSlice(val.Kind()) {
return errUnsupportedScanValsType
scanner, err := q.ScannerContext(ctx)
if err != nil {
return err
}

return q.scanIntoSlice(ctx, val, func(sc Scanner, r interface{}) error {
return sc.ScanVal(r)
})
defer func() { _ = scanner.Close() }()
return scanner.ScanVals(i)
}

// This will execute the SQL and set the value of the primitive. This method will return false if no record is found.
Expand Down Expand Up @@ -230,7 +224,7 @@ func (q QueryExecutor) ScanValContext(ctx context.Context, i interface{}) (bool,
return false, err
}

defer scanner.Close()
defer func() { _ = scanner.Close() }()

if scanner.Next() {
err = scanner.ScanVal(i)
Expand All @@ -257,26 +251,3 @@ func (q QueryExecutor) ScannerContext(ctx context.Context) (Scanner, error) {
}
return NewScanner(rows), nil
}

func (q QueryExecutor) scanIntoSlice(ctx context.Context, val reflect.Value, it func(sc Scanner, i interface{}) error) error {
elemType := util.GetSliceElementType(val)

scanner, err := q.ScannerContext(ctx)
if err != nil {
return err
}

defer scanner.Close()

for scanner.Next() {
row := reflect.New(elemType)
err = it(scanner, row.Interface())
if err != nil {
return err
}

util.AppendSliceElement(val, row)
}

return scanner.Err()
}
62 changes: 62 additions & 0 deletions exec/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ type (
Scanner interface {
Next() bool
ScanStruct(i interface{}) error
ScanStructs(i interface{}) error
ScanVal(i interface{}) error
ScanVals(i interface{}) error
Close() error
Err() error
}
Expand Down Expand Up @@ -90,6 +92,17 @@ func (s *scanner) ScanStruct(i interface{}) error {
return s.Err()
}

// ScanStructs scans results in slice of structs
func (s *scanner) ScanStructs(i interface{}) error {
val, err := checkScanStructsTarget(i)
if err != nil {
return err
}
return s.scanIntoSlice(val, func(i interface{}) error {
return s.ScanStruct(i)
})
}

// ScanVal will scan the current row and column into i.
func (s *scanner) ScanVal(i interface{}) error {
if err := s.rows.Scan(i); err != nil {
Expand All @@ -99,8 +112,57 @@ func (s *scanner) ScanVal(i interface{}) error {
return s.Err()
}

// ScanStructs scans results in slice of values
func (s *scanner) ScanVals(i interface{}) error {
val, err := checkScanValsTarget(i)
if err != nil {
return err
}
return s.scanIntoSlice(val, func(i interface{}) error {
return s.ScanVal(i)
})
}

// Close closes the Rows, preventing further enumeration. See sql.Rows#Close
// for more info.
func (s *scanner) Close() error {
return s.rows.Close()
}

func (s *scanner) scanIntoSlice(val reflect.Value, it func(i interface{}) error) error {
elemType := util.GetSliceElementType(val)

for s.Next() {
row := reflect.New(elemType)
if rowErr := it(row.Interface()); rowErr != nil {
return rowErr
}
util.AppendSliceElement(val, row)
}

return s.Err()
}

func checkScanStructsTarget(i interface{}) (reflect.Value, error) {
val := reflect.ValueOf(i)
if !util.IsPointer(val.Kind()) {
return val, errUnsupportedScanStructsType
}
val = reflect.Indirect(val)
if !util.IsSlice(val.Kind()) {
return val, errUnsupportedScanStructsType
}
return val, nil
}

func checkScanValsTarget(i interface{}) (reflect.Value, error) {
val := reflect.ValueOf(i)
if !util.IsPointer(val.Kind()) {
return val, errUnsupportedScanValsType
}
val = reflect.Indirect(val)
if !util.IsSlice(val.Kind()) {
return val, errUnsupportedScanValsType
}
return val, nil
}
69 changes: 69 additions & 0 deletions exec/scanner_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package exec

import (
"testing"

"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/suite"
)

type scannerSuite struct {
suite.Suite
}

func TestScanner(t *testing.T) {
suite.Run(t, &scannerSuite{})
}

func (s *scannerSuite) TestScanStructs() {
type StructWithTags struct {
Address string `db:"address"`
Name string `db:"name"`
}
db, mock, err := sqlmock.New()
s.Require().NoError(err)

mock.ExpectQuery(`SELECT \* FROM "items"`).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"address", "name"}).
AddRow(testAddr1, testName1).
AddRow(testAddr2, testName2),
)
rows, err := db.Query(`SELECT * FROM "items"`)
s.Require().NoError(err)

sc := NewScanner(rows)

result := make([]StructWithTags, 0)
err = sc.ScanStructs(result)
s.Require().EqualError(err, errUnsupportedScanStructsType.Error())

err = sc.ScanStructs(&result)
s.Require().NoError(err)
s.Require().ElementsMatch(
[]StructWithTags{{Address: testAddr1, Name: testName1}, {Address: testAddr2, Name: testName2}},
result,
)
}

func (s *scannerSuite) TestScanVals() {
db, mock, err := sqlmock.New()
s.Require().NoError(err)

mock.ExpectQuery(`SELECT "id" FROM "items"`).
WithArgs().
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1).AddRow(2))

rows, err := db.Query(`SELECT "id" FROM "items"`)
s.Require().NoError(err)

sc := NewScanner(rows)

result := make([]int, 0)
err = sc.ScanVals(result)
s.Require().EqualError(err, errUnsupportedScanValsType.Error())

err = sc.ScanVals(&result)
s.Require().NoError(err)
s.Require().ElementsMatch([]int{1, 2}, result)
}