Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow to refer endpoints by node id #19

Merged
merged 1 commit into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 2.2.0 ##

* allow to refer endpoints by node id

## 2.1.0 ##

* add compression support to ydb sdk
Expand Down
2 changes: 1 addition & 1 deletion ydb/_session_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def initialize_session(rpc_state, response_pb, session_state, session):
issues._process_response(response_pb.operation)
message = _apis.ydb_table.CreateSessionResult()
response_pb.operation.result.Unpack(message)
session_state.set_id(message.session_id).attach_endpoint(rpc_state.endpoint)
session_state.set_id(message.session_id).attach_endpoint(rpc_state.endpoint_key)
return session


Expand Down
15 changes: 12 additions & 3 deletions ydb/aio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
YDB_DATABASE_HEADER,
YDB_TRACE_ID_HEADER,
YDB_REQUEST_TYPE_HEADER,
EndpointKey,
)
from ydb.driver import DriverConfig
from ydb.settings import BaseRequestSettings
Expand Down Expand Up @@ -71,8 +72,8 @@ class _RpcState(RpcState):
"_trailing_metadata",
)

def __init__(self, stub_instance: Any, rpc_name: str, endpoint: str):
super().__init__(stub_instance, rpc_name, endpoint)
def __init__(self, stub_instance: Any, rpc_name: str, endpoint: str, endpoint_key):
super().__init__(stub_instance, rpc_name, endpoint, endpoint_key)

async def __call__(self, *args, **kwargs):
resp = self.rpc(*args, **kwargs)
Expand Down Expand Up @@ -105,6 +106,8 @@ class Connection:
"lock",
"calls",
"closing",
"endpoint_key",
"node_id",
)

def __init__(
Expand All @@ -115,6 +118,10 @@ def __init__(
):
global _stubs_list
self.endpoint = endpoint
self.endpoint_key = EndpointKey(
self.endpoint, getattr(endpoint_options, "node_id", None)
)
self.node_id = getattr(endpoint_options, "node_id", None)
self._channel = channel_factory(
self.endpoint, driver_config, grpc.aio, endpoint_options=endpoint_options
)
Expand All @@ -141,7 +148,9 @@ async def _prepare_call(
)
_set_server_timeouts(request, settings, timeout)
self._prepare_stub_instance(stub)
rpc_state = _RpcState(self._stub_instances[stub], rpc_name, self.endpoint)
rpc_state = _RpcState(
self._stub_instances[stub], rpc_name, self.endpoint, self.endpoint_key
)
logger.debug("%s: creating call state", rpc_state)

if self.closing:
Expand Down
14 changes: 13 additions & 1 deletion ydb/aio/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@ async def get(self, preferred_endpoint=None, fast_fail=False, wait_timeout=10):
else:
await asyncio.wait_for(self._event.wait(), timeout=wait_timeout)

if preferred_endpoint is not None and preferred_endpoint in self.connections:
if (
preferred_endpoint is not None
and preferred_endpoint.node_id in self.connections_by_node_id
):
return self.connections_by_node_id[preferred_endpoint.node_id]

if (
preferred_endpoint is not None
and preferred_endpoint.endpoint in self.connections
):
return self.connections[preferred_endpoint]

for conn_lst in self.conn_lst_order:
Expand All @@ -52,6 +61,8 @@ def add(self, connection, preferred=False):

if preferred:
self.preferred[connection.endpoint] = connection

self.connections_by_node_id[connection.node_id] = connection
self.connections[connection.endpoint] = connection

self._event.set()
Expand All @@ -66,6 +77,7 @@ def complete_discovery(self, error):
self._fast_fail_event.set()

def remove(self, connection):
self.connections_by_node_id.pop(connection.node_id, None)
self.preferred.pop(connection.endpoint, None)
self.connections.pop(connection.endpoint, None)
self.outdated.pop(connection.endpoint, None)
Expand Down
27 changes: 23 additions & 4 deletions ydb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,11 @@ def _get_request_timeout(settings):


class EndpointOptions(object):
__slots__ = ("ssl_target_name_override",)
__slots__ = ("ssl_target_name_override", "node_id")

def __init__(self, ssl_target_name_override=None):
def __init__(self, ssl_target_name_override=None, node_id=None):
self.ssl_target_name_override = ssl_target_name_override
self.node_id = node_id


def _construct_channel_options(driver_config, endpoint_options=None):
Expand Down Expand Up @@ -223,16 +224,18 @@ class _RpcState(object):
"endpoint",
"rendezvous",
"metadata_kv",
"endpoint_key",
)

def __init__(self, stub_instance, rpc_name, endpoint):
def __init__(self, stub_instance, rpc_name, endpoint, endpoint_key):
"""Stores all RPC related data"""
self.rpc_name = rpc_name
self.rpc = getattr(stub_instance, rpc_name)
self.request_id = uuid.uuid4()
self.endpoint = endpoint
self.rendezvous = None
self.metadata_kv = None
self.endpoint_key = endpoint_key

