diff --git a/google/cloud/bigquery_storage_v1/writer.py b/google/cloud/bigquery_storage_v1/writer.py index a8c447bb..3302ddf5 100644 --- a/google/cloud/bigquery_storage_v1/writer.py +++ b/google/cloud/bigquery_storage_v1/writer.py @@ -91,6 +91,11 @@ def __init__( self._inital_request_template = initial_request_template self._metadata = metadata + # if self._hibernating == True, this means the connection was closed by + # the server. The stream will try to reconnect if the message queue is + # non-empty, or when the customer sends another request. + self._hibernating = False + # Only one call to `send()` should attempt to open the RPC. self._opening = threading.Lock() @@ -311,6 +316,7 @@ def _shutdown(self, reason: Optional[Exception] = None): The reason to close the stream. If ``None``, this is considered an "intentional" shutdown. """ + # breakpoint() with self._closing: if self._closed: return @@ -344,6 +350,43 @@ def _shutdown(self, reason: Optional[Exception] = None): for callback in self._close_callbacks: callback(self, reason) + def _hibernate(self, reason: Optional[Exception] = None): + # If the connection is shut down by the server for retriable reasons, + # such as idle connection, we shut down the grpc connection and delete + # the consumer. However, we preserve futures queue, and if the queue is + # not empty, or if there is a new message to be sent, it tries to create + # a new gRPC connection and corresponding consumer to resume the process. + # breakpoint() + + # Stop consumer + if self.is_active: + _LOGGER.debug("Stopping consumer.") + self._consumer.stop() + self._consumer = None + + # Close RPC connection + if self._rpc is not None: + self._rpc.close() + # self._closed = True + _LOGGER.debug("Finished stopping manager.") + + # Register error on the future corresponding to this error message + future = self._futures_queue.get_nowait() + future.set_exception(reason) + + # Mark self._hibernating as True for future reopening + self._hibernating = True + + return + + def _shutdown_or_hibernate(self, reason: Optional[Exception] = None): + # Hibernate if a retriable error is received, otherwise, shut down + # completely. + if isinstance(reason, exceptions.Aborted): + self._hibernate(reason) + else: + self._shutdown(reason) + def _on_rpc_done(self, future): """Triggered whenever the underlying RPC terminates without recovery. @@ -358,7 +401,7 @@ def _on_rpc_done(self, future): _LOGGER.info("RPC termination has signaled streaming pull manager shutdown.") error = _wrap_as_exception(future) thread = threading.Thread( - name=_RPC_ERROR_THREAD_NAME, target=self._shutdown, kwargs={"reason": error} + name=_RPC_ERROR_THREAD_NAME, target=self._shutdown_or_hibernate, kwargs={"reason": error} ) thread.daemon = True thread.start()