Skip to content

Commit

Permalink
SNOW-1654538: asyncio download timeout setting (#2063)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aling committed Sep 27, 2024
1 parent 11bcfc9 commit 1465274
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 62 deletions.
32 changes: 20 additions & 12 deletions src/snowflake/connector/aio/_result_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@
get_http_retryable_error,
is_retryable_http_code,
)
from snowflake.connector.result_batch import (
MAX_DOWNLOAD_RETRY,
SSE_C_AES,
SSE_C_ALGORITHM,
SSE_C_KEY,
)
from snowflake.connector.result_batch import SSE_C_AES, SSE_C_ALGORITHM, SSE_C_KEY
from snowflake.connector.result_batch import ArrowResultBatch as ArrowResultBatchSync
from snowflake.connector.result_batch import DownloadMetrics
from snowflake.connector.result_batch import JSONResultBatch as JSONResultBatchSync
Expand All @@ -52,8 +47,13 @@

logger = getLogger(__name__)

# we redefine the DOWNLOAD_TIMEOUT and MAX_DOWNLOAD_RETRY for async version on purpose
# because download in sync and async are different in nature and may require separate tuning
# also be aware that currently _result_batch is a private module so these values are not exposed to users directly
DOWNLOAD_TIMEOUT = None
MAX_DOWNLOAD_RETRY = 10


# TODO: consolidate this with the sync version
def create_batches_from_response(
cursor: SnowflakeCursor,
_format: str,
Expand Down Expand Up @@ -212,19 +212,27 @@ async def download_chunk(http_session):
return response, content, encoding

content, encoding = None, None
for retry in range(MAX_DOWNLOAD_RETRY):
for retry in range(max(MAX_DOWNLOAD_RETRY, 1)):
try:
# TODO: feature parity with download timeout setting, in sync it's set to 7s
# but in async we schedule multiple tasks at the same time so some tasks might
# take longer than 7s to finish which is expected

async with TimerContextManager() as download_metric:
logger.debug(f"started downloading result batch id: {self.id}")
chunk_url = self._remote_chunk_info.url
request_data = {
"url": chunk_url,
"headers": self._chunk_headers,
# "timeout": DOWNLOAD_TIMEOUT,
}
# timeout setting for download is different from the sync version which has an
# empirical value 7 seconds. It is difficult to measure this empirical value in async
# as we maximize the network throughput by downloading multiple chunks at the same time compared
# to the sync version that the overall throughput is constrained by the number of
# prefetch threads -- in asyncio we see great download performance improvement.
# if DOWNLOAD_TIMEOUT is not set, by default the aiohttp session timeout comes into effect
# which originates from the connection config.
if DOWNLOAD_TIMEOUT:
request_data["timeout"] = aiohttp.ClientTimeout(
total=DOWNLOAD_TIMEOUT
)
# Try to reuse a connection if possible
if connection and connection._rest is not None:
async with connection._rest._use_requests_session() as session:
Expand Down
88 changes: 38 additions & 50 deletions test/integ/aio/test_cursor_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
import pickle
import time
from datetime import date, datetime, timezone
from typing import TYPE_CHECKING, NamedTuple
from unittest import mock

import pytest
import pytz

import snowflake.connector
import snowflake.connector.aio
from snowflake.connector import (
InterfaceError,
NotSupportedError,
Expand All @@ -30,64 +30,31 @@
errors,
)
from snowflake.connector.aio import DictCursor, SnowflakeCursor
from snowflake.connector.aio._result_batch import (
ArrowResultBatch,
JSONResultBatch,
ResultBatch,
)
from snowflake.connector.compat import IS_WINDOWS

try:
from snowflake.connector.cursor import ResultMetadata
except ImportError:

class ResultMetadata(NamedTuple):
name: str
type_code: int
display_size: int
internal_size: int
precision: int
scale: int
is_nullable: bool


import snowflake.connector.aio
from snowflake.connector.constants import (
FIELD_ID_TO_NAME,
PARAMETER_MULTI_STATEMENT_COUNT,
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT,
QueryStatus,
)
from snowflake.connector.cursor import ResultMetadata
from snowflake.connector.description import CLIENT_VERSION
from snowflake.connector.errorcode import (
ER_FAILED_TO_REWRITE_MULTI_ROW_INSERT,
ER_NO_ARROW_RESULT,
ER_NO_PYARROW,
ER_NO_PYARROW_SNOWSQL,
ER_NOT_POSITIVE_SIZE,
)
from snowflake.connector.errors import Error
from snowflake.connector.sqlstate import SQLSTATE_FEATURE_NOT_SUPPORTED
from snowflake.connector.telemetry import TelemetryField

try:
from snowflake.connector.util_text import random_string
except ImportError:
from ..randomize import random_string

try:
from snowflake.connector.aio._result_batch import ArrowResultBatch, JSONResultBatch
from snowflake.connector.constants import (
FIELD_ID_TO_NAME,
PARAMETER_MULTI_STATEMENT_COUNT,
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT,
)
from snowflake.connector.errorcode import (
ER_NO_ARROW_RESULT,
ER_NO_PYARROW,
ER_NO_PYARROW_SNOWSQL,
)
except ImportError:
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = None
ER_NO_ARROW_RESULT = None
ER_NO_PYARROW = None
ER_NO_PYARROW_SNOWSQL = None
ArrowResultBatch = JSONResultBatch = None
FIELD_ID_TO_NAME = {}

if TYPE_CHECKING: # pragma: no cover
from snowflake.connector.result_batch import ResultBatch

try: # pragma: no cover
from snowflake.connector.constants import QueryStatus
except ImportError:
QueryStatus = None
from snowflake.connector.util_text import random_string


@pytest.fixture
Expand Down Expand Up @@ -1824,3 +1791,24 @@ async def test_decoding_utf8_for_json_result(conn_cnx):
)
with pytest.raises(Error):
await result_batch._load("À".encode("latin1"), "latin1")


async def test_fetch_download_timeout_setting(conn_cnx):
with mock.patch.multiple(
"snowflake.connector.aio._result_batch",
DOWNLOAD_TIMEOUT=0.001,
MAX_DOWNLOAD_RETRY=2,
):
sql = "SELECT seq4(), uniform(1, 10, RANDOM(12)) FROM TABLE(GENERATOR(ROWCOUNT => 100000)) v"
async with conn_cnx() as con, con.cursor() as cur:
with pytest.raises(asyncio.TimeoutError):
await (await cur.execute(sql)).fetchall()

with mock.patch.multiple(
"snowflake.connector.aio._result_batch",
DOWNLOAD_TIMEOUT=10,
MAX_DOWNLOAD_RETRY=1,
):
sql = "SELECT seq4(), uniform(1, 10, RANDOM(12)) FROM TABLE(GENERATOR(ROWCOUNT => 100000)) v"
async with conn_cnx() as con, con.cursor() as cur:
assert len(await (await cur.execute(sql)).fetchall()) == 100000

0 comments on commit 1465274

Please sign in to comment.