diff --git a/CHANGELOG.md b/CHANGELOG.md index 139ce6f7..841e4596 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 2.1.0 ## + +* add compression support to ydb sdk + ## 1.1.16 ## * alias `kikimr.public.sdk.python.client` is deprecated. use `import ydb` instead. diff --git a/test-requirements.txt b/test-requirements.txt index 041f7684..ee7b7f8d 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -10,7 +10,7 @@ docker==5.0.0 docker-compose==1.29.2 dockerpty==0.4.1 docopt==0.6.2 -grpcio==1.38.0 +grpcio>=1.38.0 idna==3.2 importlib-metadata==4.6.1 iniconfig==1.1.1 diff --git a/tests/aio/test_connection_pool.py b/tests/aio/test_connection_pool.py index 534bf31b..221f1e39 100644 --- a/tests/aio/test_connection_pool.py +++ b/tests/aio/test_connection_pool.py @@ -21,6 +21,25 @@ async def test_async_call(endpoint, database): await driver.stop() +@pytest.mark.asyncio +async def test_gzip_compression(endpoint, database): + driver_config = ydb.DriverConfig( + endpoint, + database, + credentials=ydb.construct_credentials_from_environ(), + root_certificates=ydb.load_ydb_root_certificate(), + compression=ydb.RPCCompression.Gzip, + ) + + driver = Driver(driver_config=driver_config) + + await driver.scheme_client.make_directory( + "/local/lol", + settings=ydb.BaseRequestSettings().with_compression(ydb.RPCCompression.Deflate), + ) + await driver.stop() + + @pytest.mark.asyncio async def test_other_credentials(endpoint, database): driver = Driver(endpoint=endpoint, database=database) diff --git a/ydb/connection.py b/ydb/connection.py index e3af3d5d..1500f4c8 100644 --- a/ydb/connection.py +++ b/ydb/connection.py @@ -300,14 +300,22 @@ def channel_factory( logger.debug("Channel options: {}".format(options)) if driver_config.root_certificates is None and not driver_config.secure_channel: - return channel_provider.insecure_channel(endpoint, options) + return channel_provider.insecure_channel( + endpoint, options, compression=getattr(driver_config, "compression", None) + ) + root_certificates = driver_config.root_certificates if root_certificates is None: root_certificates = default_pem.load_default_pem() credentials = grpc.ssl_channel_credentials( root_certificates, driver_config.private_key, driver_config.certificate_chain ) - return channel_provider.secure_channel(endpoint, credentials, options) + return channel_provider.secure_channel( + endpoint, + credentials, + options, + compression=getattr(driver_config, "compression", None), + ) class Connection(object): @@ -405,7 +413,12 @@ def future( rpc_state, timeout, metadata = self._prepare_call( stub, rpc_name, request, settings ) - rendezvous, result_future = rpc_state.future(request, timeout, metadata) + rendezvous, result_future = rpc_state.future( + request, + timeout, + metadata, + compression=getattr(settings, "compression", None), + ) rendezvous.add_done_callback( lambda resp_future: _on_response_callback( rpc_state, @@ -443,7 +456,12 @@ def __call__( stub, rpc_name, request, settings ) try: - response = rpc_state(request, timeout, metadata) + response = rpc_state( + request, + timeout, + metadata, + compression=getattr(settings, "compression", None), + ) _log_response(rpc_state, response) return ( response diff --git a/ydb/driver.py b/ydb/driver.py index f04fb3ad..da300373 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -3,6 +3,7 @@ from . import tracing import six import os +import grpc if six.PY2: Any = None @@ -39,6 +40,14 @@ def parse_connection_string(connection_string): return p.scheme + "://" + p.netloc, database[0] +class RPCCompression: + """Indicates the compression method to be used for an RPC.""" + + NoCompression = grpc.Compression.NoCompression + Deflate = grpc.Compression.Deflate + Gzip = grpc.Compression.Gzip + + def default_credentials(credentials=None, tracer=None): tracer = tracer if tracer is not None else tracing.Tracer(None) with tracer.trace("Driver.default_credentials") as ctx: @@ -94,6 +103,7 @@ class DriverConfig(object): "tracer", "grpc_lb_policy_name", "discovery_request_timeout", + "compression", ) def __init__( @@ -115,6 +125,7 @@ def __init__( tracer=None, grpc_lb_policy_name="round_robin", discovery_request_timeout=10, + compression=None, ): """ A driver config to initialize a driver instance @@ -159,6 +170,7 @@ def __init__( self.tracer = tracer if tracer is not None else tracing.Tracer(None) self.grpc_lb_policy_name = grpc_lb_policy_name self.discovery_request_timeout = discovery_request_timeout + self.compression = compression def set_database(self, database): self.database = database diff --git a/ydb/settings.py b/ydb/settings.py index 96094c49..e1e1f0f2 100644 --- a/ydb/settings.py +++ b/ydb/settings.py @@ -9,6 +9,7 @@ class BaseRequestSettings(object): "cancel_after", "operation_timeout", "tracer", + "compression", ) def __init__(self): @@ -20,6 +21,11 @@ def __init__(self): self.timeout = None self.cancel_after = None self.operation_timeout = None + self.compression = None + + def with_compression(self, compression): + self.compression = compression + return self def with_trace_id(self, trace_id): """