Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify IPCClient and prevent corrupt messages #52445

Merged
merged 7 commits into from
Apr 10, 2019
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 29 additions & 44 deletions salt/transport/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

# Import Python libs
from __future__ import absolute_import, print_function, unicode_literals
import errno
import logging
import socket
import weakref
import time
import sys

# Import 3rd-party libs
import msgpack
Expand Down Expand Up @@ -83,6 +85,11 @@ def _done_callback(self, future):
self.set_exception(exc)


class IPCExceptionProxy(object):
def __init__(self, orig_info):
self.orig_info = orig_info


class IPCServer(object):
'''
A Tornado IPC server very similar to Tornado's TCPServer class
Expand Down Expand Up @@ -237,31 +244,7 @@ class IPCClient(object):
case it is used as the port for a tcp
localhost connection.
'''

# Create singleton map between two sockets
instance_map = weakref.WeakKeyDictionary()

def __new__(cls, socket_path, io_loop=None):
io_loop = io_loop or tornado.ioloop.IOLoop.current()
if io_loop not in IPCClient.instance_map:
IPCClient.instance_map[io_loop] = weakref.WeakValueDictionary()
loop_instance_map = IPCClient.instance_map[io_loop]

# FIXME
key = six.text_type(socket_path)

client = loop_instance_map.get(key)
if client is None:
log.debug('Initializing new IPCClient for path: %s', key)
client = object.__new__(cls)
# FIXME
client.__singleton_init__(io_loop=io_loop, socket_path=socket_path)
loop_instance_map[key] = client
else:
log.debug('Re-using IPCClient for %s', key)
return client

def __singleton_init__(self, socket_path, io_loop=None):
def __init__(self, socket_path, io_loop=None):
'''
Create a new IPC client

Expand All @@ -280,10 +263,6 @@ def __singleton_init__(self, socket_path, io_loop=None):
encoding = 'utf-8'
self.unpacker = msgpack.Unpacker(encoding=encoding)

def __init__(self, socket_path, io_loop=None):
# Handled by singleton __new__
pass

def connected(self):
return self.stream is not None and not self.stream.closed()

Expand Down Expand Up @@ -354,7 +333,16 @@ def _connect(self, timeout=None):
yield tornado.gen.sleep(1)

def __del__(self):
self.close()
try:
self.close()
except socket.error as exc:
if exc.errno != errno.EBADF:
# If its not a bad file descriptor error, raise
raise
except TypeError:
# This is raised when Python's GC has collected objects which
# would be needed when calling self.close()
pass

def close(self):
'''
Expand All @@ -368,16 +356,6 @@ def close(self):
if self.stream is not None and not self.stream.closed():
self.stream.close()

# Remove the entry from the instance map so
# that a closed entry may not be reused.
# This forces this operation even if the reference
# count of the entry has not yet gone to zero.
if self.io_loop in IPCClient.instance_map:
loop_instance_map = IPCClient.instance_map[self.io_loop]
key = six.text_type(self.socket_path)
if key in loop_instance_map:
del loop_instance_map[key]


class IPCMessageClient(IPCClient):
'''
Expand Down Expand Up @@ -591,8 +569,8 @@ class IPCMessageSubscriberService(IPCClient):

To use this refer to IPCMessageSubscriber documentation.
'''
def __singleton_init__(self, socket_path, io_loop=None):
super(IPCMessageSubscriberService, self).__singleton_init__(
def __init__(self, socket_path, io_loop=None):
super(IPCMessageSubscriberService, self).__init__(
socket_path, io_loop=io_loop)
self.saved_data = []
self._read_in_progress = Lock()
Expand Down Expand Up @@ -648,6 +626,7 @@ def _read(self, timeout, callback=None):
break
except Exception as exc:
log.error('Exception occurred in Subscriber while handling stream: %s', exc)
exc = IPCExceptionProxy(sys.exc_info())
self._feed_subscribers([exc])
break

Expand Down Expand Up @@ -755,13 +734,19 @@ def read_async(self, callback, timeout=None):
raise tornado.gen.Return(None)
if data is None:
break
elif isinstance(data, Exception):
raise data
elif isinstance(data, IPCExceptionProxy):
self.reraise(data.orig_info)
elif callback:
self.service.io_loop.spawn_callback(callback, data)
else:
raise tornado.gen.Return(data)

def reraise(self, exc_info):
if six.PY2:
raise exc_info[0], exc_info[1], exc_info[2] # pylint: disable=W1699
else:
raise exc_info[0].with_traceback(exc_info[1], exc_info[2])

def read_sync(self, timeout=None):
'''
Read a message from an IPC socket
Expand Down