From 952f7b4f179595b07711934691d2e15643929cc5 Mon Sep 17 00:00:00 2001 From: Vitalii Gridnev Date: Sun, 20 Feb 2022 21:01:47 +0300 Subject: [PATCH 1/3] fix tests --- .gitignore | 2 ++ tests/aio/test_async_iter_stream.py | 4 ++-- ydb/connection.py | 4 ++-- ydb/convert.py | 12 ++++++------ ydb/driver.py | 7 +++---- ydb/ydb_version.py | 2 +- 6 files changed, 16 insertions(+), 15 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..95f23d1c --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +ydb.egg-info/ diff --git a/tests/aio/test_async_iter_stream.py b/tests/aio/test_async_iter_stream.py index 799cb7aa..c8810499 100644 --- a/tests/aio/test_async_iter_stream.py +++ b/tests/aio/test_async_iter_stream.py @@ -63,7 +63,7 @@ async def test_read_shard_table(driver, database): data_by_shard_id = {} with session.transaction() as tx: - max_value = 2 ** 64 + max_value = 2**64 shard_key_bound = max_value >> 3 data = [] @@ -75,7 +75,7 @@ async def test_read_shard_table(driver, database): table_row = { "Key1": shard_id * shard_key_bound + idx, "Key2": idx + 1000, - "Value": str(idx ** 4), + "Value": str(idx**4), } data_by_shard_id[shard_id].append(table_row) data.append(table_row) diff --git a/ydb/connection.py b/ydb/connection.py index 573b276e..e3af3d5d 100644 --- a/ydb/connection.py +++ b/ydb/connection.py @@ -176,7 +176,7 @@ def _construct_channel_options(driver_config, endpoint_options=None): :param endpoint_options: Endpoint options :return: A channel initialization options """ - _max_message_size = 64 * 10 ** 6 + _max_message_size = 64 * 10**6 _default_connect_options = [ ("grpc.max_receive_message_length", _max_message_size), ("grpc.max_send_message_length", _max_message_size), @@ -269,7 +269,7 @@ def _cancel_callback(f): return self.rendezvous, self.result_future -_nanos_in_second = 10 ** 9 +_nanos_in_second = 10**9 def _set_duration(duration_value, seconds_float): diff --git a/ydb/convert.py b/ydb/convert.py index 0b0176e0..2be209bf 100644 --- a/ydb/convert.py +++ b/ydb/convert.py @@ -7,11 +7,11 @@ _SHIFT_BIT_COUNT = 64 -_SHIFT = 2 ** 64 -_SIGN_BIT = 2 ** 63 -_DecimalNanRepr = 10 ** 35 + 1 -_DecimalInfRepr = 10 ** 35 -_DecimalSignedInfRepr = -(10 ** 35) +_SHIFT = 2**64 +_SIGN_BIT = 2**63 +_DecimalNanRepr = 10**35 + 1 +_DecimalInfRepr = 10**35 +_DecimalSignedInfRepr = -(10**35) _primitive_type_by_id = {} @@ -49,7 +49,7 @@ def _pb_to_decimal(type_pb, value_pb, table_client_settings): elif int128_value == _DecimalSignedInfRepr: return decimal.Decimal("-Inf") return decimal.Decimal(int128_value) / decimal.Decimal( - 10 ** type_pb.decimal_type.scale + 10**type_pb.decimal_type.scale ) diff --git a/ydb/driver.py b/ydb/driver.py index e14f9017..58f0a995 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- -import ydb - from . import credentials as credentials_impl, table, scheme, pool +from . import tracing import six import os @@ -41,7 +40,7 @@ def parse_connection_string(connection_string): def default_credentials(credentials=None, tracer=None): - tracer = tracer if tracer is not None else ydb.Tracer(None) + tracer = tracer if tracer is not None else tracing.Tracer(None) with tracer.trace("Driver.default_credentials") as ctx: if credentials is not None: ctx.trace({"credentials.prepared": True}) @@ -158,7 +157,7 @@ def __init__( self.grpc_keep_alive_timeout = grpc_keep_alive_timeout self.table_client_settings = table_client_settings self.primary_user_agent = primary_user_agent - self.tracer = tracer if tracer is not None else ydb.Tracer(None) + self.tracer = tracer if tracer is not None else tracing.Tracer(None) self.grpc_lb_policy_name = grpc_lb_policy_name self.discovery_request_timeout = discovery_request_timeout diff --git a/ydb/ydb_version.py b/ydb/ydb_version.py index 0f7ea410..127c148a 100644 --- a/ydb/ydb_version.py +++ b/ydb/ydb_version.py @@ -1 +1 @@ -VERSION = "1.1.15" +VERSION = "2.1.0" From bd7a4e3ec52d221f20ee8adfeac0dcd2a9d8833d Mon Sep 17 00:00:00 2001 From: Vitalii Gridnev Date: Sun, 20 Feb 2022 21:01:47 +0300 Subject: [PATCH 2/3] fix tests --- .github/workflows/main.yaml | 4 +++- ydb/driver.py | 3 --- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 4be7ce54..c3338498 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -1,6 +1,8 @@ name: Python package -on: [push] +on: + push: + pull_request: jobs: build: diff --git a/ydb/driver.py b/ydb/driver.py index 58f0a995..f04fb3ad 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -116,7 +116,6 @@ def __init__( grpc_lb_policy_name="round_robin", discovery_request_timeout=10, ): - # type:(str, str, str, str, Any, ydb.Credentials, bool, bytes, bytes, bytes, float, ydb.TableClientSettings, list, str, ydb.Tracer) -> None """ A driver config to initialize a driver instance @@ -234,8 +233,6 @@ def __init__( credentials=None, **kwargs ): - # type:(DriverConfig, str, str, str, bytes, ydb.AbstractCredentials, **Any) -> None - """ Constructs a driver instance to be used in table and scheme clients. It encapsulates endpoints discovery mechanism and provides ability to execute RPCs From 489b74329a470bf65300c7b8dafa57e883e61590 Mon Sep 17 00:00:00 2001 From: Vitalii Gridnev Date: Sun, 20 Feb 2022 23:48:13 +0300 Subject: [PATCH 3/3] add compression support to ydb sdk --- CHANGELOG.md | 4 ++++ test-requirements.txt | 2 +- tests/aio/test_connection_pool.py | 19 +++++++++++++++++++ ydb/connection.py | 26 ++++++++++++++++++++++---- ydb/driver.py | 12 ++++++++++++ ydb/settings.py | 6 ++++++ 6 files changed, 64 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 139ce6f7..841e4596 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 2.1.0 ## + +* add compression support to ydb sdk + ## 1.1.16 ## * alias `kikimr.public.sdk.python.client` is deprecated. use `import ydb` instead. diff --git a/test-requirements.txt b/test-requirements.txt index 041f7684..ee7b7f8d 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -10,7 +10,7 @@ docker==5.0.0 docker-compose==1.29.2 dockerpty==0.4.1 docopt==0.6.2 -grpcio==1.38.0 +grpcio>=1.38.0 idna==3.2 importlib-metadata==4.6.1 iniconfig==1.1.1 diff --git a/tests/aio/test_connection_pool.py b/tests/aio/test_connection_pool.py index 534bf31b..221f1e39 100644 --- a/tests/aio/test_connection_pool.py +++ b/tests/aio/test_connection_pool.py @@ -21,6 +21,25 @@ async def test_async_call(endpoint, database): await driver.stop() +@pytest.mark.asyncio +async def test_gzip_compression(endpoint, database): + driver_config = ydb.DriverConfig( + endpoint, + database, + credentials=ydb.construct_credentials_from_environ(), + root_certificates=ydb.load_ydb_root_certificate(), + compression=ydb.RPCCompression.Gzip, + ) + + driver = Driver(driver_config=driver_config) + + await driver.scheme_client.make_directory( + "/local/lol", + settings=ydb.BaseRequestSettings().with_compression(ydb.RPCCompression.Deflate), + ) + await driver.stop() + + @pytest.mark.asyncio async def test_other_credentials(endpoint, database): driver = Driver(endpoint=endpoint, database=database) diff --git a/ydb/connection.py b/ydb/connection.py index e3af3d5d..1500f4c8 100644 --- a/ydb/connection.py +++ b/ydb/connection.py @@ -300,14 +300,22 @@ def channel_factory( logger.debug("Channel options: {}".format(options)) if driver_config.root_certificates is None and not driver_config.secure_channel: - return channel_provider.insecure_channel(endpoint, options) + return channel_provider.insecure_channel( + endpoint, options, compression=getattr(driver_config, "compression", None) + ) + root_certificates = driver_config.root_certificates if root_certificates is None: root_certificates = default_pem.load_default_pem() credentials = grpc.ssl_channel_credentials( root_certificates, driver_config.private_key, driver_config.certificate_chain ) - return channel_provider.secure_channel(endpoint, credentials, options) + return channel_provider.secure_channel( + endpoint, + credentials, + options, + compression=getattr(driver_config, "compression", None), + ) class Connection(object): @@ -405,7 +413,12 @@ def future( rpc_state, timeout, metadata = self._prepare_call( stub, rpc_name, request, settings ) - rendezvous, result_future = rpc_state.future(request, timeout, metadata) + rendezvous, result_future = rpc_state.future( + request, + timeout, + metadata, + compression=getattr(settings, "compression", None), + ) rendezvous.add_done_callback( lambda resp_future: _on_response_callback( rpc_state, @@ -443,7 +456,12 @@ def __call__( stub, rpc_name, request, settings ) try: - response = rpc_state(request, timeout, metadata) + response = rpc_state( + request, + timeout, + metadata, + compression=getattr(settings, "compression", None), + ) _log_response(rpc_state, response) return ( response diff --git a/ydb/driver.py b/ydb/driver.py index f04fb3ad..da300373 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -3,6 +3,7 @@ from . import tracing import six import os +import grpc if six.PY2: Any = None @@ -39,6 +40,14 @@ def parse_connection_string(connection_string): return p.scheme + "://" + p.netloc, database[0] +class RPCCompression: + """Indicates the compression method to be used for an RPC.""" + + NoCompression = grpc.Compression.NoCompression + Deflate = grpc.Compression.Deflate + Gzip = grpc.Compression.Gzip + + def default_credentials(credentials=None, tracer=None): tracer = tracer if tracer is not None else tracing.Tracer(None) with tracer.trace("Driver.default_credentials") as ctx: @@ -94,6 +103,7 @@ class DriverConfig(object): "tracer", "grpc_lb_policy_name", "discovery_request_timeout", + "compression", ) def __init__( @@ -115,6 +125,7 @@ def __init__( tracer=None, grpc_lb_policy_name="round_robin", discovery_request_timeout=10, + compression=None, ): """ A driver config to initialize a driver instance @@ -159,6 +170,7 @@ def __init__( self.tracer = tracer if tracer is not None else tracing.Tracer(None) self.grpc_lb_policy_name = grpc_lb_policy_name self.discovery_request_timeout = discovery_request_timeout + self.compression = compression def set_database(self, database): self.database = database diff --git a/ydb/settings.py b/ydb/settings.py index 96094c49..e1e1f0f2 100644 --- a/ydb/settings.py +++ b/ydb/settings.py @@ -9,6 +9,7 @@ class BaseRequestSettings(object): "cancel_after", "operation_timeout", "tracer", + "compression", ) def __init__(self): @@ -20,6 +21,11 @@ def __init__(self): self.timeout = None self.cancel_after = None self.operation_timeout = None + self.compression = None + + def with_compression(self, compression): + self.compression = compression + return self def with_trace_id(self, trace_id): """