diff --git a/salt/transport/ipc.py b/salt/transport/ipc.py index 69032ff352c6..c23af2956647 100644 --- a/salt/transport/ipc.py +++ b/salt/transport/ipc.py @@ -10,7 +10,7 @@ import socket import weakref import time -import threading +import sys # Import 3rd-party libs import msgpack @@ -85,6 +85,11 @@ def _done_callback(self, future): self.set_exception(exc) +class IPCExceptionProxy(object): + def __init__(self, orig_info): + self.orig_info = orig_info + + class IPCServer(object): ''' A Tornado IPC server very similar to Tornado's TCPServer class @@ -244,36 +249,7 @@ class IPCClient(object): case it is used as the port for a tcp localhost connection. ''' - - # Create singleton map between two sockets - instance_map = weakref.WeakKeyDictionary() - - def __new__(cls, socket_path, io_loop=None): - io_loop = io_loop or tornado.ioloop.IOLoop.current() - if io_loop not in IPCClient.instance_map: - IPCClient.instance_map[io_loop] = weakref.WeakValueDictionary() - loop_instance_map = IPCClient.instance_map[io_loop] - - # FIXME - key = six.text_type(socket_path) - - client = loop_instance_map.get(key) - if client is None: - log.debug('Initializing new IPCClient for path: %s', key) - client = object.__new__(cls) - # FIXME - client.__singleton_init__(io_loop=io_loop, socket_path=socket_path) - client._instance_key = key - loop_instance_map[key] = client - client._refcount = 1 - client._refcount_lock = threading.RLock() - else: - log.debug('Re-using IPCClient for %s', key) - with client._refcount_lock: - client._refcount += 1 - return client - - def __singleton_init__(self, socket_path, io_loop=None): + def __init__(self, socket_path, io_loop=None): ''' Create a new IPC client @@ -292,10 +268,6 @@ def __singleton_init__(self, socket_path, io_loop=None): encoding = 'utf-8' self.unpacker = msgpack.Unpacker(encoding=encoding) - def __init__(self, socket_path, io_loop=None): - # Handled by singleton __new__ - pass - def connected(self): return self.stream is not None and not self.stream.closed() @@ -367,16 +339,11 @@ def _connect(self, timeout=None): def __del__(self): try: - with self._refcount_lock: - # Make sure we actually close no matter if something - # went wrong with our ref counting - self._refcount = 1 - try: - self.close() - except socket.error as exc: - if exc.errno != errno.EBADF: - # If its not a bad file descriptor error, raise - raise + self.close() + except socket.error as exc: + if exc.errno != errno.EBADF: + # If its not a bad file descriptor error, raise + raise except TypeError: # This is raised when Python's GC has collected objects which # would be needed when calling self.close() @@ -391,16 +358,6 @@ def close(self): if self._closing: return - if self._refcount > 1: - # Decrease refcount - with self._refcount_lock: - self._refcount -= 1 - log.debug( - 'This is not the last %s instance. Not closing yet.', - self.__class__.__name__ - ) - return - self._closing = True log.debug('Closing %s instance', self.__class__.__name__) @@ -408,17 +365,6 @@ def close(self): if self.stream is not None and not self.stream.closed(): self.stream.close() - # Remove the entry from the instance map so - # that a closed entry may not be reused. - # This forces this operation even if the reference - # count of the entry has not yet gone to zero. - if self.io_loop in self.__class__.instance_map: - loop_instance_map = self.__class__.instance_map[self.io_loop] - if self._instance_key in loop_instance_map: - del loop_instance_map[self._instance_key] - if not loop_instance_map: - del self.__class__.instance_map[self.io_loop] - class IPCMessageClient(IPCClient): ''' @@ -637,12 +583,13 @@ class IPCMessageSubscriberService(IPCClient): To use this refer to IPCMessageSubscriber documentation. ''' - def __singleton_init__(self, socket_path, io_loop=None): - super(IPCMessageSubscriberService, self).__singleton_init__( + def __init__(self, socket_path, io_loop=None): + super(IPCMessageSubscriberService, self).__init__( socket_path, io_loop=io_loop) self.saved_data = [] self._read_in_progress = Lock() self.handlers = weakref.WeakSet() + self.read_stream_future = None def _subscribe(self, handler): self.handlers.add(handler) @@ -670,16 +617,16 @@ def _read(self, timeout, callback=None): if timeout is None: timeout = 5 - read_stream_future = None + self.read_stream_future = None while self._has_subscribers(): - if read_stream_future is None: - read_stream_future = self.stream.read_bytes(4096, partial=True) + if self.read_stream_future is None: + self.read_stream_future = self.stream.read_bytes(4096, partial=True) try: wire_bytes = yield FutureWithTimeout(self.io_loop, - read_stream_future, + self.read_stream_future, timeout) - read_stream_future = None + self.read_stream_future = None self.unpacker.feed(wire_bytes) msgs = [msg['body'] for msg in self.unpacker] @@ -694,6 +641,7 @@ def _read(self, timeout, callback=None): break except Exception as exc: log.error('Exception occurred in Subscriber while handling stream: %s', exc) + exc = IPCExceptionProxy(sys.exc_info()) self._feed_subscribers([exc]) break @@ -718,7 +666,7 @@ def read(self, handler, timeout=None): except Exception as exc: log.error('Exception occurred while Subscriber connecting: %s', exc) yield tornado.gen.sleep(1) - self._read(timeout) + yield self._read(timeout) def close(self): ''' @@ -726,8 +674,11 @@ def close(self): Sockets and filehandles should be closed explicitly, to prevent leaks. ''' - if not self._closing: - super(IPCMessageSubscriberService, self).close() + super(IPCMessageSubscriberService, self).close() + if self.read_stream_future is not None and self.read_stream_future.done(): + exc = self.read_stream_future.exception() + if exc and not isinstance(exc, tornado.iostream.StreamClosedError): + log.error("Read future returned exception %r", exc) def __del__(self): if IPCMessageSubscriberService in globals(): @@ -801,8 +752,8 @@ def read_async(self, callback, timeout=None): raise tornado.gen.Return(None) if data is None: break - elif isinstance(data, Exception): - raise data + elif isinstance(data, IPCExceptionProxy): + six.reraise(*data.orig_info) elif callback: self.service.io_loop.spawn_callback(callback, data) else: diff --git a/tests/unit/transport/test_ipc.py b/tests/unit/transport/test_ipc.py index 5cc5a70ee89e..3f5ad99f8a70 100644 --- a/tests/unit/transport/test_ipc.py +++ b/tests/unit/transport/test_ipc.py @@ -86,13 +86,14 @@ class IPCMessageClient(BaseIPCReqCase): ''' def _get_channel(self): - channel = salt.transport.ipc.IPCMessageClient( - socket_path=self.socket_path, - io_loop=self.io_loop, - ) - channel.connect(callback=self.stop) - self.wait() - return channel + if not hasattr(self, 'channel') or self.channel is None: + self.channel = salt.transport.ipc.IPCMessageClient( + socket_path=self.socket_path, + io_loop=self.io_loop, + ) + self.channel.connect(callback=self.stop) + self.wait() + return self.channel def setUp(self): super(IPCMessageClient, self).setUp() @@ -107,6 +108,8 @@ def tearDown(self): if exc.errno != errno.EBADF: # If its not a bad file descriptor error, raise raise + finally: + self.channel = None def test_singleton(self): channel = self._get_channel() @@ -120,23 +123,6 @@ def test_singleton(self): self.wait() self.assertEqual(self.payloads[0], msg) - def test_last_singleton_instance_closes(self): - channel = self._get_channel() - msg = {'foo': 'bar', 'stop': True} - log.debug('Sending msg1') - self.channel.send(msg) - self.wait() - self.assertEqual(self.payloads[0], msg) - channel.close() - # Since this is a singleton, and only the last singleton instance - # should actually close the connection, the next code should still - # work and not timeout - msg = {'bar': 'foo', 'stop': True} - log.debug('Sending msg2') - self.channel.send(msg) - self.wait() - self.assertEqual(self.payloads[1], msg) - def test_basic_send(self): msg = {'foo': 'bar', 'stop': True} self.channel.send(msg)