From a6a27b71b91fa833474353571363d22b4541c687 Mon Sep 17 00:00:00 2001 From: Ivo Gosemann Date: Mon, 11 Dec 2023 17:33:56 +0100 Subject: [PATCH] fix: make no argument passed validation opt-in --- expectations.go | 33 ++++++++++++++++++++++++++++ expectations_before_go18.go | 3 ++- expectations_before_go18_test.go | 10 +++++++-- expectations_go18.go | 3 ++- expectations_go18_test.go | 19 ++++++++++++---- expectations_test.go | 22 +++++++++++++++++++ sqlmock_go18_test.go | 2 +- sqlmock_test.go | 37 ++++++++++++++++++++++++++------ 8 files changed, 114 insertions(+), 15 deletions(-) diff --git a/expectations.go b/expectations.go index 5adf608..8a6cd44 100644 --- a/expectations.go +++ b/expectations.go @@ -134,11 +134,27 @@ type ExpectedQuery struct { // WithArgs will match given expected args to actual database query arguments. // if at least one argument does not match, it will return an error. For specific // arguments an sqlmock.Argument interface can be used to match an argument. +// Must not be used together with WithoutArgs() func (e *ExpectedQuery) WithArgs(args ...driver.Value) *ExpectedQuery { + if e.noArgs { + panic("WithArgs() and WithoutArgs() must not be used together") + } e.args = args return e } +// WithoutArgs will ensure that no arguments are passed for this query. +// if at least one argument is passed, it will return an error. This allows +// for stricter validation of the query arguments. +// Must no be used together with WithArgs() +func (e *ExpectedQuery) WithoutArgs() *ExpectedQuery { + if len(e.args) > 0 { + panic("WithoutArgs() and WithArgs() must not be used together") + } + e.noArgs = true + return e +} + // RowsWillBeClosed expects this query rows to be closed. func (e *ExpectedQuery) RowsWillBeClosed() *ExpectedQuery { e.rowsMustBeClosed = true @@ -195,11 +211,27 @@ type ExpectedExec struct { // WithArgs will match given expected args to actual database exec operation arguments. // if at least one argument does not match, it will return an error. For specific // arguments an sqlmock.Argument interface can be used to match an argument. +// Must not be used together with WithoutArgs() func (e *ExpectedExec) WithArgs(args ...driver.Value) *ExpectedExec { + if len(e.args) > 0 { + panic("WithArgs() and WithoutArgs() must not be used together") + } e.args = args return e } +// WithoutArgs will ensure that no args are passed for this expected database exec action. +// if at least one argument is passed, it will return an error. This allows for stricter +// validation of the query arguments. +// Must not be used together with WithArgs() +func (e *ExpectedExec) WithoutArgs() *ExpectedExec { + if len(e.args) > 0 { + panic("WithoutArgs() and WithArgs() must not be used together") + } + e.noArgs = true + return e +} + // WillReturnError allows to set an error for expected database exec action func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec { e.err = err @@ -338,6 +370,7 @@ type queryBasedExpectation struct { expectSQL string converter driver.ValueConverter args []driver.Value + noArgs bool // ensure no args are passed } // ExpectedPing is used to manage *sql.DB.Ping expectations. diff --git a/expectations_before_go18.go b/expectations_before_go18.go index 0831863..67c08dc 100644 --- a/expectations_before_go18.go +++ b/expectations_before_go18.go @@ -1,3 +1,4 @@ +//go:build !go1.8 // +build !go1.8 package sqlmock @@ -17,7 +18,7 @@ func (e *ExpectedQuery) WillReturnRows(rows *Rows) *ExpectedQuery { func (e *queryBasedExpectation) argsMatches(args []namedValue) error { if nil == e.args { - if len(args) > 0 { + if e.noArgs && len(args) > 0 { return fmt.Errorf("expected 0, but got %d arguments", len(args)) } return nil diff --git a/expectations_before_go18_test.go b/expectations_before_go18_test.go index 81dc8cf..4234cd6 100644 --- a/expectations_before_go18_test.go +++ b/expectations_before_go18_test.go @@ -1,3 +1,4 @@ +//go:build !go1.8 // +build !go1.8 package sqlmock @@ -9,10 +10,15 @@ import ( ) func TestQueryExpectationArgComparison(t *testing.T) { - e := &queryBasedExpectation{converter: driver.DefaultParameterConverter} + e := &queryBasedExpectation{converter: driver.DefaultParameterConverter, noArgs: true} against := []namedValue{{Value: int64(5), Ordinal: 1}} if err := e.argsMatches(against); err == nil { - t.Error("arguments should not match, since no expectation was set, but argument was passed") + t.Error("arguments should not match, since argument was passed, but noArgs was set") + } + + e.noArgs = false + if err := e.argsMatches(against); err != nil { + t.Error("arguments should match, since argument was passed, but no expected args or noArgs was set") } e.args = []driver.Value{5, "str"} diff --git a/expectations_go18.go b/expectations_go18.go index 767ebd4..07227ed 100644 --- a/expectations_go18.go +++ b/expectations_go18.go @@ -1,3 +1,4 @@ +//go:build go1.8 // +build go1.8 package sqlmock @@ -30,7 +31,7 @@ func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery { func (e *queryBasedExpectation) argsMatches(args []driver.NamedValue) error { if nil == e.args { - if len(args) > 0 { + if e.noArgs && len(args) > 0 { return fmt.Errorf("expected 0, but got %d arguments", len(args)) } return nil diff --git a/expectations_go18_test.go b/expectations_go18_test.go index d5638bc..cd633b7 100644 --- a/expectations_go18_test.go +++ b/expectations_go18_test.go @@ -1,3 +1,4 @@ +//go:build go1.8 // +build go1.8 package sqlmock @@ -10,10 +11,15 @@ import ( ) func TestQueryExpectationArgComparison(t *testing.T) { - e := &queryBasedExpectation{converter: driver.DefaultParameterConverter} + e := &queryBasedExpectation{converter: driver.DefaultParameterConverter, noArgs: true} against := []driver.NamedValue{{Value: int64(5), Ordinal: 1}} if err := e.argsMatches(against); err == nil { - t.Error("arguments should not match, since no expectation was set, but argument was passed") + t.Error("arguments should not match, since argument was passed, but noArgs was set") + } + + e.noArgs = false + if err := e.argsMatches(against); err != nil { + t.Error("arguments should match, since argument was passed, but no expected args or noArgs was set") } e.args = []driver.Value{5, "str"} @@ -102,10 +108,15 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) { } func TestQueryExpectationNamedArgComparison(t *testing.T) { - e := &queryBasedExpectation{converter: driver.DefaultParameterConverter} + e := &queryBasedExpectation{converter: driver.DefaultParameterConverter, noArgs: true} against := []driver.NamedValue{{Value: int64(5), Name: "id"}} if err := e.argsMatches(against); err == nil { - t.Errorf("arguments should not match, since no expectation was set, but argument was passed") + t.Error("arguments should not match, since argument was passed, but noArgs was set") + } + + e.noArgs = false + if err := e.argsMatches(against); err != nil { + t.Error("arguments should match, since argument was passed, but no expected args or noArgs was set") } e.args = []driver.Value{ diff --git a/expectations_test.go b/expectations_test.go index afda582..cf0251a 100644 --- a/expectations_test.go +++ b/expectations_test.go @@ -101,3 +101,25 @@ func TestCustomValueConverterQueryScan(t *testing.T) { t.Error(err) } } + +func TestQueryWithNoArgsAndWithArgsPanic(t *testing.T) { + defer func() { + if r := recover(); r != nil { + return + } + t.Error("Expected panic for using WithArgs and ExpectNoArgs together") + }() + mock := &sqlmock{} + mock.ExpectQuery("SELECT (.+) FROM user").WithArgs("John").WithoutArgs() +} + +func TestExecWithNoArgsAndWithArgsPanic(t *testing.T) { + defer func() { + if r := recover(); r != nil { + return + } + t.Error("Expected panic for using WithArgs and ExpectNoArgs together") + }() + mock := &sqlmock{} + mock.ExpectExec("^INSERT INTO user").WithArgs("John").WithoutArgs() +} diff --git a/sqlmock_go18_test.go b/sqlmock_go18_test.go index cf56e67..6267f38 100644 --- a/sqlmock_go18_test.go +++ b/sqlmock_go18_test.go @@ -1,3 +1,4 @@ +//go:build go1.8 // +build go1.8 package sqlmock @@ -437,7 +438,6 @@ func TestContextExecErrorDelay(t *testing.T) { // test that return of error is delayed var delay time.Duration = 100 * time.Millisecond mock.ExpectExec("^INSERT INTO articles"). - WithArgs("hello"). WillReturnError(errors.New("slow fail")). WillDelayFor(delay) diff --git a/sqlmock_test.go b/sqlmock_test.go index 982a32a..2129a16 100644 --- a/sqlmock_test.go +++ b/sqlmock_test.go @@ -749,6 +749,16 @@ func TestRunExecsWithExpectedErrorMeetsExpectations(t *testing.T) { } } +func TestRunExecsWithNoArgsExpectedMeetsExpectations(t *testing.T) { + db, dbmock, _ := New() + dbmock.ExpectExec("THE FIRST EXEC").WithoutArgs().WillReturnResult(NewResult(0, 0)) + + _, err := db.Exec("THE FIRST EXEC", "foobar") + if err == nil { + t.Fatalf("expected error, but there wasn't any") + } +} + func TestRunQueryWithExpectedErrorMeetsExpectations(t *testing.T) { db, dbmock, _ := New() dbmock.ExpectQuery("THE FIRST QUERY").WillReturnError(fmt.Errorf("big bad bug")) @@ -959,7 +969,7 @@ func TestPrepareExec(t *testing.T) { mock.ExpectBegin() ep := mock.ExpectPrepare("INSERT INTO ORDERS\\(ID, STATUS\\) VALUES \\(\\?, \\?\\)") for i := 0; i < 3; i++ { - ep.ExpectExec().WithArgs(i, "Hello"+strconv.Itoa(i)).WillReturnResult(NewResult(1, 1)) + ep.ExpectExec().WillReturnResult(NewResult(1, 1)) } mock.ExpectCommit() tx, _ := db.Begin() @@ -1073,7 +1083,7 @@ func TestPreparedStatementCloseExpectation(t *testing.T) { defer db.Close() ep := mock.ExpectPrepare("INSERT INTO ORDERS").WillBeClosed() - ep.ExpectExec().WithArgs(1, "Hello").WillReturnResult(NewResult(1, 1)) + ep.ExpectExec().WillReturnResult(NewResult(1, 1)) stmt, err := db.Prepare("INSERT INTO ORDERS(ID, STATUS) VALUES (?, ?)") if err != nil { @@ -1104,7 +1114,6 @@ func TestExecExpectationErrorDelay(t *testing.T) { // test that return of error is delayed var delay time.Duration = 100 * time.Millisecond mock.ExpectExec("^INSERT INTO articles"). - WithArgs("hello"). WillReturnError(errors.New("slow fail")). WillDelayFor(delay) @@ -1230,10 +1239,10 @@ func Test_sqlmock_Prepare_and_Exec(t *testing.T) { mock.ExpectPrepare("SELECT (.+) FROM users WHERE (.+)") expected := NewResult(1, 1) - mock.ExpectExec("SELECT (.+) FROM users WHERE (.+)").WithArgs("test"). + mock.ExpectExec("SELECT (.+) FROM users WHERE (.+)"). WillReturnResult(expected) expectedRows := mock.NewRows([]string{"id", "name", "email"}).AddRow(1, "test", "test@example.com") - mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WithArgs("test").WillReturnRows(expectedRows) + mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WillReturnRows(expectedRows) got, err := mock.(*sqlmock).Prepare(query) if err != nil { @@ -1326,7 +1335,7 @@ func Test_sqlmock_Query(t *testing.T) { } defer db.Close() expectedRows := mock.NewRows([]string{"id", "name", "email"}).AddRow(1, "test", "test@example.com") - mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WithArgs("test").WillReturnRows(expectedRows) + mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WillReturnRows(expectedRows) query := "SELECT name, email FROM users WHERE name = ?" rows, err := mock.(*sqlmock).Query(query, []driver.Value{"test"}) if err != nil { @@ -1340,3 +1349,19 @@ func Test_sqlmock_Query(t *testing.T) { return } } + +func Test_sqlmock_QueryExpectWithoutArgs(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + expectedRows := mock.NewRows([]string{"id", "name", "email"}).AddRow(1, "test", "test@example.com") + mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WillReturnRows(expectedRows).WithoutArgs() + query := "SELECT name, email FROM users WHERE name = ?" + _, err = mock.(*sqlmock).Query(query, []driver.Value{"test"}) + if err == nil { + t.Errorf("error expected") + return + } +}