diff --git a/nvflare/apis/utils/reliable_message.py b/nvflare/apis/utils/reliable_message.py index e880a6c23b..c046d459bc 100644 --- a/nvflare/apis/utils/reliable_message.py +++ b/nvflare/apis/utils/reliable_message.py @@ -164,7 +164,10 @@ def _try_reply(self, fl_ctx: FLContext): if rc == ReturnCode.OK: # reply sent successfully! self.reply_time = time.time() - ReliableMessage.info(fl_ctx, f"sent reply successfully in {time_spent} secs") + ReliableMessage.debug(fl_ctx, f"sent reply successfully in {time_spent} secs") + + # release the receiver kept by the ReliableMessage! + ReliableMessage.release_request_receiver(self, fl_ctx) else: ReliableMessage.error( fl_ctx, f"failed to send reply in {time_spent} secs: {rc=}; will wait for requester to query" @@ -279,7 +282,7 @@ def _receive_request(cls, topic: str, request: Shareable, fl_ctx: FLContext): elif op == OP_QUERY: receiver = cls._req_receivers.get(tx_id) if not receiver: - cls.error(fl_ctx, f"received query but the request ({rm_topic=}) is not received!") + cls.error(fl_ctx, f"received query but the request ({rm_topic=}) is not received or already done!") return _status_reply(STATUS_NOT_RECEIVED) # meaning the request wasn't received else: return receiver.process(request, fl_ctx) @@ -301,6 +304,22 @@ def _receive_reply(cls, topic: str, request: Shareable, fl_ctx: FLContext): receiver.process(request) return make_reply(ReturnCode.OK) + @classmethod + def release_request_receiver(cls, receiver: _RequestReceiver, fl_ctx: FLContext): + """Release the specified _RequestReceiver from the receiver table. + This is to be called after the received request is finished. + + Args: + receiver: the _RequestReceiver to be released + fl_ctx: the FL Context + + Returns: None + + """ + with cls._tx_lock: + cls._req_receivers.pop(receiver.tx_id, None) + cls.debug(fl_ctx, f"released request receiver of TX {receiver.tx_id}") + @classmethod def enable(cls, fl_ctx: FLContext): """Enable ReliableMessage. This method can be called multiple times, but only the 1st call has effect. @@ -345,7 +364,7 @@ def _monitor_req_receivers(cls): now = time.time() for tx_id, receiver in cls._req_receivers.items(): assert isinstance(receiver, _RequestReceiver) - if receiver.rcv_time and now - receiver.rcv_time > 4 * receiver.tx_timeout: + if receiver.rcv_time and now - receiver.rcv_time > receiver.tx_timeout: cls._logger.info(f"detected expired request receiver {tx_id}") expired_receivers.append(tx_id)