Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes some bugs when using dump_session() with byref=True #463

Merged
merged 18 commits into from
May 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
b98cd44
fix `dump_session(byref=True)` bug when no objetcts are imported from…
leogama Apr 22, 2022
44a9e54
fix `dump_session(byref=True)` bug when the `multiprocessing` module …
leogama Apr 22, 2022
a252e44
Save objects imported with an alias and top level modules by referenc…
leogama Apr 25, 2022
8b15b99
Deal with top level functions with `dump_session()`
leogama Apr 26, 2022
47a060d
Added tests for load_session() and dump_session()
leogama Apr 28, 2022
6156dc5
fix singleton comparison, must be by identity, not by equality
leogama Apr 28, 2022
bf419f1
split tests to different files to better test session use cases
leogama Apr 29, 2022
1aef037
Fix error Py2.7 and Py3.7 where there is a tuple in sys.modules for s…
leogama Apr 29, 2022
abdfd5c
dump_session(): extra test for code coverage
leogama Apr 29, 2022
a87496b
dump_session and load_session: some minor improvements
leogama Apr 30, 2022
e4ba1e8
dump_session(): more tests
leogama Apr 30, 2022
095b4cb
dump_session(): dump modules other than __main__ by reference
leogama Apr 30, 2022
d1450bf
dump_session(): minor code coverage investigation
leogama Apr 30, 2022
f292584
dump_session() tests: adjustments
leogama Apr 30, 2022
5b90579
dump_session() tests: fix copyright notice
leogama Apr 30, 2022
1f310a9
dump_session() tests: merge test files using subprocess to test loadi…
leogama Apr 30, 2022
3ae5f1a
tests: Revert change. Test files are independent, should run in any o…
leogama May 1, 2022
0c1b1d1
dump_sessio() tests: use an unpickleable object available in PyPy
leogama May 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 95 additions & 47 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,50 +397,87 @@ def loads(str, ignore=None, **kwds):
### End: Shorthands ###

### Pickle the Interpreter Session
SESSION_IMPORTED_AS_TYPES = (ModuleType, ClassType, TypeType, Exception,
FunctionType, MethodType, BuiltinMethodType)

def _module_map():
"""get map of imported modules"""
from collections import defaultdict
modmap = defaultdict(list)
from collections import defaultdict, namedtuple
modmap = namedtuple('Modmap', ['by_name', 'by_id', 'top_level'])
modmap = modmap(defaultdict(list), defaultdict(list), {})
items = 'items' if PY3 else 'iteritems'
for name, module in getattr(sys.modules, items)():
if module is None:
for modname, module in getattr(sys.modules, items)():
if not isinstance(module, ModuleType):
continue
for objname, obj in module.__dict__.items():
modmap[objname].append((obj, name))
if '.' not in modname:
modmap.top_level[id(module)] = modname
for objname, modobj in module.__dict__.items():
modmap.by_name[objname].append((modobj, modname))
modmap.by_id[id(modobj)].append((modobj, objname, modname))
return modmap

def _lookup_module(modmap, name, obj, main_module): #FIXME: needs work
"""lookup name if module is imported"""
for modobj, modname in modmap[name]:
if modobj is obj and modname != main_module.__name__:
return modname
def _lookup_module(modmap, name, obj, main_module):
"""lookup name or id of obj if module is imported"""
for modobj, modname in modmap.by_name[name]:
if modobj is obj and sys.modules[modname] is not main_module:
return modname, name
if isinstance(obj, SESSION_IMPORTED_AS_TYPES):
for modobj, objname, modname in modmap.by_id[id(obj)]:
if sys.modules[modname] is not main_module:
return modname, objname
return None, None

def _stash_modules(main_module):
modmap = _module_map()
newmod = ModuleType(main_module.__name__)

imported = []
imported_as = []
imported_top_level = [] # keep separeted for backwards compatibility
original = {}
items = 'items' if PY3 else 'iteritems'
for name, obj in getattr(main_module.__dict__, items)():
source_module = _lookup_module(modmap, name, obj, main_module)
if obj is main_module:
original[name] = newmod # self-reference
continue

# Avoid incorrectly matching a singleton value in another package (ex.: __doc__).
if any(obj is singleton for singleton in (None, False, True)) or \
isinstance(obj, ModuleType) and _is_builtin_module(obj): # always saved by ref
original[name] = obj
continue

