diff --git a/salt/transport/ipc.py b/salt/transport/ipc.py index 30e340c0b8b8..5899e08650a9 100644 --- a/salt/transport/ipc.py +++ b/salt/transport/ipc.py @@ -5,10 +5,12 @@ # Import Python libs from __future__ import absolute_import, print_function, unicode_literals +import errno import logging import socket import weakref import time +import sys # Import 3rd-party libs import msgpack @@ -83,6 +85,11 @@ def _done_callback(self, future): self.set_exception(exc) +class IPCExceptionProxy(object): + def __init__(self, orig_info): + self.orig_info = orig_info + + class IPCServer(object): ''' A Tornado IPC server very similar to Tornado's TCPServer class @@ -237,31 +244,7 @@ class IPCClient(object): case it is used as the port for a tcp localhost connection. ''' - - # Create singleton map between two sockets - instance_map = weakref.WeakKeyDictionary() - - def __new__(cls, socket_path, io_loop=None): - io_loop = io_loop or tornado.ioloop.IOLoop.current() - if io_loop not in IPCClient.instance_map: - IPCClient.instance_map[io_loop] = weakref.WeakValueDictionary() - loop_instance_map = IPCClient.instance_map[io_loop] - - # FIXME - key = six.text_type(socket_path) - - client = loop_instance_map.get(key) - if client is None: - log.debug('Initializing new IPCClient for path: %s', key) - client = object.__new__(cls) - # FIXME - client.__singleton_init__(io_loop=io_loop, socket_path=socket_path) - loop_instance_map[key] = client - else: - log.debug('Re-using IPCClient for %s', key) - return client - - def __singleton_init__(self, socket_path, io_loop=None): + def __init__(self, socket_path, io_loop=None): ''' Create a new IPC client @@ -280,10 +263,6 @@ def __singleton_init__(self, socket_path, io_loop=None): encoding = 'utf-8' self.unpacker = msgpack.Unpacker(encoding=encoding) - def __init__(self, socket_path, io_loop=None): - # Handled by singleton __new__ - pass - def connected(self): return self.stream is not None and not self.stream.closed() @@ -332,9 +311,8 @@ def _connect(self, timeout=None): if self.stream is None: with salt.utils.asynchronous.current_ioloop(self.io_loop): self.stream = IOStream( - socket.socket(sock_type, socket.SOCK_STREAM), + socket.socket(sock_type, socket.SOCK_STREAM) ) - try: log.trace('IPCClient: Connecting to socket: %s', self.socket_path) yield self.stream.connect(sock_addr) @@ -354,7 +332,16 @@ def _connect(self, timeout=None): yield tornado.gen.sleep(1) def __del__(self): - self.close() + try: + self.close() + except socket.error as exc: + if exc.errno != errno.EBADF: + # If its not a bad file descriptor error, raise + raise + except TypeError: + # This is raised when Python's GC has collected objects which + # would be needed when calling self.close() + pass def close(self): ''' @@ -368,16 +355,6 @@ def close(self): if self.stream is not None and not self.stream.closed(): self.stream.close() - # Remove the entry from the instance map so - # that a closed entry may not be reused. - # This forces this operation even if the reference - # count of the entry has not yet gone to zero. - if self.io_loop in IPCClient.instance_map: - loop_instance_map = IPCClient.instance_map[self.io_loop] - key = six.text_type(self.socket_path) - if key in loop_instance_map: - del loop_instance_map[key] - class IPCMessageClient(IPCClient): ''' @@ -591,12 +568,13 @@ class IPCMessageSubscriberService(IPCClient): To use this refer to IPCMessageSubscriber documentation. ''' - def __singleton_init__(self, socket_path, io_loop=None): - super(IPCMessageSubscriberService, self).__singleton_init__( + def __init__(self, socket_path, io_loop=None): + super(IPCMessageSubscriberService, self).__init__( socket_path, io_loop=io_loop) self.saved_data = [] self._read_in_progress = Lock() self.handlers = weakref.WeakSet() + self.read_stream_future = None def _subscribe(self, handler): self.handlers.add(handler) @@ -624,16 +602,16 @@ def _read(self, timeout, callback=None): if timeout is None: timeout = 5 - read_stream_future = None + self.read_stream_future = None while self._has_subscribers(): - if read_stream_future is None: - read_stream_future = self.stream.read_bytes(4096, partial=True) + if self.read_stream_future is None: + self.read_stream_future = self.stream.read_bytes(4096, partial=True) try: wire_bytes = yield FutureWithTimeout(self.io_loop, - read_stream_future, + self.read_stream_future, timeout) - read_stream_future = None + self.read_stream_future = None self.unpacker.feed(wire_bytes) msgs = [msg['body'] for msg in self.unpacker] @@ -648,6 +626,7 @@ def _read(self, timeout, callback=None): break except Exception as exc: log.error('Exception occurred in Subscriber while handling stream: %s', exc) + exc = IPCExceptionProxy(sys.exc_info()) self._feed_subscribers([exc]) break @@ -672,7 +651,7 @@ def read(self, handler, timeout=None): except Exception as exc: log.error('Exception occurred while Subscriber connecting: %s', exc) yield tornado.gen.sleep(1) - self._read(timeout) + yield self._read(timeout) def close(self): ''' @@ -680,8 +659,11 @@ def close(self): Sockets and filehandles should be closed explicitly, to prevent leaks. ''' - if not self._closing: - super(IPCMessageSubscriberService, self).close() + super(IPCMessageSubscriberService, self).close() + if self.read_stream_future is not None and self.read_stream_future.done(): + exc = self.read_stream_future.exception() + if exc and not isinstance(exc, tornado.iostream.StreamClosedError): + log.error("Read future returned exception %r", exc) def __del__(self): if IPCMessageSubscriberService in globals(): @@ -755,8 +737,8 @@ def read_async(self, callback, timeout=None): raise tornado.gen.Return(None) if data is None: break - elif isinstance(data, Exception): - raise data + elif isinstance(data, IPCExceptionProxy): + six.reraise(*data.orig_info) elif callback: self.service.io_loop.spawn_callback(callback, data) else: @@ -775,6 +757,7 @@ def read_sync(self, timeout=None): def close(self): self.service.unsubscribe(self) + self.service.close() def __del__(self): self.close() diff --git a/tests/unit/transport/test_ipc.py b/tests/unit/transport/test_ipc.py index 0ca8ebff8640..939c4958318d 100644 --- a/tests/unit/transport/test_ipc.py +++ b/tests/unit/transport/test_ipc.py @@ -86,13 +86,14 @@ class IPCMessageClient(BaseIPCReqCase): ''' def _get_channel(self): - channel = salt.transport.ipc.IPCMessageClient( - socket_path=self.socket_path, - io_loop=self.io_loop, - ) - channel.connect(callback=self.stop) - self.wait() - return channel + if not hasattr(self, 'channel') or self.channel is None: + self.channel = salt.transport.ipc.IPCMessageClient( + socket_path=self.socket_path, + io_loop=self.io_loop, + ) + self.channel.connect(callback=self.stop) + self.wait() + return self.channel def setUp(self): super(IPCMessageClient, self).setUp() @@ -106,6 +107,8 @@ def tearDown(self): if exc.errno != errno.EBADF: # If its not a bad file descriptor error, raise raise + finally: + self.channel = None def test_basic_send(self): msg = {'foo': 'bar', 'stop': True}