diff --git a/dill/dill.py b/dill/dill.py index 620d492c..b40d1e11 100644 --- a/dill/dill.py +++ b/dill/dill.py @@ -7,7 +7,8 @@ and coded to the pickle interface, by mmckerns@caltech.edu """ __all__ = ['dump','dumps','load','loads','dump_session','load_session',\ - 'Pickler','Unpickler','register','copy','pickle','pickles',\ + 'dumps_session', 'loads_session', 'Pickler','Unpickler',\ + 'register','copy','pickle','pickles',\ 'HIGHEST_PROTOCOL','PicklingError'] import logging @@ -18,12 +19,18 @@ def _trace(boolean): if boolean: log.setLevel(logging.DEBUG) else: log.setLevel(logging.WARN) return - import __builtin__ -import __main__ as _main_module import sys import marshal -import ctypes + +try: + import __main__ as DEFAULT_MAIN_MODULE + import ctypes + HAS_CTYPES = True +except: + HAS_CTYPES = False + DEFAULT_MAIN_MODULE = None + # import zlib from pickle import HIGHEST_PROTOCOL, PicklingError from pickle import Pickler as StockPickler @@ -54,38 +61,46 @@ def _trace(boolean): ### Shorthands (modified from python2.5/lib/pickle.py) try: from cStringIO import StringIO + from cStringIO import StringO as StringIOClass except ImportError: from StringIO import StringIO + StringIOClass = StringIO -def copy(obj): +def copy(obj, main_module=None): """use pickling to 'copy' an object""" - return loads(dumps(obj)) + return loads(dumps(obj, main_module=main_module), main_module=main_module) -def dump(obj, file, protocol=HIGHEST_PROTOCOL): +def dump(obj, file, protocol=HIGHEST_PROTOCOL, main_module=None): """pickle an object to a file""" + if main_module is None: + main_module = DEFAULT_MAIN_MODULE + pik = Pickler(file, protocol) - pik._main_module = _main_module + pik._main_module = main_module pik.dump(obj) return -def dumps(obj, protocol=HIGHEST_PROTOCOL): +def dumps(obj, protocol=HIGHEST_PROTOCOL, main_module=None): """pickle an object to a string""" file = StringIO() - dump(obj, file, protocol) + dump(obj, file, protocol, main_module) return file.getvalue() -def load(file): +def load(file, main_module=None): """unpickle an object from a file""" + if main_module is None: + main_module = DEFAULT_MAIN_MODULE + pik = Unpickler(file) - pik._main_module = _main_module + pik._main_module = main_module obj = pik.load() #_main_module.__dict__.update(obj.__dict__) #XXX: should update globals ? return obj -def loads(str): +def loads(str, main_module=None): """unpickle an object from a string""" file = StringIO(str) - return load(file) + return load(file, main_module) # def dumpzs(obj, protocol=HIGHEST_PROTOCOL): # """pickle an object to a compressed string""" @@ -98,9 +113,15 @@ def loads(str): ### End: Shorthands ### ### Pickle the Interpreter Session -def dump_session(filename='/tmp/console.sess', main_module=_main_module): +def dump_session(filename='/tmp/console.sess', main_module=None): """pickle the current state of __main__ to a file""" - f = file(filename, 'wb') + if main_module is None: + main_module = DEFAULT_MAIN_MODULE + + if hasattr(filename, 'write'): + f = filename + else: + f = file(filename, 'wb') try: pickler = Pickler(f, 2) pickler._main_module = main_module @@ -108,23 +129,52 @@ def dump_session(filename='/tmp/console.sess', main_module=_main_module): pickler.dump(main_module) pickler._session = False finally: - f.close() + # don't close StringIO (so callee can get value) + if not isinstance(f, StringIOClass): + f.close() return -def load_session(filename='/tmp/console.sess', main_module=_main_module): +def load_session(filename='/tmp/console.sess', main_module=None): """update the __main__ module with the state from the session file""" - f = file(filename, 'rb') + if main_module is None: + main_module = DEFAULT_MAIN_MODULE + + if hasattr(filename, 'read'): + f = filename + else: + f = file(filename, 'rb') try: + # for custom modules, make sure dill can import the module + old_module = sys.modules.get(main_module.__name__) + sys.modules[main_module.__name__] = main_module + unpickler = Unpickler(f) unpickler._main_module = main_module unpickler._session = True module = unpickler.load() unpickler._session = False main_module.__dict__.update(module.__dict__) + + if old_module: + sys.modules[main_module.__name__] = old_module + else: + del sys.modules[main_module.__name__] finally: f.close() return +def dumps_session(main_module=None): + file = StringIO() + dump_session(file, main_module) + return file.getvalue() + +def loads_session(val, main_module=None): + file = StringIO() + file.write(val) + file.seek(0) + load_session(file, main_module) + return main_module + ### End: Pickle the Interpreter ### Extend the Picklers @@ -189,27 +239,28 @@ def _load_type(name): def _create_type(typeobj, *args): return typeobj(*args) -ctypes.pythonapi.PyCell_New.restype = ctypes.py_object -ctypes.pythonapi.PyCell_New.argtypes = [ctypes.py_object] -# thanks to Paul Kienzle for cleaning the ctypes CellType logic -def _create_cell(obj): - return ctypes.pythonapi.PyCell_New(obj) - -ctypes.pythonapi.PyDictProxy_New.restype = ctypes.py_object -ctypes.pythonapi.PyDictProxy_New.argtypes = [ctypes.py_object] -def _create_dictproxy(obj, *args): - dprox = ctypes.pythonapi.PyDictProxy_New(obj) - #XXX: hack to take care of pickle 'nesting' the correct dictproxy - if 'nested' in args and type(dprox['__dict__']) == DictProxyType: - return dprox['__dict__'] - return dprox - -ctypes.pythonapi.PyWeakref_GetObject.restype = ctypes.py_object -ctypes.pythonapi.PyWeakref_GetObject.argtypes = [ctypes.py_object] -def _create_weakref(obj, *args): - from weakref import ref, ReferenceError - if obj: return ref(obj) #XXX: callback? - raise ReferenceError, "Cannot pickle reference to dead object" +if HAS_CTYPES: + ctypes.pythonapi.PyCell_New.restype = ctypes.py_object + ctypes.pythonapi.PyCell_New.argtypes = [ctypes.py_object] + # thanks to Paul Kienzle for cleaning the ctypes CellType logic + def _create_cell(obj): + return ctypes.pythonapi.PyCell_New(obj) + + ctypes.pythonapi.PyDictProxy_New.restype = ctypes.py_object + ctypes.pythonapi.PyDictProxy_New.argtypes = [ctypes.py_object] + def _create_dictproxy(obj, *args): + dprox = ctypes.pythonapi.PyDictProxy_New(obj) + #XXX: hack to take care of pickle 'nesting' the correct dictproxy + if 'nested' in args and type(dprox['__dict__']) == DictProxyType: + return dprox['__dict__'] + return dprox + + ctypes.pythonapi.PyWeakref_GetObject.restype = ctypes.py_object + ctypes.pythonapi.PyWeakref_GetObject.argtypes = [ctypes.py_object] + def _create_weakref(obj, *args): + from weakref import ref, ReferenceError + if obj: return ref(obj) #XXX: callback? + raise ReferenceError, "Cannot pickle reference to dead object" def _create_weakproxy(obj, *args): from weakref import proxy, ReferenceError @@ -276,11 +327,14 @@ def save_function(pickler, obj): return @register(dict) -def save_module_dict(pickler, obj): +def save_module_dict(pickler, obj, main_module=None): + if main_module is None: + main_module = DEFAULT_MAIN_MODULE + if is_dill(pickler) and obj is pickler._main_module.__dict__: log.info("D1: %s" % "") # obj pickler.write('c__builtin__\n__main__\n') - elif not is_dill(pickler) and obj is _main_module.__dict__: + elif not is_dill(pickler) and obj is main_module.__dict__: log.info("D3: %s" % "") # obj pickler.write('c__main__\n__dict__\n') #XXX: works in general? else: @@ -336,17 +390,18 @@ def save_wrapper_descriptor(pickler, obj): obj.__repr__()), obj=obj) return -@register(CellType) -def save_cell(pickler, obj): - log.info("Ce: %s" % obj) - pickler.save_reduce(_create_cell, (obj.cell_contents,), obj=obj) - return +if HAS_CTYPES: + @register(CellType) + def save_cell(pickler, obj): + log.info("Ce: %s" % obj) + pickler.save_reduce(_create_cell, (obj.cell_contents,), obj=obj) + return -@register(DictProxyType) -def save_dictproxy(pickler, obj): - log.info("Dp: %s" % obj) - pickler.save_reduce(_create_dictproxy, (dict(obj),'nested'), obj=obj) - return + @register(DictProxyType) + def save_dictproxy(pickler, obj): + log.info("Dp: %s" % obj) + pickler.save_reduce(_create_dictproxy, (dict(obj),'nested'), obj=obj) + return @register(SliceType) def save_slice(pickler, obj): @@ -413,6 +468,7 @@ def save_weakproxy(pickler, obj): @register(ModuleType) def save_module(pickler, obj): + log.info('TEST: %s' % obj) if is_dill(pickler) and obj is pickler._main_module: log.info("M1: %s" % obj) pickler.save_reduce(__import__, (obj.__name__,), obj=obj, @@ -427,8 +483,17 @@ def save_type(pickler, obj): if obj in _typemap: log.info("T1: %s" % obj) pickler.save_reduce(_load_type, (_typemap[obj],), obj=obj) + # we are pickling the interpreter, using a custom module + elif (is_dill(pickler) and + pickler._session and + obj.__module__ == pickler._main_module.__name__ and + type(obj) == type): + log.info("T2: %s" % obj) + _dict = _dict_from_dictproxy(obj.__dict__) + pickler.save_reduce(_create_type, (type(obj), obj.__name__, + obj.__bases__, _dict), obj=obj) elif obj.__module__ == '__main__': - if type(obj) == type: + if type(obj) == type: # we are pickling the interpreter if is_dill(pickler) and pickler._session: log.info("T2: %s" % obj) @@ -451,6 +516,7 @@ def save_type(pickler, obj): #print "%s\n%s" % (type(obj), obj.__name__) #print "%s\n%s" % (obj.__bases__, obj.__dict__) StockPickler.save_global(pickler, obj) + return # quick sanity checking @@ -459,7 +525,7 @@ def pickles(obj,exact=False): try: pik = copy(obj) if exact: - return pik == obj + return pik == obj return type(pik) == type(obj) except (TypeError, PicklingError), err: return False