diff --git a/snapcast/control/__init__.py b/snapcast/control/__init__.py index 8aa5181..d838848 100644 --- a/snapcast/control/__init__.py +++ b/snapcast/control/__init__.py @@ -5,8 +5,8 @@ @asyncio.coroutine -def create_server(loop, host, port=CONTROL_PORT): +def create_server(loop, host, port=CONTROL_PORT, reconnect=False): """Server factory.""" - server = Snapserver(loop, host, port) + server = Snapserver(loop, host, port, reconnect) yield from server.start() return server diff --git a/snapcast/control/server.py b/snapcast/control/server.py index 0d230ad..7a476cc 100755 --- a/snapcast/control/server.py +++ b/snapcast/control/server.py @@ -37,6 +37,8 @@ STREAM_ONUPDATE = 'Stream.OnUpdate' +SERVER_RECONNECT_DELAY = 5 + _EVENTS = [SERVER_ONUPDATE, CLIENT_ONVOLUMECHANGED, CLIENT_ONLATENCYCHANGED, CLIENT_ONNAMECHANGED, CLIENT_ONCONNECT, CLIENT_ONDISCONNECT, GROUP_ONMUTE, GROUP_ONSTREAMCHANGED, STREAM_ONUPDATE] @@ -51,10 +53,11 @@ class Snapserver(object): """Represents a snapserver.""" # pylint: disable=too-many-instance-attributes - def __init__(self, loop, host, port=CONTROL_PORT): + def __init__(self, loop, host, port=CONTROL_PORT, reconnect=False): """Initialize.""" self._loop = loop self._port = port + self._reconnect = reconnect self._clients = {} self._streams = {} self._groups = {} @@ -79,13 +82,30 @@ def __init__(self, loop, host, port=CONTROL_PORT): def start(self): """Initiate server connection.""" - _, self._protocol = yield from self._loop.create_connection( - lambda: SnapcastProtocol(self._callbacks), self._host, self._port) + yield from self._do_connect() _LOGGER.info('connected to snapserver on %s:%s', self._host, self._port) status = yield from self.status() self.synchronize(status) self._on_server_connect() + @asyncio.coroutine + def _do_connect(self): + """Perform the connection to the server.""" + _, self._protocol = yield from self._loop.create_connection( + lambda: SnapcastProtocol(self._callbacks), self._host, self._port) + + def _reconnect_cb(self): + """Callback to reconnect to the server.""" + @asyncio.coroutine + def try_reconnect(): + """Actual coroutine ro try to reconnect or reschedule.""" + try: + yield from self._do_connect() + except IOError: + self._loop.call_later(SERVER_RECONNECT_DELAY, + self._reconnect_cb) + asyncio.ensure_future(try_reconnect()) + @asyncio.coroutine def _transact(self, method, params=None): """Wrap requests.""" @@ -202,8 +222,11 @@ def _on_server_connect(self): def _on_server_disconnect(self, exception): """Handle server disconnection.""" + self._protocol = None if self._on_disconnect_callback_func and callable(self._on_disconnect_callback_func): self._on_disconnect_callback_func(exception) + if self._reconnect: + self._reconnect_cb() def _on_server_update(self, data): """Handle server update."""