Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add compression support to ydb sdk #17

Merged
merged 3 commits into from
Feb 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
name: Python package

on: [push]
on:
push:
pull_request:

jobs:
build:
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__
ydb.egg-info/
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 2.1.0 ##

* add compression support to ydb sdk

## 1.1.16 ##

* alias `kikimr.public.sdk.python.client` is deprecated. use `import ydb` instead.
Expand Down
2 changes: 1 addition & 1 deletion test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ docker==5.0.0
docker-compose==1.29.2
dockerpty==0.4.1
docopt==0.6.2
grpcio==1.38.0
grpcio>=1.38.0
idna==3.2
importlib-metadata==4.6.1
iniconfig==1.1.1
Expand Down
4 changes: 2 additions & 2 deletions tests/aio/test_async_iter_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def test_read_shard_table(driver, database):

data_by_shard_id = {}
with session.transaction() as tx:
max_value = 2 ** 64
max_value = 2**64
shard_key_bound = max_value >> 3
data = []

Expand All @@ -75,7 +75,7 @@ async def test_read_shard_table(driver, database):
table_row = {
"Key1": shard_id * shard_key_bound + idx,
"Key2": idx + 1000,
"Value": str(idx ** 4),
"Value": str(idx**4),
}
data_by_shard_id[shard_id].append(table_row)
data.append(table_row)
Expand Down
19 changes: 19 additions & 0 deletions tests/aio/test_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,25 @@ async def test_async_call(endpoint, database):
await driver.stop()


@pytest.mark.asyncio
async def test_gzip_compression(endpoint, database):
driver_config = ydb.DriverConfig(
endpoint,
database,
credentials=ydb.construct_credentials_from_environ(),
root_certificates=ydb.load_ydb_root_certificate(),
compression=ydb.RPCCompression.Gzip,
)

driver = Driver(driver_config=driver_config)

await driver.scheme_client.make_directory(
"/local/lol",
settings=ydb.BaseRequestSettings().with_compression(ydb.RPCCompression.Deflate),
)
await driver.stop()


