diff --git a/assert_test.go b/assert_test.go index 37b47333a..b38c96591 100644 --- a/assert_test.go +++ b/assert_test.go @@ -34,6 +34,10 @@ func assertTrueF(t *testing.T, actual bool, descriptions ...string) { fatalOnNonEmpty(t, validateEqual(actual, true, descriptions...)) } +func assertFalseF(t *testing.T, actual bool, descriptions ...string) { + fatalOnNonEmpty(t, validateEqual(actual, false, descriptions...)) +} + func assertStringContainsE(t *testing.T, actual string, expectedToContain string, descriptions ...string) { errorOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...)) } diff --git a/chunk_downloader.go b/chunk_downloader.go index 91071843e..768e20bf5 100644 --- a/chunk_downloader.go +++ b/chunk_downloader.go @@ -271,6 +271,9 @@ func getChunk( } func (scd *snowflakeChunkDownloader) startArrowBatches() error { + if scd.RowSet.RowSetBase64 == "" { + return nil + } var err error chunkMetaLen := len(scd.ChunkMetas) var loc *time.Location diff --git a/chunk_downloader_test.go b/chunk_downloader_test.go index 5d55f9f0a..b63ef5736 100644 --- a/chunk_downloader_test.go +++ b/chunk_downloader_test.go @@ -2,6 +2,7 @@ package gosnowflake import ( "context" + "database/sql/driver" "testing" ) @@ -26,3 +27,86 @@ func TestChunkDownloaderDoesNotStartWhenArrowParsingCausesError(t *testing.T) { }) } } + +func TestWithArrowBatchesWhenQueryReturnsNoRowsWhenUsingNativeGoSQLInterface(t *testing.T) { + runDBTest(t, func(dbt *DBTest) { + var rows driver.Rows + var err error + err = dbt.conn.Raw(func(x interface{}) error { + rows, err = x.(driver.QueryerContext).QueryContext(WithArrowBatches(context.Background()), "SELECT 1 WHERE 0 = 1", nil) + return err + }) + assertNilF(t, err) + rows.Close() + }) +} + +func TestWithArrowBatchesWhenQueryReturnsNoRows(t *testing.T) { + runDBTest(t, func(dbt *DBTest) { + rows := dbt.mustQueryContext(WithArrowBatches(context.Background()), "SELECT 1") + defer rows.Close() + assertFalseF(t, rows.Next()) + }) +} + +func TestWithArrowBatchesWhenQueryReturnsSomeRowsInGivenFormatUsingNativeGoSQLInterface(t *testing.T) { + for _, tc := range []struct { + useJSON bool + desc string + }{ + { + useJSON: true, + desc: "json", + }, + { + useJSON: false, + desc: "arrow", + }, + } { + t.Run(tc.desc, func(t *testing.T) { + runDBTest(t, func(dbt *DBTest) { + if tc.useJSON { + dbt.mustExec(forceJSON) + } + var rows driver.Rows + var err error + err = dbt.conn.Raw(func(x interface{}) error { + rows, err = x.(driver.QueryerContext).QueryContext(WithArrowBatches(context.Background()), "SELECT 1", nil) + return err + }) + assertNilF(t, err) + defer rows.Close() + values := make([]driver.Value, 1) + rows.Next(values) + assertEqualE(t, values[0], nil) + }) + }) + } +} + +func TestWithArrowBatchesWhenQueryReturnsSomeRowsInGivenFormat(t *testing.T) { + for _, tc := range []struct { + useJSON bool + desc string + }{ + { + useJSON: true, + desc: "json", + }, + { + useJSON: false, + desc: "arrow", + }, + } { + t.Run(tc.desc, func(t *testing.T) { + runDBTest(t, func(dbt *DBTest) { + if tc.useJSON { + dbt.mustExec(forceJSON) + } + rows := dbt.mustQueryContext(WithArrowBatches(context.Background()), "SELECT 1") + defer rows.Close() + assertFalseF(t, rows.Next()) + }) + }) + } +}