diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index c7af92de..c960cb5a 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -155,7 +155,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> return await self._loop.run_in_executor(None, connect_partial) except Exception: # we attempt a force refresh, then throw the error - instance.force_refresh() + await instance.force_refresh() raise def __enter__(self) -> "Connector": diff --git a/google/cloud/alloydb/connector/instance.py b/google/cloud/alloydb/connector/instance.py index dec485dd..62b686d5 100644 --- a/google/cloud/alloydb/connector/instance.py +++ b/google/cloud/alloydb/connector/instance.py @@ -187,7 +187,7 @@ async def _refresh_operation(self, delay: int) -> RefreshResult: return refresh_result - def force_refresh(self) -> None: + async def force_refresh(self) -> None: """ Schedules a new refresh operation immediately to be used for future connection attempts. @@ -196,8 +196,9 @@ def force_refresh(self) -> None: if not self._refresh_in_progress.is_set(): self._next.cancel() self._next = self._schedule_refresh(0) - # block all sequential connection attempts on the next refresh result - self._current = self._next + # block all sequential connection attempts on the next refresh result if current is invalid + if not await _is_valid(self._current): + self._current = self._next async def connection_info(self) -> Tuple[str, ssl.SSLContext]: """ diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 4cceb49b..58b0bec3 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -195,7 +195,7 @@ async def test_force_refresh_cancels_pending_refresh() -> None: # shouldn't be set pending_refresh = instance._next assert instance._refresh_in_progress.is_set() is False - instance.force_refresh() + await instance.force_refresh() # pending_refresh has to be awaited for it to raised as cancelled with pytest.raises(asyncio.CancelledError): assert await pending_refresh