def __str__(self):
return "RpcState(%s, %s, %s)" % (self.rpc_name, self.request_id, self.endpoint)
Expand Down Expand Up @@ -318,6 +321,14 @@ def channel_factory(
)


class EndpointKey(object):
__slots__ = ("endpoint", "node_id")

def __init__(self, endpoint, node_id):
self.endpoint = endpoint
self.node_id = node_id


class Connection(object):
__slots__ = (
"endpoint",
Expand All @@ -330,6 +341,8 @@ class Connection(object):
"lock",
"calls",
"closing",
"endpoint_key",
"node_id",
)

def __init__(self, endpoint, driver_config=None, endpoint_options=None):
Expand All @@ -341,6 +354,10 @@ def __init__(self, endpoint, driver_config=None, endpoint_options=None):
"""
global _stubs_list
self.endpoint = endpoint
self.node_id = getattr(endpoint_options, "node_id", None)
self.endpoint_key = EndpointKey(
endpoint, getattr(endpoint_options, "node_id", None)
)
self._channel = channel_factory(
self.endpoint, driver_config, endpoint_options=endpoint_options
)
Expand Down Expand Up @@ -368,7 +385,9 @@ def _prepare_call(self, stub, rpc_name, request, settings):
)
_set_server_timeouts(request, settings, timeout)
self._prepare_stub_instance(stub)
rpc_state = _RpcState(self._stub_instances[stub], rpc_name, self.endpoint)
rpc_state = _RpcState(
self._stub_instances[stub], rpc_name, self.endpoint, self.endpoint_key
)
logger.debug("%s: creating call state", rpc_state)
with self.lock:
if self.closing:
Expand Down
13 changes: 11 additions & 2 deletions ydb/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, use_all_nodes=False, tracer=tracing.Tracer(None)):
self.tracer = tracer
self.lock = threading.RLock()
self.connections = collections.OrderedDict()
self.connections_by_node_id = collections.OrderedDict()
self.outdated = collections.OrderedDict()
self.subscriptions = set()
self.preferred = collections.OrderedDict()
Expand All @@ -39,6 +40,8 @@ def add(self, connection, preferred=False):
with self.lock:
if preferred:
self.preferred[connection.endpoint] = connection

self.connections_by_node_id[connection.node_id] = connection
self.connections[connection.endpoint] = connection
subscriptions = list(self.subscriptions)
self.subscriptions.clear()
Expand Down Expand Up @@ -128,9 +131,14 @@ def get(self, preferred_endpoint=None):
with self.lock:
if (
preferred_endpoint is not None
and preferred_endpoint in self.connections
and preferred_endpoint.node_id in self.connections_by_node_id
):
return self.connections_by_node_id[preferred_endpoint.node_id]

if (
preferred_endpoint is not None
and preferred_endpoint.endpoint in self.connections
):
tracing.trace(self.tracer, {"found_preferred_endpoint": True})
return self.connections[preferred_endpoint]

for conn_lst in self.conn_lst_order:
Expand All @@ -146,6 +154,7 @@ def get(self, preferred_endpoint=None):

def remove(self, connection):
with self.lock:
self.connections_by_node_id.pop(connection.node_id, None)
self.preferred.pop(connection.endpoint, None)
self.connections.pop(connection.endpoint, None)
self.outdated.pop(connection.endpoint, None)
Expand Down
18 changes: 10 additions & 8 deletions ydb/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class EndpointInfo(object):
"ipv4_addrs",
"ipv6_addrs",
"ssl_target_name_override",
"node_id",
)

def __init__(self, endpoint_info):
Expand All @@ -30,19 +31,20 @@ def __init__(self, endpoint_info):
self.ipv4_addrs = tuple(endpoint_info.ip_v4)
self.ipv6_addrs = tuple(endpoint_info.ip_v6)
self.ssl_target_name_override = endpoint_info.ssl_target_name_override
self.node_id = endpoint_info.node_id

def endpoints_with_options(self):
ssl_target_name_override = None
if self.ssl:
if self.ssl_target_name_override:
endpoint_options = conn_impl.EndpointOptions(
self.ssl_target_name_override
)
ssl_target_name_override = self.ssl_target_name_override
elif self.ipv6_addrs or self.ipv4_addrs:
endpoint_options = conn_impl.EndpointOptions(self.address)
else:
endpoint_options = None
else:
endpoint_options = None
ssl_target_name_override = self.address

endpoint_options = conn_impl.EndpointOptions(
ssl_target_name_override=ssl_target_name_override, node_id=self.node_id
)

if self.ipv6_addrs or self.ipv4_addrs:
for ipv6addr in self.ipv6_addrs:
yield ("ipv6:[%s]:%s" % (ipv6addr, self.port), endpoint_options)
Expand Down
2 changes: 1 addition & 1 deletion ydb/ydb_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = "2.1.0"
VERSION = "2.2.0"