Skip to content

Commit

Permalink
fix dump_module() bugs and rename parameter 'main' to 'module' (fixes u…
Browse files Browse the repository at this point in the history
…qfoundation#525)

New phrasing of mismatching modules error messages in load_session():

```python
>>> import dill
>>> dill.dump_module()
>>> dill.load_module(module='math')
ValueError: can't update module 'math' with the saved state of module '__main__'

>>> import types
>>> main = types.ModuleType('__main__')
>>> dill.load_module(module=main)
ValueError: can't update module-type object '__main__' with the saved state of imported module '__main__'

>>> dill.dump_module(module=main)
>>> dill.load_module(module='__main__')
ValueError: can't update imported module '__main__' with the saved state of module-type object '__main__'
```
  • Loading branch information
leogama committed Jul 13, 2022
1 parent c23e049 commit 8b90308
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 81 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
/docs/build
/build
/README
/dill/info.py
/dill/__info__.py
158 changes: 90 additions & 68 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,16 +322,21 @@ def loads(str, ignore=None, **kwds):
### Pickle the Interpreter Session
import pathlib
import tempfile
from types import SimpleNamespace

SESSION_IMPORTED_AS_TYPES = (BuiltinMethodType, FunctionType, MethodType,
ModuleType, TypeType)

SESSION_IMPORTED_AS_TYPES = (ModuleType, ClassType, TypeType, Exception,
FunctionType, MethodType, BuiltinMethodType)
TEMPDIR = pathlib.PurePath(tempfile.gettempdir())

def _module_map():
"""get map of imported modules"""
from collections import defaultdict, namedtuple
modmap = namedtuple('Modmap', ['by_name', 'by_id', 'top_level'])
modmap = modmap(defaultdict(list), defaultdict(list), {})
from collections import defaultdict
modmap = SimpleNamespace(
by_name=defaultdict(list),
by_id=defaultdict(list),
top_level={},
)
for modname, module in sys.modules.items():
if not isinstance(module, ModuleType):
continue
Expand Down Expand Up @@ -359,36 +364,38 @@ def _stash_modules(main_module):

imported = []
imported_as = []
imported_top_level = [] # keep separeted for backwards compatibility
imported_top_level = [] # keep separeted for backward compatibility
original = {}
for name, obj in main_module.__dict__.items():
if obj is main_module:
original[name] = newmod # self-reference
continue

elif obj is main_module.__dict__:
original[name] = newmod.__dict__
# Avoid incorrectly matching a singleton value in another package (ex.: __doc__).
if any(obj is singleton for singleton in (None, False, True)) or \
isinstance(obj, ModuleType) and _is_builtin_module(obj): # always saved by ref
elif any(obj is singleton for singleton in (None, False, True)) \
or isinstance(obj, ModuleType) and _is_builtin_module(obj): # always saved by ref
original[name] = obj
continue

source_module, objname = _lookup_module(modmap, name, obj, main_module)
if source_module:
if objname == name:
imported.append((source_module, name))
else:
imported_as.append((source_module, objname, name))
else:
try:
imported_top_level.append((modmap.top_level[id(obj)], name))
except KeyError:
original[name] = obj
source_module, objname = _lookup_module(modmap, name, obj, main_module)
if source_module:
if objname == name:
imported.append((source_module, name))
else:
imported_as.append((source_module, objname, name))
else:
try:
imported_top_level.append((modmap.top_level[id(obj)], name))
except KeyError:
original[name] = obj

if len(original) < len(main_module.__dict__):
newmod.__dict__.update(original)
newmod.__dill_imported = imported
newmod.__dill_imported_as = imported_as
newmod.__dill_imported_top_level = imported_top_level
if getattr(newmod, '__loader__', None) is None and _is_imported_module(main_module):
# Trick _is_imported_module() to force saving as an imported module.
newmod.__loader__ = True # will be discarded by save_module()
return newmod
else:
return main_module
Expand All @@ -407,7 +414,7 @@ def _restore_modules(unpickler, main_module):
#NOTE: 06/03/15 renamed main_module to main
def dump_module(
filename = str(TEMPDIR/'session.pkl'),
main: Optional[Union[ModuleType, str]] = None,
module: Union[ModuleType, str] = None,
refimported: bool = False,
**kwds
) -> None:
Expand All @@ -420,7 +427,8 @@ def dump_module(
Parameters:
filename: a path-like object or a writable stream.
main: a module object or the name of an importable module.
module: a module object or the name of an importable module. If `None`
(the default), :py:mod:`__main__` is saved.
refimported: if `True`, all objects imported into the module's
namespace are saved by reference. *Note:* this is similar but
independent from ``dill.settings[`byref`]``, as ``refimported``
Expand All @@ -432,17 +440,18 @@ def dump_module(
:py:exc:`PicklingError`: if pickling fails.
Examples:
- Save current interpreter session state:
>>> import dill
>>> squared = lambda x:x*x
>>> squared = lambda x: x*x
>>> dill.dump_module() # save state of __main__ to /tmp/session.pkl
- Save the state of an imported/importable module:
>>> import dill
>>> import pox
>>> pox.plus_one = lambda x:x+1
>>> pox.plus_one = lambda x: x+1
>>> dill.dump_module('pox_session.pkl', main=pox)
- Save the state of a non-importable, module-type object:
Expand All @@ -468,24 +477,28 @@ def dump_module(
>>> [foo.sin(x) for x in foo.values]
[0.8414709848078965, 0.9092974268256817, 0.1411200080598672]
*Changed in version 0.3.6:* the function ``dump_session()`` was renamed to
``dump_module()``.
*Changed in version 0.3.6:* the parameter ``byref`` was renamed to
``refimported``.
*Changed in version 0.3.6:* Function ``dump_session()`` was renamed to
``dump_module()``. Parameters ``main`` and ``byref`` were renamed to
``module`` and ``refimported``, respectively.
"""
if 'byref' in kwds:
warnings.warn(
"The argument 'byref' has been renamed 'refimported'"
" to distinguish it from dill.settings['byref'].",
PendingDeprecationWarning
)
if refimported:
raise TypeError("both 'refimported' and 'byref' were used")
refimported = kwds.pop('byref')
for old_par, par in [('main', 'module'), ('byref', 'refimported')]:
if old_par in kwds:
message = "The argument %r has been renamed %r" % (old_par, par)
if old_par == 'byref':
message += " to distinguish it from dill.settings['byref']"
warnings.warn(message + ".", PendingDeprecationWarning)
if locals()[par]: # the defaults are None and False
raise TypeError("both %r and %r arguments were used" % (par, old_par))
refimported = kwds.pop('byref', refimported)
module = kwds.pop('main', module)

from .settings import settings
protocol = settings['protocol']
if main is None: main = _main_module
main = module
if main is None:
main = _main_module
elif isinstance(main, str):
main = _import_module(main)
if hasattr(filename, 'write'):
file = filename
else:
Expand All @@ -510,7 +523,7 @@ def dump_module(
# Backward compatibility.
def dump_session(filename=str(TEMPDIR/'session.pkl'), main=None, byref=False, **kwds):
warnings.warn("dump_session() has been renamed dump_module()", PendingDeprecationWarning)
dump_module(filename, main, refimported=byref, **kwds)
dump_module(filename, module=main, refimported=byref, **kwds)
dump_session.__doc__ = dump_module.__doc__

class _PeekableReader:
Expand Down Expand Up @@ -574,7 +587,7 @@ def _identify_module(file, main=None):

def load_module(
filename = str(TEMPDIR/'session.pkl'),
main: Union[ModuleType, str] = None,
module: Union[ModuleType, str] = None,
**kwds
) -> Optional[ModuleType]:
"""Update :py:mod:`__main__` or another module with the state from the
Expand All @@ -592,7 +605,7 @@ def load_module(
Parameters:
filename: a path-like object or a readable stream.
main: a module object or the name of an importable module.
module: a module object or the name of an importable module.
**kwds: extra keyword arguments passed to :py:class:`Unpickler()`.
Raises:
Expand All @@ -609,11 +622,11 @@ def load_module(
- Save the state of some modules:
>>> import dill
>>> squared = lambda x:x*x
>>> squared = lambda x: x*x
>>> dill.dump_module() # save state of __main__ to /tmp/session.pkl
>>>
>>> import pox # an imported module
>>> pox.plus_one = lambda x:x+1
>>> pox.plus_one = lambda x: x+1
>>> dill.dump_module('pox_session.pkl', main=pox)
>>>
>>> from types import ModuleType
Expand Down Expand Up @@ -659,19 +672,27 @@ def load_module(
>>> from types import ModuleType
>>> foo = ModuleType('foo')
>>> foo.values = ['a','b']
>>> foo.sin = lambda x:x*x
>>> foo.sin = lambda x: x*x
>>> dill.load_module('foo_session.pkl', main=foo)
>>> [foo.sin(x) for x in foo.values]
[0.8414709848078965, 0.9092974268256817, 0.1411200080598672]
*Changed in version 0.3.6:* the function ``load_session()`` was renamed to
``load_module()``.
*Changed in version 0.3.6:* Function ``load_session()`` was renamed to
``load_module()``. Parameter ``main`` was renamed to ``module``.
See also:
:py:func:`load_module_asdict` to load the contents of module saved
with :py:func:`dump_module` into a dictionary.
"""
main_arg = main
if 'main' in kwds:
warnings.warn(
"The argument 'main' has been renamed 'module'.",
PendingDeprecationWarning
)
if module is not None:
raise TypeError("both 'module' and 'main' arguments were used")
module = kwds.pop('main')
main = module
if hasattr(filename, 'read'):
file = filename
else:
Expand All @@ -681,9 +702,9 @@ def load_module(
#FIXME: dill.settings are disabled
unpickler = Unpickler(file, **kwds)
unpickler._session = True
pickle_main = _identify_module(file, main)

# Resolve unpickler._main
pickle_main = _identify_module(file, main)
if main is None and pickle_main is not None:
main = pickle_main
if isinstance(main, str):
Expand All @@ -705,44 +726,44 @@ def load_module(
is_runtime_mod = pickle_main.startswith('__runtime__.')
if is_runtime_mod:
pickle_main = pickle_main.partition('.')[-1]
error_msg = "can't update{} module{} %r with the saved state of{} module{} %r"
if is_runtime_mod and is_main_imported:
raise ValueError(
"can't restore non-imported module %r into an imported one"
% pickle_main
error_msg.format(" imported", "", "", "-type object")
% (main.__name__, pickle_main)
)
if not is_runtime_mod and not is_main_imported:
raise ValueError(
"can't restore imported module %r into a non-imported one"
% pickle_main
)
if main.__name__ != pickle_main:
raise ValueError(
"can't restore module %r into module %r"
error_msg.format("", "-type object", " imported", "")
% (pickle_main, main.__name__)
)
if main.__name__ != pickle_main:
raise ValueError(error_msg.format("", "", "", "") % (main.__name__, pickle_main))

# This is for find_class() to be able to locate it.
if not is_main_imported:
runtime_main = '__runtime__.%s' % main.__name__
sys.modules[runtime_main] = main

module = unpickler.load()
loaded = unpickler.load()
finally:
if not hasattr(filename, 'read'): # if newly opened file
file.close()
try:
del sys.modules[runtime_main]
except (KeyError, NameError):
pass
assert module is main
_restore_modules(unpickler, module)
if not (module is _main_module or module is main_arg):
return module
assert loaded is main
_restore_modules(unpickler, main)
if main is _main_module or main is module:
return None
else:
return main

# Backward compatibility.
def load_session(filename=str(TEMPDIR/'session.pkl'), main=None, **kwds):
warnings.warn("load_session() has been renamed load_module().", PendingDeprecationWarning)
load_module(filename, main, **kwds)
load_module(filename, module=main, **kwds)
load_session.__doc__ = load_module.__doc__

def load_module_asdict(
Expand Down Expand Up @@ -774,6 +795,7 @@ def load_module_asdict(
Note:
If ``update`` is True, the saved module may be imported then updated.
If imported, the loaded module remains unchanged as in the general case.
Example:
>>> import dill
Expand All @@ -796,8 +818,8 @@ def load_module_asdict(
>>> new_var in main # would be True if the option 'update' was set
False
"""
if 'main' in kwds:
raise TypeError("'main' is an invalid keyword argument for load_module_asdict()")
if 'module' in kwds:
raise TypeError("'module' is an invalid keyword argument for load_module_asdict()")
if hasattr(filename, 'read'):
file = filename
else:
Expand All @@ -815,7 +837,6 @@ def load_module_asdict(
main.__builtins__ = __builtin__
sys.modules[main_name] = main
load_module(file, **kwds)
main.__session__ = str(filename)
finally:
if not hasattr(filename, 'read'): # if newly opened file
file.close()
Expand All @@ -826,6 +847,7 @@ def load_module_asdict(
sys.modules[main_name] = old_main
except NameError: # failed before setting old_main
pass
main.__session__ = str(filename)
return main.__dict__

### End: Pickle the Interpreter
Expand Down
Loading

0 comments on commit 8b90308

Please sign in to comment.