diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index bb403c60d..53aaf238e 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -378,7 +378,7 @@ jobs: - name: Install tox run: python -m pip install tox>=4 - name: Run tests - run: python -m tox run -e `echo py${PYTHON_VERSION/\./}-aio-ci` + run: python -m tox run -e aio env: PYTHON_VERSION: ${{ matrix.python-version }} cloud_provider: ${{ matrix.cloud-provider }} diff --git a/src/snowflake/connector/aio/_connection.py b/src/snowflake/connector/aio/_connection.py index 8d0ba0996..180117a5b 100644 --- a/src/snowflake/connector/aio/_connection.py +++ b/src/snowflake/connector/aio/_connection.py @@ -265,10 +265,13 @@ async def _all_async_queries_finished(self) -> bool: async def async_query_check_helper( sfq_id: str, ) -> bool: - nonlocal found_unfinished_query - return found_unfinished_query or self.is_still_running( - await self.get_query_status(sfq_id) - ) + try: + nonlocal found_unfinished_query + return found_unfinished_query or self.is_still_running( + await self.get_query_status(sfq_id) + ) + except asyncio.CancelledError: + pass tasks = [ asyncio.create_task(async_query_check_helper(sfqid)) for sfqid in queries @@ -279,6 +282,7 @@ async def async_query_check_helper( break for task in tasks: task.cancel() + await asyncio.gather(*tasks) return not found_unfinished_query async def _authenticate(self, auth_instance: AuthByPlugin): diff --git a/src/snowflake/connector/aio/_cursor.py b/src/snowflake/connector/aio/_cursor.py index cfb586281..c71c2b3e7 100644 --- a/src/snowflake/connector/aio/_cursor.py +++ b/src/snowflake/connector/aio/_cursor.py @@ -6,9 +6,11 @@ import asyncio import collections +import logging import re import signal import sys +import typing import uuid from logging import getLogger from types import TracebackType @@ -30,8 +32,15 @@ create_batches_from_response, ) from snowflake.connector.aio._result_set import ResultSet, ResultSetIterator -from snowflake.connector.constants import PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT -from snowflake.connector.cursor import DESC_TABLE_RE +from snowflake.connector.constants import ( + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT, + QueryStatus, +) +from snowflake.connector.cursor import ( + ASYNC_NO_DATA_MAX_RETRY, + ASYNC_RETRY_PATTERN, + DESC_TABLE_RE, +) from snowflake.connector.cursor import DictCursor as DictCursorSync from snowflake.connector.cursor import ResultMetadata, ResultMetadataV2, ResultState from snowflake.connector.cursor import SnowflakeCursor as SnowflakeCursorSync @@ -43,7 +52,7 @@ ER_INVALID_VALUE, ER_NOT_POSITIVE_SIZE, ) -from snowflake.connector.errors import BindUploadError +from snowflake.connector.errors import BindUploadError, DatabaseError from snowflake.connector.file_transfer_agent import SnowflakeProgressPercentage from snowflake.connector.telemetry import TelemetryField from snowflake.connector.time_util import get_time_millis @@ -65,9 +74,11 @@ def __init__( ): super().__init__(connection, use_dict_result) # the following fixes type hint - self._connection: SnowflakeConnection = connection + self._connection = typing.cast("SnowflakeConnection", self._connection) + self._inner_cursor = typing.cast(SnowflakeCursor, self._inner_cursor) self._lock_canceling = asyncio.Lock() self._timebomb: asyncio.Task | None = None + self._prefetch_hook: typing.Callable[[], typing.Awaitable] | None = None def __aiter__(self): return self @@ -87,6 +98,18 @@ async def __anext__(self): async def __aenter__(self): return self + def __enter__(self): + # async cursor does not support sync context manager + raise TypeError( + "'SnowflakeCursor' object does not support the context manager protocol" + ) + + def __exit__(self, exc_type, exc_val, exc_tb): + # async cursor does not support sync context manager + raise TypeError( + "'SnowflakeCursor' object does not support the context manager protocol" + ) + def __del__(self): # do nothing in async, __del__ is unreliable pass @@ -337,6 +360,7 @@ async def _init_result_and_meta(self, data: dict[Any, Any]) -> None: self._total_rowcount += updated_rows async def _init_multi_statement_results(self, data: dict) -> None: + # TODO: async telemetry SNOW-1572217 # self._log_telemetry_job_data(TelemetryField.MULTI_STATEMENT, TelemetryData.TRUE) self.multi_statement_savedIds = data["resultIds"].split(",") self._multi_statement_resultIds = collections.deque( @@ -357,7 +381,45 @@ async def _init_multi_statement_results(self, data: dict) -> None: async def _log_telemetry_job_data( self, telemetry_field: TelemetryField, value: Any ) -> None: - raise NotImplementedError("Telemetry is not supported in async.") + # TODO: async telemetry SNOW-1572217 + pass + + async def _preprocess_pyformat_query( + self, + command: str, + params: Sequence[Any] | dict[Any, Any] | None = None, + ) -> str: + # pyformat/format paramstyle + # client side binding + processed_params = self._connection._process_params_pyformat(params, self) + # SNOW-513061 collect telemetry for empty sequence usage before we make the breaking change announcement + # TODO: async telemetry support + # if params is not None and len(params) == 0: + # await self._log_telemetry_job_data( + # TelemetryField.EMPTY_SEQ_INTERPOLATION, + # ( + # TelemetryData.TRUE + # if self.connection._interpolate_empty_sequences + # else TelemetryData.FALSE + # ), + # ) + if logger.getEffectiveLevel() <= logging.DEBUG: + logger.debug( + f"binding: [{self._format_query_for_log(command)}] " + f"with input=[{params}], " + f"processed=[{processed_params}]", + ) + if ( + self.connection._interpolate_empty_sequences + and processed_params is not None + ) or ( + not self.connection._interpolate_empty_sequences + and len(processed_params) > 0 + ): + query = command % processed_params + else: + query = command + return query async def abort_query(self, qid: str) -> bool: url = f"/queries/{qid}/abort-request" @@ -387,6 +449,10 @@ async def callproc(self, procname: str, args=tuple()): await self.execute(command, args) return args + @property + def connection(self) -> SnowflakeConnection: + return self._connection + async def close(self): """Closes the cursor object. @@ -471,7 +537,7 @@ async def execute( } if self._connection.is_pyformat: - query = self._preprocess_pyformat_query(command, params) + query = await self._preprocess_pyformat_query(command, params) else: # qmark and numeric paramstyle query = command @@ -538,7 +604,7 @@ async def execute( self._connection.converter.set_parameter(param, value) if "resultIds" in data: - self._init_multi_statement_results(data) + await self._init_multi_statement_results(data) return self else: self.multi_statement_savedIds = [] @@ -707,7 +773,7 @@ async def executemany( command = command + "; " if self._connection.is_pyformat: processed_queries = [ - self._preprocess_pyformat_query(command, params) + await self._preprocess_pyformat_query(command, params) for params in seqparams ] query = "".join(processed_queries) @@ -752,7 +818,7 @@ async def describe(self, *args: Any, **kwargs: Any) -> list[ResultMetadata]: async def fetchone(self) -> dict | tuple | None: """Fetches one row.""" if self._prefetch_hook is not None: - self._prefetch_hook() + await self._prefetch_hook() if self._result is None and self._result_set is not None: self._result: ResultSetIterator = await self._result_set._create_iter() self._result_state = ResultState.VALID @@ -804,7 +870,7 @@ async def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]: async def fetchall(self) -> list[tuple] | list[dict]: """Fetches all of the results.""" if self._prefetch_hook is not None: - self._prefetch_hook() + await self._prefetch_hook() if self._result is None and self._result_set is not None: self._result: ResultSetIterator = await self._result_set._create_iter( is_fetch_all=True @@ -822,9 +888,10 @@ async def fetchall(self) -> list[tuple] | list[dict]: async def fetch_arrow_batches(self) -> AsyncIterator[Table]: self.check_can_use_arrow_resultset() if self._prefetch_hook is not None: - self._prefetch_hook() + await self._prefetch_hook() if self._query_result_format != "arrow": raise NotSupportedError + # TODO: async telemetry SNOW-1572217 # self._log_telemetry_job_data( # TelemetryField.ARROW_FETCH_BATCHES, TelemetryData.TRUE # ) @@ -848,9 +915,10 @@ async def fetch_arrow_all(self, force_return_table: bool = False) -> Table | Non self.check_can_use_arrow_resultset() if self._prefetch_hook is not None: - self._prefetch_hook() + await self._prefetch_hook() if self._query_result_format != "arrow": raise NotSupportedError + # TODO: async telemetry SNOW-1572217 # self._log_telemetry_job_data(TelemetryField.ARROW_FETCH_ALL, TelemetryData.TRUE) return await self._result_set._fetch_arrow_all( force_return_table=force_return_table @@ -860,7 +928,7 @@ async def fetch_pandas_batches(self, **kwargs: Any) -> AsyncIterator[DataFrame]: """Fetches a single Arrow Table.""" self.check_can_use_pandas() if self._prefetch_hook is not None: - self._prefetch_hook() + await self._prefetch_hook() if self._query_result_format != "arrow": raise NotSupportedError # TODO: async telemetry @@ -872,7 +940,7 @@ async def fetch_pandas_batches(self, **kwargs: Any) -> AsyncIterator[DataFrame]: async def fetch_pandas_all(self, **kwargs: Any) -> DataFrame: self.check_can_use_pandas() if self._prefetch_hook is not None: - self._prefetch_hook() + await self._prefetch_hook() if self._query_result_format != "arrow": raise NotSupportedError # # TODO: async telemetry @@ -917,8 +985,70 @@ async def get_result_batches(self) -> list[ResultBatch] | None: return self._result_set.batches async def get_results_from_sfqid(self, sfqid: str) -> None: - """Gets the results from previously ran query.""" - raise NotImplementedError("Not implemented in async") + """Gets the results from previously ran query. This methods differs from ``SnowflakeCursor.query_result`` + in that it monitors the ``sfqid`` until it is no longer running, and then retrieves the results. + """ + + async def wait_until_ready() -> None: + """Makes sure query has finished executing and once it has retrieves results.""" + no_data_counter = 0 + retry_pattern_pos = 0 + while True: + status, status_resp = await self.connection._get_query_status(sfqid) + self.connection._cache_query_status(sfqid, status) + if not self.connection.is_still_running(status): + break + if status == QueryStatus.NO_DATA: # pragma: no cover + no_data_counter += 1 + if no_data_counter > ASYNC_NO_DATA_MAX_RETRY: + raise DatabaseError( + "Cannot retrieve data on the status of this query. No information returned " + "from server for query '{}'" + ) + await asyncio.sleep( + 0.5 * ASYNC_RETRY_PATTERN[retry_pattern_pos] + ) # Same wait as JDBC + # If we can advance in ASYNC_RETRY_PATTERN then do so + if retry_pattern_pos < (len(ASYNC_RETRY_PATTERN) - 1): + retry_pattern_pos += 1 + if status != QueryStatus.SUCCESS: + logger.info(f"Status of query '{sfqid}' is {status.name}") + self.connection._process_error_query_status( + sfqid, + status_resp, + error_message=f"Status of query '{sfqid}' is {status.name}, results are unavailable", + error_cls=DatabaseError, + ) + await self._inner_cursor.execute( + f"select * from table(result_scan('{sfqid}'))" + ) + self._result = self._inner_cursor._result + self._query_result_format = self._inner_cursor._query_result_format + self._total_rowcount = self._inner_cursor._total_rowcount + self._description = self._inner_cursor._description + self._result_set = self._inner_cursor._result_set + self._result_state = ResultState.VALID + self._rownumber = 0 + # Unset this function, so that we don't block anymore + self._prefetch_hook = None + + if ( + self._inner_cursor._total_rowcount == 1 + and await self._inner_cursor.fetchall() + == [("Multiple statements executed successfully.",)] + ): + url = f"/queries/{sfqid}/result" + ret = await self._connection.rest.request(url=url, method="get") + if "data" in ret and "resultIds" in ret["data"]: + await self._init_multi_statement_results(ret["data"]) + + await self.connection.get_query_status_throw_if_error( + sfqid + ) # Trigger an exception if query failed + klass = self.__class__ + self._inner_cursor = klass(self.connection) + self._sfqid = sfqid + self._prefetch_hook = wait_until_ready async def query_result(self, qid: str) -> SnowflakeCursor: url = f"/queries/{qid}/result" diff --git a/src/snowflake/connector/aio/_network.py b/src/snowflake/connector/aio/_network.py index 80b6ef8a8..a3eb1b350 100644 --- a/src/snowflake/connector/aio/_network.py +++ b/src/snowflake/connector/aio/_network.py @@ -136,7 +136,7 @@ async def close(self): """Closes all active and idle sessions in this session pool.""" if self._active_sessions: logger.debug(f"Closing {len(self._active_sessions)} active sessions") - for s in itertools.chain(self._active_sessions, self._idle_sessions): + for s in itertools.chain(set(self._active_sessions), set(self._idle_sessions)): try: await s.close() except Exception as e: @@ -289,7 +289,7 @@ async def _token_request(self, request_type): token=header_token, ) if ret.get("success") and ret.get("data", {}).get("sessionToken"): - logger.debug("success: %s", ret) + logger.debug("success: %s", SecretDetector.mask_secrets(str(ret))) await self.update_tokens( ret["data"]["sessionToken"], ret["data"].get("masterToken"), diff --git a/src/snowflake/connector/aio/_result_set.py b/src/snowflake/connector/aio/_result_set.py index 797554e35..4879860f9 100644 --- a/src/snowflake/connector/aio/_result_set.py +++ b/src/snowflake/connector/aio/_result_set.py @@ -31,7 +31,11 @@ from snowflake.connector.options import pandas from snowflake.connector.result_set import ResultSet as ResultSetSync +from .. import NotSupportedError from ..options import pyarrow as pa +from ..result_batch import DownloadMetrics +from ..telemetry import TelemetryField +from ..time_util import get_time_millis if TYPE_CHECKING: from pandas import DataFrame @@ -155,6 +159,16 @@ def __init__( list[JSONResultBatch] | list[ArrowResultBatch], self.batches ) + def _can_create_arrow_iter(self) -> None: + # For now we don't support mixed ResultSets, so assume first partition's type + # represents them all + head_type = type(self.batches[0]) + if head_type != ArrowResultBatch: + raise NotSupportedError( + f"Trying to use arrow fetching on {head_type} which " + f"is not ArrowResultChunk" + ) + async def _create_iter( self, **kwargs, @@ -214,7 +228,7 @@ async def _fetch_arrow_all(self, force_return_table: bool = False) -> Table | No if tables: return pa.concat_tables(tables) else: - return self.batches[0].to_arrow() if force_return_table else None + return await self.batches[0].to_arrow() if force_return_table else None async def _fetch_pandas_batches(self, **kwargs) -> AsyncIterator[DataFrame]: self._can_create_arrow_iter() @@ -238,7 +252,7 @@ async def _fetch_pandas_all(self, **kwargs) -> DataFrame: **concat_kwargs, ) # Empty dataframe - return self.batches[0].to_pandas(**kwargs) + return await self.batches[0].to_pandas(**kwargs) async def _finish_iterating(self) -> None: await self._report_metrics() @@ -246,4 +260,26 @@ async def _finish_iterating(self) -> None: async def _report_metrics(self) -> None: """Report metrics for the result set.""" # TODO: SNOW-1572217 async telemetry - super()._report_metrics() + """Report all metrics totalled up. + + This includes TIME_CONSUME_LAST_RESULT, TIME_DOWNLOADING_CHUNKS and + TIME_PARSING_CHUNKS in that order. + """ + if self._cursor._first_chunk_time is not None: + time_consume_last_result = ( + get_time_millis() - self._cursor._first_chunk_time + ) + await self._cursor._log_telemetry_job_data( + TelemetryField.TIME_CONSUME_LAST_RESULT, time_consume_last_result + ) + metrics = self._get_metrics() + if DownloadMetrics.download.value in metrics: + await self._cursor._log_telemetry_job_data( + TelemetryField.TIME_DOWNLOADING_CHUNKS, + metrics.get(DownloadMetrics.download.value), + ) + if DownloadMetrics.parse.value in metrics: + await self._cursor._log_telemetry_job_data( + TelemetryField.TIME_PARSING_CHUNKS, + metrics.get(DownloadMetrics.parse.value), + ) diff --git a/test/helpers.py b/test/helpers.py index 3f2846e21..19558564e 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -5,6 +5,7 @@ from __future__ import annotations +import asyncio import base64 import functools import math @@ -42,6 +43,10 @@ from snowflake.connector.constants import QueryStatus except ImportError: QueryStatus = None +try: + import snowflake.connector.aio +except ImportError: + pass def create_mock_response(status_code: int) -> Mock: @@ -123,6 +128,40 @@ def _wait_until_query_success( ) +async def _wait_while_query_running_async( + con: snowflake.connector.aio.SnowflakeConnection, + sfqid: str, + sleep_time: int, + dont_cache: bool = False, +) -> None: + """ + Checks if the provided still returns that it is still running, and if so, + sleeps for the specified time in a while loop. + """ + query_status = con._get_query_status if dont_cache else con.get_query_status + while con.is_still_running(await query_status(sfqid)): + await asyncio.sleep(sleep_time) + + +async def _wait_until_query_success_async( + con: snowflake.connector.aio.SnowflakeConnection, + sfqid: str, + num_checks: int, + sleep_per_check: int, +) -> None: + for _ in range(num_checks): + status = await con.get_query_status(sfqid) + if status == QueryStatus.SUCCESS: + break + await asyncio.sleep(sleep_per_check) + else: + pytest.fail( + "We should have broke out of wait loop for query success." + f"Query ID: {sfqid}" + f"Final query status: {status}" + ) + + def create_nanoarrow_pyarrow_iterator(input_data, use_table_iterator): # create nanoarrow based iterator return ( diff --git a/test/integ/aio/__init__.py b/test/integ/aio/__init__.py new file mode 100644 index 000000000..ef416f64a --- /dev/null +++ b/test/integ/aio/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# diff --git a/test/integ/aio/pandas/__init__.py b/test/integ/aio/pandas/__init__.py new file mode 100644 index 000000000..ef416f64a --- /dev/null +++ b/test/integ/aio/pandas/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# diff --git a/test/integ/aio/pandas/test_arrow_chunk_iterator_async.py b/test/integ/aio/pandas/test_arrow_chunk_iterator_async.py new file mode 100644 index 000000000..8ac2ddbee --- /dev/null +++ b/test/integ/aio/pandas/test_arrow_chunk_iterator_async.py @@ -0,0 +1,80 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import datetime +import random +from typing import Callable + +import pytest + +try: + from snowflake.connector.options import installed_pandas +except ImportError: + installed_pandas = False + +try: + import snowflake.connector.nanoarrow_arrow_iterator # NOQA + + no_arrow_iterator_ext = False +except ImportError: + no_arrow_iterator_ext = True + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas option is not installed.", +) +@pytest.mark.parametrize("timestamp_type", ("TZ", "LTZ", "NTZ")) +async def test_iterate_over_timestamp_chunk(conn_cnx, timestamp_type): + seed = datetime.datetime.now().timestamp() + row_numbers = 10 + random.seed(seed) + + # Generate random test data + def generator_test_data(scale: int) -> Callable[[], int]: + def generate_test_data() -> int: + nonlocal scale + epoch = random.randint(-100_355_968, 2_534_023_007) + frac = random.randint(0, 10**scale - 1) + if scale == 8: + frac *= 10 ** (9 - scale) + scale = 9 + return int(f"{epoch}{str(frac).rjust(scale, '0')}") + + return generate_test_data + + test_generators = [generator_test_data(i) for i in range(10)] + test_data = [[g() for g in test_generators] for _ in range(row_numbers)] + + async with conn_cnx( + session_parameters={ + "PYTHON_CONNECTOR_QUERY_RESULT_FORMAT": "ARROW_FORCE", + "TIMESTAMP_TZ_OUTPUT_FORMAT": "YYYY-MM-DD HH24:MI:SS.FF6 TZHTZM", + "TIMESTAMP_LTZ_OUTPUT_FORMAT": "YYYY-MM-DD HH24:MI:SS.FF6 TZHTZM", + "TIMESTAMP_NTZ_OUTPUT_FORMAT": "YYYY-MM-DD HH24:MI:SS.FF6 ", + } + ) as conn: + async with conn.cursor() as cur: + results = await ( + await cur.execute( + "select " + + ", ".join( + f"to_timestamp_{timestamp_type}(${s + 1}, {s if s != 8 else 9}) c_{s}" + for s in range(10) + ) + + ", " + + ", ".join(f"c_{i}::varchar" for i in range(10)) + + f" from values {', '.join(str(tuple(e)) for e in test_data)}" + ) + ).fetch_arrow_all() + retrieved_results = [ + list(map(lambda e: e.as_py().strftime("%Y-%m-%d %H:%M:%S.%f %z"), line)) + for line in list(results)[:10] + ] + retrieved_strigs = [ + list(map(lambda e: e.as_py().replace("Z", "+0000"), line)) + for line in list(results)[10:] + ] + + assert retrieved_results == retrieved_strigs diff --git a/test/integ/aio/pandas/test_arrow_pandas_async.py b/test/integ/aio/pandas/test_arrow_pandas_async.py new file mode 100644 index 000000000..d35558bbe --- /dev/null +++ b/test/integ/aio/pandas/test_arrow_pandas_async.py @@ -0,0 +1,1526 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import decimal +import itertools +import random +import time +from datetime import datetime +from decimal import Decimal +from enum import Enum +from unittest import mock + +import numpy +import pytest +import pytz +from numpy.testing import assert_equal + +try: + from snowflake.connector.constants import ( + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT, + IterUnit, + ) +except ImportError: + # This is because of olddriver tests + class IterUnit(Enum): + ROW_UNIT = "row" + TABLE_UNIT = "table" + + +try: + from snowflake.connector.options import installed_pandas, pandas, pyarrow +except ImportError: + installed_pandas = False + pandas = None + pyarrow = None + +try: + from snowflake.connector.nanoarrow_arrow_iterator import PyArrowIterator # NOQA + + no_arrow_iterator_ext = False +except ImportError: + no_arrow_iterator_ext = True + +SQL_ENABLE_ARROW = "alter session set python_connector_query_result_format='ARROW';" + +EPSILON = 1e-8 + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_num_one(conn_cnx): + print("Test fetching one single dataframe") + row_count = 50000 + col_count = 2 + random_seed = get_random_seed() + sql_exec = ( + f"select seq4() as c1, uniform(1, 10, random({random_seed})) as c2 from " + f"table(generator(rowcount=>{row_count})) order by c1, c2" + ) + await fetch_pandas(conn_cnx, sql_exec, row_count, col_count, "one") + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_scaled_tinyint(conn_cnx): + cases = ["NULL", 0.11, -0.11, "NULL", 1.27, -1.28, "NULL"] + table = "test_arrow_tiny_int" + column = "(a number(5,2))" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_scaled_smallint(conn_cnx): + cases = ["NULL", 0, 0.11, -0.11, "NULL", 32.767, -32.768, "NULL"] + table = "test_arrow_small_int" + column = "(a number(5,3))" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_scaled_int(conn_cnx): + cases = [ + "NULL", + 0, + "NULL", + 0.123456789, + -0.123456789, + 2.147483647, + -2.147483648, + "NULL", + ] + table = "test_arrow_int" + column = "(a number(10,9))" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is not installed.", +) +async def test_scaled_bigint(conn_cnx): + cases = [ + "NULL", + 0, + "NULL", + "1.23456789E-10", + "-1.23456789E-10", + "2.147483647E-9", + "-2.147483647E-9", + "-1e-9", + "1e-9", + "1e-8", + "-1e-8", + "NULL", + ] + table = "test_arrow_big_int" + column = "(a number(38,18))" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one", epsilon=EPSILON) + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_decimal(conn_cnx): + cases = [ + "NULL", + 0, + "NULL", + "10000000000000000000000000000000000000", + "12345678901234567890123456789012345678", + "99999999999999999999999999999999999999", + "-1000000000000000000000000000000000000", + "-2345678901234567890123456789012345678", + "-9999999999999999999999999999999999999", + "NULL", + ] + table = "test_arrow_decimal" + column = "(a number(38,0))" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one", data_type="decimal") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is not installed.", +) +async def test_scaled_decimal(conn_cnx): + cases = [ + "NULL", + 0, + "NULL", + "1.0000000000000000000000000000000000000", + "1.2345678901234567890123456789012345678", + "9.9999999999999999999999999999999999999", + "-1.000000000000000000000000000000000000", + "-2.345678901234567890123456789012345678", + "-9.999999999999999999999999999999999999", + "NULL", + ] + table = "test_arrow_decimal" + column = "(a number(38,37))" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one", data_type="decimal") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is not installed.", +) +async def test_scaled_decimal_SNOW_133561(conn_cnx): + cases = [ + "NULL", + 0, + "NULL", + "1.2345", + "2.1001", + "2.2001", + "2.3001", + "2.3456", + "-9.999", + "-1.000", + "-3.4567", + "3.4567", + "4.5678", + "5.6789", + "-0.0012", + "NULL", + ] + table = "test_scaled_decimal_SNOW_133561" + column = "(a number(38,10))" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one", data_type="float") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_boolean(conn_cnx): + cases = ["NULL", True, "NULL", False, True, True, "NULL", True, False, "NULL"] + table = "test_arrow_boolean" + column = "(a boolean)" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_double(conn_cnx): + cases = [ + "NULL", + # SNOW-31249 + "-86.6426540296895", + "3.14159265359", + # SNOW-76269 + "1.7976931348623157E308", + "1.7E308", + "1.7976931348623151E308", + "-1.7976931348623151E308", + "-1.7E308", + "-1.7976931348623157E308", + "NULL", + ] + table = "test_arrow_double" + column = "(a double)" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_semi_struct(conn_cnx): + sql_text = """ + select array_construct(10, 20, 30), + array_construct(null, 'hello', 3::double, 4, 5), + array_construct(), + object_construct('a',1,'b','BBBB', 'c',null), + object_construct('Key_One', parse_json('NULL'), 'Key_Two', null, 'Key_Three', 'null'), + to_variant(3.2), + parse_json('{ "a": null}'), + 100::variant; + """ + res = [ + "[\n" + " 10,\n" + " 20,\n" + " 30\n" + "]", + "[\n" + + " undefined,\n" + + ' "hello",\n' + + " 3.000000000000000e+00,\n" + + " 4,\n" + + " 5\n" + + "]", + "[]", + "{\n" + ' "a": 1,\n' + ' "b": "BBBB"\n' + "}", + "{\n" + ' "Key_One": null,\n' + ' "Key_Three": "null"\n' + "}", + "3.2", + "{\n" + ' "a": null\n' + "}", + "100", + ] + async with conn_cnx() as cnx_table: + # fetch dataframe with new arrow support + cursor_table = cnx_table.cursor() + await cursor_table.execute(SQL_ENABLE_ARROW) + await cursor_table.execute(sql_text) + df_new = await cursor_table.fetch_pandas_all() + col_new = df_new.iloc[0] + for j, c_new in enumerate(col_new): + assert res[j] == c_new, ( + "{} column: original value is {}, new value is {}, " + "values are not equal".format(j, res[j], c_new) + ) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_date(conn_cnx): + cases = [ + "NULL", + "2017-01-01", + "2014-01-02", + "2014-01-02", + "1970-01-01", + "1970-01-01", + "NULL", + "1969-12-31", + "0200-02-27", + "NULL", + "0200-02-28", + # "0200-02-29", # day is out of range + # "0000-01-01", # year 0 is out of range + "0001-12-31", + "NULL", + ] + table = "test_arrow_date" + column = "(a date)" + values = ",".join( + [f"({i}, {c})" if c == "NULL" else f"({i}, '{c}')" for i, c in enumerate(cases)] + ) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas(conn, sql_text, cases, 1, "one", data_type="date") + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +@pytest.mark.parametrize("scale", [i for i in range(10)]) +async def test_time(conn_cnx, scale): + cases = [ + "NULL", + "00:00:51", + "01:09:03.100000", + "02:23:23.120000", + "03:56:23.123000", + "04:56:53.123400", + "09:01:23.123450", + "11:03:29.123456", + # note: Python's max time precision is microsecond, rest of them will lose precision + # "15:31:23.1234567", + # "19:01:43.12345678", + # "23:59:59.99999999", + "NULL", + ] + table = "test_arrow_time" + column = f"(a time({scale}))" + values = ",".join( + [f"({i}, {c})" if c == "NULL" else f"({i}, '{c}')" for i, c in enumerate(cases)] + ) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas( + conn, sql_text, cases, 1, "one", data_type="time", scale=scale + ) + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +@pytest.mark.parametrize("scale", [i for i in range(10)]) +async def test_timestampntz(conn_cnx, scale): + cases = [ + "NULL", + "1970-01-01 00:00:00", + "1970-01-01 00:00:01", + "1970-01-01 00:00:10", + "2014-01-02 16:00:00", + "2014-01-02 12:34:56", + "2017-01-01 12:00:00.123456789", + "2014-01-02 16:00:00.000000001", + "NULL", + "2014-01-02 12:34:57.1", + "1969-12-31 23:59:59.000000001", + "1970-01-01 00:00:00.123412423", + "1970-01-01 00:00:01.000001", + "1969-12-31 11:59:59.001", + # "0001-12-31 11:59:59.11", + # pandas._libs.tslibs.np_datetime.OutOfBoundsDatetime: + # Out of bounds nanosecond timestamp: 1-12-31 11:59:59 + "NULL", + ] + table = "test_arrow_timestamp" + column = f"(a timestampntz({scale}))" + + values = ",".join( + [f"({i}, {c})" if c == "NULL" else f"({i}, '{c}')" for i, c in enumerate(cases)] + ) + async with conn_cnx() as conn: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + await validate_pandas( + conn, sql_text, cases, 1, "one", data_type="timestamp", scale=scale + ) + await finish(conn, table) + + +@pytest.mark.parametrize( + "timestamp_str", + [ + "'1400-01-01 01:02:03.123456789'::timestamp as low_ts", + "'9999-01-01 01:02:03.123456789789'::timestamp as high_ts", + ], +) +async def test_timestampntz_raises_overflow(conn_cnx, timestamp_str): + async with conn_cnx() as conn: + r = await conn.cursor().execute(f"select {timestamp_str}") + with pytest.raises(OverflowError, match="overflows int64 range."): + await r.fetch_arrow_all() + + +async def test_timestampntz_down_scale(conn_cnx): + async with conn_cnx() as conn: + r = await conn.cursor().execute( + "select '1400-01-01 01:02:03.123456'::timestamp as low_ts, '9999-01-01 01:02:03.123456'::timestamp as high_ts" + ) + table = await r.fetch_arrow_all() + lower_dt = table[0][0].as_py() # type: datetime + assert ( + lower_dt.year, + lower_dt.month, + lower_dt.day, + lower_dt.hour, + lower_dt.minute, + lower_dt.second, + lower_dt.microsecond, + ) == (1400, 1, 1, 1, 2, 3, 123456) + higher_dt = table[1][0].as_py() + assert ( + higher_dt.year, + higher_dt.month, + higher_dt.day, + higher_dt.hour, + higher_dt.minute, + higher_dt.second, + higher_dt.microsecond, + ) == (9999, 1, 1, 1, 2, 3, 123456) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +@pytest.mark.parametrize( + "scale, timezone", + itertools.product( + [i for i in range(10)], ["UTC", "America/New_York", "Australia/Sydney"] + ), +) +async def test_timestamptz(conn_cnx, scale, timezone): + cases = [ + "NULL", + "1971-01-01 00:00:00", + "1971-01-11 00:00:01", + "1971-01-01 00:00:10", + "2014-01-02 16:00:00", + "2014-01-02 12:34:56", + "2017-01-01 12:00:00.123456789", + "2014-01-02 16:00:00.000000001", + "NULL", + "2014-01-02 12:34:57.1", + "1969-12-31 23:59:59.000000001", + "1970-01-01 00:00:00.123412423", + "1970-01-01 00:00:01.000001", + "1969-12-31 11:59:59.001", + # "0001-12-31 11:59:59.11", + # pandas._libs.tslibs.np_datetime.OutOfBoundsDatetime: + # Out of bounds nanosecond timestamp: 1-12-31 11:59:59 + "NULL", + ] + table = "test_arrow_timestamp" + column = f"(a timestamptz({scale}))" + values = ",".join( + [f"({i}, {c})" if c == "NULL" else f"({i}, '{c}')" for i, c in enumerate(cases)] + ) + async with conn_cnx() as conn: + await init(conn, table, column, values, timezone=timezone) + sql_text = f"select a from {table} order by s" + await validate_pandas( + conn, + sql_text, + cases, + 1, + "one", + data_type="timestamptz", + scale=scale, + timezone=timezone, + ) + await finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +@pytest.mark.parametrize( + "scale, timezone", + itertools.product( + [i for i in range(10)], ["UTC", "America/New_York", "Australia/Sydney"] + ), +) +async def test_timestampltz(conn_cnx, scale, timezone): + cases = [ + "NULL", + "1970-01-01 00:00:00", + "1970-01-01 00:00:01", + "1970-01-01 00:00:10", + "2014-01-02 16:00:00", + "2014-01-02 12:34:56", + "2017-01-01 12:00:00.123456789", + "2014-01-02 16:00:00.000000001", + "NULL", + "2014-01-02 12:34:57.1", + "1969-12-31 23:59:59.000000001", + "1970-01-01 00:00:00.123412423", + "1970-01-01 00:00:01.000001", + "1969-12-31 11:59:59.001", + # "0001-12-31 11:59:59.11", + # pandas._libs.tslibs.np_datetime.OutOfBoundsDatetime: + # Out of bounds nanosecond timestamp: 1-12-31 11:59:59 + "NULL", + ] + table = "test_arrow_timestamp" + column = f"(a timestampltz({scale}))" + values = ",".join( + [f"({i}, {c})" if c == "NULL" else f"({i}, '{c}')" for i, c in enumerate(cases)] + ) + async with conn_cnx() as conn: + await init(conn, table, column, values, timezone=timezone) + sql_text = f"select a from {table} order by s" + await validate_pandas( + conn, + sql_text, + cases, + 1, + "one", + data_type="timestamp", + scale=scale, + timezone=timezone, + ) + await finish(conn, table) + + +@pytest.mark.skipolddriver +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_vector(conn_cnx, is_public_test): + if is_public_test: + pytest.xfail( + reason="This feature hasn't been rolled out for public Snowflake deployments yet." + ) + tests = [ + ( + "vector(int,3)", + [ + "NULL", + "[1,2,3]::vector(int,3)", + ], + ["NULL", numpy.array([1, 2, 3])], + ), + ( + "vector(float,3)", + [ + "NULL", + "[1.3,2.4,3.5]::vector(float,3)", + ], + ["NULL", numpy.array([1.3, 2.4, 3.5], dtype=numpy.float32)], + ), + ] + for vector_type, cases, typed_cases in tests: + table = "test_arrow_vector" + column = f"(a {vector_type})" + values = [f"{i}, {c}" for i, c in enumerate(cases)] + async with conn_cnx() as conn: + await init_with_insert_select(conn, table, column, values) + # Test general fetches + sql_text = f"select a from {table} order by s" + await validate_pandas( + conn, sql_text, typed_cases, 1, method="one", data_type=vector_type + ) + + # Test empty result sets + cur = conn.cursor() + await cur.execute(f"select a from {table} limit 0") + df = await cur.fetch_pandas_all() + assert len(df) == 0 + assert df.dtypes[0] == "object" + + await finish(conn, table) + + +async def validate_pandas( + cnx_table, + sql, + cases, + col_count, + method="one", + data_type="float", + epsilon=None, + scale=0, + timezone=None, +): + """Tests that parameters can be customized. + + Args: + cnx_table: Connection object. + sql: SQL command for execution. + cases: Test cases. + col_count: Number of columns in dataframe. + method: If method is 'batch', we fetch dataframes in batch. If method is 'one', we fetch a single dataframe + containing all data (Default value = 'one'). + data_type: Defines how to compare values (Default value = 'float'). + epsilon: For comparing double values (Default value = None). + scale: For comparing time values with scale (Default value = 0). + timezone: For comparing timestamp ltz (Default value = None). + """ + + row_count = len(cases) + assert col_count != 0, "# of columns should be larger than 0" + + cursor_table = cnx_table.cursor() + await cursor_table.execute(SQL_ENABLE_ARROW) + await cursor_table.execute(sql) + + # build dataframe + total_rows, total_batches = 0, 0 + start_time = time.time() + + if method == "one": + df_new = await cursor_table.fetch_pandas_all() + total_rows = df_new.shape[0] + else: + for df_new in await cursor_table.fetch_pandas_batches(): + total_rows += df_new.shape[0] + total_batches += 1 + end_time = time.time() + + print(f"new way (fetching {method}) took {end_time - start_time}s") + if method == "batch": + print(f"new way has # of batches : {total_batches}") + await cursor_table.close() + assert ( + total_rows == row_count + ), f"there should be {row_count} rows, but {total_rows} rows" + + # verify the correctness + # only do it when fetch one dataframe + if method == "one": + assert (row_count, col_count) == df_new.shape, ( + "the shape of old dataframe is {}, " + "the shape of new dataframe is {}, " + "shapes are not equal".format((row_count, col_count), df_new.shape) + ) + + for i in range(row_count): + for j in range(col_count): + c_new = df_new.iat[i, j] + if type(cases[i]) is str and cases[i] == "NULL": + assert c_new is None or pandas.isnull(c_new), ( + "{} row, {} column: original value is NULL, " + "new value is {}, values are not equal".format(i, j, c_new) + ) + else: + if data_type == "float": + c_case = float(cases[i]) + elif data_type == "decimal": + c_case = Decimal(cases[i]) + elif data_type == "date": + c_case = datetime.strptime(cases[i], "%Y-%m-%d").date() + elif data_type == "time": + time_str_len = 8 if scale == 0 else 9 + scale + c_case = cases[i].strip()[:time_str_len] + c_new = str(c_new).strip()[:time_str_len] + assert c_new == c_case, ( + "{} row, {} column: original value is {}, " + "new value is {}, " + "values are not equal".format(i, j, cases[i], c_new) + ) + break + elif data_type.startswith("timestamp"): + time_str_len = 19 if scale == 0 else 20 + scale + if timezone: + c_case = pandas.Timestamp( + cases[i][:time_str_len], tz=timezone + ) + if data_type == "timestamptz": + c_case = c_case.tz_convert("UTC") + else: + c_case = pandas.Timestamp(cases[i][:time_str_len]) + assert c_case == c_new, ( + "{} row, {} column: original value is {}, new value is {}, " + "values are not equal".format(i, j, cases[i], c_new) + ) + break + elif data_type.startswith("vector"): + assert numpy.array_equal(cases[i], c_new), ( + "{} row, {} column: original value is {}, new value is {}, " + "values are not equal".format(i, j, cases[i], c_new) + ) + continue + else: + c_case = cases[i] + if epsilon is None: + assert c_case == c_new, ( + "{} row, {} column: original value is {}, new value is {}, " + "values are not equal".format(i, j, cases[i], c_new) + ) + else: + assert abs(c_case - c_new) < epsilon, ( + "{} row, {} column: original value is {}, " + "new value is {}, epsilon is {} \ + values are not equal".format( + i, j, cases[i], c_new, epsilon + ) + ) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_num_batch(conn_cnx): + print("Test fetching dataframes in batch") + row_count = 1000000 + col_count = 2 + random_seed = get_random_seed() + sql_exec = ( + f"select seq4() as c1, uniform(1, 10, random({random_seed})) as c2 from " + f"table(generator(rowcount=>{row_count})) order by c1, c2" + ) + await fetch_pandas(conn_cnx, sql_exec, row_count, col_count, "batch") + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +@pytest.mark.parametrize( + "result_format", + ["pandas", "arrow"], +) +async def test_empty(conn_cnx, result_format): + print("Test fetch empty dataframe") + async with conn_cnx() as cnx: + cursor = cnx.cursor() + await cursor.execute(SQL_ENABLE_ARROW) + await cursor.execute( + "select seq4() as foo, seq4() as bar from table(generator(rowcount=>1)) limit 0" + ) + fetch_all_fn = getattr(cursor, f"fetch_{result_format}_all") + fetch_batches_fn = getattr(cursor, f"fetch_{result_format}_batches") + result = await fetch_all_fn() + if result_format == "pandas": + assert len(list(result)) == 2 + assert list(result)[0] == "FOO" + assert list(result)[1] == "BAR" + else: + assert result is None + + await cursor.execute( + "select seq4() as foo from table(generator(rowcount=>1)) limit 0" + ) + df_count = 0 + async for _ in await fetch_batches_fn(): + df_count += 1 + assert df_count == 0 + + +def get_random_seed(): + random.seed(datetime.now().timestamp()) + return random.randint(0, 10000) + + +async def fetch_pandas(conn_cnx, sql, row_count, col_count, method="one"): + """Tests that parameters can be customized. + + Args: + conn_cnx: Connection object. + sql: SQL command for execution. + row_count: Number of total rows combining all dataframes. + col_count: Number of columns in dataframe. + method: If method is 'batch', we fetch dataframes in batch. If method is 'one', we fetch a single dataframe + containing all data (Default value = 'one'). + """ + assert row_count != 0, "# of rows should be larger than 0" + assert col_count != 0, "# of columns should be larger than 0" + + async with conn_cnx() as conn: + # fetch dataframe by fetching row by row + cursor_row = conn.cursor() + await cursor_row.execute(SQL_ENABLE_ARROW) + await cursor_row.execute(sql) + + # build dataframe + # actually its exec time would be different from `pandas.read_sql()` via sqlalchemy as most people use + # further perf test can be done separately + start_time = time.time() + rows = 0 + if method == "one": + df_old = pandas.DataFrame( + await cursor_row.fetchall(), + columns=[f"c{i}" for i in range(col_count)], + ) + else: + print("use fetchmany") + while True: + dat = await cursor_row.fetchmany(10000) + if not dat: + break + else: + df_old = pandas.DataFrame( + dat, columns=[f"c{i}" for i in range(col_count)] + ) + rows += df_old.shape[0] + end_time = time.time() + print(f"The original way took {end_time - start_time}s") + await cursor_row.close() + + # fetch dataframe with new arrow support + cursor_table = conn.cursor() + await cursor_table.execute(SQL_ENABLE_ARROW) + await cursor_table.execute(sql) + + # build dataframe + total_rows, total_batches = 0, 0 + start_time = time.time() + if method == "one": + df_new = await cursor_table.fetch_pandas_all() + total_rows = df_new.shape[0] + else: + async for df_new in await cursor_table.fetch_pandas_batches(): + total_rows += df_new.shape[0] + total_batches += 1 + end_time = time.time() + print(f"new way (fetching {method}) took {end_time - start_time}s") + if method == "batch": + print(f"new way has # of batches : {total_batches}") + await cursor_table.close() + assert total_rows == row_count, "there should be {} rows, but {} rows".format( + row_count, total_rows + ) + + # verify the correctness + # only do it when fetch one dataframe + if method == "one": + assert ( + df_old.shape == df_new.shape + ), "the shape of old dataframe is {}, the shape of new dataframe is {}, \ + shapes are not equal".format( + df_old.shape, df_new.shape + ) + + for i in range(row_count): + col_old = df_old.iloc[i] + col_new = df_new.iloc[i] + for j, (c_old, c_new) in enumerate(zip(col_old, col_new)): + assert c_old == c_new, ( + f"{i} row, {j} column: old value is {c_old}, new value " + f"is {c_new} values are not equal" + ) + else: + assert ( + rows == total_rows + ), f"the number of rows are not equal {rows} vs {total_rows}" + + +async def init(json_cnx, table, column, values, timezone=None): + cursor_json = json_cnx.cursor() + if timezone is not None: + await cursor_json.execute(f"ALTER SESSION SET TIMEZONE = '{timezone}'") + column_with_seq = column[0] + "s number, " + column[1:] + await cursor_json.execute(f"create or replace table {table} {column_with_seq}") + await cursor_json.execute(f"insert into {table} values {values}") + + +async def init_with_insert_select(json_cnx, table, column, rows, timezone=None): + cursor_json = json_cnx.cursor() + if timezone is not None: + await cursor_json.execute(f"ALTER SESSION SET TIMEZONE = '{timezone}'") + column_with_seq = column[0] + "s number, " + column[1:] + await cursor_json.execute(f"create or replace table {table} {column_with_seq}") + for row in rows: + await cursor_json.execute(f"insert into {table} select {row}") + + +async def finish(json_cnx, table): + cursor_json = json_cnx.cursor() + await cursor_json.execute(f"drop table if exists {table};") + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_arrow_fetch_result_scan(conn_cnx): + async with conn_cnx() as cnx: + cur = cnx.cursor() + await cur.execute("alter session set query_result_format='ARROW_FORCE'") + await cur.execute( + "alter session set python_connector_query_result_format='ARROW_FORCE'" + ) + res = await (await cur.execute("select 1, 2, 3")).fetch_pandas_all() + assert tuple(res) == ("1", "2", "3") + result_scan_res = await ( + await cur.execute(f"select * from table(result_scan('{cur.sfqid}'));") + ).fetch_pandas_all() + assert tuple(result_scan_res) == ("1", "2", "3") + + +@pytest.mark.parametrize("query_format", ("JSON", "ARROW")) +@pytest.mark.parametrize("resultscan_format", ("JSON", "ARROW")) +async def test_query_resultscan_combos(conn_cnx, query_format, resultscan_format): + if query_format == "JSON" and resultscan_format == "ARROW": + pytest.xfail("fix not yet released to test deployment") + async with conn_cnx() as cnx: + sfqid = None + results = None + scanned_results = None + async with cnx.cursor() as query_cur: + await query_cur.execute( + "alter session set python_connector_query_result_format='{}'".format( + query_format + ) + ) + await query_cur.execute( + "select seq8(), randstr(1000,random()) from table(generator(rowcount=>100))" + ) + sfqid = query_cur.sfqid + assert query_cur._query_result_format.upper() == query_format + if query_format == "JSON": + results = await query_cur.fetchall() + else: + results = await query_cur.fetch_pandas_all() + async with cnx.cursor() as resultscan_cur: + await resultscan_cur.execute( + "alter session set python_connector_query_result_format='{}'".format( + resultscan_format + ) + ) + await resultscan_cur.execute(f"select * from table(result_scan('{sfqid}'))") + if resultscan_format == "JSON": + scanned_results = await resultscan_cur.fetchall() + else: + scanned_results = await resultscan_cur.fetch_pandas_all() + assert resultscan_cur._query_result_format.upper() == resultscan_format + if isinstance(results, pandas.DataFrame): + results = [tuple(e) for e in results.values.tolist()] + if isinstance(scanned_results, pandas.DataFrame): + scanned_results = [tuple(e) for e in scanned_results.values.tolist()] + assert results == scanned_results + + +@pytest.mark.parametrize( + "use_decimal,expected", + [ + (False, numpy.float64), + pytest.param(True, decimal.Decimal, marks=pytest.mark.skipolddriver), + ], +) +async def test_number_fetchall_retrieve_type(conn_cnx, use_decimal, expected): + async with conn_cnx(arrow_number_to_decimal=use_decimal) as con: + async with con.cursor() as cur: + await cur.execute("SELECT 12345600.87654301::NUMBER(18, 8) a") + result_df = await cur.fetch_pandas_all() + a_column = result_df["A"] + assert isinstance(a_column.values[0], expected), type(a_column.values[0]) + + +@pytest.mark.parametrize( + "use_decimal,expected", + [ + ( + False, + numpy.float64, + ), + pytest.param(True, decimal.Decimal, marks=pytest.mark.skipolddriver), + ], +) +async def test_number_fetchbatches_retrieve_type( + conn_cnx, use_decimal: bool, expected: type +): + async with conn_cnx(arrow_number_to_decimal=use_decimal) as con: + async with con.cursor() as cur: + await cur.execute("SELECT 12345600.87654301::NUMBER(18, 8) a") + async for batch in await cur.fetch_pandas_batches(): + a_column = batch["A"] + assert isinstance(a_column.values[0], expected), type( + a_column.values[0] + ) + + +async def test_execute_async_and_fetch_pandas_batches(conn_cnx): + """Test get pandas in an asynchronous way""" + + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + await cur.execute("select 1/2") + res_sync = await cur.fetch_pandas_batches() + + result = await cur.execute_async("select 1/2") + await cur.get_results_from_sfqid(result["queryId"]) + res_async = await cur.fetch_pandas_batches() + + assert res_sync is not None + assert res_async is not None + while True: + try: + r_sync = await res_sync.__anext__() + r_async = await res_async.__anext__() + assert r_sync.values == r_async.values + except StopAsyncIteration: + break + + +async def test_execute_async_and_fetch_arrow_batches(conn_cnx): + """Test fetching result of an asynchronous query as batches of arrow tables""" + + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + await cur.execute("select 1/2") + res_sync = await cur.fetch_arrow_batches() + + result = await cur.execute_async("select 1/2") + await cur.get_results_from_sfqid(result["queryId"]) + res_async = await cur.fetch_arrow_batches() + + assert res_sync is not None + assert res_async is not None + while True: + try: + r_sync = await res_sync.__anext__() + r_async = await res_async.__anext__() + assert r_sync == r_async + except StopAsyncIteration: + break + + +async def test_simple_async_pandas(conn_cnx): + """Simple test to that shows the most simple usage of fire and forget. + + This test also makes sure that wait_until_ready function's sleeping is tested and + that some fields are copied over correctly from the original query. + """ + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async( + "select count(*) from table(generator(timeLimit => 5))" + ) + await cur.get_results_from_sfqid(cur.sfqid) + assert len(await cur.fetch_pandas_all()) == 1 + assert cur.rowcount + assert cur.description + + +async def test_simple_async_arrow(conn_cnx): + """Simple test for async fetch_arrow_all""" + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async( + "select count(*) from table(generator(timeLimit => 5))" + ) + await cur.get_results_from_sfqid(cur.sfqid) + assert len(await cur.fetch_arrow_all()) == 1 + assert cur.rowcount + assert cur.description + + +@pytest.mark.parametrize( + "use_decimal,expected", + [ + ( + True, + decimal.Decimal, + ), + pytest.param(False, numpy.float64, marks=pytest.mark.xfail), + ], +) +async def test_number_iter_retrieve_type(conn_cnx, use_decimal: bool, expected: type): + async with conn_cnx(arrow_number_to_decimal=use_decimal) as con: + async with con.cursor() as cur: + await cur.execute("SELECT 12345600.87654301::NUMBER(18, 8) a") + async for row in cur: + assert isinstance(row[0], expected), type(row[0]) + + +async def test_resultbatches_pandas_functionality(conn_cnx): + """Fetch ArrowResultBatches as pandas dataframes and check its result.""" + rowcount = 100000 + expected_df = pandas.DataFrame(data={"A": range(rowcount)}) + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute( + f"select seq4() a from table(generator(rowcount => {rowcount}));" + ) + assert cur._result_set.total_row_index() == rowcount + result_batches = await cur.get_result_batches() + assert (await cur.fetch_pandas_all()).index[-1] == rowcount - 1 + assert len(result_batches) > 1 + + iterables = [] + for b in result_batches: + iterables.append( + list(await b.create_iter(iter_unit=IterUnit.TABLE_UNIT, structure="arrow")) + ) + tables = itertools.chain.from_iterable(iterables) + final_df = pyarrow.concat_tables(tables).to_pandas() + assert numpy.array_equal(expected_df, final_df) + + +@pytest.mark.skip("SNOW-1617451 async telemetry support") +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing. or no new telemetry defined - skipolddrive", +) +@pytest.mark.parametrize( + "fetch_method, expected_telemetry_type", + [ + ("one", "client_fetch_pandas_all"), # TelemetryField.PANDAS_FETCH_ALL + ("batch", "client_fetch_pandas_batches"), # TelemetryField.PANDAS_FETCH_BATCHES + ], +) +async def test_pandas_telemetry( + conn_cnx, capture_sf_telemetry, fetch_method, expected_telemetry_type +): + cases = ["NULL", 0.11, -0.11, "NULL", 1.27, -1.28, "NULL"] + table = "test_telemetry" + column = "(a number(5,2))" + values = ",".join([f"({i}, {c})" for i, c in enumerate(cases)]) + async with conn_cnx() as conn, capture_sf_telemetry.patch_connection( + conn, False + ) as telemetry_test: + await init(conn, table, column, values) + sql_text = f"select a from {table} order by s" + + await validate_pandas( + conn, + sql_text, + cases, + 1, + fetch_method, + ) + + occurence = 0 + for t in telemetry_test.records: + if t.message["type"] == expected_telemetry_type: + occurence += 1 + assert occurence == 1 + + await finish(conn, table) + + +@pytest.mark.parametrize("result_format", ["pandas", "arrow"]) +async def test_batch_to_pandas_arrow(conn_cnx, result_format): + rowcount = 10 + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + await cur.execute(SQL_ENABLE_ARROW) + await cur.execute( + f"select seq4() as foo, seq4() as bar from table(generator(rowcount=>{rowcount})) order by foo asc" + ) + batches = await cur.get_result_batches() + assert len(batches) == 1 + batch = batches[0] + + # check that size, columns, and FOO column data is correct + if result_format == "pandas": + df = await batch.to_pandas() + assert type(df) is pandas.DataFrame + assert df.shape == (10, 2) + assert all(df.columns == ["FOO", "BAR"]) + assert list(df.FOO) == list(range(rowcount)) + elif result_format == "arrow": + arrow_table = await batch.to_arrow() + assert type(arrow_table) is pyarrow.Table + assert arrow_table.shape == (10, 2) + assert arrow_table.column_names == ["FOO", "BAR"] + assert arrow_table.to_pydict()["FOO"] == list(range(rowcount)) + + +@pytest.mark.internal +@pytest.mark.parametrize("enable_structured_types", [True, False]) +async def test_to_arrow_datatypes(enable_structured_types, conn_cnx): + expected_types = ( + pyarrow.int64(), + pyarrow.float64(), + pyarrow.string(), + pyarrow.date64(), + pyarrow.timestamp("ns"), + pyarrow.string(), + pyarrow.timestamp("ns"), + pyarrow.timestamp("ns"), + pyarrow.timestamp("ns"), + pyarrow.binary(), + pyarrow.time64("ns"), + pyarrow.bool_(), + pyarrow.string(), + pyarrow.string(), + pyarrow.list_(pyarrow.float64(), 5), + ) + + query = """ + select + 1 :: INTEGER as FIXED_type, + 2.0 :: FLOAT as REAL_type, + 'test' :: TEXT as TEXT_type, + '2024-02-28' :: DATE as DATE_type, + '2020-03-12 01:02:03.123456789' :: TIMESTAMP as TIMESTAMP_type, + '{"foo": "bar"}' :: VARIANT as VARIANT_type, + '2020-03-12 01:02:03.123456789' :: TIMESTAMP_LTZ as TIMESTAMP_LTZ_type, + '2020-03-12 01:02:03.123456789' :: TIMESTAMP_TZ as TIMESTAMP_TZ_type, + '2020-03-12 01:02:03.123456789' :: TIMESTAMP_NTZ as TIMESTAMP_NTZ_type, + '0xAAAA' :: BINARY as BINARY_type, + '01:02:03.123456789' :: TIME as TIME_type, + true :: BOOLEAN as BOOLEAN_type, + TO_GEOGRAPHY('LINESTRING(13.4814 52.5015, -121.8212 36.8252)') as GEOGRAPHY_type, + TO_GEOMETRY('LINESTRING(13.4814 52.5015, -121.8212 36.8252)') as GEOMETRY_type, + [1,2,3,4,5] :: vector(float, 5) as VECTOR_type, + object_construct('k1', 1, 'k2', 2, 'k3', 3, 'k4', 4, 'k5', 5) :: map(varchar, int) as MAP_type, + object_construct('city', 'san jose', 'population', 0.05) :: object(city varchar, population float) as OBJECT_type, + [1.0, 3.1, 4.5] :: array(float) as ARRAY_type + WHERE 1=0 + """ + + structured_params = { + "ENABLE_STRUCTURED_TYPES_IN_CLIENT_RESPONSE", + "IGNORE_CLIENT_VESRION_IN_STRUCTURED_TYPES_RESPONSE", + "FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT", + } + + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + await cur.execute(SQL_ENABLE_ARROW) + try: + if enable_structured_types: + for param in structured_params: + await cur.execute(f"alter session set {param}=true") + expected_types += ( + pyarrow.map_(pyarrow.string(), pyarrow.int64()), + pyarrow.struct( + {"city": pyarrow.string(), "population": pyarrow.float64()} + ), + pyarrow.list_(pyarrow.float64()), + ) + else: + expected_types += ( + pyarrow.string(), + pyarrow.string(), + pyarrow.string(), + ) + # Ensure an empty batch to use default typing + # Otherwise arrow will resize types to save space + await cur.execute(query) + batches = cur.get_result_batches() + assert len(batches) == 1 + batch = batches[0] + arrow_table = batch.to_arrow() + for actual, expected in zip(arrow_table.schema, expected_types): + assert ( + actual.type == expected + ), f"Expected {actual.name} :: {actual.type} column to be of type {expected}" + finally: + if enable_structured_types: + for param in structured_params: + await cur.execute(f"alter session unset {param}") + + +async def test_simple_arrow_fetch(conn_cnx): + rowcount = 250_000 + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + await cur.execute(SQL_ENABLE_ARROW) + await cur.execute( + f"select seq4() as foo from table(generator(rowcount=>{rowcount})) order by foo asc" + ) + arrow_table = await cur.fetch_arrow_all() + assert arrow_table.shape == (rowcount, 1) + assert arrow_table.to_pydict()["FOO"] == list(range(rowcount)) + + await cur.execute( + f"select seq4() as foo from table(generator(rowcount=>{rowcount})) order by foo asc" + ) + assert ( + len(await cur.get_result_batches()) > 1 + ) # non-trivial number of batches + + # the start and end points of each batch + lo, hi = 0, 0 + async for table in await cur.fetch_arrow_batches(): + assert type(table) is pyarrow.Table # sanity type check + + # check that data is correct + length = len(table) + hi += length + assert table.to_pydict()["FOO"] == list(range(lo, hi)) + lo += length + + assert lo == rowcount + + +async def test_arrow_zero_rows(conn_cnx): + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + await cur.execute(SQL_ENABLE_ARROW) + await cur.execute("select 1::NUMBER(38,0) limit 0") + table = await cur.fetch_arrow_all(force_return_table=True) + # Snowflake will return an integer dtype with maximum bit-length if + # no rows are returned + assert table.schema[0].type == pyarrow.int64() + await cur.execute("select 1::NUMBER(38,0) limit 0") + # test default behavior + assert await cur.fetch_arrow_all(force_return_table=False) is None + + +@pytest.mark.parametrize("fetch_fn_name", ["to_arrow", "to_pandas", "create_iter"]) +@pytest.mark.parametrize("pass_connection", [True, False]) +async def test_sessions_used(conn_cnx, fetch_fn_name, pass_connection): + rowcount = 250_000 + async with conn_cnx() as cnx: + async with cnx.cursor() as cur: + await cur.execute(SQL_ENABLE_ARROW) + await cur.execute( + f"select seq1() from table(generator(rowcount=>{rowcount}))" + ) + batches = await cur.get_result_batches() + assert len(batches) > 1 + batch = batches[-1] + + connection = cnx if pass_connection else None + fetch_fn = getattr(batch, fetch_fn_name) + + # check that sessions are used when connection is supplied + with mock.patch( + "snowflake.connector.aio._network.SnowflakeRestful._use_requests_session", + side_effect=cnx._rest._use_requests_session, + ) as get_session_mock: + await fetch_fn(connection=connection) + assert get_session_mock.call_count == (1 if pass_connection else 0) + + +def assert_dtype_equal(a, b): + """Pandas method of asserting the same numpy dtype of variables by computing hash.""" + assert_equal(a, b) + assert_equal( + hash(a), hash(b), "two equivalent types do not hash to the same value !" + ) + + +def assert_pandas_batch_types( + batch: pandas.DataFrame, expected_types: list[type] +) -> None: + assert batch.dtypes is not None + + pandas_dtypes = batch.dtypes + # pd.string is represented as an np.object + # np.dtype string is not the same as pd.string (python) + for pandas_dtype, expected_type in zip(pandas_dtypes, expected_types): + assert_dtype_equal(pandas_dtype.type, numpy.dtype(expected_type).type) + + +async def test_pandas_dtypes(conn_cnx): + async with conn_cnx( + session_parameters={ + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "arrow_force" + } + ) as cnx: + async with cnx.cursor() as cur: + await cur.execute( + "select 1::integer, 2.3::double, 'foo'::string, current_timestamp()::timestamp where 1=0" + ) + expected_types = [numpy.int64, float, object, numpy.datetime64] + assert_pandas_batch_types(await cur.fetch_pandas_all(), expected_types) + + batches = await cur.get_result_batches() + assert await batches[0].to_arrow() is not True + assert_pandas_batch_types(await batches[0].to_pandas(), expected_types) + + +async def test_timestamp_tz(conn_cnx): + async with conn_cnx( + session_parameters={ + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "arrow_force" + } + ) as cnx: + async with cnx.cursor() as cur: + await cur.execute("select '1990-01-04 10:00:00 +1100'::timestamp_tz as d") + res = await cur.fetchall() + assert res[0][0].tzinfo is not None + res_pd = await cur.fetch_pandas_all() + assert res_pd.D.dt.tz is pytz.UTC + res_pa = await cur.fetch_arrow_all() + assert res_pa.field("D").type.tz == "UTC" + + +async def test_arrow_number_to_decimal(conn_cnx): + async with conn_cnx( + session_parameters={ + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "arrow_force" + }, + arrow_number_to_decimal=True, + ) as cnx: + async with cnx.cursor() as cur: + await cur.execute("select -3.20 as num") + df = await cur.fetch_pandas_all() + val = df.NUM[0] + assert val == Decimal("-3.20") + assert isinstance(val, decimal.Decimal) + + +@pytest.mark.parametrize( + "timestamp_type", + [ + "TIMESTAMP_TZ", + "TIMESTAMP_NTZ", + "TIMESTAMP_LTZ", + ], +) +async def test_time_interval_microsecond(conn_cnx, timestamp_type): + async with conn_cnx( + session_parameters={ + PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "arrow_force" + } + ) as cnx: + async with cnx.cursor() as cur: + res = await ( + await cur.execute( + f"SELECT TO_{timestamp_type}('2010-06-25 12:15:30.747000')+INTERVAL '8999999999999998 MICROSECONDS'" + ) + ).fetchone() + assert res[0].microsecond == 746998 + res = await ( + await cur.execute( + f"SELECT TO_{timestamp_type}('2010-06-25 12:15:30.747000')+INTERVAL '8999999999999999 MICROSECONDS'" + ) + ).fetchone() + assert res[0].microsecond == 746999 + + +async def test_fetch_with_pandas_nullable_types(conn_cnx): + # use several float values to test nullable types. Nullable types can preserve both nan and null in float + sql_text = """ + select 1.0::float, 'NaN'::float, Null::float; + """ + # https://arrow.apache.org/docs/python/pandas.html#nullable-types + dtype_mapping = { + pyarrow.int8(): pandas.Int8Dtype(), + pyarrow.int16(): pandas.Int16Dtype(), + pyarrow.int32(): pandas.Int32Dtype(), + pyarrow.int64(): pandas.Int64Dtype(), + pyarrow.uint8(): pandas.UInt8Dtype(), + pyarrow.uint16(): pandas.UInt16Dtype(), + pyarrow.uint32(): pandas.UInt32Dtype(), + pyarrow.uint64(): pandas.UInt64Dtype(), + pyarrow.bool_(): pandas.BooleanDtype(), + pyarrow.float32(): pandas.Float32Dtype(), + pyarrow.float64(): pandas.Float64Dtype(), + pyarrow.string(): pandas.StringDtype(), + } + + expected_dtypes = pandas.Series( + [pandas.Float64Dtype(), pandas.Float64Dtype(), pandas.Float64Dtype()], + index=["1.0::FLOAT", "'NAN'::FLOAT", "NULL::FLOAT"], + ) + expected_df_to_string = """ 1.0::FLOAT 'NAN'::FLOAT NULL::FLOAT +0 1.0 NaN """ + async with conn_cnx() as cnx_table: + # fetch dataframe with new arrow support + cursor_table = cnx_table.cursor() + await cursor_table.execute(SQL_ENABLE_ARROW) + await cursor_table.execute(sql_text) + # test fetch_pandas_batches + async for df in await cursor_table.fetch_pandas_batches( + types_mapper=dtype_mapping.get + ): + pandas._testing.assert_series_equal(df.dtypes, expected_dtypes) + print(df) + assert df.to_string() == expected_df_to_string + # test fetch_pandas_all + df = await cursor_table.fetch_pandas_all(types_mapper=dtype_mapping.get) + pandas._testing.assert_series_equal(df.dtypes, expected_dtypes) + assert df.to_string() == expected_df_to_string diff --git a/test/integ/aio/pandas/test_logging_async.py b/test/integ/aio/pandas/test_logging_async.py new file mode 100644 index 000000000..9b35d11a8 --- /dev/null +++ b/test/integ/aio/pandas/test_logging_async.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import logging + + +async def test_rand_table_log(caplog, conn_cnx, db_parameters): + async with conn_cnx() as conn: + caplog.set_level(logging.DEBUG, "snowflake.connector") + + num_of_rows = 10 + async with conn.cursor() as cur: + await ( + await cur.execute( + "select randstr(abs(mod(random(), 100)), random()) from table(generator(rowcount => {}));".format( + num_of_rows + ) + ) + ).fetchall() + + # make assertions + has_batch_read = has_batch_size = has_chunk_info = has_batch_index = False + for record in caplog.records: + if "Batches read:" in record.msg: + has_batch_read = True + assert "arrow_iterator" in record.filename + assert "__cinit__" in record.funcName + + if "Arrow BatchSize:" in record.msg: + has_batch_size = True + assert "CArrowIterator.cpp" in record.filename + assert "CArrowIterator" in record.funcName + + if "Arrow chunk info:" in record.msg: + has_chunk_info = True + assert "CArrowChunkIterator.cpp" in record.filename + assert "CArrowChunkIterator" in record.funcName + + if "Current batch index:" in record.msg: + has_batch_index = True + assert "CArrowChunkIterator.cpp" in record.filename + assert "next" in record.funcName + + # each of these records appear at least once in records + assert has_batch_read and has_batch_size and has_chunk_info and has_batch_index diff --git a/test/integ/aio/test_async_async.py b/test/integ/aio/test_async_async.py new file mode 100644 index 000000000..8dcdb936d --- /dev/null +++ b/test/integ/aio/test_async_async.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import logging + +import pytest + +from snowflake.connector import DatabaseError, ProgrammingError +from snowflake.connector.constants import QueryStatus + +# Mark all tests in this file to time out after 2 minutes to prevent hanging forever +pytestmark = pytest.mark.timeout(120) + + +async def test_simple_async(conn_cnx): + """Simple test to that shows the most simple usage of fire and forget. + + This test also makes sure that wait_until_ready function's sleeping is tested and + that some fields are copied over correctly from the original query. + """ + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async( + "select count(*) from table(generator(timeLimit => 5))" + ) + await cur.get_results_from_sfqid(cur.sfqid) + assert len(await cur.fetchall()) == 1 + assert cur.rowcount + assert cur.description + + +async def test_async_result_iteration(conn_cnx): + """Test yielding results of an async query. + + Ensures that wait_until_ready is also called in __iter__() via _prefetch_hook(). + """ + + async def result_generator(query): + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async(query) + await cur.get_results_from_sfqid(cur.sfqid) + async for row in cur: + yield row + + gen = result_generator("select count(*) from table(generator(timeLimit => 5))") + assert await anext(gen) + with pytest.raises(StopAsyncIteration): + await anext(gen) + + +async def test_async_exec(conn_cnx): + """Tests whether simple async query execution works. + + Runs a query that takes a few seconds to finish and then totally closes connection + to Snowflake. Then waits enough time for that query to finish, opens a new connection + and fetches results. It also tests QueryStatus related functionality too. + + This test tends to hang longer than expected when the testing warehouse is overloaded. + Manually looking at query history reveals that when a full GH actions + Jenkins test load hits one warehouse + it can be queued for 15 seconds, so for now we wait 5 seconds before checking and then we give it another 25 + seconds to finish. + """ + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async( + "select count(*) from table(generator(timeLimit => 5))" + ) + q_id = cur.sfqid + status = await con.get_query_status(q_id) + assert con.is_still_running(status) + await asyncio.sleep(5) + async with conn_cnx() as con: + async with con.cursor() as cur: + for _ in range(25): + # Check upto 15 times once a second to see if it's done + status = await con.get_query_status(q_id) + if status == QueryStatus.SUCCESS: + break + await asyncio.sleep(1) + else: + pytest.fail( + f"We should have broke out of this loop, final query status: {status}" + ) + status = await con.get_query_status_throw_if_error(q_id) + assert status == QueryStatus.SUCCESS + await cur.get_results_from_sfqid(q_id) + assert len(await cur.fetchall()) == 1 + + +async def test_async_error(conn_cnx, caplog): + """Tests whether simple async query error retrieval works. + + Runs a query that will fail to execute and then tests that if we tried to get results for the query + then that would raise an exception. It also tests QueryStatus related functionality too. + """ + async with conn_cnx() as con: + async with con.cursor() as cur: + sql = "select * from nonexistentTable" + await cur.execute_async(sql) + q_id = cur.sfqid + with pytest.raises(ProgrammingError) as sync_error: + await cur.execute(sql) + while con.is_still_running(await con.get_query_status(q_id)): + await asyncio.sleep(1) + status = await con.get_query_status(q_id) + assert status == QueryStatus.FAILED_WITH_ERROR + assert con.is_an_error(status) + with pytest.raises(ProgrammingError) as e1: + await con.get_query_status_throw_if_error(q_id) + assert sync_error.value.errno != -1 + with pytest.raises(ProgrammingError) as e2: + await cur.get_results_from_sfqid(q_id) + assert e1.value.errno == e2.value.errno == sync_error.value.errno + + sfqid = (await cur.execute_async("SELECT SYSTEM$WAIT(2)"))["queryId"] + await cur.get_results_from_sfqid(sfqid) + async with con.cursor() as cancel_cursor: + # use separate cursor to cancel as execute will overwrite the previous query status + await cancel_cursor.execute(f"SELECT SYSTEM$CANCEL_QUERY('{sfqid}')") + with pytest.raises(DatabaseError) as e3, caplog.at_level(logging.INFO): + await cur.fetchall() + assert ( + "SQL execution canceled" in e3.value.msg + and f"Status of query '{sfqid}' is {QueryStatus.FAILED_WITH_ERROR.name}" + in caplog.text + ) + + +async def test_mix_sync_async(conn_cnx): + async with conn_cnx() as con: + async with con.cursor() as cur: + # Setup + await cur.execute( + "alter session set CLIENT_TIMESTAMP_TYPE_MAPPING=TIMESTAMP_TZ" + ) + try: + for table in ["smallTable", "uselessTable"]: + await cur.execute( + "create or replace table {} (colA string, colB int)".format( + table + ) + ) + await cur.execute( + "insert into {} values ('row1', 1), ('row2', 2), ('row3', 3)".format( + table + ) + ) + await cur.execute_async("select * from smallTable") + sf_qid1 = cur.sfqid + await cur.execute_async("select * from uselessTable") + sf_qid2 = cur.sfqid + # Wait until the 2 queries finish + while con.is_still_running(await con.get_query_status(sf_qid1)): + await asyncio.sleep(1) + while con.is_still_running(await con.get_query_status(sf_qid2)): + await asyncio.sleep(1) + await cur.execute("drop table uselessTable") + assert await cur.fetchall() == [("USELESSTABLE successfully dropped.",)] + await cur.get_results_from_sfqid(sf_qid1) + assert await cur.fetchall() == [("row1", 1), ("row2", 2), ("row3", 3)] + await cur.get_results_from_sfqid(sf_qid2) + assert await cur.fetchall() == [("row1", 1), ("row2", 2), ("row3", 3)] + finally: + for table in ["smallTable", "uselessTable"]: + await cur.execute(f"drop table if exists {table}") + + +async def test_async_qmark(conn_cnx): + """Tests that qmark parameter binding works with async queries.""" + import snowflake.connector + + orig_format = snowflake.connector.paramstyle + snowflake.connector.paramstyle = "qmark" + try: + async with conn_cnx() as con: + async with con.cursor() as cur: + try: + await cur.execute( + "create or replace table qmark_test (aa STRING, bb STRING)" + ) + await cur.execute( + "insert into qmark_test VALUES(?, ?)", ("test11", "test12") + ) + await cur.execute_async("select * from qmark_test") + async_qid = cur.sfqid + async with conn_cnx() as con2: + async with con2.cursor() as cur2: + await cur2.get_results_from_sfqid(async_qid) + assert await cur2.fetchall() == [("test11", "test12")] + finally: + await cur.execute("drop table if exists qmark_test") + finally: + snowflake.connector.paramstyle = orig_format + + +async def test_done_caching(conn_cnx): + """Tests whether get status caching is working as expected.""" + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async( + "select count(*) from table(generator(timeLimit => 5))" + ) + qid1 = cur.sfqid + await cur.execute_async( + "select count(*) from table(generator(timeLimit => 10))" + ) + qid2 = cur.sfqid + assert len(con._async_sfqids) == 2 + await asyncio.sleep(5) + while con.is_still_running(await con.get_query_status(qid1)): + await asyncio.sleep(1) + assert await con.get_query_status(qid1) == QueryStatus.SUCCESS + assert len(con._async_sfqids) == 1 + assert len(con._done_async_sfqids) == 1 + await asyncio.sleep(5) + while con.is_still_running(await con.get_query_status(qid2)): + await asyncio.sleep(1) + assert await con.get_query_status(qid2) == QueryStatus.SUCCESS + assert len(con._async_sfqids) == 0 + assert len(con._done_async_sfqids) == 2 + assert await con._all_async_queries_finished() + + +async def test_invalid_uuid_get_status(conn_cnx): + async with conn_cnx() as con: + async with con.cursor() as cur: + with pytest.raises( + ValueError, match=r"Invalid UUID: 'doesnt exist, dont even look'" + ): + await cur.get_results_from_sfqid("doesnt exist, dont even look") + + +async def test_unknown_sfqid(conn_cnx): + """Tests the exception that there is no Exception thrown when we attempt to get a status of a not existing query.""" + async with conn_cnx() as con: + assert ( + await con.get_query_status("12345678-1234-4123-A123-123456789012") + == QueryStatus.NO_DATA + ) + + +async def test_unknown_sfqid_results(conn_cnx): + """Tests that there is no Exception thrown when we attempt to get a status of a not existing query.""" + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.get_results_from_sfqid("12345678-1234-4123-A123-123456789012") + + +async def test_not_fetching(conn_cnx): + """Tests whether executing a new query actually cleans up after an async result retrieving. + + If someone tries to retrieve results then the first fetch would have to block. We should not block + if we executed a new query. + """ + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async("select 1") + sf_qid = cur.sfqid + await cur.get_results_from_sfqid(sf_qid) + await cur.execute("select 2") + assert cur._inner_cursor is None + assert cur._prefetch_hook is None + + +async def test_close_connection_with_running_async_queries(conn_cnx): + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async( + "select count(*) from table(generator(timeLimit => 10))" + ) + await cur.execute_async( + "select count(*) from table(generator(timeLimit => 1))" + ) + assert not (await con._all_async_queries_finished()) + assert len(con._done_async_sfqids) < 2 and con.rest is None + + +async def test_close_connection_with_completed_async_queries(conn_cnx): + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async("select 1") + qid1 = cur.sfqid + await cur.execute_async("select 2") + qid2 = cur.sfqid + while con.is_still_running( + (await con._get_query_status(qid1))[0] + ): # use _get_query_status to avoid caching + await asyncio.sleep(1) + while con.is_still_running((await con._get_query_status(qid2))[0]): + await asyncio.sleep(1) + assert await con._all_async_queries_finished() + assert len(con._done_async_sfqids) == 2 and con.rest is None diff --git a/test/integ/aio/test_converter_async.py b/test/integ/aio/test_converter_async.py index a1f5f8c9f..4ab921672 100644 --- a/test/integ/aio/test_converter_async.py +++ b/test/integ/aio/test_converter_async.py @@ -353,7 +353,7 @@ async def test_date_0001_9999(conn_cnx): async with conn_cnx( converter_class=SnowflakeConverterSnowSQL, support_negative_year=True ) as cnx: - cnx.cursor().execute( + await cnx.cursor().execute( """ ALTER SESSION SET DATE_OUTPUT_FORMAT='YYYY-MM-DD' @@ -388,7 +388,7 @@ async def test_five_or_more_digit_year_date_converter(conn_cnx): async with conn_cnx( converter_class=SnowflakeConverterSnowSQL, support_negative_year=True ) as cnx: - cnx.cursor().execute( + await cnx.cursor().execute( """ ALTER SESSION SET DATE_OUTPUT_FORMAT='YYYY-MM-DD' diff --git a/test/integ/aio/test_cursor_async.py b/test/integ/aio/test_cursor_async.py index 674c63599..56b6de936 100644 --- a/test/integ/aio/test_cursor_async.py +++ b/test/integ/aio/test_cursor_async.py @@ -417,7 +417,7 @@ async def test_struct_time(conn, db_parameters): async for rec in c: cnt += int(rec[0]) finally: - c.close() + await c.close() os.environ["TZ"] = "UTC" if not IS_WINDOWS: time.tzset() @@ -510,7 +510,7 @@ async def test_insert_binary_select_with_bytearray(conn, db_parameters): assert count == 1, "wrong number of records were inserted" assert c.rowcount == 1, "wrong number of records were selected" finally: - c.close() + await c.close() cnx2 = snowflake.connector.aio.SnowflakeConnection( user=db_parameters["user"], diff --git a/test/integ/aio/test_multi_statement_async.py b/test/integ/aio/test_multi_statement_async.py new file mode 100644 index 000000000..0968a4256 --- /dev/null +++ b/test/integ/aio/test_multi_statement_async.py @@ -0,0 +1,398 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from test.helpers import ( + _wait_until_query_success_async, + _wait_while_query_running_async, +) + +import pytest + +from snowflake.connector import ProgrammingError, errors +from snowflake.connector.aio import SnowflakeCursor +from snowflake.connector.constants import PARAMETER_MULTI_STATEMENT_COUNT, QueryStatus +from snowflake.connector.util_text import random_string + + +@pytest.fixture(scope="module", params=[False, True]) +def skip_to_last_set(request) -> bool: + return request.param + + +async def test_multi_statement_wrong_count(conn_cnx): + """Tries to send the wrong number of statements.""" + async with conn_cnx(session_parameters={PARAMETER_MULTI_STATEMENT_COUNT: 1}) as con: + async with con.cursor() as cur: + with pytest.raises( + errors.ProgrammingError, + match="Actual statement count 2 did not match the desired statement count 1.", + ): + await cur.execute("select 1; select 2") + + with pytest.raises( + errors.ProgrammingError, + match="Actual statement count 2 did not match the desired statement count 1.", + ): + await cur.execute( + "alter session set MULTI_STATEMENT_COUNT=2; select 1;" + ) + + await cur.execute("alter session set MULTI_STATEMENT_COUNT=5") + with pytest.raises( + errors.ProgrammingError, + match="Actual statement count 1 did not match the desired statement count 5.", + ): + await cur.execute("select 1;") + + with pytest.raises( + errors.ProgrammingError, + match="Actual statement count 3 did not match the desired statement count 5.", + ): + await cur.execute("select 1; select 2; select 3;") + + +async def _check_multi_statement_results( + cur: SnowflakeCursor, + checks: "list[list[tuple] | function]", + skip_to_last_set: bool, +) -> None: + savedIds = [] + for index, check in enumerate(checks): + if not skip_to_last_set or index == len(checks) - 1: + if callable(check): + assert check(await cur.fetchall()) + else: + assert await cur.fetchall() == check + savedIds.append(cur.sfqid) + assert await cur.nextset() == (cur if index < len(checks) - 1 else None) + assert await cur.fetchall() == [] + + assert cur.multi_statement_savedIds[-1 if skip_to_last_set else 0 :] == savedIds + + +async def test_multi_statement_basic(conn_cnx, skip_to_last_set: bool): + """Selects fixed integer data using statement level parameters.""" + async with conn_cnx() as con: + async with con.cursor() as cur: + statement_params = dict() + await cur.execute( + "select 1; select 2; select 'a';", + num_statements=3, + _statement_params=statement_params, + ) + await _check_multi_statement_results( + cur, + checks=[ + [(1,)], + [(2,)], + [("a",)], + ], + skip_to_last_set=skip_to_last_set, + ) + assert len(statement_params) == 0 + + +async def test_insert_select_multi(conn_cnx, db_parameters, skip_to_last_set: bool): + """Naive use of multi-statement to check multiple SQL functions.""" + async with conn_cnx(session_parameters={PARAMETER_MULTI_STATEMENT_COUNT: 0}) as con: + async with con.cursor() as cur: + table_name = random_string(5, "test_multi_table_").upper() + await cur.execute( + "use schema {db}.{schema};\n" + "create table {name} (aa int);\n" + "insert into {name}(aa) values(123456),(98765),(65432);\n" + "select aa from {name} order by aa;\n" + "drop table {name};".format( + db=db_parameters["database"], + schema=( + db_parameters["schema"] + if "schema" in db_parameters + else "PUBLIC" + ), + name=table_name, + ) + ) + await _check_multi_statement_results( + cur, + checks=[ + [("Statement executed successfully.",)], + [(f"Table {table_name} successfully created.",)], + [(3,)], + [(65432,), (98765,), (123456,)], + [(f"{table_name} successfully dropped.",)], + ], + skip_to_last_set=skip_to_last_set, + ) + + +@pytest.mark.parametrize("style", ["pyformat", "qmark"]) +async def test_binding_multi(conn_cnx, style: str, skip_to_last_set: bool): + """Tests using pyformat and qmark style bindings with multi-statement""" + test_string = "select {s}; select {s}, {s}; select {s}, {s}, {s};" + async with conn_cnx(paramstyle=style) as con: + async with con.cursor() as cur: + sql = test_string.format(s="%s" if style == "pyformat" else "?") + await cur.execute(sql, (10, 20, 30, "a", "b", "c"), num_statements=3) + await _check_multi_statement_results( + cur, + checks=[[(10,)], [(20, 30)], [("a", "b", "c")]], + skip_to_last_set=skip_to_last_set, + ) + + +async def test_async_exec_multi(conn_cnx, skip_to_last_set: bool): + """Tests whether async execution query works within a multi-statement""" + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute_async( + "select 1; select 2; select count(*) from table(generator(timeLimit => 1)); select 'b';", + num_statements=4, + ) + q_id = cur.sfqid + assert con.is_still_running(await con.get_query_status(q_id)) + await _wait_while_query_running_async(con, q_id, sleep_time=1) + async with conn_cnx() as con: + async with con.cursor() as cur: + await _wait_until_query_success_async( + con, q_id, num_checks=3, sleep_per_check=1 + ) + assert ( + await con.get_query_status_throw_if_error(q_id) == QueryStatus.SUCCESS + ) + + await cur.get_results_from_sfqid(q_id) + await _check_multi_statement_results( + cur, + checks=[[(1,)], [(2,)], lambda x: x > [(0,)], [("b",)]], + skip_to_last_set=skip_to_last_set, + ) + + +async def test_async_error_multi(conn_cnx): + """ + Runs a query that will fail to execute and then tests that if we tried to get results for the query + then that would raise an exception. It also tests QueryStatus related functionality too. + """ + async with conn_cnx(session_parameters={PARAMETER_MULTI_STATEMENT_COUNT: 0}) as con: + async with con.cursor() as cur: + sql = "select 1; select * from nonexistentTable" + q_id = (await cur.execute_async(sql)).get("queryId") + with pytest.raises( + ProgrammingError, + match="SQL compilation error:\nObject 'NONEXISTENTTABLE' does not exist or not authorized.", + ) as sync_error: + await cur.execute(sql) + await _wait_while_query_running_async(con, q_id, sleep_time=1) + assert await con.get_query_status(q_id) == QueryStatus.FAILED_WITH_ERROR + with pytest.raises(ProgrammingError) as e1: + await con.get_query_status_throw_if_error(q_id) + assert sync_error.value.errno != -1 + with pytest.raises(ProgrammingError) as e2: + await cur.get_results_from_sfqid(q_id) + assert e1.value.errno == e2.value.errno == sync_error.value.errno + + +async def test_mix_sync_async_multi(conn_cnx, skip_to_last_set: bool): + """Tests sending multiple multi-statement async queries at the same time.""" + async with conn_cnx( + session_parameters={ + PARAMETER_MULTI_STATEMENT_COUNT: 0, + "CLIENT_TIMESTAMP_TYPE_MAPPING": "TIMESTAMP_TZ", + } + ) as con: + async with con.cursor() as cur: + await cur.execute( + "create or replace temp table smallTable (colA string, colB int);" + "create or replace temp table uselessTable (colA string, colB int);" + ) + for table in ["smallTable", "uselessTable"]: + await cur.execute( + f"insert into {table} values('row1', 1);" + f"insert into {table} values('row2', 2);" + f"insert into {table} values('row3', 3);" + ) + await cur.execute_async("select 1; select 'a'; select * from smallTable;") + sf_qid1 = cur.sfqid + await cur.execute_async("select 2; select 'b'; select * from uselessTable") + sf_qid2 = cur.sfqid + # Wait until the 2 queries finish + await _wait_while_query_running_async(con, sf_qid1, sleep_time=1) + await _wait_while_query_running_async(con, sf_qid2, sleep_time=1) + await cur.execute("drop table uselessTable") + assert await cur.fetchall() == [("USELESSTABLE successfully dropped.",)] + await cur.get_results_from_sfqid(sf_qid1) + await _check_multi_statement_results( + cur, + checks=[[(1,)], [("a",)], [("row1", 1), ("row2", 2), ("row3", 3)]], + skip_to_last_set=skip_to_last_set, + ) + await cur.get_results_from_sfqid(sf_qid2) + await _check_multi_statement_results( + cur, + checks=[[(2,)], [("b",)], [("row1", 1), ("row2", 2), ("row3", 3)]], + skip_to_last_set=skip_to_last_set, + ) + + +async def test_done_caching_multi(conn_cnx, skip_to_last_set: bool): + """Tests whether get status caching is working as expected.""" + async with conn_cnx(session_parameters={PARAMETER_MULTI_STATEMENT_COUNT: 0}) as con: + async with con.cursor() as cur: + await cur.execute_async( + "select 1; select 'a'; select count(*) from table(generator(timeLimit => 2));" + ) + qid1 = cur.sfqid + await cur.execute_async( + "select 2; select 'b'; select count(*) from table(generator(timeLimit => 2));" + ) + qid2 = cur.sfqid + assert len(con._async_sfqids) == 2 + await _wait_while_query_running_async(con, qid1, sleep_time=1) + await _wait_until_query_success_async( + con, qid1, num_checks=3, sleep_per_check=1 + ) + assert await con.get_query_status(qid1) == QueryStatus.SUCCESS + await cur.get_results_from_sfqid(qid1) + await _check_multi_statement_results( + cur, + checks=[[(1,)], [("a",)], lambda x: x > [(0,)]], + skip_to_last_set=skip_to_last_set, + ) + assert len(con._async_sfqids) == 1 + assert len(con._done_async_sfqids) == 1 + await _wait_while_query_running_async(con, qid2, sleep_time=1) + await _wait_until_query_success_async( + con, qid2, num_checks=3, sleep_per_check=1 + ) + assert await con.get_query_status(qid2) == QueryStatus.SUCCESS + await cur.get_results_from_sfqid(qid2) + await _check_multi_statement_results( + cur, + checks=[[(2,)], [("b",)], lambda x: x > [(0,)]], + skip_to_last_set=skip_to_last_set, + ) + assert len(con._async_sfqids) == 0 + assert len(con._done_async_sfqids) == 2 + assert await con._all_async_queries_finished() + + +async def test_alter_session_multi(conn_cnx): + """Tests whether multiple alter session queries are detected and stored in the connection.""" + async with conn_cnx(session_parameters={PARAMETER_MULTI_STATEMENT_COUNT: 0}) as con: + async with con.cursor() as cur: + sql = ( + "select 1;" + "alter session set autocommit=false;" + "select 'a';" + "alter session set json_indent = 4;" + "alter session set CLIENT_TIMESTAMP_TYPE_MAPPING = 'TIMESTAMP_TZ'" + ) + await cur.execute(sql) + assert con.converter.get_parameter("AUTOCOMMIT") == "false" + assert con.converter.get_parameter("JSON_INDENT") == "4" + assert ( + con.converter.get_parameter("CLIENT_TIMESTAMP_TYPE_MAPPING") + == "TIMESTAMP_TZ" + ) + + +async def test_executemany_multi(conn_cnx, skip_to_last_set: bool): + """Tests executemany with multi-statement optimizations enabled through the num_statements parameter.""" + table1 = random_string(5, "test_executemany_multi_") + table2 = random_string(5, "test_executemany_multi_") + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute( + f"create temp table {table1} (aa number); create temp table {table2} (bb number);", + num_statements=2, + ) + await cur.executemany( + f"insert into {table1}(aa) values(%(value1)s); insert into {table2}(bb) values(%(value2)s);", + [ + {"value1": 1234, "value2": 4}, + {"value1": 234, "value2": 34}, + {"value1": 34, "value2": 234}, + {"value1": 4, "value2": 1234}, + ], + num_statements=2, + ) + assert (await cur.fetchone())[0] == 1 + while await cur.nextset(): + assert (await cur.fetchone())[0] == 1 + await cur.execute( + f"select aa from {table1}; select bb from {table2};", num_statements=2 + ) + await _check_multi_statement_results( + cur, + checks=[[(1234,), (234,), (34,), (4,)], [(4,), (34,), (234,), (1234,)]], + skip_to_last_set=skip_to_last_set, + ) + + async with conn_cnx() as con: + async with con.cursor() as cur: + await cur.execute( + f"create temp table {table1} (aa number); create temp table {table2} (bb number);", + num_statements=2, + ) + await cur.executemany( + f"insert into {table1}(aa) values(%s); insert into {table2}(bb) values(%s);", + [ + (12345, 4), + (1234, 34), + (234, 234), + (34, 1234), + (4, 12345), + ], + num_statements=2, + ) + assert (await cur.fetchone())[0] == 1 + while await cur.nextset(): + assert (await cur.fetchone())[0] == 1 + await cur.execute( + f"select aa from {table1}; select bb from {table2};", num_statements=2 + ) + await _check_multi_statement_results( + cur, + checks=[ + [(12345,), (1234,), (234,), (34,), (4,)], + [(4,), (34,), (234,), (1234,), (12345,)], + ], + skip_to_last_set=skip_to_last_set, + ) + + +async def test_executmany_qmark_multi(conn_cnx, skip_to_last_set: bool): + """Tests executemany with multi-statement optimization with qmark style.""" + table1 = random_string(5, "test_executemany_qmark_multi_") + table2 = random_string(5, "test_executemany_qmark_multi_") + async with conn_cnx(paramstyle="qmark") as con: + async with con.cursor() as cur: + await cur.execute( + f"create temp table {table1}(aa number); create temp table {table2}(bb number);", + num_statements=2, + ) + await cur.executemany( + f"insert into {table1}(aa) values(?); insert into {table2}(bb) values(?);", + [ + [1234, 4], + [234, 34], + [34, 234], + [4, 1234], + ], + num_statements=2, + ) + assert (await cur.fetchone())[0] == 1 + while await cur.nextset(): + assert (await cur.fetchone())[0] == 1 + await cur.execute( + f"select aa from {table1}; select bb from {table2};", num_statements=2 + ) + await _check_multi_statement_results( + cur, + checks=[ + [(1234,), (234,), (34,), (4,)], + [(4,), (34,), (234,), (1234,)], + ], + skip_to_last_set=skip_to_last_set, + ) diff --git a/test/unit/aio/test_connection_async_unit.py b/test/unit/aio/test_connection_async_unit.py index 36bbf159b..86a5fd89d 100644 --- a/test/unit/aio/test_connection_async_unit.py +++ b/test/unit/aio/test_connection_async_unit.py @@ -233,8 +233,8 @@ async def test_negative_custom_auth(auth_class): async def test_missing_default_connection(monkeypatch, tmp_path): - connections_file = tmp_path / "connections.toml" - config_file = tmp_path / "config.toml" + connections_file = tmp_path / "aio_connections.toml" + config_file = tmp_path / "aio_config.toml" with monkeypatch.context() as m: m.delenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", raising=False) m.delenv("SNOWFLAKE_CONNECTIONS", raising=False) @@ -252,8 +252,8 @@ async def test_missing_default_connection(monkeypatch, tmp_path): async def test_missing_default_connection_conf_file(monkeypatch, tmp_path): connection_name = random_string(5) - connections_file = tmp_path / "connections.toml" - config_file = tmp_path / "config.toml" + connections_file = tmp_path / "aio_connections.toml" + config_file = tmp_path / "aio_config.toml" config_file.write_text( dedent( f"""\ @@ -278,8 +278,8 @@ async def test_missing_default_connection_conf_file(monkeypatch, tmp_path): async def test_missing_default_connection_conn_file(monkeypatch, tmp_path): - connections_file = tmp_path / "connections.toml" - config_file = tmp_path / "config.toml" + connections_file = tmp_path / "aio_connections.toml" + config_file = tmp_path / "aio_config.toml" connections_file.write_text( dedent( """\ @@ -308,8 +308,8 @@ async def test_missing_default_connection_conn_file(monkeypatch, tmp_path): async def test_missing_default_connection_conf_conn_file(monkeypatch, tmp_path): connection_name = random_string(5) - connections_file = tmp_path / "connections.toml" - config_file = tmp_path / "config.toml" + connections_file = tmp_path / "aio_connections.toml" + config_file = tmp_path / "aio_config.toml" config_file.write_text( dedent( f"""\ @@ -381,7 +381,7 @@ async def test_handle_timeout(mockSessionRequest, next_action): @pytest.mark.skip("SNOW-1572226 authentication support") async def test_private_key_file_reading(tmp_path: Path): - key_file = tmp_path / "key.pem" + key_file = tmp_path / "aio_key.pem" private_key = rsa.generate_private_key( backend=default_backend(), public_exponent=65537, key_size=2048 @@ -422,7 +422,7 @@ async def test_private_key_file_reading(tmp_path: Path): @pytest.mark.skip("SNOW-1572226 authentication support") async def test_encrypted_private_key_file_reading(tmp_path: Path): - key_file = tmp_path / "key.pem" + key_file = tmp_path / "aio_key.pem" private_key_password = token_urlsafe(25) private_key = rsa.generate_private_key( backend=default_backend(), public_exponent=65537, key_size=2048 diff --git a/test/unit/aio/test_cursor_async_unit.py b/test/unit/aio/test_cursor_async_unit.py new file mode 100644 index 000000000..ec2363573 --- /dev/null +++ b/test/unit/aio/test_cursor_async_unit.py @@ -0,0 +1,86 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import asyncio +import unittest.mock +from unittest.mock import MagicMock, patch + +import pytest + +from snowflake.connector.aio import SnowflakeConnection, SnowflakeCursor +from snowflake.connector.errors import ServiceUnavailableError + +try: + from snowflake.connector.constants import FileTransferType +except ImportError: + from enum import Enum + + class FileTransferType(Enum): + GET = "get" + PUT = "put" + + +class FakeConnection(SnowflakeConnection): + def __init__(self): + self._log_max_query_length = 0 + self._reuse_results = None + + +@pytest.mark.parametrize( + "sql,_type", + ( + ("", None), + ("select 1;", None), + ("PUT file:///tmp/data/mydata.csv @my_int_stage;", FileTransferType.PUT), + ("GET @%mytable file:///tmp/data/;", FileTransferType.GET), + ("/**/PUT file:///tmp/data/mydata.csv @my_int_stage;", FileTransferType.PUT), + ("/**/ GET @%mytable file:///tmp/data/;", FileTransferType.GET), + pytest.param( + "/**/\n" + + "\t/*/get\t*/\t/**/\n" * 10000 + + "\t*/get @~/test.csv file:///tmp\n", + None, + id="long_incorrect", + ), + pytest.param( + "/**/\n" + "\t/*/put\t*/\t/**/\n" * 10000 + "put file:///tmp/data.csv @~", + FileTransferType.PUT, + id="long_correct", + ), + ), +) +def test_get_filetransfer_type(sql, _type): + assert SnowflakeCursor.get_file_transfer_type(sql) == _type + + +def test_cursor_attribute(): + fake_conn = FakeConnection() + cursor = SnowflakeCursor(fake_conn) + assert cursor.lastrowid is None + + +@patch("snowflake.connector.aio._cursor.SnowflakeCursor._SnowflakeCursor__cancel_query") +async def test_cursor_execute_timeout(mockCancelQuery): + async def mock_cmd_query(*args, **kwargs): + await asyncio.sleep(10) + raise ServiceUnavailableError() + + fake_conn = FakeConnection() + fake_conn.cmd_query = mock_cmd_query + fake_conn._rest = unittest.mock.AsyncMock() + fake_conn._paramstyle = MagicMock() + fake_conn._next_sequence_counter = unittest.mock.AsyncMock() + + cursor = SnowflakeCursor(fake_conn) + + with pytest.raises(ServiceUnavailableError): + await cursor.execute( + command="SELECT * FROM nonexistent", + timeout=1, + ) + + # query cancel request should be sent upon timeout + assert mockCancelQuery.called diff --git a/test/unit/aio/test_ocsp.py b/test/unit/aio/test_ocsp.py index 8de8f641a..90cbcc3cb 100644 --- a/test/unit/aio/test_ocsp.py +++ b/test/unit/aio/test_ocsp.py @@ -14,6 +14,7 @@ import platform import ssl import time +from contextlib import asynccontextmanager from os import environ, path from unittest import mock @@ -70,16 +71,18 @@ def overwrite_ocsp_cache(tmpdir): THIS_DIR = path.dirname(path.realpath(__file__)) +@asynccontextmanager async def _asyncio_connect(url, timeout=5): loop = asyncio.get_event_loop() - _, protocol = await loop.create_connection( + transport, protocol = await loop.create_connection( functools.partial(aiohttp.client_proto.ResponseHandler, loop), host=url, port=443, ssl=ssl.create_default_context(), ssl_handshake_timeout=timeout, ) - return protocol + yield protocol + transport.close() @pytest.fixture(autouse=True) @@ -123,8 +126,8 @@ async def test_ocsp(): SnowflakeOCSP.clear_cache() ocsp = SFOCSP() for url in TARGET_HOSTS: - connection = await _asyncio_connect(url, timeout=5) - assert await ocsp.validate(url, connection), f"Failed to validate: {url}" + async with _asyncio_connect(url, timeout=5) as connection: + assert await ocsp.validate(url, connection), f"Failed to validate: {url}" async def test_ocsp_wo_cache_server(): @@ -132,8 +135,8 @@ async def test_ocsp_wo_cache_server(): SnowflakeOCSP.clear_cache() ocsp = SFOCSP(use_ocsp_cache_server=False) for url in TARGET_HOSTS: - connection = await _asyncio_connect(url, timeout=5) - assert await ocsp.validate(url, connection), f"Failed to validate: {url}" + async with _asyncio_connect(url, timeout=5) as connection: + assert await ocsp.validate(url, connection), f"Failed to validate: {url}" async def test_ocsp_wo_cache_file(): @@ -151,8 +154,10 @@ async def test_ocsp_wo_cache_file(): try: ocsp = SFOCSP() for url in TARGET_HOSTS: - connection = await _asyncio_connect(url, timeout=5) - assert await ocsp.validate(url, connection), f"Failed to validate: {url}" + async with _asyncio_connect(url, timeout=5) as connection: + assert await ocsp.validate( + url, connection + ), f"Failed to validate: {url}" finally: del environ["SF_OCSP_RESPONSE_CACHE_DIR"] OCSPCache.reset_cache_dir() @@ -169,12 +174,11 @@ async def test_ocsp_fail_open_w_single_endpoint(): ocsp = SFOCSP(use_ocsp_cache_server=False) - connection = await _asyncio_connect("snowflake.okta.com") - try: - assert await ocsp.validate( - "snowflake.okta.com", connection - ), "Failed to validate: {}".format("snowflake.okta.com") + async with _asyncio_connect("snowflake.okta.com") as connection: + assert await ocsp.validate( + "snowflake.okta.com", connection + ), "Failed to validate: {}".format("snowflake.okta.com") finally: del environ["SF_OCSP_TEST_MODE"] del environ["SF_TEST_OCSP_URL"] @@ -195,10 +199,10 @@ async def test_ocsp_fail_close_w_single_endpoint(): OCSPCache.del_cache_file() ocsp = SFOCSP(use_ocsp_cache_server=False, use_fail_open=False) - connection = await _asyncio_connect("snowflake.okta.com") with pytest.raises(RevocationCheckError) as ex: - await ocsp.validate("snowflake.okta.com", connection) + async with _asyncio_connect("snowflake.okta.com") as connection: + await ocsp.validate("snowflake.okta.com", connection) try: assert ( @@ -219,11 +223,11 @@ async def test_ocsp_bad_validity(): OCSPCache.del_cache_file() ocsp = SFOCSP(use_ocsp_cache_server=False) - connection = await _asyncio_connect("snowflake.okta.com") + async with _asyncio_connect("snowflake.okta.com") as connection: - assert await ocsp.validate( - "snowflake.okta.com", connection - ), "Connection should have passed with fail open" + assert await ocsp.validate( + "snowflake.okta.com", connection + ), "Connection should have passed with fail open" del environ["SF_OCSP_TEST_MODE"] del environ["SF_TEST_OCSP_FORCE_BAD_RESPONSE_VALIDITY"] @@ -233,10 +237,10 @@ async def test_ocsp_single_endpoint(): SnowflakeOCSP.clear_cache() ocsp = SFOCSP() ocsp.OCSP_CACHE_SERVER.NEW_DEFAULT_CACHE_SERVER_BASE_URL = "https://snowflake.preprod3.us-west-2-dev.external-zone.snowflakecomputing.com:8085/ocsp/" - connection = await _asyncio_connect("snowflake.okta.com") - assert await ocsp.validate( - "snowflake.okta.com", connection - ), "Failed to validate: {}".format("snowflake.okta.com") + async with _asyncio_connect("snowflake.okta.com") as connection: + assert await ocsp.validate( + "snowflake.okta.com", connection + ), "Failed to validate: {}".format("snowflake.okta.com") del environ["SF_OCSP_ACTIVATE_NEW_ENDPOINT"] @@ -247,8 +251,8 @@ async def test_ocsp_by_post_method(): SnowflakeOCSP.clear_cache() ocsp = SFOCSP(use_post_method=True) for url in TARGET_HOSTS: - connection = await _asyncio_connect("snowflake.okta.com") - assert await ocsp.validate(url, connection), f"Failed to validate: {url}" + async with _asyncio_connect("snowflake.okta.com") as connection: + assert await ocsp.validate(url, connection), f"Failed to validate: {url}" async def test_ocsp_with_file_cache(tmpdir): @@ -260,8 +264,8 @@ async def test_ocsp_with_file_cache(tmpdir): SnowflakeOCSP.clear_cache() ocsp = SFOCSP(ocsp_response_cache_uri="file://" + cache_file_name) for url in TARGET_HOSTS: - connection = await _asyncio_connect("snowflake.okta.com") - assert await ocsp.validate(url, connection), f"Failed to validate: {url}" + async with _asyncio_connect("snowflake.okta.com") as connection: + assert await ocsp.validate(url, connection), f"Failed to validate: {url}" async def test_ocsp_with_bogus_cache_files( @@ -298,10 +302,10 @@ async def test_ocsp_with_bogus_cache_files( SnowflakeOCSP.clear_cache() ocsp = SFOCSP() for hostname in target_hosts: - connection = await _asyncio_connect("snowflake.okta.com") - assert await ocsp.validate( - hostname, connection - ), f"Failed to validate: {hostname}" + async with _asyncio_connect("snowflake.okta.com") as connection: + assert await ocsp.validate( + hostname, connection + ), f"Failed to validate: {hostname}" async def test_ocsp_with_outdated_cache(tmpdir, random_ocsp_response_validation_cache): @@ -355,10 +359,10 @@ async def _store_cache_in_file(tmpdir, target_hosts=None): ocsp_response_cache_uri="file://" + filename, use_ocsp_cache_server=False ) for hostname in target_hosts: - connection = await _asyncio_connect("snowflake.okta.com") - assert await ocsp.validate( - hostname, connection - ), f"Failed to validate: {hostname}" + async with _asyncio_connect("snowflake.okta.com") as connection: + assert await ocsp.validate( + hostname, connection + ), f"Failed to validate: {hostname}" assert path.exists(filename), "OCSP response cache file" return filename, target_hosts @@ -368,8 +372,8 @@ async def test_ocsp_with_invalid_cache_file(): SnowflakeOCSP.clear_cache() # reset the memory cache ocsp = SFOCSP(ocsp_response_cache_uri="NEVER_EXISTS") for url in TARGET_HOSTS[0:1]: - connection = await _asyncio_connect(url) - assert await ocsp.validate(url, connection), f"Failed to validate: {url}" + async with _asyncio_connect(url) as connection: + assert await ocsp.validate(url, connection), f"Failed to validate: {url}" @mock.patch( @@ -432,6 +436,6 @@ async def _validate_certs_using_ocsp(url, cache_file_name): except OSError: pass - connection = await _asyncio_connect(url) - ocsp = SFOCSP(ocsp_response_cache_uri="file://" + cache_file_name) - await ocsp.validate(url, connection) + async with _asyncio_connect(url) as connection: + ocsp = SFOCSP(ocsp_response_cache_uri="file://" + cache_file_name) + await ocsp.validate(url, connection) diff --git a/test/unit/aio/test_renew_session_async.py b/test/unit/aio/test_renew_session_async.py new file mode 100644 index 000000000..205bbcac3 --- /dev/null +++ b/test/unit/aio/test_renew_session_async.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +import logging +from test.unit.mock_utils import mock_connection +from unittest.mock import Mock, PropertyMock + +from snowflake.connector.aio._network import SnowflakeRestful + + +async def test_renew_session(): + OLD_SESSION_TOKEN = "old_session_token" + OLD_MASTER_TOKEN = "old_master_token" + NEW_SESSION_TOKEN = "new_session_token" + NEW_MASTER_TOKEN = "new_master_token" + connection = mock_connection() + connection.errorhandler = Mock(return_value=None) + type(connection)._probe_connection = PropertyMock(return_value=False) + + rest = SnowflakeRestful( + host="testaccount.snowflakecomputing.com", port=443, connection=connection + ) + rest._token = OLD_SESSION_TOKEN + rest._master_token = OLD_MASTER_TOKEN + + # inject a fake method (success) + async def fake_request_exec(**_): + return { + "success": True, + "data": { + "sessionToken": NEW_SESSION_TOKEN, + "masterToken": NEW_MASTER_TOKEN, + }, + } + + rest._request_exec = fake_request_exec + + await rest._renew_session() + assert not rest._connection.errorhandler.called # no error + assert rest.master_token == NEW_MASTER_TOKEN + assert rest.token == NEW_SESSION_TOKEN + + # inject a fake method (failure) + async def fake_request_exec(**_): + return {"success": False, "message": "failed to renew session", "code": 987654} + + rest._request_exec = fake_request_exec + + await rest._renew_session() + assert rest._connection.errorhandler.called # error + + # no master token + del rest._master_token + await rest._renew_session() + assert rest._connection.errorhandler.called # error + + +async def test_mask_token_when_renew_session(caplog): + caplog.set_level(logging.DEBUG) + OLD_SESSION_TOKEN = "old_session_token" + OLD_MASTER_TOKEN = "old_master_token" + NEW_SESSION_TOKEN = "new_session_token" + NEW_MASTER_TOKEN = "new_master_token" + connection = mock_connection() + connection.errorhandler = Mock(return_value=None) + type(connection)._probe_connection = PropertyMock(return_value=False) + + rest = SnowflakeRestful( + host="testaccount.snowflakecomputing.com", port=443, connection=connection + ) + rest._token = OLD_SESSION_TOKEN + rest._master_token = OLD_MASTER_TOKEN + + # inject a fake method (success) + async def fake_request_exec(**_): + return { + "success": True, + "data": { + "sessionToken": NEW_SESSION_TOKEN, + "masterToken": NEW_MASTER_TOKEN, + }, + } + + rest._request_exec = fake_request_exec + + # no secrets recorded when renew succeed + await rest._renew_session() + assert "new_session_token" not in caplog.text + assert "new_master_token" not in caplog.text + assert "old_session_token" not in caplog.text + assert "old_master_token" not in caplog.text + + async def fake_request_exec(**_): + return {"success": False, "message": "failed to renew session", "code": 987654} + + rest._request_exec = fake_request_exec + + # no secrets recorded when renew failed + await rest._renew_session() + assert "new_session_token" not in caplog.text + assert "new_master_token" not in caplog.text + assert "old_session_token" not in caplog.text + assert "old_master_token" not in caplog.text diff --git a/tox.ini b/tox.ini index dd51911c6..f4924e7a8 100644 --- a/tox.ini +++ b/tox.ini @@ -38,7 +38,6 @@ setenv = !unit-!integ: SNOWFLAKE_TEST_TYPE = (unit or integ) unit: SNOWFLAKE_TEST_TYPE = unit and not aio integ: SNOWFLAKE_TEST_TYPE = integ and not aio - aio: SNOWFLAKE_TEST_TYPE = aio parallel: SNOWFLAKE_PYTEST_OPTS = {env:SNOWFLAKE_PYTEST_OPTS:} -n auto # Add common parts into pytest command SNOWFLAKE_PYTEST_COV_LOCATION = {env:JUNIT_REPORT_DIR:{toxworkdir}}/junit.{envname}-{env:cloud_provider:dev}.xml @@ -62,10 +61,10 @@ passenv = commands = # Test environments # Note: make sure to have a default env and all the other special ones - !pandas-!sso-!lambda-!extras: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} and not sso and not pandas and not lambda" {posargs:} test - pandas: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} and pandas" {posargs:} test - sso: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} and sso" {posargs:} test - lambda: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} and lambda" {posargs:} test + !pandas-!sso-!lambda-!extras: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} and not sso and not pandas and not lambda and not aio" {posargs:} test + pandas: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} and pandas and not aio" {posargs:} test + sso: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} and sso and not aio" {posargs:} test + lambda: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} and lambda and not aio" {posargs:} test extras: python -m test.extras.run {posargs:} [testenv:olddriver] @@ -100,8 +99,11 @@ commands = python -c 'import snowflake.connector.result_batch' [testenv:aio] -basepython = 3.10 description = Run aio tests +extras= + development + aio + pandas commands = {env:SNOWFLAKE_PYTEST_CMD} -m "aio" -vvv {posargs:} test [testenv:coverage]