Skip to content

Commit

Permalink
Merge pull request #17 from gridnevvvit/add-compression-support
Browse files Browse the repository at this point in the history
add compression support to ydb sdk
  • Loading branch information
gridnevvvit authored Feb 20, 2022
2 parents c6ce18a + 489b743 commit c47b488
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 2.1.0 ##

* add compression support to ydb sdk

## 1.1.16 ##

* alias `kikimr.public.sdk.python.client` is deprecated. use `import ydb` instead.
Expand Down
2 changes: 1 addition & 1 deletion test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ docker==5.0.0
docker-compose==1.29.2
dockerpty==0.4.1
docopt==0.6.2
grpcio==1.38.0
grpcio>=1.38.0
idna==3.2
importlib-metadata==4.6.1
iniconfig==1.1.1
Expand Down
19 changes: 19 additions & 0 deletions tests/aio/test_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,25 @@ async def test_async_call(endpoint, database):
await driver.stop()


@pytest.mark.asyncio
async def test_gzip_compression(endpoint, database):
driver_config = ydb.DriverConfig(
endpoint,
database,
credentials=ydb.construct_credentials_from_environ(),
root_certificates=ydb.load_ydb_root_certificate(),
compression=ydb.RPCCompression.Gzip,
)

driver = Driver(driver_config=driver_config)

await driver.scheme_client.make_directory(
"/local/lol",
settings=ydb.BaseRequestSettings().with_compression(ydb.RPCCompression.Deflate),
)
await driver.stop()


@pytest.mark.asyncio
async def test_other_credentials(endpoint, database):
driver = Driver(endpoint=endpoint, database=database)
Expand Down
26 changes: 22 additions & 4 deletions ydb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,14 +300,22 @@ def channel_factory(
logger.debug("Channel options: {}".format(options))

if driver_config.root_certificates is None and not driver_config.secure_channel:
return channel_provider.insecure_channel(endpoint, options)
return channel_provider.insecure_channel(
endpoint, options, compression=getattr(driver_config, "compression", None)
)

root_certificates = driver_config.root_certificates
if root_certificates is None:
root_certificates = default_pem.load_default_pem()
credentials = grpc.ssl_channel_credentials(
root_certificates, driver_config.private_key, driver_config.certificate_chain
)
return channel_provider.secure_channel(endpoint, credentials, options)
return channel_provider.secure_channel(
endpoint,
credentials,
options,
compression=getattr(driver_config, "compression", None),
)


class Connection(object):
Expand Down Expand Up @@ -405,7 +413,12 @@ def future(
rpc_state, timeout, metadata = self._prepare_call(
stub, rpc_name, request, settings
)
rendezvous, result_future = rpc_state.future(request, timeout, metadata)
rendezvous, result_future = rpc_state.future(
request,
timeout,
metadata,
compression=getattr(settings, "compression", None),
)
rendezvous.add_done_callback(
lambda resp_future: _on_response_callback(
rpc_state,
Expand Down Expand Up @@ -443,7 +456,12 @@ def __call__(
stub, rpc_name, request, settings
)
try:
response = rpc_state(request, timeout, metadata)
response = rpc_state(
request,
timeout,
metadata,
compression=getattr(settings, "compression", None),
)
_log_response(rpc_state, response)
return (
response
Expand Down
12 changes: 12 additions & 0 deletions ydb/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from . import tracing
import six
import os
import grpc

if six.PY2:
Any = None
Expand Down Expand Up @@ -39,6 +40,14 @@ def parse_connection_string(connection_string):
return p.scheme + "://" + p.netloc, database[0]


class RPCCompression:
"""Indicates the compression method to be used for an RPC."""

NoCompression = grpc.Compression.NoCompression
Deflate = grpc.Compression.Deflate
Gzip = grpc.Compression.Gzip


def default_credentials(credentials=None, tracer=None):
tracer = tracer if tracer is not None else tracing.Tracer(None)
with tracer.trace("Driver.default_credentials") as ctx:
Expand Down Expand Up @@ -94,6 +103,7 @@ class DriverConfig(object):
"tracer",
"grpc_lb_policy_name",
"discovery_request_timeout",
"compression",
)

def __init__(
Expand All @@ -115,6 +125,7 @@ def __init__(
tracer=None,
grpc_lb_policy_name="round_robin",
discovery_request_timeout=10,
compression=None,
):
"""
A driver config to initialize a driver instance
Expand Down Expand Up @@ -159,6 +170,7 @@ def __init__(
self.tracer = tracer if tracer is not None else tracing.Tracer(None)
self.grpc_lb_policy_name = grpc_lb_policy_name
self.discovery_request_timeout = discovery_request_timeout
self.compression = compression

def set_database(self, database):
self.database = database
Expand Down
6 changes: 6 additions & 0 deletions ydb/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class BaseRequestSettings(object):
"cancel_after",
"operation_timeout",
"tracer",
"compression",
)

def __init__(self):
Expand All @@ -20,6 +21,11 @@ def __init__(self):
self.timeout = None
self.cancel_after = None
self.operation_timeout = None
self.compression = None

def with_compression(self, compression):
self.compression = compression
return self

def with_trace_id(self, trace_id):
"""
Expand Down

0 comments on commit c47b488

Please sign in to comment.