Skip to content

Commit

Permalink
SNOW-1572300: async cursor coverage (#2062)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aling committed Sep 30, 2024
1 parent 1465274 commit 9d43032
Show file tree
Hide file tree
Showing 20 changed files with 2,851 additions and 86 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ jobs:
- name: Install tox
run: python -m pip install tox>=4
- name: Run tests
run: python -m tox run -e `echo py${PYTHON_VERSION/\./}-aio-ci`
run: python -m tox run -e aio
env:
PYTHON_VERSION: ${{ matrix.python-version }}
cloud_provider: ${{ matrix.cloud-provider }}
Expand Down
12 changes: 8 additions & 4 deletions src/snowflake/connector/aio/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,13 @@ async def _all_async_queries_finished(self) -> bool:
async def async_query_check_helper(
sfq_id: str,
) -> bool:
nonlocal found_unfinished_query
return found_unfinished_query or self.is_still_running(
await self.get_query_status(sfq_id)
)
try:
nonlocal found_unfinished_query
return found_unfinished_query or self.is_still_running(
await self.get_query_status(sfq_id)
)
except asyncio.CancelledError:
pass

tasks = [
asyncio.create_task(async_query_check_helper(sfqid)) for sfqid in queries
Expand All @@ -279,6 +282,7 @@ async def async_query_check_helper(
break
for task in tasks:
task.cancel()
await asyncio.gather(*tasks)
return not found_unfinished_query

async def _authenticate(self, auth_instance: AuthByPlugin):
Expand Down
162 changes: 146 additions & 16 deletions src/snowflake/connector/aio/_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

import asyncio
import collections
import logging
import re
import signal
import sys
import typing
import uuid
from logging import getLogger
from types import TracebackType
Expand All @@ -30,8 +32,15 @@
create_batches_from_response,
)
from snowflake.connector.aio._result_set import ResultSet, ResultSetIterator
from snowflake.connector.constants import PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT
from snowflake.connector.cursor import DESC_TABLE_RE
from snowflake.connector.constants import (
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT,
QueryStatus,
)
from snowflake.connector.cursor import (
ASYNC_NO_DATA_MAX_RETRY,
ASYNC_RETRY_PATTERN,
DESC_TABLE_RE,
)
from snowflake.connector.cursor import DictCursor as DictCursorSync
from snowflake.connector.cursor import ResultMetadata, ResultMetadataV2, ResultState
from snowflake.connector.cursor import SnowflakeCursor as SnowflakeCursorSync
Expand All @@ -43,7 +52,7 @@
ER_INVALID_VALUE,
ER_NOT_POSITIVE_SIZE,
)
from snowflake.connector.errors import BindUploadError
from snowflake.connector.errors import BindUploadError, DatabaseError
from snowflake.connector.file_transfer_agent import SnowflakeProgressPercentage
from snowflake.connector.telemetry import TelemetryField
from snowflake.connector.time_util import get_time_millis
Expand All @@ -65,9 +74,11 @@ def __init__(
):
super().__init__(connection, use_dict_result)
# the following fixes type hint
self._connection: SnowflakeConnection = connection
self._connection = typing.cast("SnowflakeConnection", self._connection)
self._inner_cursor = typing.cast(SnowflakeCursor, self._inner_cursor)
self._lock_canceling = asyncio.Lock()
self._timebomb: asyncio.Task | None = None
self._prefetch_hook: typing.Callable[[], typing.Awaitable] | None = None

def __aiter__(self):
return self
Expand All @@ -87,6 +98,18 @@ async def __anext__(self):
async def __aenter__(self):
return self

def __enter__(self):
# async cursor does not support sync context manager
raise TypeError(
"'SnowflakeCursor' object does not support the context manager protocol"
)

def __exit__(self, exc_type, exc_val, exc_tb):
# async cursor does not support sync context manager
raise TypeError(
"'SnowflakeCursor' object does not support the context manager protocol"
)

def __del__(self):
# do nothing in async, __del__ is unreliable
pass
Expand Down Expand Up @@ -337,6 +360,7 @@ async def _init_result_and_meta(self, data: dict[Any, Any]) -> None:
self._total_rowcount += updated_rows

async def _init_multi_statement_results(self, data: dict) -> None:
# TODO: async telemetry SNOW-1572217
# self._log_telemetry_job_data(TelemetryField.MULTI_STATEMENT, TelemetryData.TRUE)
self.multi_statement_savedIds = data["resultIds"].split(",")
self._multi_statement_resultIds = collections.deque(
Expand All @@ -357,7 +381,45 @@ async def _init_multi_statement_results(self, data: dict) -> None:
async def _log_telemetry_job_data(
self, telemetry_field: TelemetryField, value: Any
) -> None:
raise NotImplementedError("Telemetry is not supported in async.")
# TODO: async telemetry SNOW-1572217
pass

async def _preprocess_pyformat_query(
self,
command: str,
params: Sequence[Any] | dict[Any, Any] | None = None,
) -> str:
# pyformat/format paramstyle
# client side binding
processed_params = self._connection._process_params_pyformat(params, self)
# SNOW-513061 collect telemetry for empty sequence usage before we make the breaking change announcement
# TODO: async telemetry support
# if params is not None and len(params) == 0:
# await self._log_telemetry_job_data(
# TelemetryField.EMPTY_SEQ_INTERPOLATION,
# (
# TelemetryData.TRUE
# if self.connection._interpolate_empty_sequences
# else TelemetryData.FALSE
# ),
# )
if logger.getEffectiveLevel() <= logging.DEBUG:
logger.debug(
f"binding: [{self._format_query_for_log(command)}] "
f"with input=[{params}], "
f"processed=[{processed_params}]",
)
if (
self.connection._interpolate_empty_sequences
and processed_params is not None
) or (
not self.connection._interpolate_empty_sequences
and len(processed_params) > 0
):
query = command % processed_params
else:
query = command
return query

async def abort_query(self, qid: str) -> bool:
url = f"/queries/{qid}/abort-request"
Expand Down Expand Up @@ -387,6 +449,10 @@ async def callproc(self, procname: str, args=tuple()):
await self.execute(command, args)
return args

@property
def connection(self) -> SnowflakeConnection:
return self._connection

async def close(self):
"""Closes the cursor object.
Expand Down Expand Up @@ -471,7 +537,7 @@ async def execute(
}

if self._connection.is_pyformat:
query = self._preprocess_pyformat_query(command, params)
query = await self._preprocess_pyformat_query(command, params)
else:
# qmark and numeric paramstyle
query = command
Expand Down Expand Up @@ -538,7 +604,7 @@ async def execute(
self._connection.converter.set_parameter(param, value)

if "resultIds" in data:
self._init_multi_statement_results(data)
await self._init_multi_statement_results(data)
return self
else:
self.multi_statement_savedIds = []
Expand Down Expand Up @@ -707,7 +773,7 @@ async def executemany(
command = command + "; "
if self._connection.is_pyformat:
processed_queries = [
self._preprocess_pyformat_query(command, params)
await self._preprocess_pyformat_query(command, params)
for params in seqparams
]
query = "".join(processed_queries)
Expand Down Expand Up @@ -752,7 +818,7 @@ async def describe(self, *args: Any, **kwargs: Any) -> list[ResultMetadata]:
async def fetchone(self) -> dict | tuple | None:
"""Fetches one row."""
if self._prefetch_hook is not None:
self._prefetch_hook()
await self._prefetch_hook()
if self._result is None and self._result_set is not None:
self._result: ResultSetIterator = await self._result_set._create_iter()
self._result_state = ResultState.VALID
Expand Down Expand Up @@ -804,7 +870,7 @@ async def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:
async def fetchall(self) -> list[tuple] | list[dict]:
"""Fetches all of the results."""
if self._prefetch_hook is not None:
self._prefetch_hook()
await self._prefetch_hook()
if self._result is None and self._result_set is not None:
self._result: ResultSetIterator = await self._result_set._create_iter(
is_fetch_all=True
Expand All @@ -822,9 +888,10 @@ async def fetchall(self) -> list[tuple] | list[dict]:
async def fetch_arrow_batches(self) -> AsyncIterator[Table]:
self.check_can_use_arrow_resultset()
if self._prefetch_hook is not None:
self._prefetch_hook()
await self._prefetch_hook()
if self._query_result_format != "arrow":
raise NotSupportedError
# TODO: async telemetry SNOW-1572217
# self._log_telemetry_job_data(
# TelemetryField.ARROW_FETCH_BATCHES, TelemetryData.TRUE
# )
Expand All @@ -848,9 +915,10 @@ async def fetch_arrow_all(self, force_return_table: bool = False) -> Table | Non
self.check_can_use_arrow_resultset()

if self._prefetch_hook is not None:
self._prefetch_hook()
await self._prefetch_hook()
if self._query_result_format != "arrow":
raise NotSupportedError
# TODO: async telemetry SNOW-1572217
# self._log_telemetry_job_data(TelemetryField.ARROW_FETCH_ALL, TelemetryData.TRUE)
return await self._result_set._fetch_arrow_all(
force_return_table=force_return_table
Expand All @@ -860,7 +928,7 @@ async def fetch_pandas_batches(self, **kwargs: Any) -> AsyncIterator[DataFrame]:
"""Fetches a single Arrow Table."""
self.check_can_use_pandas()
if self._prefetch_hook is not None:
self._prefetch_hook()
await self._prefetch_hook()
if self._query_result_format != "arrow":
raise NotSupportedError
# TODO: async telemetry
Expand All @@ -872,7 +940,7 @@ async def fetch_pandas_batches(self, **kwargs: Any) -> AsyncIterator[DataFrame]:
async def fetch_pandas_all(self, **kwargs: Any) -> DataFrame:
self.check_can_use_pandas()
if self._prefetch_hook is not None:
self._prefetch_hook()
await self._prefetch_hook()
if self._query_result_format != "arrow":
raise NotSupportedError
# # TODO: async telemetry
Expand Down Expand Up @@ -917,8 +985,70 @@ async def get_result_batches(self) -> list[ResultBatch] | None:
return self._result_set.batches

async def get_results_from_sfqid(self, sfqid: str) -> None:
"""Gets the results from previously ran query."""
raise NotImplementedError("Not implemented in async")
"""Gets the results from previously ran query. This methods differs from ``SnowflakeCursor.query_result``
in that it monitors the ``sfqid`` until it is no longer running, and then retrieves the results.
"""

async def wait_until_ready() -> None:
"""Makes sure query has finished executing and once it has retrieves results."""
no_data_counter = 0
retry_pattern_pos = 0
while True:
status, status_resp = await self.connection._get_query_status(sfqid)
self.connection._cache_query_status(sfqid, status)
if not self.connection.is_still_running(status):
break
if status == QueryStatus.NO_DATA: # pragma: no cover
no_data_counter += 1
if no_data_counter > ASYNC_NO_DATA_MAX_RETRY:
raise DatabaseError(
"Cannot retrieve data on the status of this query. No information returned "
"from server for query '{}'"
)
await asyncio.sleep(
0.5 * ASYNC_RETRY_PATTERN[retry_pattern_pos]
) # Same wait as JDBC
# If we can advance in ASYNC_RETRY_PATTERN then do so
if retry_pattern_pos < (len(ASYNC_RETRY_PATTERN) - 1):
retry_pattern_pos += 1
if status != QueryStatus.SUCCESS:
logger.info(f"Status of query '{sfqid}' is {status.name}")
self.connection._process_error_query_status(
sfqid,
status_resp,
error_message=f"Status of query '{sfqid}' is {status.name}, results are unavailable",
error_cls=DatabaseError,
)
await self._inner_cursor.execute(
f"select * from table(result_scan('{sfqid}'))"
)
self._result = self._inner_cursor._result
self._query_result_format = self._inner_cursor._query_result_format
self._total_rowcount = self._inner_cursor._total_rowcount
self._description = self._inner_cursor._description
self._result_set = self._inner_cursor._result_set
self._result_state = ResultState.VALID
self._rownumber = 0
# Unset this function, so that we don't block anymore
self._prefetch_hook = None

if (
self._inner_cursor._total_rowcount == 1
and await self._inner_cursor.fetchall()
== [("Multiple statements executed successfully.",)]
):
url = f"/queries/{sfqid}/result"
ret = await self._connection.rest.request(url=url, method="get")
if "data" in ret and "resultIds" in ret["data"]:
await self._init_multi_statement_results(ret["data"])

await self.connection.get_query_status_throw_if_error(
sfqid
) # Trigger an exception if query failed
klass = self.__class__
self._inner_cursor = klass(self.connection)
self._sfqid = sfqid
self._prefetch_hook = wait_until_ready

async def query_result(self, qid: str) -> SnowflakeCursor:
url = f"/queries/{qid}/result"
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/connector/aio/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def close(self):
"""Closes all active and idle sessions in this session pool."""
if self._active_sessions:
logger.debug(f"Closing {len(self._active_sessions)} active sessions")
for s in itertools.chain(self._active_sessions, self._idle_sessions):
for s in itertools.chain(set(self._active_sessions), set(self._idle_sessions)):
try:
await s.close()
except Exception as e:
Expand Down Expand Up @@ -289,7 +289,7 @@ async def _token_request(self, request_type):
token=header_token,
)
if ret.get("success") and ret.get("data", {}).get("sessionToken"):
logger.debug("success: %s", ret)
logger.debug("success: %s", SecretDetector.mask_secrets(str(ret)))
await self.update_tokens(
ret["data"]["sessionToken"],
ret["data"].get("masterToken"),
Expand Down
Loading

0 comments on commit 9d43032

Please sign in to comment.