Skip to content

Commit

Permalink
Close the socket before UDP retries
Browse files Browse the repository at this point in the history
LAN kit module seems to be extremely unstable when using the same socket (UDP source port). When keep_alive is off (now default) every request, incl. re-tries should be done in separate socket (ephemeral source port).
  • Loading branch information
mletenay committed Jun 16, 2024
1 parent 2ac7d3f commit e1f7869
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 79 deletions.
93 changes: 48 additions & 45 deletions goodwe/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, host: str, port: int, comm_addr: int, timeout: int, retries:
self._timer: asyncio.TimerHandle | None = None
self.timeout: int = timeout
self.retries: int = retries
self.keep_alive: bool = True
self.keep_alive: bool = False
self.protocol: asyncio.Protocol | None = None
self.response_future: Future | None = None
self.command: ProtocolCommand | None = None
Expand All @@ -62,6 +62,24 @@ def _ensure_lock(self) -> asyncio.Lock:
self._close_transport()
return self._lock

def _max_retries_reached(self) -> Future:
logger.debug("Max number of retries (%d) reached, request %s failed.", self.retries, self.command)
self._close_transport()
self.response_future = asyncio.get_running_loop().create_future()
self.response_future.set_exception(MaxRetriesException)
return self.response_future

def _close_transport(self) -> None:
if self._transport:
try:
self._transport.close()
except RuntimeError:
logger.debug("Failed to close transport.")
self._transport = None
# Cancel Future on connection lost
if self.response_future and not self.response_future.done():
self.response_future.cancel()

async def close(self) -> None:
"""Close the underlying transport/connection."""
raise NotImplementedError()
Expand Down Expand Up @@ -133,15 +151,16 @@ def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None:
self._partial_missing = 0
if self.command.validator(data):
logger.debug("Received: %s", data.hex())
self._retry = 0
self.response_future.set_result(data)
else:
logger.debug("Received invalid response: %s", data.hex())
asyncio.get_running_loop().call_soon(self._retry_mechanism)
asyncio.get_running_loop().call_soon(self._timeout_mechanism)
except PartialResponseException as ex:
logger.debug("Received response fragment (%d of %d): %s", ex.length, ex.expected, data.hex())
self._partial_data = data
self._partial_missing = ex.expected - ex.length
self._timer = asyncio.get_running_loop().call_later(self.timeout, self._retry_mechanism)
self._timer = asyncio.get_running_loop().call_later(self.timeout, self._timeout_mechanism)
except asyncio.InvalidStateError:
logger.debug("Response already handled: %s", data.hex())
except RequestRejectedException as ex:
Expand All @@ -158,13 +177,28 @@ def error_received(self, exc: Exception) -> None:

async def send_request(self, command: ProtocolCommand) -> Future:
"""Send message via transport"""
async with self._ensure_lock():
await self._ensure_lock().acquire()
try:
await self._connect()
response_future = asyncio.get_running_loop().create_future()
self._retry = 0
self._send_request(command, response_future)
await response_future
return response_future
except asyncio.CancelledError:
if self._retry < self.retries:
self._retry += 1
if self._lock and self._lock.locked():
self._lock.release()
if not self.keep_alive:
self._close_transport()
return await self.send_request(command)
else:
return self._max_retries_reached()
finally:
if self._lock and self._lock.locked():
self._lock.release()
if not self.keep_alive:
self._close_transport()

def _send_request(self, command: ProtocolCommand, response_future: Future) -> None:
"""Send message via transport"""
Expand All @@ -178,32 +212,19 @@ def _send_request(self, command: ProtocolCommand, response_future: Future) -> No
else:
logger.debug("Sending: %s", self.command)
self._transport.sendto(payload)
self._timer = asyncio.get_running_loop().call_later(self.timeout, self._retry_mechanism)
self._timer = asyncio.get_running_loop().call_later(self.timeout, self._timeout_mechanism)

def _retry_mechanism(self) -> None:
"""Retry mechanism to prevent hanging transport"""
if self.response_future.done():
def _timeout_mechanism(self) -> None:
"""Timeout mechanism to prevent hanging transport"""
if self.response_future and self.response_future.done():
logger.debug("Response already received.")
elif self._retry < self.retries:
self._retry = 0
else:
if self._timer:
logger.debug("Failed to receive response to %s in time (%ds).", self.command, self.timeout)
self._retry += 1
self._send_request(self.command, self.response_future)
else:
logger.debug("Max number of retries (%d) reached, request %s failed.", self.retries, self.command)
self.response_future.set_exception(MaxRetriesException)
self._close_transport()

def _close_transport(self) -> None:
if self._transport:
try:
self._transport.close()
except RuntimeError:
logger.debug("Failed to close transport.")
self._transport = None
# Cancel Future on connection close
if self.response_future and not self.response_future.done():
self.response_future.cancel()
self._timer = None
if self.response_future and not self.response_future.done():
self.response_future.cancel()

