Skip to content

Commit

Permalink
Merge pull request #19 from ydb-platform/support-node-id
Browse files Browse the repository at this point in the history
allow to refer endpoints by node id
  • Loading branch information
gridnevvvit authored Mar 23, 2022
2 parents 862505a + 8cb1d72 commit 792d8d7
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 20 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 2.2.0 ##

* allow to refer endpoints by node id

## 2.1.0 ##

* add compression support to ydb sdk
Expand Down
2 changes: 1 addition & 1 deletion ydb/_session_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def initialize_session(rpc_state, response_pb, session_state, session):
issues._process_response(response_pb.operation)
message = _apis.ydb_table.CreateSessionResult()
response_pb.operation.result.Unpack(message)
session_state.set_id(message.session_id).attach_endpoint(rpc_state.endpoint)
session_state.set_id(message.session_id).attach_endpoint(rpc_state.endpoint_key)
return session


Expand Down
15 changes: 12 additions & 3 deletions ydb/aio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
YDB_DATABASE_HEADER,
YDB_TRACE_ID_HEADER,
YDB_REQUEST_TYPE_HEADER,
EndpointKey,
)
from ydb.driver import DriverConfig
from ydb.settings import BaseRequestSettings
Expand Down Expand Up @@ -71,8 +72,8 @@ class _RpcState(RpcState):
"_trailing_metadata",
)

def __init__(self, stub_instance: Any, rpc_name: str, endpoint: str):
super().__init__(stub_instance, rpc_name, endpoint)
def __init__(self, stub_instance: Any, rpc_name: str, endpoint: str, endpoint_key):
super().__init__(stub_instance, rpc_name, endpoint, endpoint_key)

async def __call__(self, *args, **kwargs):
resp = self.rpc(*args, **kwargs)
Expand Down Expand Up @@ -105,6 +106,8 @@ class Connection:
"lock",
"calls",
"closing",
"endpoint_key",
"node_id",
)

def __init__(
Expand All @@ -115,6 +118,10 @@ def __init__(
):
global _stubs_list
self.endpoint = endpoint
self.endpoint_key = EndpointKey(
self.endpoint, getattr(endpoint_options, "node_id", None)
)
self.node_id = getattr(endpoint_options, "node_id", None)
self._channel = channel_factory(
self.endpoint, driver_config, grpc.aio, endpoint_options=endpoint_options
)
Expand All @@ -141,7 +148,9 @@ async def _prepare_call(
)
_set_server_timeouts(request, settings, timeout)
self._prepare_stub_instance(stub)
rpc_state = _RpcState(self._stub_instances[stub], rpc_name, self.endpoint)
rpc_state = _RpcState(
self._stub_instances[stub], rpc_name, self.endpoint, self.endpoint_key
)
logger.debug("%s: creating call state", rpc_state)

if self.closing:
Expand Down
14 changes: 13 additions & 1 deletion ydb/aio/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@ async def get(self, preferred_endpoint=None, fast_fail=False, wait_timeout=10):
else:
await asyncio.wait_for(self._event.wait(), timeout=wait_timeout)

if preferred_endpoint is not None and preferred_endpoint in self.connections:
if (
preferred_endpoint is not None
and preferred_endpoint.node_id in self.connections_by_node_id
):
return self.connections_by_node_id[preferred_endpoint.node_id]

if (
preferred_endpoint is not None
and preferred_endpoint.endpoint in self.connections
):
return self.connections[preferred_endpoint]

for conn_lst in self.conn_lst_order:
Expand All @@ -52,6 +61,8 @@ def add(self, connection, preferred=False):

if preferred:
self.preferred[connection.endpoint] = connection

self.connections_by_node_id[connection.node_id] = connection
self.connections[connection.endpoint] = connection

self._event.set()
Expand All @@ -66,6 +77,7 @@ def complete_discovery(self, error):
self._fast_fail_event.set()

def remove(self, connection):
self.connections_by_node_id.pop(connection.node_id, None)
self.preferred.pop(connection.endpoint, None)
self.connections.pop(connection.endpoint, None)
self.outdated.pop(connection.endpoint, None)
Expand Down
27 changes: 23 additions & 4 deletions ydb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,11 @@ def _get_request_timeout(settings):


class EndpointOptions(object):
__slots__ = ("ssl_target_name_override",)
__slots__ = ("ssl_target_name_override", "node_id")

def __init__(self, ssl_target_name_override=None):
def __init__(self, ssl_target_name_override=None, node_id=None):
self.ssl_target_name_override = ssl_target_name_override
self.node_id = node_id


def _construct_channel_options(driver_config, endpoint_options=None):
Expand Down Expand Up @@ -223,16 +224,18 @@ class _RpcState(object):
"endpoint",
"rendezvous",
"metadata_kv",
"endpoint_key",
)

def __init__(self, stub_instance, rpc_name, endpoint):
def __init__(self, stub_instance, rpc_name, endpoint, endpoint_key):
"""Stores all RPC related data"""
self.rpc_name = rpc_name
self.rpc = getattr(stub_instance, rpc_name)
self.request_id = uuid.uuid4()
self.endpoint = endpoint
self.rendezvous = None
self.metadata_kv = None
self.endpoint_key = endpoint_key