@pytest.mark.asyncio
async def test_other_credentials(endpoint, database):
driver = Driver(endpoint=endpoint, database=database)
Expand Down
30 changes: 24 additions & 6 deletions ydb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _construct_channel_options(driver_config, endpoint_options=None):
:param endpoint_options: Endpoint options
:return: A channel initialization options
"""
_max_message_size = 64 * 10 ** 6
_max_message_size = 64 * 10**6
_default_connect_options = [
("grpc.max_receive_message_length", _max_message_size),
("grpc.max_send_message_length", _max_message_size),
Expand Down Expand Up @@ -269,7 +269,7 @@ def _cancel_callback(f):
return self.rendezvous, self.result_future


_nanos_in_second = 10 ** 9
_nanos_in_second = 10**9


def _set_duration(duration_value, seconds_float):
Expand Down Expand Up @@ -300,14 +300,22 @@ def channel_factory(
logger.debug("Channel options: {}".format(options))

if driver_config.root_certificates is None and not driver_config.secure_channel:
return channel_provider.insecure_channel(endpoint, options)
return channel_provider.insecure_channel(
endpoint, options, compression=getattr(driver_config, "compression", None)
)

root_certificates = driver_config.root_certificates
if root_certificates is None:
root_certificates = default_pem.load_default_pem()
credentials = grpc.ssl_channel_credentials(
root_certificates, driver_config.private_key, driver_config.certificate_chain
)
return channel_provider.secure_channel(endpoint, credentials, options)
return channel_provider.secure_channel(
endpoint,
credentials,
options,
compression=getattr(driver_config, "compression", None),
)


class Connection(object):
Expand Down Expand Up @@ -405,7 +413,12 @@ def future(
rpc_state, timeout, metadata = self._prepare_call(
stub, rpc_name, request, settings
)
rendezvous, result_future = rpc_state.future(request, timeout, metadata)
rendezvous, result_future = rpc_state.future(
request,
timeout,
metadata,
compression=getattr(settings, "compression", None),
)
rendezvous.add_done_callback(
lambda resp_future: _on_response_callback(
rpc_state,
Expand Down Expand Up @@ -443,7 +456,12 @@ def __call__(
stub, rpc_name, request, settings
)
try:
response = rpc_state(request, timeout, metadata)
response = rpc_state(
request,
timeout,
metadata,
compression=getattr(settings, "compression", None),
)
_log_response(rpc_state, response)
return (
response
Expand Down
12 changes: 6 additions & 6 deletions ydb/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@


_SHIFT_BIT_COUNT = 64
_SHIFT = 2 ** 64
_SIGN_BIT = 2 ** 63
_DecimalNanRepr = 10 ** 35 + 1
_DecimalInfRepr = 10 ** 35
_DecimalSignedInfRepr = -(10 ** 35)
_SHIFT = 2**64
_SIGN_BIT = 2**63
_DecimalNanRepr = 10**35 + 1
_DecimalInfRepr = 10**35
_DecimalSignedInfRepr = -(10**35)
_primitive_type_by_id = {}


Expand Down Expand Up @@ -49,7 +49,7 @@ def _pb_to_decimal(type_pb, value_pb, table_client_settings):
elif int128_value == _DecimalSignedInfRepr:
return decimal.Decimal("-Inf")
return decimal.Decimal(int128_value) / decimal.Decimal(
10 ** type_pb.decimal_type.scale
10**type_pb.decimal_type.scale
)


Expand Down
22 changes: 15 additions & 7 deletions ydb/driver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# -*- coding: utf-8 -*-
import ydb

from . import credentials as credentials_impl, table, scheme, pool
from . import tracing
import six
import os
import grpc

if six.PY2:
Any = None
Expand Down Expand Up @@ -40,8 +40,16 @@ def parse_connection_string(connection_string):
return p.scheme + "://" + p.netloc, database[0]


class RPCCompression:
"""Indicates the compression method to be used for an RPC."""

NoCompression = grpc.Compression.NoCompression
Deflate = grpc.Compression.Deflate
Gzip = grpc.Compression.Gzip


def default_credentials(credentials=None, tracer=None):
tracer = tracer if tracer is not None else ydb.Tracer(None)
tracer = tracer if tracer is not None else tracing.Tracer(None)
with tracer.trace("Driver.default_credentials") as ctx:
if credentials is not None:
ctx.trace({"credentials.prepared": True})
Expand Down Expand Up @@ -95,6 +103,7 @@ class DriverConfig(object):
"tracer",
"grpc_lb_policy_name",
"discovery_request_timeout",
"compression",
)

def __init__(
Expand All @@ -116,8 +125,8 @@ def __init__(
tracer=None,
grpc_lb_policy_name="round_robin",
discovery_request_timeout=10,
compression=None,
):
# type:(str, str, str, str, Any, ydb.Credentials, bool, bytes, bytes, bytes, float, ydb.TableClientSettings, list, str, ydb.Tracer) -> None
"""
A driver config to initialize a driver instance

Expand Down Expand Up @@ -158,9 +167,10 @@ def __init__(
self.grpc_keep_alive_timeout = grpc_keep_alive_timeout
self.table_client_settings = table_client_settings
self.primary_user_agent = primary_user_agent
self.tracer = tracer if tracer is not None else ydb.Tracer(None)
self.tracer = tracer if tracer is not None else tracing.Tracer(None)
self.grpc_lb_policy_name = grpc_lb_policy_name
self.discovery_request_timeout = discovery_request_timeout
self.compression = compression

def set_database(self, database):
self.database = database
Expand Down Expand Up @@ -235,8 +245,6 @@ def __init__(
credentials=None,
**kwargs
):
# type:(DriverConfig, str, str, str, bytes, ydb.AbstractCredentials, **Any) -> None

"""
Constructs a driver instance to be used in table and scheme clients.
It encapsulates endpoints discovery mechanism and provides ability to execute RPCs
Expand Down
6 changes: 6 additions & 0 deletions ydb/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class BaseRequestSettings(object):
"cancel_after",
"operation_timeout",
"tracer",
"compression",
)

def __init__(self):
Expand All @@ -20,6 +21,11 @@ def __init__(self):
self.timeout = None
self.cancel_after = None
self.operation_timeout = None
self.compression = None

def with_compression(self, compression):
self.compression = compression
return self

def with_trace_id(self, trace_id):
"""
Expand Down
2 changes: 1 addition & 1 deletion ydb/ydb_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = "1.1.15"
VERSION = "2.1.0"