Skip to content

Commit

Permalink
bugfix: make RowIter.HasNext idempotent (pingcap#18)
Browse files Browse the repository at this point in the history
* bugfix: make RowIter.HasNext idemponent

* add a test for rowIter
  • Loading branch information
tangenta authored and kennytm committed Dec 29, 2019
1 parent 75dc405 commit 7caf19b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 8 deletions.
27 changes: 19 additions & 8 deletions v4/export/ir_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,31 @@ import (
)

// rowIter implements the SQLRowIter interface.
// Note: To create a rowIter, please use `newRowIter()` instead of struct literal.
type rowIter struct {
rows *sql.Rows
args []interface{}
rows *sql.Rows
hasNext bool
args []interface{}
}

func newRowIter(rows *sql.Rows, argLen int) *rowIter {
r := &rowIter{
rows: rows,
hasNext: false,
args: make([]interface{}, argLen),
}
r.hasNext = r.rows.Next()
return r
}

func (iter *rowIter) Next(row RowReceiver) error {
return decodeFromRows(iter.rows, iter.args, row)
err := decodeFromRows(iter.rows, iter.args, row)
iter.hasNext = iter.rows.Next()
return err
}

func (iter *rowIter) HasNext() bool {
return iter.rows.Next()
return iter.hasNext
}

type sizedRowIter struct {
Expand Down Expand Up @@ -94,10 +108,7 @@ func (td *tableData) ColumnCount() uint {
}

func (td *tableData) Rows() SQLRowIter {
return &rowIter{
rows: td.rows,
args: make([]interface{}, len(td.colTypes)),
}
return newRowIter(td.rows, len(td.colTypes))
}

func (td *tableData) SpecialComments() StringIter {
Expand Down
40 changes: 40 additions & 0 deletions v4/export/ir_impl_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package export

import (
"github.com/DATA-DOG/go-sqlmock"
. "github.com/pingcap/check"
)

var _ = Suite(&testIRImplSuite{})

type testIRImplSuite struct{}

func (s *testIRImplSuite) TestRowIter(c *C) {
db, mock, err := sqlmock.New()
c.Assert(err, IsNil)
defer db.Close()

expectedRows := mock.NewRows([]string{"id"}).
AddRow("1").
AddRow("2").
AddRow("3")
mock.ExpectQuery("SELECT id from t").WillReturnRows(expectedRows)
rows, err := db.Query("SELECT id from t")
c.Assert(err, IsNil)

iter := newRowIter(rows, 1)
for i := 0; i < 100; i += 1 {
c.Assert(iter.HasNext(), IsTrue)
}
res := make(dumplingRow, 1)
c.Assert(iter.Next(res), IsNil)
c.Assert(res[0].String, Equals, "1")
c.Assert(iter.HasNext(), IsTrue)
c.Assert(iter.HasNext(), IsTrue)
c.Assert(iter.Next(res), IsNil)
c.Assert(res[0].String, Equals, "2")
c.Assert(iter.HasNext(), IsTrue)
c.Assert(iter.Next(res), IsNil)
c.Assert(res[0].String, Equals, "3")
c.Assert(iter.HasNext(), IsFalse)
}

0 comments on commit 7caf19b

Please sign in to comment.