source_module, objname = _lookup_module(modmap, name, obj, main_module)
if source_module:
imported.append((source_module, name))
if objname == name:
imported.append((source_module, name))
else:
imported_as.append((source_module, objname, name))
else:
original[name] = obj
if len(imported):
import types
newmod = types.ModuleType(main_module.__name__)
try:
imported_top_level.append((modmap.top_level[id(obj)], name))
except KeyError:
original[name] = obj

if len(original) < len(main_module.__dict__):
newmod.__dict__.update(original)
newmod.__dill_imported = imported
newmod.__dill_imported_as = imported_as
newmod.__dill_imported_top_level = imported_top_level
return newmod
else:
return original
return main_module

def _restore_modules(main_module):
if '__dill_imported' not in main_module.__dict__:
return
imports = main_module.__dict__.pop('__dill_imported')
for module, name in imports:
exec("from %s import %s" % (module, name), main_module.__dict__)
def _restore_modules(unpickler, main_module):
try:
for modname, name in main_module.__dict__.pop('__dill_imported'):
main_module.__dict__[name] = unpickler.find_class(modname, name)
for modname, objname, name in main_module.__dict__.pop('__dill_imported_as'):
main_module.__dict__[name] = unpickler.find_class(modname, objname)
for modname, name in main_module.__dict__.pop('__dill_imported_top_level'):
main_module.__dict__[name] = __import__(modname)
except KeyError:
pass

#NOTE: 06/03/15 renamed main_module to main
def dump_session(filename='/tmp/session.pkl', main=None, byref=False, **kwds):
Expand All @@ -453,13 +490,16 @@ def dump_session(filename='/tmp/session.pkl', main=None, byref=False, **kwds):
else:
f = open(filename, 'wb')
try:
pickler = Pickler(f, protocol, **kwds)
pickler._original_main = main
if byref:
main = _stash_modules(main)
pickler = Pickler(f, protocol, **kwds)
pickler._main = main #FIXME: dill.settings are disabled
pickler._byref = False # disable pickling by name reference
pickler._recurse = False # disable pickling recursion for globals
pickler._session = True # is best indicator of when pickling a session
pickler._first_pass = True
pickler._main_modified = main is not pickler._original_main
pickler.dump(main)
finally:
if f is not filename: # If newly opened file
Expand All @@ -480,7 +520,7 @@ def load_session(filename='/tmp/session.pkl', main=None, **kwds):
module = unpickler.load()
unpickler._session = False
main.__dict__.update(module.__dict__)
_restore_modules(main)
_restore_modules(unpickler, main)
finally:
if f is not filename: # If newly opened file
f.close()
Expand Down Expand Up @@ -1060,9 +1100,11 @@ def _import_module(import_name, safe=False):
return None
raise

def _locate_function(obj, session=False):
if obj.__module__ in ['__main__', None]: # and session:
def _locate_function(obj, pickler=None):
if obj.__module__ in ['__main__', None] or \
pickler and pickler._session and obj.__module__ == pickler._main.__name__:
return False

found = _import_module(obj.__module__ + '.' + obj.__name__, safe=True)
return found is obj

Expand Down Expand Up @@ -1177,7 +1219,8 @@ def save_code(pickler, obj):

@register(dict)
def save_module_dict(pickler, obj):
if is_dill(pickler, child=False) and obj == pickler._main.__dict__ and not pickler._session:
if is_dill(pickler, child=False) and obj == pickler._main.__dict__ and \
not (pickler._session and pickler._first_pass):
log.info("D1: <dict%s" % str(obj.__repr__).split('dict')[-1]) # obj
if PY3:
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
Expand All @@ -1204,7 +1247,7 @@ def save_module_dict(pickler, obj):
log.info("D2: <dict%s" % str(obj.__repr__).split('dict')[-1]) # obj
if is_dill(pickler, child=False) and pickler._session:
# we only care about session the first pass thru
pickler._session = False
pickler._first_pass = False
StockPickler.save_dict(pickler, obj)
log.info("# D2")
return
Expand Down Expand Up @@ -1267,7 +1310,7 @@ def save_dict_items(pickler, obj):

