Skip to content

Commit

Permalink
add option to resend queued message
Browse files Browse the repository at this point in the history
  • Loading branch information
Linchin committed Jan 17, 2025
1 parent 4300511 commit f15033b
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions google/cloud/bigquery_storage_v1/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
client: big_query_write.BigQueryWriteClient,
initial_request_template: gapic_types.AppendRowsRequest,
metadata: Sequence[Tuple[str, str]] = (),
resend: bool = True, # resend failed message
):
"""Construct a stream manager.
Expand All @@ -87,9 +88,10 @@ def __init__(
self._closing = threading.Lock()
self._closed = False
self._close_callbacks = []
self._futures_queue = queue.Queue()
self._queue = queue.Queue()
self._inital_request_template = initial_request_template
self._metadata = metadata
self._resend = resend

# if self._hibernating == True, this means the connection was closed by
# the server. The stream will try to reconnect if the message queue is
Expand Down Expand Up @@ -175,7 +177,7 @@ def _open(
request.trace_id = f"python-writer:{package_version.__version__}"

inital_response_future = AppendRowsFuture(self)
self._futures_queue.put(inital_response_future)
self._queue.put((request, inital_response_future))

self._rpc = bidi.BidiRpc(
self._client.append_rows,
Expand Down Expand Up @@ -271,7 +273,7 @@ def send(self, request: gapic_types.AppendRowsRequest) -> "AppendRowsFuture":
# future to the queue so that when the response comes, the callback can
# pull it off and notify completion.
future = AppendRowsFuture(self)
self._futures_queue.put(future)
self._queue.put((request, future))
self._rpc.send(request)
return future

Expand All @@ -287,7 +289,7 @@ def _on_response(self, response: gapic_types.AppendRowsResponse):

# Since we have 1 response per request, if we get here from a response
# callback, the queue should never be empty.
future: AppendRowsFuture = self._futures_queue.get_nowait()
future: AppendRowsFuture = self._queue.get_nowait()[1]
if response.error.code:
exc = exceptions.from_grpc_status(
response.error.code, response.error.message, response=response
Expand Down Expand Up @@ -316,7 +318,6 @@ def _shutdown(self, reason: Optional[Exception] = None):
The reason to close the stream. If ``None``, this is
considered an "intentional" shutdown.
"""
# breakpoint()
with self._closing:
if self._closed:
return
Expand All @@ -334,11 +335,11 @@ def _shutdown(self, reason: Optional[Exception] = None):

# We know that no new items will be added to the queue because
# we've marked the stream as closed.
while not self._futures_queue.empty():
while not self._queue.empty():
# Mark each future as failed. Since the consumer thread has
# stopped (or at least is attempting to stop), we won't get
# response callbacks to populate the remaining futures.
future = self._futures_queue.get_nowait()
future = self._queue.get_nowait()[1]
if reason is None:
exc = bqstorage_exceptions.StreamClosedError(
"Stream closed before receiving a response."
Expand Down Expand Up @@ -371,19 +372,36 @@ def _hibernate(self, reason: Optional[Exception] = None):
_LOGGER.debug("Finished stopping manager.")

# Register error on the future corresponding to this error message
future = self._futures_queue.get_nowait()
future.set_exception(reason)
if not self._resend:
future = self._queue.get_nowait()[1]
future.set_exception(reason)

# Mark self._hibernating as True for future reopening
self._hibernating = True

return

def _shutdown_or_hibernate(self, reason: Optional[Exception] = None):
def _retry(self):
new_queue = queue.Queue()
self._hibernating = False

# Resend each request remaining in the queue, and create a new queue
# with the new futures
while not self._queue.empty():
request, future = self._queue.get_nowait()
new_future = self.send(request)
new_queue.put((request, new_future))

self._queue = new_queue
return

def _shutdown_or_hibernate_or_retry(self, reason: Optional[Exception] = None):
# Hibernate if a retriable error is received, otherwise, shut down
# completely.
if isinstance(reason, exceptions.Aborted):
self._hibernate(reason)
if self._resend:
self._retry()
else:
self._shutdown(reason)

Expand All @@ -401,7 +419,9 @@ def _on_rpc_done(self, future):
_LOGGER.info("RPC termination has signaled streaming pull manager shutdown.")
error = _wrap_as_exception(future)
thread = threading.Thread(
name=_RPC_ERROR_THREAD_NAME, target=self._shutdown_or_hibernate, kwargs={"reason": error}
name=_RPC_ERROR_THREAD_NAME,
target=self._shutdown_or_hibernate_or_retry,
kwargs={"reason": error},
)
thread.daemon = True
thread.start()
Expand Down

0 comments on commit f15033b

Please sign in to comment.