async def close(self):
self._close_transport()
Expand Down Expand Up @@ -358,24 +379,6 @@ def _timeout_mechanism(self) -> None:
self._timer = None
self._close_transport()

def _max_retries_reached(self) -> Future:
logger.debug("Max number of retries (%d) reached, request %s failed.", self.retries, self.command)
self._close_transport()
self.response_future = asyncio.get_running_loop().create_future()
self.response_future.set_exception(MaxRetriesException)
return self.response_future

def _close_transport(self) -> None:
if self._transport:
try:
self._transport.close()
except RuntimeError:
logger.debug("Failed to close transport.")
self._transport = None
# Cancel Future on connection lost
if self.response_future and not self.response_future.done():
self.response_future.cancel()

async def close(self):
await self._ensure_lock().acquire()
try:
Expand Down
68 changes: 34 additions & 34 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ def test_connection_made(self, mock_get_event_loop):
mock_loop = mock.Mock()
mock_get_event_loop.return_value = mock_loop

mock_retry_mechanism = mock.Mock()
self.protocol._retry_mechanism = mock_retry_mechanism
mock_timeout_mechanism = mock.Mock()
self.protocol._timeout_mechanism = mock_timeout_mechanism
self.protocol.connection_made(transport)
self.protocol._send_request(self.protocol.command, self.protocol.response_future)

transport.sendto.assert_called_with(self.protocol.command.request)
mock_get_event_loop.assert_called()
mock_loop.call_later.assert_called_with(1, mock_retry_mechanism)
mock_loop.call_later.assert_called_with(1, mock_timeout_mechanism)

def test_connection_lost(self):
self.protocol.response_future.done.return_value = True
Expand All @@ -59,41 +59,41 @@ def test_retry_mechanism(self):
self.protocol._transport = mock.Mock()
self.protocol._send_request = mock.Mock()
self.protocol.response_future.done.return_value = True
self.protocol._retry_mechanism()
self.protocol._timeout_mechanism()

# self.protocol._transport.close.assert_called()
self.protocol._send_request.assert_not_called()

@mock.patch('goodwe.protocol.asyncio.get_running_loop')
def test_retry_mechanism_two_retries(self, mock_get_event_loop):
def call_later(_: int, retry_func: Callable):
retry_func()

mock_loop = mock.Mock()
mock_get_event_loop.return_value = mock_loop
mock_loop.call_later = call_later

self.protocol._transport = mock.Mock()
self.protocol.response_future.done.side_effect = [False, False, True, False]
self.protocol._retry_mechanism()

# self.protocol._transport.close.assert_called()
self.assertEqual(self.protocol._retry, 2)

@mock.patch('goodwe.protocol.asyncio.get_running_loop')
def test_retry_mechanism_max_retries(self, mock_get_event_loop):
def call_later(_: int, retry_func: Callable):
retry_func()

mock_loop = mock.Mock()
mock_get_event_loop.return_value = mock_loop
mock_loop.call_later = call_later

self.protocol._transport = mock.Mock()
self.protocol.response_future.done.side_effect = [False, False, False, False, False]
self.protocol._retry_mechanism()
self.protocol.response_future.set_exception.assert_called_once_with(MaxRetriesException)
self.assertEqual(self.protocol._retry, 3)
# @mock.patch('goodwe.protocol.asyncio.get_running_loop')
# def test_retry_mechanism_two_retries(self, mock_get_event_loop):
# def call_later(_: int, retry_func: Callable):
# retry_func()
#
# mock_loop = mock.Mock()
# mock_get_event_loop.return_value = mock_loop
# mock_loop.call_later = call_later
#
# self.protocol._transport = mock.Mock()
# self.protocol.response_future.done.side_effect = [False, False, True, False]
# self.protocol._timeout_mechanism()
#
# # self.protocol._transport.close.assert_called()
# self.assertEqual(self.protocol._retry, 2)

# @mock.patch('goodwe.protocol.asyncio.get_running_loop')
# def test_retry_mechanism_max_retries(self, mock_get_event_loop):
# def call_later(_: int, retry_func: Callable):
# retry_func()
#
# mock_loop = mock.Mock()
# mock_get_event_loop.return_value = mock_loop
# mock_loop.call_later = call_later
#
# self.protocol._transport = mock.Mock()
# self.protocol.response_future.done.side_effect = [False, False, False, False, False]
# self.protocol._timeout_mechanism()
# self.protocol.response_future.set_exception.assert_called_once_with(MaxRetriesException)
# self.assertEqual(self.protocol._retry, 3)

def test_modbus_rtu_read_command(self):
command = ModbusRtuReadCommand(0xf7, 0x88b8, 0x0021)
Expand Down

0 comments on commit e1f7869

Please sign in to comment.