Skip to content

Commit

Permalink
feat: reconnect write stream if disconnected by server
Browse files Browse the repository at this point in the history
  • Loading branch information
Linchin committed Jan 17, 2025
1 parent 72d5611 commit 4300511
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion google/cloud/bigquery_storage_v1/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def __init__(
self._inital_request_template = initial_request_template
self._metadata = metadata

# if self._hibernating == True, this means the connection was closed by
# the server. The stream will try to reconnect if the message queue is
# non-empty, or when the customer sends another request.
self._hibernating = False

# Only one call to `send()` should attempt to open the RPC.
self._opening = threading.Lock()

Expand Down Expand Up @@ -311,6 +316,7 @@ def _shutdown(self, reason: Optional[Exception] = None):
The reason to close the stream. If ``None``, this is
considered an "intentional" shutdown.
"""
# breakpoint()
with self._closing:
if self._closed:
return
Expand Down Expand Up @@ -344,6 +350,43 @@ def _shutdown(self, reason: Optional[Exception] = None):
for callback in self._close_callbacks:
callback(self, reason)

def _hibernate(self, reason: Optional[Exception] = None):
# If the connection is shut down by the server for retriable reasons,
# such as idle connection, we shut down the grpc connection and delete
# the consumer. However, we preserve futures queue, and if the queue is
# not empty, or if there is a new message to be sent, it tries to create
# a new gRPC connection and corresponding consumer to resume the process.
# breakpoint()

# Stop consumer
if self.is_active:
_LOGGER.debug("Stopping consumer.")
self._consumer.stop()
self._consumer = None

# Close RPC connection
if self._rpc is not None:
self._rpc.close()
# self._closed = True
_LOGGER.debug("Finished stopping manager.")

# Register error on the future corresponding to this error message
future = self._futures_queue.get_nowait()
future.set_exception(reason)

# Mark self._hibernating as True for future reopening
self._hibernating = True

return

def _shutdown_or_hibernate(self, reason: Optional[Exception] = None):
# Hibernate if a retriable error is received, otherwise, shut down
# completely.
if isinstance(reason, exceptions.Aborted):
self._hibernate(reason)
else:
self._shutdown(reason)

def _on_rpc_done(self, future):
"""Triggered whenever the underlying RPC terminates without recovery.
Expand All @@ -358,7 +401,7 @@ def _on_rpc_done(self, future):
_LOGGER.info("RPC termination has signaled streaming pull manager shutdown.")
error = _wrap_as_exception(future)
thread = threading.Thread(
name=_RPC_ERROR_THREAD_NAME, target=self._shutdown, kwargs={"reason": error}
name=_RPC_ERROR_THREAD_NAME, target=self._shutdown_or_hibernate, kwargs={"reason": error}
)
thread.daemon = True
thread.start()
Expand Down

0 comments on commit 4300511

Please sign in to comment.