@register(ClassType)
def save_classobj(pickler, obj): #FIXME: enable pickler._byref
if obj.__module__ == '__main__': #XXX: use _main_module.__name__ everywhere?
if not _locate_function(obj, pickler):
log.info("C1: %s" % obj)
pickler.save_reduce(ClassType, (obj.__name__, obj.__bases__,
obj.__dict__), obj=obj)
Expand Down Expand Up @@ -1676,6 +1719,16 @@ def save_weakproxy(pickler, obj):
log.info("# %s" % _t)
return

def _is_builtin_module(module):
if not hasattr(module, "__file__"): return True
# If a module file name starts with prefix, it should be a builtin
# module, so should always be pickled as a reference.
names = ["base_prefix", "base_exec_prefix", "exec_prefix", "prefix", "real_prefix"]
return any(os.path.realpath(module.__file__).startswith(os.path.realpath(getattr(sys, name)))
for name in names if hasattr(sys, name)) or \
module.__file__.endswith(EXTENSION_SUFFIXES) or \
'site-packages' in module.__file__

@register(ModuleType)
def save_module(pickler, obj):
if False: #_use_diff:
Expand All @@ -1696,19 +1749,9 @@ def save_module(pickler, obj):
pickler.save_reduce(_import_module, (obj.__name__,), obj=obj)
log.info("# M1")
else:
# if a module file name starts with prefix, it should be a builtin
# module, so should be pickled as a reference
if hasattr(obj, "__file__"):
names = ["base_prefix", "base_exec_prefix", "exec_prefix",
"prefix", "real_prefix"]
builtin_mod = any(os.path.realpath(obj.__file__).startswith(os.path.realpath(getattr(sys, name)))
for name in names if hasattr(sys, name))
builtin_mod = (builtin_mod or obj.__file__.endswith(EXTENSION_SUFFIXES) or
'site-packages' in obj.__file__)
else:
builtin_mod = True
if obj.__name__ not in ("builtins", "dill", "dill._dill") \
and not builtin_mod or is_dill(pickler, child=True) and obj is pickler._main:
builtin_mod = _is_builtin_module(obj)
if obj.__name__ not in ("builtins", "dill", "dill._dill") and not builtin_mod or \
is_dill(pickler, child=True) and obj is pickler._main:
log.info("M1: %s" % obj)
_main_dict = obj.__dict__.copy() #XXX: better no copy? option to copy?
[_main_dict.pop(item, None) for item in singletontypes
Expand Down Expand Up @@ -1766,7 +1809,7 @@ def save_type(pickler, obj, postproc_list=None):
obj_name = getattr(obj, '__qualname__', getattr(obj, '__name__', None))
_byref = getattr(pickler, '_byref', None)
obj_recursive = id(obj) in getattr(pickler, '_postproc', ())
incorrectly_named = not _locate_function(obj)
incorrectly_named = not _locate_function(obj, pickler)
if not _byref and not obj_recursive and incorrectly_named: # not a function, but the name was held over
if issubclass(type(obj), type):
# thanks to Tom Stepleton pointing out pickler._session unneeded
Expand Down Expand Up @@ -1844,11 +1887,12 @@ def save_classmethod(pickler, obj):

@register(FunctionType)
def save_function(pickler, obj):
if not _locate_function(obj): #, pickler._session):
if not _locate_function(obj, pickler):
log.info("F1: %s" % obj)
_recurse = getattr(pickler, '_recurse', None)
_byref = getattr(pickler, '_byref', None)
_postproc = getattr(pickler, '_postproc', None)
_main_modified = getattr(pickler, '_main_modified', None)
postproc_list = []
if _recurse:
# recurse to get all globals referred to by obj
Expand All @@ -1863,8 +1907,12 @@ def save_function(pickler, obj):
else:
globs_copy = obj.__globals__ if PY3 else obj.func_globals

# If the globals is the __dict__ from the module being save as a
# session, substitute it by the dictionary being actually saved.
if _main_modified and globs_copy is pickler._original_main.__dict__:
globs = globs_copy = pickler._main.__dict__
# If the globals is a module __dict__, do not save it in the pickle.
if globs_copy is not None and obj.__module__ is not None and \
elif globs_copy is not None and obj.__module__ is not None and \
getattr(_import_module(obj.__module__, True), '__dict__', None) is globs_copy:
globs = globs_copy
else:
Expand Down
1 change: 0 additions & 1 deletion tests/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,3 @@
if not p:
print('.', end='')
print('')

Loading