def __str__(self):
return "RpcState(%s, %s, %s)" % (self.rpc_name, self.request_id, self.endpoint)
Expand Down Expand Up @@ -318,6 +321,14 @@ def channel_factory(
)


class EndpointKey(object):
__slots__ = ("endpoint", "node_id")

def __init__(self, endpoint, node_id):
self.endpoint = endpoint
self.node_id = node_id


class Connection(object):
__slots__ = (
"endpoint",
Expand All @@ -330,6 +341,8 @@ class Connection(object):
"lock",
"calls",
"closing",
"endpoint_key",
"node_id",
)

def __init__(self, endpoint, driver_config=None, endpoint_options=None):
Expand All @@ -341,6 +354,10 @@ def __init__(self, endpoint, driver_config=None, endpoint_options=None):
"""
global _stubs_list
self.endpoint = endpoint
self.node_id = getattr(endpoint_options, "node_id", None)
self.endpoint_key = EndpointKey(
endpoint, getattr(endpoint_options, "node_id", None)
)
self._channel = channel_factory(
self.endpoint, driver_config, endpoint_options=endpoint_options
)
Expand Down Expand Up @@ -368,7 +385,9 @@ def _prepare_call(self, stub, rpc_name, request, settings):
)
_set_server_timeouts(request, settings, timeout)
self._prepare_stub_instance(stub)
rpc_state = _RpcState(self._stub_instances[stub], rpc_name, self.endpoint)
rpc_state = _RpcState(
self._stub_instances[stub], rpc_name, self.endpoint, self.endpoint_key
)
logger.debug("%s: creating call state", rpc_state)
with self.lock:
if self.closing:
Expand Down
13 changes: 11 additions & 2 deletions ydb/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, use_all_nodes=False, tracer=tracing.Tracer(None)):
self.tracer = tracer
self.lock = threading.RLock()
self.connections = collections.OrderedDict()
self.connections_by_node_id = collections.OrderedDict()
self.outdated = collections.OrderedDict()
self.subscriptions = set()
self.preferred = collections.OrderedDict()
Expand All @@ -39,6 +40,8 @@ def add(self, connection, preferred=False):
with self.lock:
if preferred:
self.preferred[connection.endpoint] = connection

self.connections_by_node_id[connection.node_id] = connection
self.connections[connection.endpoint] = connection
subscriptions = list(self.subscriptions)
self.subscriptions.clear()
Expand Down Expand Up @@ -128,9 +131,14 @@ def get(self, preferred_endpoint=None):
with self.lock:
if (
preferred_endpoint is not None
and preferred_endpoint in self.connections
and preferred_endpoint.node_id in self.connections_by_node_id
):
return self.connections_by_node_id[preferred_endpoint.node_id]

if (
preferred_endpoint is not None
and preferred_endpoint.endpoint in self.connections
):
tracing.trace(self.tracer, {"found_preferred_endpoint": True})
return self.connections[preferred_endpoint]

for conn_lst in self.conn_lst_order:
Expand All @@ -146,6 +154,7 @@ def get(self, preferred_endpoint=None):

def remove(self, connection):
with self.lock:
self.connections_by_node_id.pop(connection.node_id, None)
self.preferred.pop(connection.endpoint, None)
self.connections.pop(connection.endpoint, None)
self.outdated.pop(connection.endpoint, None)
Expand Down
18 changes: 10 additions & 8 deletions ydb/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class EndpointInfo(object):
"ipv4_addrs",
"ipv6_addrs",
"ssl_target_name_override",
"node_id",
)

def __init__(self, endpoint_info):
Expand All @@ -30,19 +31,20 @@ def __init__(self, endpoint_info):
self.ipv4_addrs = tuple(endpoint_info.ip_v4)
self.ipv6_addrs = tuple(endpoint_info.ip_v6)
self.ssl_target_name_override = endpoint_info.ssl_target_name_override
self.node_id = endpoint_info.node_id

def endpoints_with_options(self):
ssl_target_name_override = None
if self.ssl:
if self.ssl_target_name_override:
endpoint_options = conn_impl.EndpointOptions(
self.ssl_target_name_override
)
ssl_target_name_override = self.ssl_target_name_override
elif self.ipv6_addrs or self.ipv4_addrs:
endpoint_options = conn_impl.EndpointOptions(self.address)
else:
endpoint_options = None
else:
endpoint_options = None
ssl_target_name_override = self.address

endpoint_options = conn_impl.EndpointOptions(
ssl_target_name_override=ssl_target_name_override, node_id=self.node_id
)

if self.ipv6_addrs or self.ipv4_addrs:
for ipv6addr in self.ipv6_addrs:
yield ("ipv6:[%s]:%s" % (ipv6addr, self.port), endpoint_options)
Expand Down
2 changes: 1 addition & 1 deletion ydb/ydb_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = "2.1.0"
VERSION = "2.2.0"

0 comments on commit 792d8d7

Please sign in to comment.