diff --git a/database_test.go b/database_test.go index ba77b38a..a4e717b7 100644 --- a/database_test.go +++ b/database_test.go @@ -67,7 +67,14 @@ func (ds *databaseSuite) TestScanStructs() { WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"address", "name"}). FromCSVString("111 Test Addr,Test1\n211 Test Addr,Test2")) - + mock.ExpectQuery(`SELECT \* FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"address", "name"}). + FromCSVString("111 Test Addr,Test1\n211 Test Addr,Test2")) + mock.ExpectQuery(`SELECT \* FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"address", "name"}). + FromCSVString("111 Test Addr,Test1\n211 Test Addr,Test2")) mock.ExpectQuery(`SELECT "test" FROM "items"`). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"test"}).FromCSVString("test1\ntest2")) @@ -124,6 +131,12 @@ func (ds *databaseSuite) TestScanVals() { mock.ExpectQuery(`SELECT "id" FROM "items"`). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"id"}).FromCSVString("1\n2\n3\n4\n5")) + mock.ExpectQuery(`SELECT "id" FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"id"}).FromCSVString("1\n2\n3\n4\n5")) + mock.ExpectQuery(`SELECT "id" FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"id"}).FromCSVString("1\n2\n3\n4\n5")) db := goqu.New("mock", mDB) var ids []uint32 @@ -492,7 +505,14 @@ func (tds *txdatabaseSuite) TestScanStructs() { WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"address", "name"}). FromCSVString("111 Test Addr,Test1\n211 Test Addr,Test2")) - + mock.ExpectQuery(`SELECT \* FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"address", "name"}). + FromCSVString("111 Test Addr,Test1\n211 Test Addr,Test2")) + mock.ExpectQuery(`SELECT \* FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"address", "name"}). + FromCSVString("111 Test Addr,Test1\n211 Test Addr,Test2")) mock.ExpectQuery(`SELECT "test" FROM "items"`). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"test"}).FromCSVString("test1\ntest2")) @@ -554,6 +574,12 @@ func (tds *txdatabaseSuite) TestScanVals() { mDB, mock, err := sqlmock.New() tds.NoError(err) mock.ExpectBegin() + mock.ExpectQuery(`SELECT "id" FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"id"}).FromCSVString("1\n2\n3\n4\n5")) + mock.ExpectQuery(`SELECT "id" FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"id"}).FromCSVString("1\n2\n3\n4\n5")) mock.ExpectQuery(`SELECT "id" FROM "items"`). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"id"}).FromCSVString("1\n2\n3\n4\n5")) diff --git a/exec/query_executor.go b/exec/query_executor.go index 09ff98f6..ce460dff 100644 --- a/exec/query_executor.go +++ b/exec/query_executor.go @@ -79,18 +79,12 @@ func (q QueryExecutor) ScanStructs(i interface{}) error { // // i: A pointer to a slice of structs. func (q QueryExecutor) ScanStructsContext(ctx context.Context, i interface{}) error { - val := reflect.ValueOf(i) - if !util.IsPointer(val.Kind()) { - return errUnsupportedScanStructsType - } - val = reflect.Indirect(val) - if !util.IsSlice(val.Kind()) { - return errUnsupportedScanStructsType + scanner, err := q.ScannerContext(ctx) + if err != nil { + return err } - - return q.scanIntoSlice(ctx, val, func(sc Scanner, r interface{}) error { - return sc.ScanStruct(r) - }) + defer func() { _ = scanner.Close() }() + return scanner.ScanStructs(i) } // This will execute the SQL and fill out the struct with the fields returned. @@ -136,7 +130,7 @@ func (q QueryExecutor) ScanStructContext(ctx context.Context, i interface{}) (bo return false, err } - defer scanner.Close() + defer func() { _ = scanner.Close() }() if scanner.Next() { err = scanner.ScanStruct(i) @@ -169,18 +163,12 @@ func (q QueryExecutor) ScanVals(i interface{}) error { // // i: Takes a pointer to a slice of primitive values. func (q QueryExecutor) ScanValsContext(ctx context.Context, i interface{}) error { - val := reflect.ValueOf(i) - if !util.IsPointer(val.Kind()) { - return errUnsupportedScanValsType - } - val = reflect.Indirect(val) - if !util.IsSlice(val.Kind()) { - return errUnsupportedScanValsType + scanner, err := q.ScannerContext(ctx) + if err != nil { + return err } - - return q.scanIntoSlice(ctx, val, func(sc Scanner, r interface{}) error { - return sc.ScanVal(r) - }) + defer func() { _ = scanner.Close() }() + return scanner.ScanVals(i) } // This will execute the SQL and set the value of the primitive. This method will return false if no record is found. @@ -230,7 +218,7 @@ func (q QueryExecutor) ScanValContext(ctx context.Context, i interface{}) (bool, return false, err } - defer scanner.Close() + defer func() { _ = scanner.Close() }() if scanner.Next() { err = scanner.ScanVal(i) @@ -257,26 +245,3 @@ func (q QueryExecutor) ScannerContext(ctx context.Context) (Scanner, error) { } return NewScanner(rows), nil } - -func (q QueryExecutor) scanIntoSlice(ctx context.Context, val reflect.Value, it func(sc Scanner, i interface{}) error) error { - elemType := util.GetSliceElementType(val) - - scanner, err := q.ScannerContext(ctx) - if err != nil { - return err - } - - defer scanner.Close() - - for scanner.Next() { - row := reflect.New(elemType) - err = it(scanner, row.Interface()) - if err != nil { - return err - } - - util.AppendSliceElement(val, row) - } - - return scanner.Err() -} diff --git a/exec/query_executor_internal_test.go b/exec/query_executor_internal_test.go index f098059c..04256f0c 100644 --- a/exec/query_executor_internal_test.go +++ b/exec/query_executor_internal_test.go @@ -486,14 +486,33 @@ func (qes *queryExecutorSuite) TestScanStructs_badValue() { Name string `db:"name"` } - db, _, err := sqlmock.New() - qes.NoError(err) - - e := newQueryExecutor(db, nil, `SELECT * FROM "items"`) - - var items []StructWithTags - qes.Equal(errUnsupportedScanStructsType, e.ScanStructs(items)) - qes.Equal(errUnsupportedScanStructsType, e.ScanStructs(&StructWithTags{})) + tests := []struct { + name string + items interface{} + }{ + { + name: "non-pointer items", + items: []StructWithTags{}, + }, + { + name: "non-slice items", + items: &StructWithTags{}, + }, + } + for i := range tests { + test := tests[i] + qes.Run(test.name, func() { + db, mock, err := sqlmock.New() + qes.NoError(err) + mock.ExpectQuery(`SELECT \* FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"address", "name"}). + AddRow(testAddr1, testName1).AddRow(testAddr2, testName2), + ) + e := newQueryExecutor(db, nil, `SELECT * FROM "items"`) + qes.Equal(errUnsupportedScanStructsType, e.ScanStructs(test.items)) + }) + } } func (qes *queryExecutorSuite) TestScanStructs_queryError() { @@ -788,15 +807,33 @@ func (qes *queryExecutorSuite) TestScanStructsContext_badValue() { Name string `db:"name"` } - ctx := context.Background() - db, _, err := sqlmock.New() - qes.NoError(err) - - e := newQueryExecutor(db, nil, `SELECT * FROM "items"`) - - var items []StructWithTags - qes.Equal(errUnsupportedScanStructsType, e.ScanStructsContext(ctx, items)) - qes.Equal(errUnsupportedScanStructsType, e.ScanStructsContext(ctx, &StructWithTags{})) + tests := []struct { + name string + items interface{} + }{ + { + name: "non-pointer items", + items: []StructWithTags{}, + }, + { + name: "non-slice items", + items: &StructWithTags{}, + }, + } + for i := range tests { + test := tests[i] + qes.Run(test.name, func() { + db, mock, err := sqlmock.New() + qes.NoError(err) + mock.ExpectQuery(`SELECT \* FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"address", "name"}). + AddRow(testAddr1, testName1).AddRow(testAddr2, testName2), + ) + e := newQueryExecutor(db, nil, `SELECT * FROM "items"`) + qes.Equal(errUnsupportedScanStructsType, e.ScanStructsContext(context.Background(), test.items)) + }) + } } func (qes *queryExecutorSuite) TestScanStructsContext_queryError() { @@ -1040,10 +1077,7 @@ func (qes *queryExecutorSuite) TestScanVals() { e := newQueryExecutor(db, nil, `SELECT "id" FROM "items"`) - var id int64 var ids []int64 - qes.Equal(errUnsupportedScanValsType, e.ScanVals(ids)) - qes.Equal(errUnsupportedScanValsType, e.ScanVals(&id)) qes.EqualError(e.ScanVals(&ids), "queryExecutor error") qes.EqualError(e.ScanVals(&ids), "row error") qes.Error(e.ScanVals(&ids)) @@ -1059,6 +1093,37 @@ func (qes *queryExecutorSuite) TestScanVals() { qes.Equal(&id2, pointers[1]) } +func (qes *queryExecutorSuite) TestScanValsError() { + var id int64 + + tests := []struct { + name string + items interface{} + }{ + { + name: "non-pointer items", + items: []int64{}, + }, + { + name: "non-slice items", + items: &id, + }, + } + for i := range tests { + test := tests[i] + qes.Run(test.name, func() { + db, mock, err := sqlmock.New() + qes.NoError(err) + mock.ExpectQuery(`SELECT "id" FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1).AddRow(2)) + + e := newQueryExecutor(db, nil, `SELECT "id" FROM "items"`) + qes.Equal(errUnsupportedScanValsType, e.ScanVals(test.items)) + }) + } +} + func (qes *queryExecutorSuite) TestScanVal() { db, mock, err := sqlmock.New() qes.NoError(err) diff --git a/exec/scanner.go b/exec/scanner.go index 506f0b11..c05411b9 100644 --- a/exec/scanner.go +++ b/exec/scanner.go @@ -14,7 +14,9 @@ type ( Scanner interface { Next() bool ScanStruct(i interface{}) error + ScanStructs(i interface{}) error ScanVal(i interface{}) error + ScanVals(i interface{}) error Close() error Err() error } @@ -90,6 +92,17 @@ func (s *scanner) ScanStruct(i interface{}) error { return s.Err() } +// ScanStructs scans results in slice of structs +func (s *scanner) ScanStructs(i interface{}) error { + val, err := checkScanStructsTarget(i) + if err != nil { + return err + } + return s.scanIntoSlice(val, func(i interface{}) error { + return s.ScanStruct(i) + }) +} + // ScanVal will scan the current row and column into i. func (s *scanner) ScanVal(i interface{}) error { if err := s.rows.Scan(i); err != nil { @@ -99,8 +112,57 @@ func (s *scanner) ScanVal(i interface{}) error { return s.Err() } +// ScanStructs scans results in slice of values +func (s *scanner) ScanVals(i interface{}) error { + val, err := checkScanValsTarget(i) + if err != nil { + return err + } + return s.scanIntoSlice(val, func(i interface{}) error { + return s.ScanVal(i) + }) +} + // Close closes the Rows, preventing further enumeration. See sql.Rows#Close // for more info. func (s *scanner) Close() error { return s.rows.Close() } + +func (s *scanner) scanIntoSlice(val reflect.Value, it func(i interface{}) error) error { + elemType := util.GetSliceElementType(val) + + for s.Next() { + row := reflect.New(elemType) + if rowErr := it(row.Interface()); rowErr != nil { + return rowErr + } + util.AppendSliceElement(val, row) + } + + return s.Err() +} + +func checkScanStructsTarget(i interface{}) (reflect.Value, error) { + val := reflect.ValueOf(i) + if !util.IsPointer(val.Kind()) { + return val, errUnsupportedScanStructsType + } + val = reflect.Indirect(val) + if !util.IsSlice(val.Kind()) { + return val, errUnsupportedScanStructsType + } + return val, nil +} + +func checkScanValsTarget(i interface{}) (reflect.Value, error) { + val := reflect.ValueOf(i) + if !util.IsPointer(val.Kind()) { + return val, errUnsupportedScanValsType + } + val = reflect.Indirect(val) + if !util.IsSlice(val.Kind()) { + return val, errUnsupportedScanValsType + } + return val, nil +} diff --git a/exec/scanner_internal_test.go b/exec/scanner_internal_test.go new file mode 100644 index 00000000..0a73b1ff --- /dev/null +++ b/exec/scanner_internal_test.go @@ -0,0 +1,69 @@ +package exec + +import ( + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/suite" +) + +type scannerSuite struct { + suite.Suite +} + +func TestScanner(t *testing.T) { + suite.Run(t, &scannerSuite{}) +} + +func (s *scannerSuite) TestScanStructs() { + type StructWithTags struct { + Address string `db:"address"` + Name string `db:"name"` + } + db, mock, err := sqlmock.New() + s.Require().NoError(err) + + mock.ExpectQuery(`SELECT \* FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"address", "name"}). + AddRow(testAddr1, testName1). + AddRow(testAddr2, testName2), + ) + rows, err := db.Query(`SELECT * FROM "items"`) + s.Require().NoError(err) + + sc := NewScanner(rows) + + result := make([]StructWithTags, 0) + err = sc.ScanStructs(result) + s.Require().EqualError(err, errUnsupportedScanStructsType.Error()) + + err = sc.ScanStructs(&result) + s.Require().NoError(err) + s.Require().ElementsMatch( + []StructWithTags{{Address: testAddr1, Name: testName1}, {Address: testAddr2, Name: testName2}}, + result, + ) +} + +func (s *scannerSuite) TestScanVals() { + db, mock, err := sqlmock.New() + s.Require().NoError(err) + + mock.ExpectQuery(`SELECT "id" FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1).AddRow(2)) + + rows, err := db.Query(`SELECT "id" FROM "items"`) + s.Require().NoError(err) + + sc := NewScanner(rows) + + result := make([]int, 0) + err = sc.ScanVals(result) + s.Require().EqualError(err, errUnsupportedScanValsType.Error()) + + err = sc.ScanVals(&result) + s.Require().NoError(err) + s.Require().ElementsMatch([]int{1, 2}, result) +} diff --git a/select_dataset_test.go b/select_dataset_test.go index d1361bc0..a98d4304 100644 --- a/select_dataset_test.go +++ b/select_dataset_test.go @@ -1146,7 +1146,12 @@ func (sds *selectDatasetSuite) TestScanStructs() { WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"address", "name"}). FromCSVString("111 Test Addr,Test1\n211 Test Addr,Test2")) - + sqlMock.ExpectQuery(`SELECT "address", "name" FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"test"}).FromCSVString("test1\ntest2")) + sqlMock.ExpectQuery(`SELECT "address", "name" FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"test"}).FromCSVString("test1\ntest2")) sqlMock.ExpectQuery(`SELECT "test" FROM "items"`). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"test"}).FromCSVString("test1\ntest2")) @@ -1187,6 +1192,13 @@ func (sds *selectDatasetSuite) TestScanStructs_WithPreparedStatements() { WillReturnRows(sqlmock.NewRows([]string{"address", "name"}). FromCSVString("111 Test Addr,Test1\n211 Test Addr,Test2")) + sqlMock.ExpectQuery(`SELECT "address", "name" FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"test"}).FromCSVString("test1\ntest2")) + sqlMock.ExpectQuery(`SELECT "address", "name" FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"test"}).FromCSVString("test1\ntest2")) + sqlMock.ExpectQuery( `SELECT "test" FROM "items" WHERE \(\("address" = \?\) AND \("name" IN \(\?, \?, \?\)\)\)`, ). @@ -1299,6 +1311,12 @@ func (sds *selectDatasetSuite) TestScanVals() { sqlMock.ExpectQuery(`SELECT "id" FROM "items"`). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"id"}).FromCSVString("1\n2\n3\n4\n5")) + sqlMock.ExpectQuery(`SELECT \* FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"id"}).FromCSVString("1\n2\n3\n4\n5")) + sqlMock.ExpectQuery(`SELECT \* FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"id"}).FromCSVString("1\n2\n3\n4\n5")) db := goqu.New("mock", mDB) var ids []uint32 @@ -1323,6 +1341,13 @@ func (sds *selectDatasetSuite) TestScanVals_WithPreparedStatment() { WithArgs("111 Test Addr", "Bob", "Sally", "Billy"). WillReturnRows(sqlmock.NewRows([]string{"id"}).FromCSVString("1\n2\n3\n4\n5")) + sqlMock.ExpectQuery(`SELECT \* FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"id"}).FromCSVString("1\n2\n3\n4\n5")) + sqlMock.ExpectQuery(`SELECT \* FROM "items"`). + WithArgs(). + WillReturnRows(sqlmock.NewRows([]string{"id"}).FromCSVString("1\n2\n3\n4\n5")) + db := goqu.New("mock", mDB) var ids []uint32 sds.NoError(db.From("items"). @@ -1334,6 +1359,7 @@ func (sds *selectDatasetSuite) TestScanVals_WithPreparedStatment() { sds.EqualError(db.From("items").ScanVals([]uint32{}), "goqu: type must be a pointer to a slice when scanning into vals") + sds.EqualError(db.From("items").ScanVals(dsTestActionItem{}), "goqu: type must be a pointer to a slice when scanning into vals") }