From de60905b8b264254643e6bbc30bb8dab944ddbb7 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Mon, 23 Dec 2024 17:41:03 +0300 Subject: [PATCH] Fix auth credentials --- ydb/_topic_reader/topic_reader_asyncio.py | 5 ++++- ydb/_topic_writer/topic_writer_asyncio.py | 5 ++++- ydb/aio/credentials.py | 15 ++++++++++++--- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 6833492d..351efb9a 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -516,7 +516,10 @@ async def _read_messages_loop(self): async def _update_token_loop(self): while True: await asyncio.sleep(self._update_token_interval) - await self._update_token(token=self._get_token_function()) + token = self._get_token_function() + if asyncio.iscoroutine(token): + token = await token + await self._update_token(token=token) async def _update_token(self, token: str): await self._update_token_event.wait() diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index d759072c..869808f7 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -686,7 +686,10 @@ def write(self, messages: List[InternalMessage]): async def _update_token_loop(self): while True: await asyncio.sleep(self._update_token_interval) - await self._update_token(token=self._get_token_function()) + token = self._get_token_function() + if asyncio.iscoroutine(token): + token = await token + await self._update_token(token=token) async def _update_token(self, token: str): await self._update_token_event.wait() diff --git a/ydb/aio/credentials.py b/ydb/aio/credentials.py index 08db1fd0..5e581c01 100644 --- a/ydb/aio/credentials.py +++ b/ydb/aio/credentials.py @@ -1,11 +1,14 @@ -import time - import abc import asyncio import logging -from ydb import issues, credentials +import time + +from ydb import credentials +from ydb import issues logger = logging.getLogger(__name__) +YDB_AUTH_TICKET_HEADER = "x-ydb-auth-ticket" + class _OneToManyValue(object): @@ -64,6 +67,12 @@ def __init__(self): async def _make_token_request(self): pass + async def get_auth_token(self) -> str: + for header, token in await self.auth_metadata(): + if header == YDB_AUTH_TICKET_HEADER: + return token + return "" + async def _refresh(self): current_time = time.time() self._log_refresh_start(current_time)