diff --git a/dill/__init__.py b/dill/__init__.py index 6f71bbe5..6a012270 100644 --- a/dill/__init__.py +++ b/dill/__init__.py @@ -11,10 +11,10 @@ from .__info__ import __version__, __author__, __doc__, __license__ except: # pragma: no cover import os - import sys + import sys parent = os.path.dirname(os.path.abspath(os.path.dirname(__file__))) sys.path.append(parent) - # get distribution meta info + # get distribution meta info from version import (__version__, __author__, get_license_text, get_readme_as_rst) __license__ = get_license_text(os.path.join(parent, 'LICENSE')) @@ -24,25 +24,23 @@ from ._dill import ( - Pickler, Unpickler, - check, copy, dump, dumps, load, loads, pickle, pickles, register, - DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, CONTENTS_FMODE, FILE_FMODE, HANDLE_FMODE, + dump, dumps, load, loads, copy, + Pickler, Unpickler, register, pickle, pickles, check, + DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, HANDLE_FMODE, CONTENTS_FMODE, FILE_FMODE, PickleError, PickleWarning, PicklingError, PicklingWarning, UnpicklingError, UnpicklingWarning, ) from .session import ( - dump_module, load_module, load_module_asdict, + dump_module, load_module, load_module_asdict, is_pickled_module, dump_session, load_session # backward compatibility ) -from . import detect, logger, session, source, temp +from . import detect, logging, session, source, temp # get global settings from .settings import settings # make sure "trace" is turned off -logger.trace(False) - -from importlib import reload +logging.trace(False) objects = {} # local import of dill._objects @@ -68,6 +66,7 @@ def load_types(pickleable=True, unpickleable=True): Returns: None """ + from importlib import reload # local import of dill.objects from . import _objects if pickleable: diff --git a/dill/_dill.py b/dill/_dill.py index 0130e709..3ad7ead0 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -8,6 +8,13 @@ """ dill: a utility for serialization of python objects +The main API of the package are the functions :func:`dump` and +:func:`dumps` for serialization ("pickling"), and :func:`load` +and :func:`loads` for deserialization ("unpickling"). The +functions :func:`~dill.session.dump_module` and +:func:`~dill.session.load_module` can be used to save and restore +the intepreter session. + Based on code written by Oren Tirosh and Armin Ronacher. Extended to a (near) full set of the builtin types (in types module), and coded to the pickle interface, by . @@ -15,10 +22,13 @@ Test against "all" python types (Std. Lib. CH 1-15 @ 2.7) by mmckerns. Test against CH16+ Std. Lib. ... TBD. """ + +from __future__ import annotations + __all__ = [ - 'Pickler','Unpickler', - 'check','copy','dump','dumps','load','loads','pickle','pickles','register', - 'DEFAULT_PROTOCOL','HIGHEST_PROTOCOL','CONTENTS_FMODE','FILE_FMODE','HANDLE_FMODE', + 'dump','dumps','load','loads','copy', + 'Pickler','Unpickler','register','pickle','pickles','check', + 'DEFAULT_PROTOCOL','HIGHEST_PROTOCOL','HANDLE_FMODE','CONTENTS_FMODE','FILE_FMODE', 'PickleError','PickleWarning','PicklingError','PicklingWarning','UnpicklingError', 'UnpicklingWarning', ] @@ -26,8 +36,10 @@ __module__ = 'dill' import warnings -from .logger import adapter as logger -from .logger import trace as _trace +from dill import logging +from .logging import adapter as logger +from .logging import trace as _trace +_logger = logging.getLogger(__name__) import os import sys @@ -39,6 +51,7 @@ #XXX: get types from .objtypes ? import builtins as __builtin__ from pickle import _Pickler as StockPickler, Unpickler as StockUnpickler +from pickle import DICT, GLOBAL, MARK, POP, SETITEM from _thread import LockType from _thread import RLock as RLockType #from io import IOBase @@ -58,13 +71,14 @@ import marshal import gc # import zlib +import dataclasses +import weakref from weakref import ReferenceType, ProxyType, CallableProxyType from collections import OrderedDict -from functools import partial +from functools import partial, wraps from operator import itemgetter, attrgetter GENERATOR_FAIL = False import importlib.machinery -EXTENSION_SUFFIXES = tuple(importlib.machinery.EXTENSION_SUFFIXES) try: import ctypes HAS_CTYPES = True @@ -158,22 +172,19 @@ def get_file_type(*args, **kwargs): from socket import socket as SocketType #FIXME: additionally calls ForkingPickler.register several times from multiprocessing.reduction import _reduce_socket as reduce_socket -try: +try: #pragma: no cover IS_IPYTHON = __IPYTHON__ # is True - ExitType = None # IPython.core.autocall.ExitAutocall - singletontypes = ['exit', 'quit', 'get_ipython'] + ExitType = None # IPython.core.autocall.ExitAutocall + IPYTHON_SINGLETONS = ('exit', 'quit', 'get_ipython') except NameError: IS_IPYTHON = False try: ExitType = type(exit) # apparently 'exit' can be removed except NameError: ExitType = None - singletontypes = [] + IPYTHON_SINGLETONS = () import inspect -import dataclasses import typing -from pickle import GLOBAL - ### Shims for different versions of Python and dill class Sentinel(object): @@ -212,6 +223,9 @@ def __reduce_ex__(self, protocol): #: Pickles the entire file (handle and contents), preserving mode and position. FILE_FMODE = 2 +# Exceptions commonly raised by unpickleable objects in the Standard Library. +UNPICKLEABLE_ERRORS = (PicklingError, TypeError, ValueError, NotImplementedError) + ### Shorthands (modified from python2.5/lib/pickle.py) def copy(obj, *args, **kwds): """ @@ -229,10 +243,9 @@ def dump(obj, file, protocol=None, byref=None, fmode=None, recurse=None, **kwds) See :func:`dumps` for keyword arguments. """ from .settings import settings - protocol = settings['protocol'] if protocol is None else int(protocol) - _kwds = kwds.copy() - _kwds.update(dict(byref=byref, fmode=fmode, recurse=recurse)) - Pickler(file, protocol, **_kwds).dump(obj) + protocol = int(_getopt(settings, 'protocol', protocol)) + kwds.update(byref=byref, fmode=fmode, recurse=recurse) + Pickler(file, protocol, **kwds).dump(obj) return def dumps(obj, protocol=None, byref=None, fmode=None, recurse=None, **kwds):#, strictio=None): @@ -317,35 +330,72 @@ class PicklingWarning(PickleWarning, PicklingError): class UnpicklingWarning(PickleWarning, UnpicklingError): pass +def _getopt(settings, key, arg=None, *, kwds=None): + """Get option from named argument 'arg' or 'kwds', falling back to settings. + + Examples: + + # With an explicitly named argument: + protocol = int(_getopt(settings, 'protocol', protocol)) + + # With a named argument in **kwds: + self._byref = _getopt(settings, 'byref', kwds=kwds) + """ + # Sanity check, it's a bug in calling code if False. + assert kwds is None or arg is None + if kwds is not None: + arg = kwds.pop(key, None) + if arg is not None: + return arg + else: + return settings[key] + ### Extend the Picklers class Pickler(StockPickler): """python's Pickler extended to interpreter sessions""" - dispatch = MetaCatchingDict(StockPickler.dispatch.copy()) - _session = False + dispatch: typing.Dict[type, typing.Callable[[Pickler, typing.Any], None]] \ + = MetaCatchingDict(StockPickler.dispatch.copy()) + """The dispatch table, a dictionary of serializing functions used + by Pickler to save objects of specific types. Use :func:`pickle` + or :func:`register` to associate types to custom functions. + + :meta hide-value: + """ from .settings import settings + # Flags set by dump_module() is dill.session: + _refimported = False + _refonfail = False + _session = False + _first_pass = False def __init__(self, file, *args, **kwds): settings = Pickler.settings - _byref = kwds.pop('byref', None) - #_strictio = kwds.pop('strictio', None) - _fmode = kwds.pop('fmode', None) - _recurse = kwds.pop('recurse', None) - StockPickler.__init__(self, file, *args, **kwds) self._main = _main_module self._diff_cache = {} - self._byref = settings['byref'] if _byref is None else _byref - self._strictio = False #_strictio - self._fmode = settings['fmode'] if _fmode is None else _fmode - self._recurse = settings['recurse'] if _recurse is None else _recurse + self._byref = _getopt(settings, 'byref', kwds=kwds) + self._fmode = _getopt(settings, 'fmode', kwds=kwds) + self._recurse = _getopt(settings, 'recurse', kwds=kwds) + self._strictio = False #_getopt(settings, 'strictio', kwds=kwds) self._postproc = OrderedDict() - self._file = file + self._file_tell = getattr(file, 'tell', None) # for logger and refonfail + StockPickler.__init__(self, file, *args, **kwds) def save(self, obj, save_persistent_id=True): - # register if the object is a numpy ufunc - # thanks to Paul Kienzle for pointing out ufuncs didn't pickle + # This method overrides StockPickler.save() and is called for every + # object pickled. When 'refonfail' is True, it tries to save the object + # by reference if pickling it fails with a common pickling error, as + # defined by the constant UNPICKLEABLE_ERRORS. If that also fails, then + # the exception is raised and, if this method was called indirectly from + # another Pickler.save() call, the parent objects will try to be saved + # by reference recursively, until it succeeds or the exception + # propagates beyond the topmost save() call. + + # numpy hack obj_type = type(obj) if NumpyArrayType and not (obj_type is type or obj_type in Pickler.dispatch): - if NumpyUfuncType and numpyufunc(obj_type): + # register if the object is a numpy ufunc + # thanks to Paul Kienzle for pointing out ufuncs didn't pickle + if numpyufunc(obj_type): @register(obj_type) def save_numpy_ufunc(pickler, obj): logger.trace(pickler, "Nu: %s", obj) @@ -359,7 +409,7 @@ def save_numpy_ufunc(pickler, obj): # def uload(name): return getattr(numpy, name) # copy_reg.pickle(NumpyUfuncType, udump, uload) # register if the object is a numpy dtype - if NumpyDType and numpydtype(obj_type): + if numpydtype(obj_type): @register(obj_type) def save_numpy_dtype(pickler, obj): logger.trace(pickler, "Dt: %s", obj) @@ -372,27 +422,81 @@ def save_numpy_dtype(pickler, obj): # def udump(f): return uload, (f.type,) # copy_reg.pickle(NumpyDTypeType, udump, uload) # register if the object is a subclassed numpy array instance - if NumpyArrayType and ndarraysubclassinstance(obj_type): + if ndarraysubclassinstance(obj_type): @register(obj_type) def save_numpy_array(pickler, obj): - logger.trace(pickler, "Nu: (%s, %s)", obj.shape, obj.dtype) + logger.trace(pickler, "Nu: (%s, %s)", obj.shape, obj.dtype, obj=obj) npdict = getattr(obj, '__dict__', None) f, args, state = obj.__reduce__() pickler.save_reduce(_create_array, (f,args,state,npdict), obj=obj) logger.trace(pickler, "# Nu") return - # end hack - if GENERATOR_FAIL and type(obj) == GeneratorType: + # end numpy hack + + if GENERATOR_FAIL and obj_type is GeneratorType: msg = "Can't pickle %s: attribute lookup builtins.generator failed" % GeneratorType raise PicklingError(msg) - StockPickler.save(self, obj, save_persistent_id) + if not self._refonfail: + StockPickler.save(self, obj, save_persistent_id) + return + + ## Save with 'refonfail' ## + + # Disable framing. This must be set right after the + # framer.init_framing() call at StockPickler.dump()). + self.framer.current_frame = None + # Store initial state. + position = self._file_tell() + memo_size = len(self.memo) + try: + StockPickler.save(self, obj, save_persistent_id) + except UNPICKLEABLE_ERRORS as error_stack: + trace_message = ( + "# X: fallback to save as global: <%s object at %#012x>" + % (type(obj).__name__, id(obj)) + ) + # Roll back the stream. Note: truncate(position) doesn't always work. + self._file_seek(position) + self._file_truncate() + # Roll back memo. + for _ in range(len(self.memo) - memo_size): + self.memo.popitem() # LIFO order is guaranteed since 3.7 + # Handle session main. + if self._session and obj is self._main: + if self._main is _main_module or not _is_imported_module(self._main): + raise + # Save an empty dict as state to distinguish from modules saved with dump(). + self.save_reduce(_import_module, (obj.__name__,), obj=obj, state={}) + logger.trace(self, trace_message, obj=obj) + warnings.warn( + "module %r saved by reference due to the unpickleable " + "variable %r. No changes to the module were saved. " + "Unpickleable variables can be ignored with filters." + % (self._main.__name__, error_stack.name), + PicklingWarning, + stacklevel=5, + ) + # Try to save object by reference. + elif hasattr(obj, '__name__') or hasattr(obj, '__qualname__'): + try: + self.save_global(obj) + logger.trace(self, trace_message, obj=obj) + return True # for _saved_byref, ignored otherwise + except PicklingError as error: + # Roll back trace state. + logger.roll_back(self, obj) + raise error from error_stack + else: + # Roll back trace state. + logger.roll_back(self, obj) + raise + return save.__doc__ = StockPickler.save.__doc__ def dump(self, obj): #NOTE: if settings change, need to update attributes logger.trace_setup(self) StockPickler.dump(self, obj) - dump.__doc__ = StockPickler.dump.__doc__ class Unpickler(StockUnpickler): @@ -410,10 +514,9 @@ def find_class(self, module, name): def __init__(self, *args, **kwds): settings = Pickler.settings - _ignore = kwds.pop('ignore', None) - StockUnpickler.__init__(self, *args, **kwds) self._main = _main_module - self._ignore = settings['ignore'] if _ignore is None else _ignore + self._ignore = _getopt(settings, 'ignore', kwds=kwds) + StockUnpickler.__init__(self, *args, **kwds) def load(self): #NOTE: if settings change, need to update attributes obj = StockUnpickler.load(self) @@ -436,12 +539,12 @@ def dispatch_table(): pickle_dispatch_copy = StockPickler.dispatch.copy() def pickle(t, func): - """expose dispatch table for user-created extensions""" + """expose :attr:`~Pickler.dispatch` table for user-created extensions""" Pickler.dispatch[t] = func return def register(t): - """register type to Pickler's dispatch table """ + """decorator to register types to Pickler's :attr:`~Pickler.dispatch` table""" def proxy(func): Pickler.dispatch[t] = func return func @@ -460,7 +563,7 @@ def use_diff(on=True): Reduces size of pickles by only including object which have changed. Decreases pickle size but increases CPU time needed. - Also helps avoid some unpicklable objects. + Also helps avoid some unpickleable objects. MUST be called at start of script, otherwise changes will not be recorded. """ global _use_diff, diff @@ -1088,7 +1191,7 @@ def _save_with_postproc(pickler, reduction, is_pickler_dill=None, obj=Getattr.NO else: pickler.save_reduce(*reduction) # pop None created by calling preprocessing step off stack - pickler.write(bytes('0', 'UTF-8')) + pickler.write(POP) #@register(CodeType) #def save_code(pickler, obj): @@ -1157,30 +1260,156 @@ def save_code(pickler, obj): logger.trace(pickler, "# Co") return +def _module_map(main_module): + """get map of imported modules""" + from collections import defaultdict + from types import SimpleNamespace + modmap = SimpleNamespace( + by_name = defaultdict(list), + by_id = defaultdict(list), + top_level = {}, # top-level modules + module = main_module.__name__, + package = _module_package(main_module), + ) + for modname, module in sys.modules.items(): + if (modname in ('__main__', '__mp_main__') or module is main_module + or not isinstance(module, ModuleType)): + continue + if '.' not in modname: + modmap.top_level[id(module)] = modname + for objname, modobj in module.__dict__.items(): + modmap.by_name[objname].append((modobj, modname)) + modmap.by_id[id(modobj)].append((objname, modname)) + return modmap + +def _lookup_module(modmap, name, obj, lookup_by_id=True) -> typing.Tuple[str, str, bool]: + """Lookup name or id of obj if module is imported. + + Lookup for objects identical to 'obj' at modules in 'modmpap'. If multiple + copies are found in different modules, return the one from the module with + higher probability of being available at unpickling time, according to the + hierarchy: + + 1. Standard Library modules + 2. modules of the same top-level package as the module being saved (if it's part of a package) + 3. installed modules in general + 4. non-installed modules + + Returns: + A 3-tuple containing the module's name, the object's name in the module, + and a boolean flag, which is `True` if the module falls under categories + (1) to (3) from the hierarchy, or `False` if it's in category (4). + """ + not_found = None, None, None + # Don't look for objects likely related to the module itself. + obj_module = getattr(obj, '__module__', type(obj).__module__) + if obj_module == modmap.module: + return not_found + obj_package = _module_package(_import_module(obj_module, safe=True)) + + for map, by_id in [(modmap.by_name, False), (modmap.by_id, True)]: + if by_id and not lookup_by_id: + break + _2nd_choice = _3rd_choice = _4th_choice = None + key = id(obj) if by_id else name + for other, modname in map[key]: + if by_id or other is obj: + other_name = other if by_id else name + other_module = sys.modules[modname] + other_package = _module_package(other_module) + # Don't return a reference to a module of another package + # if the object is likely from the same top-level package. + if (modmap.package and obj_package == modmap.package + and other_package != modmap.package): + continue + # Prefer modules imported earlier (the first found). + if _is_stdlib_module(other_module): + return modname, other_name, True + elif modmap.package and modmap.package == other_package: + if _2nd_choice: continue + _2nd_choice = modname, other_name, True + elif not _2nd_choice: + # Don't call _is_builtin_module() unnecessarily. + if _is_builtin_module(other_module): + if _3rd_choice: continue + _3rd_choice = modname, other_name, True + else: + if _4th_choice: continue + _4th_choice = modname, other_name, False # unsafe + found = _2nd_choice or _3rd_choice or _4th_choice + if found: + return found + return not_found + +def _global_string(modname, name): + return GLOBAL + bytes('%s\n%s\n' % (modname, name), 'UTF-8') + +def _save_module_dict(pickler, main_dict): + """Save a module's dictionary, saving unpickleable variables by referece.""" + main = getattr(pickler, '_original_main', pickler._main) + modmap = getattr(pickler, '_modmap', None) # cached from _stash_modules() + is_builtin = _is_builtin_module(main) + pickler.write(MARK + DICT) # don't need to memoize + for name, value in main_dict.items(): + _logger.debug("Pickling %r (%s)", name, type(value).__name__) + pickler.save(name) + try: + if pickler.save(value): + global_name = getattr(value, '__qualname__', value.__name__) + pickler._saved_byref.append((name, value.__module__, global_name)) + except UNPICKLEABLE_ERRORS as error_stack: + if modmap is None: + modmap = _module_map(main) + modname, objname, installed = _lookup_module(modmap, name, value) + if modname and (installed or not is_builtin): + pickler.write(_global_string(modname, objname)) + pickler._saved_byref.append((name, modname, objname)) + elif is_builtin: + pickler.write(_global_string(main.__name__, name)) + pickler._saved_byref.append((name, main.__name__, name)) + else: + error = PicklingError("can't save variable %r as global" % name) + error.name = name + raise error from error_stack + pickler.memoize(value) + pickler.write(SETITEM) + def _repr_dict(obj): - """make a short string representation of a dictionary""" + """Make a short string representation of a dictionary.""" return "<%s object at %#012x>" % (type(obj).__name__, id(obj)) @register(dict) def save_module_dict(pickler, obj): - if is_dill(pickler, child=False) and obj == pickler._main.__dict__ and \ - not (pickler._session and pickler._first_pass): - logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj - pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8')) + is_pickler_dill = is_dill(pickler, child=False) + if (is_pickler_dill + and obj is pickler._main.__dict__ + and not (pickler._session and pickler._first_pass)): + logger.trace(pickler, "D1: %s", _repr_dict(obj), obj=obj) + pickler.write(GLOBAL + b'__builtin__\n__main__\n') logger.trace(pickler, "# D1") - elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__): - logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj - pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general? + elif not is_pickler_dill and obj is _main_module.__dict__: #prama: no cover + logger.trace(pickler, "D3: %s", _repr_dict(obj), obj=obj) + pickler.write(GLOBAL + b'__main__\n__dict__\n') #XXX: works in general? logger.trace(pickler, "# D3") - elif '__name__' in obj and obj != _main_module.__dict__ \ - and type(obj['__name__']) is str \ - and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None): - logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj - pickler.write(bytes('c%s\n__dict__\n' % obj['__name__'], 'UTF-8')) + elif (is_pickler_dill + and pickler._session + and pickler._refonfail + and obj is pickler._main_dict_copy): + logger.trace(pickler, "D5: %s", _repr_dict(obj), obj=obj) + # we only care about session the first pass thru + pickler.first_pass = False + _save_module_dict(pickler, obj) + logger.trace(pickler, "# D5") + elif ('__name__' in obj + and obj is not _main_module.__dict__ + and type(obj['__name__']) is str + and obj is getattr(_import_module(obj['__name__'], safe=True), '__dict__', None)): + logger.trace(pickler, "D4: %s", _repr_dict(obj), obj=obj) + pickler.write(_global_string(obj['__name__'], '__dict__')) logger.trace(pickler, "# D4") else: - logger.trace(pickler, "D2: %s", _repr_dict(obj)) # obj - if is_dill(pickler, child=False) and pickler._session: + logger.trace(pickler, "D2: %s", _repr_dict(obj), obj=obj) + if is_pickler_dill: # we only care about session the first pass thru pickler._first_pass = False StockPickler.save_dict(pickler, obj) @@ -1470,7 +1699,7 @@ def save_cell(pickler, obj): # The result of this function call will be None pickler.save_reduce(_shims._delattr, (obj, 'cell_contents')) # pop None created by calling _delattr off stack - pickler.write(bytes('0', 'UTF-8')) + pickler.write(POP) logger.trace(pickler, "# Ce3") return if is_dill(pickler, child=True): @@ -1498,7 +1727,7 @@ def save_cell(pickler, obj): if MAPPING_PROXY_TRICK: @register(DictProxyType) def save_dictproxy(pickler, obj): - logger.trace(pickler, "Mp: %s", _repr_dict(obj)) # obj + logger.trace(pickler, "Mp: %s", _repr_dict(obj), obj=obj) mapping = obj | _dictproxy_helper_instance pickler.save_reduce(DictProxyType, (mapping,), obj=obj) logger.trace(pickler, "# Mp") @@ -1506,7 +1735,7 @@ def save_dictproxy(pickler, obj): else: @register(DictProxyType) def save_dictproxy(pickler, obj): - logger.trace(pickler, "Mp: %s", _repr_dict(obj)) # obj + logger.trace(pickler, "Mp: %s", _repr_dict(obj), obj=obj) pickler.save_reduce(DictProxyType, (obj.copy(),), obj=obj) logger.trace(pickler, "# Mp") return @@ -1577,24 +1806,78 @@ def save_weakref(pickler, obj): @register(CallableProxyType) def save_weakproxy(pickler, obj): # Must do string substitution here and use %r to avoid ReferenceError. - logger.trace(pickler, "R2: %r" % obj) + logger.trace(pickler, "R2: %r" % obj, obj=obj) refobj = _locate_object(_proxy_helper(obj)) pickler.save_reduce(_create_weakproxy, (refobj, callable(obj)), obj=obj) logger.trace(pickler, "# R2") return +def _weak_cache(func=None, *, defaults=None): + if defaults is None: + defaults = {} + if func is None: + return partial(_weak_cache, defaults=defaults) + cache = weakref.WeakKeyDictionary() + @wraps(func) + def wrapper(referent): + try: + return defaults[referent] + except KeyError: + try: + return cache[referent] + except KeyError: + value = func(referent) + cache[referent] = value + return value + return wrapper + +@_weak_cache(defaults={None: False}) +def _is_imported_module(module): + return getattr(module, '__loader__', None) is not None or module in sys.modules.values() + +PYTHONPATH_PREFIXES = {getattr(sys, attr) for attr in ( + 'base_prefix', 'prefix', 'base_exec_prefix', 'exec_prefix', + 'real_prefix', # for old virtualenv versions + ) if hasattr(sys, attr)} +PYTHONPATH_PREFIXES = tuple(os.path.realpath(path) for path in PYTHONPATH_PREFIXES) +EXTENSION_SUFFIXES = tuple(importlib.machinery.EXTENSION_SUFFIXES) +if OLD310: + STDLIB_PREFIX = os.path.dirname(os.path.realpath(os.__file__)) + +@_weak_cache(defaults={None: True}) #XXX: shouldn't return False for None? def _is_builtin_module(module): - if not hasattr(module, "__file__"): return True + if module.__name__ in ('__main__', '__mp_main__'): + return False + mod_path = getattr(module, '__file__', None) + if not mod_path: + return _is_imported_module(module) # If a module file name starts with prefix, it should be a builtin # module, so should always be pickled as a reference. - names = ["base_prefix", "base_exec_prefix", "exec_prefix", "prefix", "real_prefix"] - return any(os.path.realpath(module.__file__).startswith(os.path.realpath(getattr(sys, name))) - for name in names if hasattr(sys, name)) or \ - module.__file__.endswith(EXTENSION_SUFFIXES) or \ - 'site-packages' in module.__file__ + mod_path = os.path.realpath(mod_path) + return ( + any(mod_path.startswith(prefix) for prefix in PYTHONPATH_PREFIXES) + or mod_path.endswith(EXTENSION_SUFFIXES) + or 'site-packages' in mod_path + ) -def _is_imported_module(module): - return getattr(module, '__loader__', None) is not None or module in sys.modules.values() +@_weak_cache(defaults={None: False}) +def _is_stdlib_module(module): + first_level = module.__name__.partition('.')[0] + if OLD310: + if first_level in sys.builtin_module_names: + return True + mod_path = getattr(module, '__file__', '') + if mod_path: + mod_path = os.path.realpath(mod_path) + return mod_path.startswith(STDLIB_PREFIX) + else: + return first_level in sys.stdlib_module_names + +@_weak_cache(defaults={None: None}) +def _module_package(module): + """get the top-level package of a module, if any""" + package = getattr(module, '__package__', None) + return package.partition('.')[0] if package else None @register(ModuleType) def save_module(pickler, obj): @@ -1606,7 +1889,7 @@ def save_module(pickler, obj): pass else: logger.trace(pickler, "M2: %s with diff", obj) - logger.trace(pickler, "Diff: %s", changed.keys()) + logger.info("Diff: %s", changed.keys()) pickler.save_reduce(_import_module, (obj.__name__,), obj=obj, state=changed) logger.trace(pickler, "# M2") @@ -1617,15 +1900,22 @@ def save_module(pickler, obj): logger.trace(pickler, "# M1") else: builtin_mod = _is_builtin_module(obj) - if obj.__name__ not in ("builtins", "dill", "dill._dill") and not builtin_mod or \ - is_dill(pickler, child=True) and obj is pickler._main: + is_session_main = is_dill(pickler, child=True) and obj is pickler._main + if (obj.__name__ not in ("builtins", "dill", "dill._dill") and not builtin_mod + or is_session_main): logger.trace(pickler, "M1: %s", obj) - _main_dict = obj.__dict__.copy() #XXX: better no copy? option to copy? - [_main_dict.pop(item, None) for item in singletontypes - + ["__builtins__", "__loader__"]] + # Hack for handling module-type objects in load_module(). mod_name = obj.__name__ if _is_imported_module(obj) else '__runtime__.%s' % obj.__name__ - pickler.save_reduce(_import_module, (mod_name,), obj=obj, - state=_main_dict) + # Second references are saved as __builtin__.__main__ in save_module_dict(). + main_dict = obj.__dict__.copy() + for item in ('__builtins__', '__loader__'): + main_dict.pop(item, None) + for item in IPYTHON_SINGLETONS: #pragma: no cover + if getattr(main_dict.get(item), '__module__', '').startswith('IPython'): + del main_dict[item] + if is_session_main: + pickler._main_dict_copy = main_dict + pickler.save_reduce(_import_module, (mod_name,), obj=obj, state=main_dict) logger.trace(pickler, "# M1") elif obj.__name__ == "dill._dill": logger.trace(pickler, "M2: %s", obj) @@ -1635,7 +1925,6 @@ def save_module(pickler, obj): logger.trace(pickler, "M2: %s", obj) pickler.save_reduce(_import_module, (obj.__name__,), obj=obj) logger.trace(pickler, "# M2") - return return @register(TypeType) @@ -1661,7 +1950,7 @@ def save_type(pickler, obj, postproc_list=None): elif obj is type(None): logger.trace(pickler, "T7: %s", obj) #XXX: pickler.save_reduce(type, (None,), obj=obj) - pickler.write(bytes('c__builtin__\nNoneType\n', 'UTF-8')) + pickler.write(GLOBAL + b'__builtin__\nNoneType\n') logger.trace(pickler, "# T7") elif obj is NotImplementedType: logger.trace(pickler, "T7: %s", obj) @@ -1702,9 +1991,18 @@ def save_type(pickler, obj, postproc_list=None): else: logger.trace(pickler, "T4: %s", obj) if incorrectly_named: - warnings.warn('Cannot locate reference to %r.' % (obj,), PicklingWarning) + warnings.warn( + "Cannot locate reference to %r." % (obj,), + PicklingWarning, + stacklevel=3, + ) if obj_recursive: - warnings.warn('Cannot pickle %r: %s.%s has recursive self-references that trigger a RecursionError.' % (obj, obj.__module__, obj_name), PicklingWarning) + warnings.warn( + "Cannot pickle %r: %s.%s has recursive self-references that " + "trigger a RecursionError." % (obj, obj.__module__, obj_name), + PicklingWarning, + stacklevel=3, + ) #print (obj.__dict__) #print ("%s\n%s" % (type(obj), obj.__name__)) #print ("%s\n%s" % (obj.__bases__, obj.__dict__)) @@ -1763,8 +2061,7 @@ def save_function(pickler, obj): logger.trace(pickler, "F1: %s", obj) _recurse = getattr(pickler, '_recurse', None) _postproc = getattr(pickler, '_postproc', None) - _main_modified = getattr(pickler, '_main_modified', None) - _original_main = getattr(pickler, '_original_main', __builtin__)#'None' + _original_main = getattr(pickler, '_original_main', None) postproc_list = [] if _recurse: # recurse to get all globals referred to by obj @@ -1781,8 +2078,8 @@ def save_function(pickler, obj): # If the globals is the __dict__ from the module being saved as a # session, substitute it by the dictionary being actually saved. - if _main_modified and globs_copy is _original_main.__dict__: - globs_copy = getattr(pickler, '_main', _original_main).__dict__ + if _original_main is not None and globs_copy is _original_main.__dict__: + globs_copy = pickler._main.__dict__ globs = globs_copy # If the globals is a module __dict__, do not save it in the pickle. elif globs_copy is not None and obj.__module__ is not None and \ @@ -1840,7 +2137,7 @@ def save_function(pickler, obj): # Change the value of the cell pickler.save_reduce(*possible_postproc) # pop None created by calling preprocessing step off stack - pickler.write(bytes('0', 'UTF-8')) + pickler.write(POP) logger.trace(pickler, "# F1") else: @@ -1949,7 +2246,7 @@ def pickles(obj,exact=False,safe=False,**kwds): """ if safe: exceptions = (Exception,) # RuntimeError, ValueError else: - exceptions = (TypeError, AssertionError, NotImplementedError, PicklingError, UnpicklingError) + exceptions = UNPICKLEABLE_ERRORS + (AssertionError, UnpicklingError) try: pik = copy(obj, **kwds) #FIXME: should check types match first, then check content if "exact" diff --git a/dill/_utils.py b/dill/_utils.py new file mode 100644 index 00000000..0aaf65a5 --- /dev/null +++ b/dill/_utils.py @@ -0,0 +1,636 @@ +#!/usr/bin/env python +# +# Author: Leonardo Gama (@leogama) +# Copyright (c) 2022 The Uncertainty Quantification Foundation. +# License: 3-clause BSD. The full license text is available at: +# - https://github.com/uqfoundation/dill/blob/master/LICENSE +""" +Auxiliary classes and functions used in more than one module, defined here to +avoid circular import problems. +""" + +from __future__ import annotations + +__all__ = [ + 'Filter', 'FilterFunction', 'FilterRules', 'FilterSet', 'NamedObject', + 'Rule', 'RuleType', 'size_filter', 'EXCLUDE', 'INCLUDE', +] + +import contextlib +import io +import math +import re +import warnings +from dataclasses import dataclass, field, fields +from collections import abc +from contextlib import suppress +from enum import Enum +from functools import partialmethod +from itertools import chain, filterfalse + +from dill import _dill # _dill is not completely loaded + +# Type hints. +from typing import ( + Any, Callable, Dict, Iterable, Iterator, + Optional, Pattern, Set, Tuple, Union, +) + +def _format_bytes_size(size: Union[int, float]) -> Tuple[int, str]: + """Return bytes size text representation in human-redable form.""" + unit = "B" + power_of_2 = math.trunc(size).bit_length() - 1 + magnitude = min(power_of_2 - power_of_2 % 10, 80) # 2**80 == 1 YiB + if magnitude: + # Rounding trick: 1535 (1024 + 511) -> 1K; 1536 -> 2K + size = ((size >> magnitude-1) + 1) >> 1 + unit = "%siB" % "KMGTPEZY"[(magnitude // 10) - 1] + return size, unit + + +## File-related utilities ## + +class _PeekableReader(contextlib.AbstractContextManager): + """lightweight readable stream wrapper that implements peek()""" + def __init__(self, stream, closing=True): + self.stream = stream + self.closing = closing + def __exit__(self, *exc_info): + if self.closing: + self.stream.close() + def read(self, n): + return self.stream.read(n) + def readline(self): + return self.stream.readline() + def tell(self): + return self.stream.tell() + def close(self): + return self.stream.close() + def peek(self, n): + stream = self.stream + try: + if hasattr(stream, 'flush'): + stream.flush() + position = stream.tell() + stream.seek(position) # assert seek() works before reading + chunk = stream.read(n) + stream.seek(position) + return chunk + except (AttributeError, OSError): + raise NotImplementedError("stream is not peekable: %r", stream) from None + +class _SeekableWriter(io.BytesIO, contextlib.AbstractContextManager): + """works as an unlimited buffer, writes to file on close""" + def __init__(self, stream, closing=True, *args, **kwds): + super().__init__(*args, **kwds) + self.stream = stream + self.closing = closing + def __exit__(self, *exc_info): + self.close() + def close(self): + self.stream.write(self.getvalue()) + with suppress(AttributeError): + self.stream.flush() + super().close() + if self.closing: + self.stream.close() + +def _open(file, mode, *, peekable=False, seekable=False): + """return a context manager with an opened file-like object""" + readonly = ('r' in mode and '+' not in mode) + if not readonly and peekable: + raise ValueError("the 'peekable' option is invalid for writable files") + if readonly and seekable: + raise ValueError("the 'seekable' option is invalid for read-only files") + should_close = not hasattr(file, 'read' if readonly else 'write') + if should_close: + file = open(file, mode) + # Wrap stream in a helper class if necessary. + if peekable and not hasattr(file, 'peek'): + # Try our best to return it as an object with a peek() method. + if hasattr(file, 'seekable'): + file_seekable = file.seekable() + elif hasattr(file, 'seek') and hasattr(file, 'tell'): + try: + file.seek(file.tell()) + file_seekable = True + except Exception: + file_seekable = False + else: + file_seekable = False + if file_seekable: + file = _PeekableReader(file, closing=should_close) + else: + try: + file = io.BufferedReader(file) + except Exception: + # It won't be peekable, but will fail gracefully in _identify_module(). + file = _PeekableReader(file, closing=should_close) + elif seekable and ( + not hasattr(file, 'seek') + or not hasattr(file, 'truncate') + or (hasattr(file, 'seekable') and not file.seekable()) + ): + file = _SeekableWriter(file, closing=should_close) + if should_close or isinstance(file, (_PeekableReader, _SeekableWriter)): + return file + else: + return contextlib.nullcontext(file) + + +## Namespace filtering ## + +RuleType = Enum('RuleType', 'EXCLUDE INCLUDE', module=__name__) +EXCLUDE, INCLUDE = RuleType.EXCLUDE, RuleType.INCLUDE + +class NamedObject: + """Simple container for a variable's name and value used by filter functions.""" + __slots__ = 'name', 'value' + name: str + value: Any + def __init__(self, name_value: Tuple[str, Any]): + self.name, self.value = name_value + def __eq__(self, other: Any) -> bool: + """ + Prevent simple bugs from writing `lambda obj: obj == 'literal'` instead + of `lambda obj: obj.value == 'literal' in a filter definition.` + """ + if type(other) is not NamedObject: + raise TypeError("'==' not supported between instances of 'NamedObject' and %r" % + type(other).__name__) + return self.value is other.value and self.name == other.name + def __repr__(self): + return "NamedObject(%r, %r)" % (self.name, self.value) + +FilterFunction = Callable[[NamedObject], bool] +Filter = Union[str, Pattern[str], int, type, FilterFunction] +Rule = Tuple[RuleType, Union[Filter, Iterable[Filter]]] + +def _iter(obj): + """return iterator of object if it's not a string""" + if isinstance(obj, (str, bytes)): + return None + try: + return iter(obj) + except TypeError: + return None + +@dataclass +class FilterSet(abc.MutableSet): + """A superset of exclusion/inclusion filter sets.""" + names: Set[str] = field(default_factory=set) + regexes: Set[Pattern[str]] = field(default_factory=set) + ids: Set[int] = field(default_factory=set) + types: Set[type] = field(default_factory=set) + funcs: Set[FilterFunction] = field(default_factory=set) + + # Initialized later. + _fields = None + _rtypemap = None + + def _match_type(self, filter: Filter) -> Tuple[filter, str]: + """identify the filter's type and convert it to standard internal format""" + filter_type = type(filter) + if filter_type is str: + if filter.isidentifier(): + field = 'names' + elif filter.startswith('type:'): + filter = self.get_type(filter.partition(':')[-1].strip()) + field = 'types' + else: + filter = re.compile(filter) + field = 'regexes' + elif filter_type is re.Pattern and type(filter.pattern) is str: + field = 'regexes' + elif filter_type is int: + field = 'ids' + elif isinstance(filter, type): + field = 'types' + elif callable(filter): + field = 'funcs' + else: + raise ValueError("invalid filter: %r" % filter) + return filter, getattr(self, field) + + # Mandatory MutableSet methods. + @classmethod + def _from_iterable(cls, it: Iterable[Filter]) -> FilterSet: + obj = cls() + obj |= it + return obj + def __bool__(self) -> bool: + return any(getattr(self, field) for field in self._fields) + def __len__(self) -> int: + return sum(len(getattr(self, field)) for field in self._fields) + def __contains__(self, filter: Filter) -> bool: + filter, filter_set = self._match_type(filter) + return filter in filter_set + def __iter__(self) -> Iterator[Filter]: + return chain.from_iterable(getattr(self, field) for field in self._fields) + def add(self, filter: Filter) -> None: + filter, filter_set = self._match_type(filter) + filter_set.add(filter) + def discard(self, filter: Filter) -> None: + filter, filter_set = self._match_type(filter) + filter_set.discard(filter) + + # Overwrite generic methods (optimization). + def remove(self, filter: Filter) -> None: + filter, filter_set = self._match_type(filter) + filter_set.remove(filter) + def clear(self) -> None: + for field in self._fields: + getattr(self, field).clear() + def __or__(self, other: Iterable[Filter]) -> FilterSet: + obj = self.copy() + obj |= other + return obj + __ror__ = __or__ + def __ior__(self, other: Iterable[Filter]) -> FilterSet: + if not isinstance(other, Iterable): + return NotImplemented + if isinstance(other, FilterSet): + for field in self._fields: + getattr(self, field).update(getattr(other, field)) + else: + for filter in other: + self.add(filter) + return self + + # Extra methods. + def update(self, filters: Iterable[Filters]) -> None: + self |= filters + def copy(self) -> FilterSet: + return FilterSet(*(getattr(self, field).copy() for field in self._fields)) + + # Convert type name to type. + TYPENAME_REGEX = re.compile(r'\w+(?=Type$)|\w+$', re.IGNORECASE) + @classmethod + def _get_typekey(cls, typename: str) -> str: + return cls.TYPENAME_REGEX.match(typename).group().lower() + @classmethod + def get_type(cls, typename: str) -> type: + """retrieve a type registered in ``dill``'s "reverse typemap"'""" + if cls._rtypemap is None: + cls._rtypemap = {cls._get_typekey(k): v for k, v in _dill._reverse_typemap.items()} + return cls._rtypemap[cls._get_typekey(typename)] + +FilterSet._fields = tuple(field.name for field in fields(FilterSet)) + +class _FilterSetDescriptor: + """descriptor for FilterSet members of FilterRules""" + def __set_name__(self, owner, name): + self.name = name + self._name = '_' + name + def __set__(self, obj, value): + # This is the important method. + if isinstance(value, FilterSet): + setattr(obj, self._name, value) + else: + setattr(obj, self._name, FilterSet._from_iterable(value)) + def __get__(self, obj, objtype=None): + try: + return getattr(obj, self._name) + except AttributeError: + raise AttributeError(self.name) from None + def __delete__(self, obj): + try: + delattr(obj, self._name) + except AttributeError: + raise AttributeError(self.name) from None + +class FilterRules: + """Exclusion and inclusion rules for filtering a namespace. + + Namespace filtering rules can be of two types, :const:`EXCLUDE` and + :const:`INCLUDE` rules, and of five "flavors": + + - `name`: match a variable name exactly; + - `regex`: match a variable name by regular expression; + - `id`: match a variable value by id; + - `type`: match a variable value by type (using ``isinstance``); + - `func`: callable filter, match a variable name and/or value by an + arbitrary logic. + + A `name` filter is specified by a simple string, e.g. ``'some_var'``. If its + value is not a valid Python identifier, except for the special `type` case + below, it is treated as a regular expression instead. + + A `regex` filter is specified either by a string containing a regular + expression, e.g. ``r'\w+_\d+'``, or by a :py:class:`re.Pattern` object. + + An `id` filter is specified by an ``int`` that corresponds to the id of an + object. For example, to exclude a specific object ``obj`` that may be + assigned to multiple variables, just use ``id(obj)`` as an `id` filter. + + A `type` filter is specified by a type-object, e.g. ``list`` or + ``type(some_var)``, or by a string with the format ``"type:"``, + where ```` is a type name (case insensitive) known by ``dill`` , + e.g. ``"type:function"`` or ``"type: FunctionType"``. These include all + the types defined in the module :py:mod:`types` and many more. + + Finally, a `func` filter can be any callable that accepts a single argument and + returns a boolean value, being it `True` if the object should be excluded + (or included, depending on how the filter is used) or `False` if it should + *not* be excluded (or included). + + The single argument passed to `func` filters is an instance of + :py:class:`NamedObject`, an object with two attributes: ``name`` is the + variable's name in the namespace and ``value`` is the object that it refers + to. Below are some examples of `func` filters. + + A strict type filter, exclude ``int`` but not ``bool`` (an ``int`` subclass): + + >>> int_filter = lambda obj: type(obj) == int + + Exclude objects that were renamed after definition: + + >>> renamed_filter = lambda obj: obj.name != getattr(obj.value, '__name__', obj.name) + + Filters may be added interactively after creating an empty ``FilterRules`` + object: + + >>> from dill.session import FilterRules + >>> filters = FilterRules() + >>> filters.exclude.add('some_var') + >>> filters.exclude.add(r'__\w+') + >>> filters.include.add(r'__\w+__') # keep __dunder__ variables + + Or may be created all at once at initialization with "filter rule literals": + + >>> from dill.session import FilterRules, EXCLUDE, INCLUDE + >>> filters = FilterRules([ + ... (EXCLUDE, ['some_var', r'__\+']), + ... (INCLUDE, r'__\w+__'), + ... ]) + + The order that the exclusion and inclusion filters are added is irrelevant + because **exclusion filters are always applied first**. Therefore, + generally the rules work as a blocklist, with inclusion rules acting as + exceptions to the exclusion rules. However, **if there are only inclusion + filters, the rules work as an allowlist** instead, and only the variables + matched by the inclusion filters are kept. + """ + __slots__ = '_exclude', '_include', '__weakref__' + exclude = _FilterSetDescriptor() + include = _FilterSetDescriptor() + + def __init__(self, rules: Union[Iterable[Rule], FilterRules] = None): + self._exclude = FilterSet() + self._include = FilterSet() + if rules is not None: + self.update(rules) + + def __repr__(self) -> str: + """Compact representation of FilterSet.""" + COLUMNS = 78 + desc = [" COLUMNS: + set_desc = ["FilterSet("] + re.findall(r'\w+={.+?}', set_desc) + set_desc = ",\n ".join(set_desc) + "\n )" + set_desc = "%s=%s" % (attr, set_desc) + if attr == 'exclude' and hasattr(self, 'include'): + set_desc += ',' + desc.append(set_desc) + if len(desc) == 1: + desc += ["NOT SET"] + sep = "\n " if sum(len(x) for x in desc) > COLUMNS else " " + return sep.join(desc) + ">" + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, FilterRules): + return NotImplemented + MISSING = object() + self_exclude = getattr(self, 'exclude', MISSING) + self_include = getattr(self, 'include', MISSING) + other_exclude = getattr(other, 'exclude', MISSING) + other_include = getattr(other, 'include', MISSING) + return self_exclude == other_exclude and self_include == other_include + + # Proxy add(), discard(), remove() and clear() to FilterSets. + def __proxy__(self, + method: str, + filter: Filter, + *, + rule_type: RuleType = RuleType.EXCLUDE, + ) -> None: + """Call 'method' over FilterSet specified by 'rule_type'.""" + if not isinstance(rule_type, RuleType): + raise ValueError("invalid rule type: %r (must be one of %r)" % (rule_type, list(RuleType))) + filter_set = getattr(self, rule_type.name.lower()) + getattr(filter_set, method)(filter) + add = partialmethod(__proxy__, 'add') + discard = partialmethod(__proxy__, 'discard') + remove = partialmethod(__proxy__, 'remove') + def clear(self) -> None: + self.exclude.clear() + self.include.clear() + + def update(self, rules: Union[Iterable[Rule], FilterRules]) -> None: + """Update both FilterSets from a list of (RuleType, Filter) rules.""" + if isinstance(rules, FilterRules): + for field in FilterSet._fields: + getattr(self.exclude, field).update(getattr(rules.exclude, field)) + getattr(self.include, field).update(getattr(rules.include, field)) + else: + for rule in rules: + # Validate rules. + if not isinstance(rule, tuple) or len(rule) != 2: + raise ValueError("invalid rule format: %r" % rule) + for rule_type, filter in rules: + filters = _iter(filter) + if filters is not None: + for f in filters: + self.add(f, rule_type=rule_type) + else: + self.add(filter, rule_type=rule_type) + + def _apply_filters(self, + filter_set: FilterSet, + objects: Iterable[NamedObject] + ) -> Iterator[NamedObject]: + filters = [] + types_list = tuple(filter_set.types) + # Apply broader/cheaper filters first. + if types_list: + filters.append(lambda obj: isinstance(obj.value, types_list)) + if filter_set.ids: + filters.append(lambda obj: id(obj.value) in filter_set.ids) + if filter_set.names: + filters.append(lambda obj: obj.name in filter_set.names) + if filter_set.regexes: + filters.append(lambda obj: any(regex.fullmatch(obj.name) for regex in filter_set.regexes)) + filters.extend(filter_set.funcs) + for filter in filters: + objects = filterfalse(filter, objects) + return objects + + def apply_filters(self, namespace: Dict[str, Any]) -> Dict[str, Any]: + """Apply filters to dictionary with names as keys.""" + if not namespace or not (self.exclude or self.include): + return namespace + # Protect agains dict changes during the call. + namespace_copy = namespace.copy() + all_objs = [NamedObject(item) for item in namespace_copy.items()] + + if self.exclude: + include_names = {obj.name for obj in self._apply_filters(self.exclude, all_objs)} + exclude_objs = [obj for obj in all_objs if obj.name not in include_names] + else: + # Treat this rule set as an allowlist. + exclude_objs = all_objs + if self.include and exclude_objs: + exclude_objs = list(self._apply_filters(self.include, exclude_objs)) + + if not exclude_objs: + return namespace + if len(exclude_objs) == len(namespace): + warnings.warn( + "the exclusion/inclusion rules applied have excluded all %d items" % len(all_objs), + _dill.PicklingWarning, + stacklevel=2 + ) + return {} + for obj in exclude_objs: + del namespace_copy[obj.name] + return namespace_copy + + +## Filter factories ## + +import collections +import collections.abc +import random +from statistics import mean +from sys import getsizeof +from types import ModuleType + +class size_filter: + """Create a filter function with a limit for estimated object size. + + Parameters: + limit: maximum size allowed in bytes. May be an absolute number of bytes + as an ``int`` or ``float``, or a string representing a size in bytes, + e.g. ``1000``, ``10e3``, ``"1000"``, ``"1k"`` and ``"1 KiB"`` are all + valid and roughly equivalent (the last one represents 1024 bytes). + recursive: if `False`, the function won't recurse into the object's + attributes and items to estimate its size. + + Returns: + A callable filter to be used with :py:func:`dump_module`. + + Note: + Doesn't work on PyPy. See ``help(sys.getsizeof)``. + """ + # Cover "true" collections from 'builtins', 'collections' and 'collections.abc'. + COLLECTION_TYPES = ( + list, + tuple, + collections.deque, + collections.UserList, + collections.abc.Mapping, # dict, OrderedDict, UserDict, etc. + collections.abc.Set, # set, frozenset + ) + MINIMUM_SIZE = getsizeof(None, 16) + MISSING_SLOT = object() + + def __init__(self, + limit: Union[int, float, str], + recursive: bool = True, + ) -> FilterFunction: + if _dill.IS_PYPY: + raise NotImplementedError("size_filter() is not implemented for PyPy") + self.limit = limit + if type(limit) != int: + try: + self.limit = float(limit) + except (TypeError, ValueError): + limit_match = re.fullmatch(r'(\d+)\s*(B|[KMGT]i?B?)', limit, re.IGNORECASE) + if limit_match: + coeff, unit = limit_match.groups() + coeff, unit = int(coeff), unit.lower() + if unit == 'b': + self.limit = coeff + else: + base = 1024 if unit[1:2] == 'i' else 1000 + exponent = 'kmgt'.index(unit[0]) + 1 + self.limit = coeff * base**exponent + else: + # Will raise error for Inf and NaN. + self.limit = math.truc(self.limit) + if type(self.limit) != int: + # Everything failed. + raise ValueError("invalid 'limit' value: %r" % limit) + elif self.limit < 0: + raise ValueError("'limit' can't be negative %r" % limit) + self.recursive = recursive + + def __call__(self, obj: NamedObject) -> bool: + if self.recursive: + size = self.estimate_size(obj.value) + else: + try: + size = getsizeof(obj.value) + except ReferenceError: + size = self.MINIMUM_SIZE + return size > self.limit + + def __repr__(self): + return "size_filter(limit=%r, recursive=%r)" % ( + "%d %s" % _format_bytes_size(self.limit), + self.recursive, + ) + + @classmethod + def estimate_size(cls, obj: Any, memo: Optional[set] = None) -> int: + if memo is None: + memo = set() + obj_id = id(obj) + if obj_id in memo: + # Object size already counted. + return 0 + memo.add(obj_id) + size = cls.MINIMUM_SIZE + try: + if isinstance(obj, ModuleType) and _dill._is_builtin_module(obj): + # Always saved by reference. + return cls.MINIMUM_SIZE + size = getsizeof(obj) + if hasattr(obj, '__dict__'): + size += cls.estimate_size(obj.__dict__, memo) + if hasattr(obj, '__slots__'): + slots = (getattr(obj, x, cls.MISSING_SLOT) for x in obj.__slots__ if x != '__dict__') + size += sum(cls.estimate_size(x, memo) for x in slots if x is not cls.MISSING_SLOT) + if ( + isinstance(obj, str) # common case shortcut + or not isinstance(obj, collections.abc.Collection) # general, single test + or not isinstance(obj, cls.COLLECTION_TYPES) # specific, multiple tests + ): + return size + if isinstance(obj, collections.ChainMap): # collections.Mapping subtype + size += sum(cls.estimate_size(mapping, memo) for mapping in obj.maps) + elif len(obj) < 1000: + if isinstance(obj, collections.abc.Mapping): + size += sum(cls.estimate_size(k, memo) + cls.estimate_size(v, memo) + for k, v in obj.items()) + else: + size += sum(cls.estimate_size(item, memo) for item in obj) + else: + # Use random sample for large collections. + sample = set(random.sample(range(len(obj)), k=100)) + if isinstance(obj, collections.abc.Mapping): + samples_sizes = (cls.estimate_size(k, memo) + cls.estimate_size(v, memo) + for i, (k, v) in enumerate(obj.items()) if i in sample) + else: + samples_sizes = (cls.estimate_size(item, memo) + for i, item in enumerate(obj) if i in sample) + size += len(obj) * mean(samples_sizes) + except Exception: + pass + return size diff --git a/dill/detect.py b/dill/detect.py index b6a6cb76..e6149d15 100644 --- a/dill/detect.py +++ b/dill/detect.py @@ -13,7 +13,7 @@ from inspect import ismethod, isfunction, istraceback, isframe, iscode from .pointers import parent, reference, at, parents, children -from .logger import trace +from .logging import trace __all__ = ['baditems','badobjects','badtypes','code','errors','freevars', 'getmodule','globalvars','nestedcode','nestedglobals','outermost', diff --git a/dill/logger.py b/dill/logging.py similarity index 65% rename from dill/logger.py rename to dill/logging.py index be557a5e..92386e0c 100644 --- a/dill/logger.py +++ b/dill/logging.py @@ -11,37 +11,45 @@ The 'logger' object is dill's top-level logger. The 'adapter' object wraps the logger and implements a 'trace()' method that -generates a detailed tree-style trace for the pickling call at log level INFO. +generates a detailed tree-style trace for the pickling call at log level +:const:`dill.logging.TRACE`, which has an intermediary value between +:const:`logging.INFO` and :const:`logging.DEGUB`. The 'trace()' function sets and resets dill's logger log level, enabling and disabling the pickling trace. The trace shows a tree structure depicting the depth of each object serialized *with dill save functions*, but not the ones that use save functions from -'pickle._Pickler.dispatch'. If the information is available, it also displays +``pickle._Pickler.dispatch``. If the information is available, it also displays the size in bytes that the object contributed to the pickle stream (including its child objects). Sample trace output: - >>> import dill, dill.tests - >>> dill.detect.trace(True) - >>> dill.dump_session(main=dill.tests) - ┬ M1: - ├┬ F2: + >>> import dill + >>> import keyword + >>> with dill.detect.trace(): + ... dill.dump_module(module=keyword) + ┬ M1: + ├┬ F2: │└ # F2 [32 B] - ├┬ D2: + ├┬ D5: │├┬ T4: ││└ # T4 [35 B] - │├┬ D2: + │├┬ D2: ││├┬ T4: │││└ # T4 [50 B] - ││├┬ D2: - │││└ # D2 [84 B] - ││└ # D2 [413 B] - │└ # D2 [763 B] - └ # M1 [813 B] + ││├┬ D2: + │││└ # D2 [47 B] + ││└ # D2 [280 B] + │└ # D5 [1 KiB] + └ # M1 [1 KiB] """ -__all__ = ['adapter', 'logger', 'trace'] +from __future__ import annotations + +__all__ = [ + 'adapter', 'logger', 'trace', 'getLogger', + 'CRITICAL', 'ERROR', 'WARNING', 'INFO', 'TRACE', 'DEBUG', 'NOTSET', +] import codecs import contextlib @@ -49,10 +57,21 @@ import logging import math import os +from contextlib import suppress +from logging import getLogger, CRITICAL, ERROR, WARNING, INFO, DEBUG, NOTSET from functools import partial -from typing import TextIO, Union +from typing import Optional, TextIO, Union import dill +from ._utils import _format_bytes_size + +# Intermediary logging level for tracing. +TRACE = (INFO + DEBUG) // 2 + +_nameOrBoolToLevel = logging._nameToLevel.copy() +_nameOrBoolToLevel['TRACE'] = TRACE +_nameOrBoolToLevel[False] = WARNING +_nameOrBoolToLevel[True] = TRACE # Tree drawing characters: Unicode to ASCII map. ASCII_MAP = str.maketrans({"│": "|", "├": "|", "┬": "+", "└": "`"}) @@ -105,13 +124,24 @@ class TraceAdapter(logging.LoggerAdapter): creates extra values to be added in the LogRecord from it, then calls 'info()'. - Usage of logger with 'trace()' method: + Examples: - >>> from dill.logger import adapter as logger #NOTE: not dill.logger.logger - >>> ... - >>> def save_atype(pickler, obj): - >>> logger.trace(pickler, "Message with %s and %r etc. placeholders", 'text', obj) - >>> ... + In the first call to `trace()`, before pickling an object, it must be passed + to `trace()` as the last positional argument or as the keyword argument + `obj`. Note how, in the second example, the object is not passed as a + positional argument, and therefore won't be substituted in the message: + + >>> from dill.logger import adapter as logger #NOTE: not dill.logger.logger + >>> ... + >>> def save_atype(pickler, obj): + >>> logger.trace(pickler, "X: Message with %s and %r placeholders", 'text', obj) + >>> ... + >>> logger.trace(pickler, "# X") + >>> def save_weakproxy(pickler, obj) + >>> trace_message = "W: This works even with a broken weakproxy: %r" % obj + >>> logger.trace(pickler, trace_message, obj=obj) + >>> ... + >>> logger.trace(pickler, "# W") """ def __init__(self, logger): self.logger = logger @@ -128,44 +158,57 @@ def trace_setup(self, pickler): # Called by Pickler.dump(). if not dill._dill.is_dill(pickler, child=False): return - if self.isEnabledFor(logging.INFO): - pickler._trace_depth = 1 + elif self.isEnabledFor(TRACE): + pickler._trace_stack = [] pickler._size_stack = [] else: - pickler._trace_depth = None - def trace(self, pickler, msg, *args, **kwargs): - if not hasattr(pickler, '_trace_depth'): + pickler._trace_stack = None + def trace(self, pickler, msg, *args, obj=None, **kwargs): + if not hasattr(pickler, '_trace_stack'): logger.info(msg, *args, **kwargs) return - if pickler._trace_depth is None: + elif pickler._trace_stack is None: return extra = kwargs.get('extra', {}) pushed_obj = msg.startswith('#') + if not pushed_obj: + if obj is None and (not args or type(args[-1]) is str): + raise TypeError( + "the pickled object must be passed as the last positional " + "argument (being substituted in the message) or as the " + "'obj' keyword argument." + ) + if obj is None: + obj = args[-1] + pickler._trace_stack.append(id(obj)) size = None - try: + with suppress(AttributeError, TypeError): # Streams are not required to be tellable. - size = pickler._file.tell() + size = pickler._file_tell() frame = pickler.framer.current_frame try: size += frame.tell() except AttributeError: # PyPy may use a BytesBuilder as frame size += len(frame) - except (AttributeError, TypeError): - pass if size is not None: if not pushed_obj: pickler._size_stack.append(size) + if len(pickler._size_stack) == 3: # module > dict > variable + with suppress(AttributeError, KeyError): + extra['varname'] = pickler._id_to_name.pop(id(obj)) else: size -= pickler._size_stack.pop() extra['size'] = size - if pushed_obj: - pickler._trace_depth -= 1 - extra['depth'] = pickler._trace_depth + extra['depth'] = len(pickler._trace_stack) kwargs['extra'] = extra self.info(msg, *args, **kwargs) - if not pushed_obj: - pickler._trace_depth += 1 + if pushed_obj: + pickler._trace_stack.pop() + def roll_back(self, pickler, obj): + if pickler._trace_stack and id(obj) == pickler._trace_stack[-1]: + pickler._trace_stack.pop() + pickler._size_stack.pop() class TraceFormatter(logging.Formatter): """ @@ -200,24 +243,26 @@ def format(self, record): if not self.is_utf8: prefix = prefix.translate(ASCII_MAP) + "-" fields['prefix'] = prefix + " " - if hasattr(record, 'size'): - # Show object size in human-redable form. - power = int(math.log(record.size, 2)) // 10 - size = record.size >> power*10 - fields['suffix'] = " [%d %sB]" % (size, "KMGTP"[power] + "i" if power else "") + if hasattr(record, 'varname'): + fields['suffix'] = " as %r" % record.varname + elif hasattr(record, 'size'): + fields['suffix'] = " [%d %s]" % _format_bytes_size(record.size) vars(record).update(fields) return super().format(record) -logger = logging.getLogger('dill') +logger = getLogger('dill') logger.propagate = False adapter = TraceAdapter(logger) stderr_handler = logging._StderrHandler() adapter.addHandler(stderr_handler) -def trace(arg: Union[bool, TextIO, str, os.PathLike] = None, *, mode: str = 'a') -> None: +def trace( + arg: Union[bool, str, TextIO, os.PathLike] = None, *, mode: str = 'a' + ) -> Optional[TraceManager]: """print a trace through the stack when pickling; useful for debugging - With a single boolean argument, enable or disable the tracing. + With a single boolean argument, enable or disable the tracing. Or, with a + logging level name (not ``int``), set the logging level of the dill logger. Example usage: @@ -227,10 +272,10 @@ def trace(arg: Union[bool, TextIO, str, os.PathLike] = None, *, mode: str = 'a') Alternatively, ``trace()`` can be used as a context manager. With no arguments, it just takes care of restoring the tracing state on exit. - Either a file handle, or a file name and (optionally) a file mode may be - specitfied to redirect the tracing output in the ``with`` block context. A - log function is yielded by the manager so the user can write extra - information to the file. + Either a file handle, or a file name and a file mode (optional) may be + specified to redirect the tracing output in the ``with`` block. A ``log()`` + function is yielded by the manager so the user can write extra information + to the file. Example usage: @@ -249,13 +294,18 @@ def trace(arg: Union[bool, TextIO, str, os.PathLike] = None, *, mode: str = 'a') >>> log("> squared = %r", squared) >>> dumps(squared) - Arguments: - arg: a boolean value, or an optional file-like or path-like object for the context manager - mode: mode string for ``open()`` if a file name is passed as the first argument + Parameters: + arg: a boolean value, the name of a logging level (including "TRACE") + or an optional file-like or path-like object for the context manager + mode: mode string for ``open()`` if a file name is passed as the first + argument """ - if not isinstance(arg, bool): + level = _nameOrBoolToLevel.get(arg) if isinstance(arg, (bool, str)) else None + if level is not None: + logger.setLevel(level) + return + else: return TraceManager(file=arg, mode=mode) - logger.setLevel(logging.INFO if arg else logging.WARNING) class TraceManager(contextlib.AbstractContextManager): """context manager version of trace(); can redirect the trace to a file""" @@ -274,7 +324,7 @@ def __enter__(self): adapter.removeHandler(stderr_handler) adapter.addHandler(self.handler) self.old_level = adapter.getEffectiveLevel() - adapter.setLevel(logging.INFO) + adapter.setLevel(TRACE) return adapter.info def __exit__(self, *exc_info): adapter.setLevel(self.old_level) diff --git a/dill/session.py b/dill/session.py index 6acdd432..9fc4ef56 100644 --- a/dill/session.py +++ b/dill/session.py @@ -7,91 +7,111 @@ # License: 3-clause BSD. The full license text is available at: # - https://github.com/uqfoundation/dill/blob/master/LICENSE """ -Pickle and restore the intepreter session. +Pickle and restore the intepreter session or a module's state. + +The functions :func:`dump_module`, :func:`load_module` and +:func:`load_module_asdict` are capable of saving and restoring, as long as +objects are pickleable, the complete state of a module. For imported modules +that are pickled, `dill` requires them to be importable at unpickling. + +Options like ``dill.settings['byref']`` and ``dill.settings['recurse']`` don't +affect the behavior of :func:`dump_module`. However, if a module has variables +refering to objects from other modules that would prevent it from pickling or +drastically increase its disk size, using the option ``refimported`` forces them +to be saved by reference instead of by value. + +Also with :func:`dump_module`, namespace filters may be used to restrict the +list of pickled variables to a subset of those in the module, based on their +names and values. + +In turn, :func:`load_module_asdict` allows one to load the variables from +different saved states of the same module into dictionaries. + +Using :func:`dill.detect.trace` enables the complete pickling trace of a +module. Alternatively, ``dill.detect.trace('INFO')`` enables only the messages +about variables excluded by filtering or variables saved by reference (by +effect of the `refimported` or the `refoonfail` option) in the pickled module's +namespace. + +Note: + Contrary of using :func:`dill.dump` and :func:`dill.load` to save and load + a module object, :func:`dill.dump_module` always tries to pickle the module + by value (including built-in modules). Modules saved with :func:`dill.dump` + can't be loaded with :func:`dill.load_module`. """ +from __future__ import annotations + __all__ = [ - 'dump_module', 'load_module', 'load_module_asdict', + 'dump_module', 'load_module', 'load_module_asdict', 'is_pickled_module', + 'ModuleFilters', 'NamedObject', 'FilterRules', 'FilterSet', 'size_filter', + 'ipython_filter', 'dump_session', 'load_session' # backward compatibility ] +import pprint import re import sys import warnings -from dill import _dill, Pickler, Unpickler +from dill import _dill, logging +from dill import Pickler, Unpickler, UnpicklingError from ._dill import ( BuiltinMethodType, FunctionType, MethodType, ModuleType, TypeType, - _import_module, _is_builtin_module, _is_imported_module, _main_module, - _reverse_typemap, __builtin__, + _getopt, _import_module, _is_builtin_module, _is_imported_module, + _lookup_module, _main_module, _module_map, _reverse_typemap, __builtin__, ) +from ._utils import FilterRules, FilterSet, _open, size_filter, EXCLUDE, INCLUDE + +logger = logging.getLogger(__name__) # Type hints. -from typing import Optional, Union +from typing import Any, Dict, Iterable, Optional, Union +from ._utils import Filter, FilterFunction, NamedObject, Rule, RuleType import pathlib import tempfile TEMPDIR = pathlib.PurePath(tempfile.gettempdir()) -def _module_map(): - """get map of imported modules""" - from collections import defaultdict - from types import SimpleNamespace - modmap = SimpleNamespace( - by_name=defaultdict(list), - by_id=defaultdict(list), - top_level={}, - ) - for modname, module in sys.modules.items(): - if modname in ('__main__', '__mp_main__') or not isinstance(module, ModuleType): - continue - if '.' not in modname: - modmap.top_level[id(module)] = modname - for objname, modobj in module.__dict__.items(): - modmap.by_name[objname].append((modobj, modname)) - modmap.by_id[id(modobj)].append((modobj, objname, modname)) - return modmap - +# Unique objects (with no duplicates) that may be imported with "import as". IMPORTED_AS_TYPES = (ModuleType, TypeType, FunctionType, MethodType, BuiltinMethodType) if 'PyCapsuleType' in _reverse_typemap: IMPORTED_AS_TYPES += (_reverse_typemap['PyCapsuleType'],) -IMPORTED_AS_MODULES = ('ctypes', 'typing', 'subprocess', 'threading', - r'concurrent\.futures(\.\w+)?', r'multiprocessing(\.\w+)?') -IMPORTED_AS_MODULES = tuple(re.compile(x) for x in IMPORTED_AS_MODULES) - -def _lookup_module(modmap, name, obj, main_module): - """lookup name or id of obj if module is imported""" - for modobj, modname in modmap.by_name[name]: - if modobj is obj and sys.modules[modname] is not main_module: - return modname, name - __module__ = getattr(obj, '__module__', None) - if isinstance(obj, IMPORTED_AS_TYPES) or (__module__ is not None - and any(regex.fullmatch(__module__) for regex in IMPORTED_AS_MODULES)): - for modobj, objname, modname in modmap.by_id[id(obj)]: - if sys.modules[modname] is not main_module: - return modname, objname - return None, None - -def _stash_modules(main_module): - modmap = _module_map() - newmod = ModuleType(main_module.__name__) +# For unique objects of various types that have a '__module__' attribute. +IMPORTED_AS_MODULES = [re.compile(x) for x in ( + 'ctypes', 'typing', 'subprocess', 'threading', + r'concurrent\.futures(\.\w+)?', r'multiprocessing(\.\w+)?' +)] + +BUILTIN_CONSTANTS = (None, False, True, NotImplemented) + +def _stash_modules(main_module, original_main): + """pop imported variables to be saved by reference in the __dill_imported* attributes""" + modmap = _module_map(original_main) + newmod = ModuleType(main_module.__name__) + original = {} imported = [] imported_as = [] imported_top_level = [] # keep separated for backward compatibility - original = {} + for name, obj in main_module.__dict__.items(): - if obj is main_module: - original[name] = newmod # self-reference - elif obj is main_module.__dict__: - original[name] = newmod.__dict__ - # Avoid incorrectly matching a singleton value in another package (ex.: __doc__). - elif any(obj is singleton for singleton in (None, False, True)) \ - or isinstance(obj, ModuleType) and _is_builtin_module(obj): # always saved by ref + # Avoid incorrectly matching a singleton value in another package (e.g. __doc__ == None). + if (any(obj is constant for constant in BUILTIN_CONSTANTS) # must compare by identity + or type(obj) is str and obj == '' # internalized, for cases like: __package__ == '' + or type(obj) is int and -128 <= obj <= 256 # possibly cached by compiler/interpreter + or isinstance(obj, ModuleType) and _is_builtin_module(obj) # always saved by ref + or obj is main_module or obj is main_module.__dict__): original[name] = obj else: - source_module, objname = _lookup_module(modmap, name, obj, main_module) + modname = getattr(obj, '__module__', None) + lookup_by_id = ( + isinstance(obj, IMPORTED_AS_TYPES) + or modname is not None + and any(regex.fullmatch(modname) for regex in IMPORTED_AS_MODULES) + ) + source_module, objname, _ = _lookup_module(modmap, name, obj, lookup_by_id) if source_module is not None: if objname == name: imported.append((source_module, name)) @@ -108,51 +128,135 @@ def _stash_modules(main_module): newmod.__dill_imported = imported newmod.__dill_imported_as = imported_as newmod.__dill_imported_top_level = imported_top_level - if getattr(newmod, '__loader__', None) is None and _is_imported_module(main_module): - # Trick _is_imported_module() to force saving as an imported module. - newmod.__loader__ = True # will be discarded by save_module() - return newmod + _discard_added_variables(newmod, main_module.__dict__) + + if logger.isEnabledFor(logging.INFO): + refimported = [(name, "%s.%s" % (mod, name)) for mod, name in imported] + refimported += [(name, "%s.%s" % (mod, objname)) for mod, objname, name in imported_as] + refimported += [(name, mod) for mod, name in imported_top_level] + message = "[dump_module] Variables saved by reference (refimported):\n" + logger.info(message + _format_log_dict(dict(refimported))) + logger.debug("main namespace after _stash_modules(): %s", dir(newmod)) + + return newmod, modmap else: - return main_module + return main_module, modmap def _restore_modules(unpickler, main_module): - try: - for modname, name in main_module.__dict__.pop('__dill_imported'): - main_module.__dict__[name] = unpickler.find_class(modname, name) - for modname, objname, name in main_module.__dict__.pop('__dill_imported_as'): - main_module.__dict__[name] = unpickler.find_class(modname, objname) - for modname, name in main_module.__dict__.pop('__dill_imported_top_level'): - main_module.__dict__[name] = __import__(modname) - except KeyError: - pass - -#NOTE: 06/03/15 renamed main_module to main + for modname, name in main_module.__dict__.pop('__dill_imported', ()): + main_module.__dict__[name] = unpickler.find_class(modname, name) + for modname, objname, name in main_module.__dict__.pop('__dill_imported_as', ()): + main_module.__dict__[name] = unpickler.find_class(modname, objname) + for modname, name in main_module.__dict__.pop('__dill_imported_top_level', ()): + main_module.__dict__[name] = _import_module(modname) + +def _format_log_dict(dict): + return pprint.pformat(dict, compact=True, sort_dicts=True).replace("'", "") + +def _filter_vars(main_module, exclude, include, base_rules): + """apply exclude/include filters from arguments *and* settings""" + rules = FilterRules() + mod_rules = base_rules.get_rules(main_module.__name__) + rules.exclude |= mod_rules.get_filters(EXCLUDE) + rules.include |= mod_rules.get_filters(INCLUDE) + if exclude is not None: + rules.update([(EXCLUDE, exclude)]) + if include is not None: + rules.update([(INCLUDE, include)]) + + namespace = rules.apply_filters(main_module.__dict__) + if namespace is main_module.__dict__: + return main_module + + if logger.isEnabledFor(logging.INFO): + excluded = {name: type(value).__name__ + for name, value in sorted(main_module.__dict__.items()) if name not in namespace} + message = "[dump_module] Variables excluded by filtering:\n" + logger.info(message + _format_log_dict(excluded)) + + newmod = ModuleType(main_module.__name__) + newmod.__dict__.update(namespace) + _discard_added_variables(newmod, namespace) + logger.debug("main namespace after _filter_vars(): %s", dir(newmod)) + return newmod + +def _discard_added_variables(main, original_namespace): + # Some empty attributes like __doc__ may have been added by ModuleType(). + added_names = set(main.__dict__) + added_names.discard('__name__') # required + added_names.difference_update(original_namespace) + added_names.difference_update('__dill_imported%s' % s for s in ('', '_as', '_top_level')) + for name in added_names: + delattr(main, name) + +def _fix_module_namespace(main, original_main): + # Self-references. + for name, obj in main.__dict__.items(): + if obj is original_main: + setattr(main, name, main) + elif obj is original_main.__dict__: + setattr(main, name, main.__dict__) + # Trick _is_imported_module(), forcing main to be saved as an imported module. + if getattr(main, '__loader__', None) is None and _is_imported_module(original_main): + main.__loader__ = True # will be discarded by _dill.save_module() + def dump_module( filename = str(TEMPDIR/'session.pkl'), module: Optional[Union[ModuleType, str]] = None, - refimported: bool = False, + *, + refimported: Optional[bool] = None, + refonfail: Optional[bool] = None, + exclude: Optional[Union[Filter, Iterable[Filter]]] = None, + include: Optional[Union[Filter, Iterable[Filter]]] = None, + base_rules: Optional[ModuleFilters] = None, **kwds ) -> None: - """Pickle the current state of :py:mod:`__main__` or another module to a file. + """Pickle the current state of :mod:`__main__` or another module to a file. - Save the contents of :py:mod:`__main__` (e.g. from an interactive + Save the contents of :mod:`__main__` (e.g. from an interactive interpreter session), an imported module, or a module-type object (e.g. - built with :py:class:`~types.ModuleType`), to a file. The pickled - module can then be restored with the function :py:func:`load_module`. + built with :class:`~types.ModuleType`), to a file. The pickled + module can then be restored with the function :func:`load_module`. + + Only a subset of the module's variables may be saved if exclusion/inclusion + filters are specified. Filters are applied to every pair of variable's name + and value to determine if they should be saved or not. They can be set in + ``dill.session.settings['filters']`` or passed directly to the ``exclude`` + and ``include`` parameters. + + See :class:`FilterRules` and :class:`ModuleFilters` for details. See + also the bundled "filter factories": :class:`size_filter` and + :func:`ipython_filter`. Parameters: filename: a path-like object or a writable stream. module: a module object or the name of an importable module. If `None` - (the default), :py:mod:`__main__` is saved. + (the default), :mod:`__main__` is saved. refimported: if `True`, all objects identified as having been imported into the module's namespace are saved by reference. *Note:* this is - similar but independent from ``dill.settings[`byref`]``, as + similar but independent from ``dill.settings['byref']``, as ``refimported`` refers to virtually all imported objects, while ``byref`` only affects select objects. - **kwds: extra keyword arguments passed to :py:class:`Pickler()`. + refonfail: if `True` (the default), objects that fail to pickle by value + will try to be saved by reference. If this also fails, saving their + parent objects by reference will be attempted recursively. In the + worst case scenario, the module itself may be saved by reference, + with a warning. *Note:* this has the side effect of disabling framing + for pickle protocol ≥ 4. Turning this option off may improve + unpickling speed, but may cause a module to fail pickling. + exclude: one or more variable `exclusion` filters (see + :class:`FilterRules`). + include: one or more variable `inclusion` filters. + base_rules: if passed, overwrites ``settings['filters']``. + **kwds: extra keyword arguments passed to :class:`Pickler()`. Raises: - :py:exc:`PicklingError`: if pickling fails. + :exc:`PicklingError`: if pickling fails. + :exc:`PicklingWarning`: if the module itself ends being saved by + reference due to unpickleable objects in its namespace. + + Default values for keyword-only arguments can be set in + `dill.session.settings`. Examples: @@ -177,7 +281,16 @@ def dump_module( >>> foo.values = [1,2,3] >>> import math >>> foo.sin = math.sin - >>> dill.dump_module('foo_session.pkl', module=foo, refimported=True) + >>> dill.dump_module('foo_session.pkl', module=foo) + + - Save the state of a module with unpickleable objects: + + >>> import dill + >>> import os + >>> os.altsep = '\\' + >>> dill.dump_module('os_session.pkl', module=os, refonfail=False) + PicklingError: ... + >>> dill.dump_module('os_session.pkl', module=os, refonfail=True) # the default - Restore the state of the saved modules: @@ -191,6 +304,31 @@ def dump_module( >>> foo = dill.load_module('foo_session.pkl') >>> [foo.sin(x) for x in foo.values] [0.8414709848078965, 0.9092974268256817, 0.1411200080598672] + >>> os = dill.load_module('os_session.pkl') + >>> print(os.altsep.join('path')) + p\\a\\t\\h + + - Use `refimported` to save imported objects by reference: + + >>> import dill + >>> from html.entities import html5 + >>> type(html5), len(html5) + (dict, 2231) + >>> import io + >>> buf = io.BytesIO() + >>> dill.dump_module(buf) # saves __main__, with html5 saved by value + >>> len(buf.getvalue()) # pickle size in bytes + 71665 + >>> buf = io.BytesIO() + >>> dill.dump_module(buf, refimported=True) # html5 saved by reference + >>> len(buf.getvalue()) + 438 + + - Save current session but exclude some variables: + + >>> import dill + >>> num, text, alist = 1, 'apple', [4, 9, 16] + >>> dill.dump_module(exclude=['text', int])) # only 'alist' is saved *Changed in version 0.3.6:* Function ``dump_session()`` was renamed to ``dump_module()``. Parameters ``main`` and ``byref`` were renamed to @@ -198,7 +336,7 @@ def dump_module( Note: Currently, ``dill.settings['byref']`` and ``dill.settings['recurse']`` - don't apply to this function.` + don't apply to this function. """ for old_par, par in [('main', 'module'), ('byref', 'refimported')]: if old_par in kwds: @@ -211,8 +349,14 @@ def dump_module( refimported = kwds.pop('byref', refimported) module = kwds.pop('main', module) - from .settings import settings - protocol = settings['protocol'] + from .settings import settings as dill_settings + protocol = dill_settings['protocol'] + refimported = _getopt(settings, 'refimported', refimported) + refonfail = _getopt(settings, 'refonfail', refonfail) + base_rules = _getopt(settings, 'filters', base_rules) + if not isinstance(base_rules, ModuleFilters): #pragma: no cover + base_rules = ModuleFilters(base_rules) + main = module if main is None: main = _main_module @@ -220,25 +364,38 @@ def dump_module( main = _import_module(main) if not isinstance(main, ModuleType): raise TypeError("%r is not a module" % main) - if hasattr(filename, 'write'): - file = filename - else: - file = open(filename, 'wb') - try: + original_main = main + + logger.debug("original main namespace: %s", dir(main)) + main = _filter_vars(main, exclude, include, base_rules) + if refimported: + main, modmap = _stash_modules(main, original_main) + + with _open(filename, 'wb', seekable=True) as file: pickler = Pickler(file, protocol, **kwds) - pickler._original_main = main - if refimported: - main = _stash_modules(main) pickler._main = main #FIXME: dill.settings are disabled pickler._byref = False # disable pickling by name reference pickler._recurse = False # disable pickling recursion for globals pickler._session = True # is best indicator of when pickling a session pickler._first_pass = True - pickler._main_modified = main is not pickler._original_main + if main is not original_main: + pickler._original_main = original_main + _fix_module_namespace(main, original_main) + if refonfail: + pickler._refonfail = True # False by default + pickler._file_seek = file.seek + pickler._file_truncate = file.truncate + pickler._saved_byref = [] + if refimported: + # Cache modmap for refonfail. + pickler._modmap = modmap + if logger.isEnabledFor(logging.TRACE): + pickler._id_to_name = {id(v): k for k, v in main.__dict__.items()} pickler.dump(main) - finally: - if file is not filename: # if newly opened file - file.close() + if refonfail and pickler._saved_byref and logger.isEnabledFor(logging.INFO): + saved_byref = {var: "%s.%s" % (mod, obj) for var, mod, obj in pickler._saved_byref} + message = "[dump_module] Variables saved by reference (refonfail):\n" + logger.info(message + _format_log_dict(saved_byref)) return # Backward compatibility. @@ -247,98 +404,184 @@ def dump_session(filename=str(TEMPDIR/'session.pkl'), main=None, byref=False, ** dump_module(filename, module=main, refimported=byref, **kwds) dump_session.__doc__ = dump_module.__doc__ -class _PeekableReader: - """lightweight stream wrapper that implements peek()""" - def __init__(self, stream): - self.stream = stream - def read(self, n): - return self.stream.read(n) - def readline(self): - return self.stream.readline() - def tell(self): - return self.stream.tell() - def close(self): - return self.stream.close() - def peek(self, n): - stream = self.stream - try: - if hasattr(stream, 'flush'): stream.flush() - position = stream.tell() - stream.seek(position) # assert seek() works before reading - chunk = stream.read(n) - stream.seek(position) - return chunk - except (AttributeError, OSError): - raise NotImplementedError("stream is not peekable: %r", stream) from None - -def _make_peekable(stream): - """return stream as an object with a peek() method""" - import io - if hasattr(stream, 'peek'): - return stream - if not (hasattr(stream, 'tell') and hasattr(stream, 'seek')): - try: - return io.BufferedReader(stream) - except Exception: - pass - return _PeekableReader(stream) - def _identify_module(file, main=None): """identify the name of the module stored in the given file-type object""" - from pickletools import genops - UNICODE = {'UNICODE', 'BINUNICODE', 'SHORT_BINUNICODE'} - found_import = False + import pickletools + NEUTRAL = {'PROTO', 'FRAME', 'PUT', 'BINPUT', 'MEMOIZE', 'MARK', 'STACK_GLOBAL'} try: - for opcode, arg, pos in genops(file.peek(256)): - if not found_import: - if opcode.name in ('GLOBAL', 'SHORT_BINUNICODE') and \ - arg.endswith('_import_module'): - found_import = True - else: - if opcode.name in UNICODE: - return arg - else: - raise UnpicklingError("reached STOP without finding main module") + opcodes = ((opcode.name, arg) for opcode, arg, pos in pickletools.genops(file.peek(256)) + if opcode.name not in NEUTRAL) + opcode, arg = next(opcodes) + if (opcode, arg) == ('SHORT_BINUNICODE', 'dill._dill'): + # The file uses STACK_GLOBAL instead of GLOBAL. + opcode, arg = next(opcodes) + if not (opcode in ('SHORT_BINUNICODE', 'GLOBAL') and arg.split()[-1] == '_import_module'): + raise ValueError + opcode, arg = next(opcodes) + if not opcode in ('SHORT_BINUNICODE', 'BINUNICODE', 'UNICODE'): + raise ValueError + module_name = arg + if not ( + next(opcodes)[0] in ('TUPLE1', 'TUPLE') and + next(opcodes)[0] == 'REDUCE' and + next(opcodes)[0] in ('EMPTY_DICT', 'DICT') + ): + raise ValueError + return module_name + except StopIteration: + raise UnpicklingError("reached STOP without finding module") from None except (NotImplementedError, ValueError) as error: - # ValueError occours when the end of the chunk is reached (without a STOP). + # ValueError also occours when the end of the chunk is reached (without a STOP). if isinstance(error, NotImplementedError) and main is not None: - # file is not peekable, but we have main. + # The file is not peekable, but we have the argument main. return None - raise UnpicklingError("unable to identify main module") from error + raise UnpicklingError("unable to identify module") from error + +def is_pickled_module( + filename, importable: bool = True, identify: bool = False +) -> Union[bool, str]: + """Check if a file can be loaded with :func:`load_module`. + + Check if the file is a pickle file generated with :func:`dump_module`, + and thus can be loaded with :func:`load_module`. + + Parameters: + filename: a path-like object or a readable stream. + importable: expected kind of the file's saved module. Use `True` for + importable modules (the default) or `False` for module-type objects. + identify: if `True`, return the module name if the test succeeds. + + Returns: + `True` if the pickle file at ``filename`` was generated with + :func:`dump_module` **AND** the module whose state is saved in it is + of the kind specified by the ``importable`` argument. `False` otherwise. + If `identify` is set, return the name of the module instead of `True`. + + Examples: + Create three types of pickle files: + + >>> import dill + >>> import types + >>> dill.dump_module('module_session.pkl') # saves __main__ + >>> dill.dump_module('module_object.pkl', module=types.ModuleType('example')) + >>> with open('common_object.pkl', 'wb') as file: + >>> dill.dump('example', file) + + Test each file's kind: + + >>> dill.is_pickled_module('module_session.pkl') # the module is importable + True + >>> dill.is_pickled_module('module_session.pkl', importable=False) + False + >>> dill.is_pickled_module('module_object.pkl') # the module is not importable + False + >>> dill.is_pickled_module('module_object.pkl', importable=False) + True + >>> dill.is_pickled_module('module_object.pkl', importable=False, identify=True) + 'example' + >>> dill.is_pickled_module('common_object.pkl') # always return False + False + >>> dill.is_pickled_module('common_object.pkl', importable=False) + False + """ + with _open(filename, 'rb', peekable=True) as file: + try: + pickle_main = _identify_module(file) + except UnpicklingError: + return False + is_runtime_mod = pickle_main.startswith('__runtime__.') + res = importable ^ is_runtime_mod + if res and identify: + return pickle_main.partition('.')[-1] if is_runtime_mod else pickle_main + else: + return res def load_module( filename = str(TEMPDIR/'session.pkl'), module: Optional[Union[ModuleType, str]] = None, **kwds ) -> Optional[ModuleType]: - """Update the selected module (default is :py:mod:`__main__`) with - the state saved at ``filename``. + """Update the selected module with the state saved at ``filename``. - Restore a module to the state saved with :py:func:`dump_module`. The - saved module can be :py:mod:`__main__` (e.g. an interpreter session), + Restore a module to the state saved with :func:`dump_module`. The + saved module can be :mod:`__main__` (e.g. an interpreter session), an imported module, or a module-type object (e.g. created with - :py:class:`~types.ModuleType`). + :class:`~types.ModuleType`). - When restoring the state of a non-importable module-type object, the - current instance of this module may be passed as the argument ``main``. - Otherwise, a new instance is created with :py:class:`~types.ModuleType` + When restoring the state of a non-importable, module-type object, the + current instance of this module may be passed as the argument ``module``. + Otherwise, a new instance is created with :class:`~types.ModuleType` and returned. Parameters: filename: a path-like object or a readable stream. module: a module object or the name of an importable module; - the module name and kind (i.e. imported or non-imported) must + the module's name and kind (i.e. imported or non-imported) must match the name and kind of the module stored at ``filename``. - **kwds: extra keyword arguments passed to :py:class:`Unpickler()`. + **kwds: extra keyword arguments passed to :class:`Unpickler()`. Raises: - :py:exc:`UnpicklingError`: if unpickling fails. - :py:exc:`ValueError`: if the argument ``main`` and module saved - at ``filename`` are incompatible. + :exc:`UnpicklingError`: if unpickling fails. + :exc:`ValueError`: if the argument ``module`` and the module + saved at ``filename`` are incompatible. Returns: - A module object, if the saved module is not :py:mod:`__main__` or - a module instance wasn't provided with the argument ``main``. + A module object, if the saved module is not :mod:`__main__` and + a module instance wasn't provided with the argument ``module``. + + Passing an argument to ``module`` forces `dill` to verify that the module + being loaded is compatible with the argument value. Additionally, if the + argument is a module instance (instead of a module name), it supresses the + return value. Each case and behavior is exemplified below: + + 1. `module`: ``None`` --- This call loads a previously saved state of + the module ``math`` and returns it (the module object) at the end: + + >>> import dill + >>> # load module -> restore state -> return module + >>> dill.load_module('math_session.pkl') + + + 2. `module`: ``str`` --- Passing the module name does the same as above, + but also verifies that the module being loaded, restored and returned is + indeed ``math``: + + >>> import dill + >>> # load module -> check name/kind -> restore state -> return module + >>> dill.load_module('math_session.pkl', module='math') + + >>> dill.load_module('math_session.pkl', module='cmath') + ValueError: can't update module 'cmath' with the saved state of module 'math' + + 3. `module`: ``ModuleType`` --- Passing the module itself instead of its + name has the additional effect of suppressing the return value (and the + module is already loaded at this point): + + >>> import dill + >>> import math + >>> # check name/kind -> restore state -> return None + >>> dill.load_module('math_session.pkl', module=math) + + For imported modules, the return value is meant as a convenience, so that + the function call can substitute an ``import`` statement. Therefore these + statements: + + >>> import dill + >>> math2 = dill.load_module('math_session.pkl', module='math') + + are equivalent to these: + + >>> import dill + >>> import math as math2 + >>> dill.load_module('math_session.pkl', module=math2) + + Note that, in both cases, ``math2`` is just a reference to + ``sys.modules['math']``: + + >>> import math + >>> import sys + >>> math is math2 is sys.modules['math'] + True Examples: @@ -402,10 +645,6 @@ def load_module( *Changed in version 0.3.6:* Function ``load_session()`` was renamed to ``load_module()``. Parameter ``main`` was renamed to ``module``. - - See also: - :py:func:`load_module_asdict` to load the contents of module saved - with :py:func:`dump_module` into a dictionary. """ if 'main' in kwds: warnings.warn( @@ -415,20 +654,12 @@ def load_module( if module is not None: raise TypeError("both 'module' and 'main' arguments were used") module = kwds.pop('main') - main = module - if hasattr(filename, 'read'): - file = filename - else: - file = open(filename, 'rb') - try: - file = _make_peekable(file) - #FIXME: dill.settings are disabled - unpickler = Unpickler(file, **kwds) - unpickler._session = True - # Resolve unpickler._main + main = module + with _open(filename, 'rb', peekable=True) as file: + # Resolve main. pickle_main = _identify_module(file, main) - if main is None and pickle_main is not None: + if main is None: main = pickle_main if isinstance(main, str): if main.startswith('__runtime__.'): @@ -436,12 +667,8 @@ def load_module( main = ModuleType(main.partition('.')[-1]) else: main = _import_module(main) - if main is not None: - if not isinstance(main, ModuleType): - raise TypeError("%r is not a module" % main) - unpickler._main = main - else: - main = unpickler._main + if not isinstance(main, ModuleType): + raise TypeError("%r is not a module" % main) # Check against the pickle's main. is_main_imported = _is_imported_module(main) @@ -450,32 +677,33 @@ def load_module( if is_runtime_mod: pickle_main = pickle_main.partition('.')[-1] error_msg = "can't update{} module{} %r with the saved state of{} module{} %r" - if is_runtime_mod and is_main_imported: + if main.__name__ != pickle_main: + raise ValueError(error_msg.format("", "", "", "") % (main.__name__, pickle_main)) + elif is_runtime_mod and is_main_imported: raise ValueError( error_msg.format(" imported", "", "", "-type object") - % (main.__name__, pickle_main) + % (main.__name__, main.__name__) ) - if not is_runtime_mod and not is_main_imported: + elif not is_runtime_mod and not is_main_imported: raise ValueError( error_msg.format("", "-type object", " imported", "") - % (pickle_main, main.__name__) + % (main.__name__, main.__name__) ) - if main.__name__ != pickle_main: - raise ValueError(error_msg.format("", "", "", "") % (main.__name__, pickle_main)) - - # This is for find_class() to be able to locate it. - if not is_main_imported: - runtime_main = '__runtime__.%s' % main.__name__ - sys.modules[runtime_main] = main - loaded = unpickler.load() - finally: - if not hasattr(filename, 'read'): # if newly opened file - file.close() + # Load the module's state. + #FIXME: dill.settings are disabled + unpickler = Unpickler(file, **kwds) + unpickler._session = True try: - del sys.modules[runtime_main] - except (KeyError, NameError): - pass + if not is_main_imported: + # This is for find_class() to be able to locate it. + runtime_main = '__runtime__.%s' % main.__name__ + sys.modules[runtime_main] = main + loaded = unpickler.load() + finally: + if not is_main_imported: + del sys.modules[runtime_main] + assert loaded is main _restore_modules(unpickler, main) if main is _main_module or main is module: @@ -491,9 +719,8 @@ def load_session(filename=str(TEMPDIR/'session.pkl'), main=None, **kwds): def load_module_asdict( filename = str(TEMPDIR/'session.pkl'), - update: bool = False, **kwds -) -> dict: +) -> Dict[str, Any]: """ Load the contents of a saved module into a dictionary. @@ -501,27 +728,22 @@ def load_module_asdict( lambda filename: vars(dill.load_module(filename)).copy() - however, does not alter the original module. Also, the path of - the loaded module is stored in the ``__session__`` attribute. + however, it does not alter the original module. Also, the path of + the loaded file is stored with the key ``'__session__'``. Parameters: filename: a path-like object or a readable stream - update: if `True`, initialize the dictionary with the current state - of the module prior to loading the state stored at filename. - **kwds: extra keyword arguments passed to :py:class:`Unpickler()` + **kwds: extra keyword arguments passed to :class:`Unpickler()` Raises: - :py:exc:`UnpicklingError`: if unpickling fails + :exc:`UnpicklingError`: if unpickling fails Returns: A copy of the restored module's dictionary. Note: - If ``update`` is True, the corresponding module may first be imported - into the current namespace before the saved state is loaded from - filename to the dictionary. Note that any module that is imported into - the current namespace as a side-effect of using ``update`` will not be - modified by loading the saved module in filename to a dictionary. + Even if not changed, the module refered in the file is always loaded + before its saved state is restored from `filename` to the dictionary. Example: >>> import dill @@ -541,47 +763,302 @@ def load_module_asdict( False >>> main['anum'] == anum # changed after the session was saved False - >>> new_var in main # would be True if the option 'update' was set - False + >>> new_var in main # it was initialized with the current state of __main__ + True """ if 'module' in kwds: raise TypeError("'module' is an invalid keyword argument for load_module_asdict()") - if hasattr(filename, 'read'): - file = filename + + with _open(filename, 'rb', peekable=True) as file: + main_qualname = _identify_module(file) + main = _import_module(main_qualname) + main_copy = ModuleType(main_qualname) + main_copy.__dict__.clear() + main_copy.__dict__.update(main.__dict__) + + parent_name, _, main_name = main_qualname.rpartition('.') + if parent_name: + parent = sys.modules[parent_name] + try: + sys.modules[main_qualname] = main_copy + if parent_name and getattr(parent, main_name, None) is main: + setattr(parent, main_name, main_copy) + load_module(file, **kwds) + finally: + sys.modules[main_qualname] = main + if parent_name and getattr(parent, main_name, None) is main_copy: + setattr(parent, main_name, main) + + if isinstance(getattr(filename, 'name', None), str): + main_copy.__session__ = filename.name else: - file = open(filename, 'rb') - try: - file = _make_peekable(file) - main_name = _identify_module(file) - old_main = sys.modules.get(main_name) - main = ModuleType(main_name) - if update: - if old_main is None: - old_main = _import_module(main_name) - main.__dict__.update(old_main.__dict__) + main_copy.__session__ = str(filename) + return main_copy.__dict__ + +class ModuleFilters(FilterRules): + """Stores default filtering rules for modules. + + :class:`FilterRules` subclass with a tree-like structure that may hold + exclusion/inclusion filters for specific modules and submodules. See the + base class documentation to learn more about how to create and use filters. + + This is the type of ``dill.session.settings['filters']``: + + >>> import dill + >>> filters = dill.session.settings['filters'] + >>> filters + + + Exclusion and inclusion filters for global variables can be added using the + ``add()`` methods of the ``exclude`` and ``include`` attributes, or of the + ``ModuleFilters`` object itself. In the latter case, the filter is added to + its ``exclude`` :class:`FilterSet` by default: + + >>> filters.add('some_var') # exclude a variable named 'some_var' + >>> filters.exclude.add('_.*') # exclude any variable with a name prefixed by '_' + >>> filters.include.add('_keep_this') # an exception to the rule above + >>> filters + + + Similarly, a filter can be discarded with the ``discard()`` method: + + >>> filters.discard('some_var') + >>> filters.exclude.discard('_.*') + >>> filters + + + Note how, after the last operation, ``filters.exclude`` was left empty but + ``filters.include`` still contains a name filter. In cases like this, i.e. + when ``len(filters.exclude) == 0 and len(filters.include) > 0.``, the + filters are treated as an "allowlist", which means that **only** the + variables that match the ``include`` filters will be pickled. In this + example, only the variable ``_keep_this``, if it existed, would be saved. + + To create filters specific for a module and its submodules, use the + following syntax to add a child node to the default ``ModuleFilters``: + + >>> import dill + >>> from dill.session import EXCLUDE, INCLUDE + >>> filters = dill.session.settings['filters'] + >>> # set empty rules for module 'foo': + >>> # (these will override any existing default rules) + >>> filters['foo'] = [] + >>> filters['foo'] + + >>> # add a name (exclusion) filter: + >>> # (this filter will also apply to any submodule of 'foo') + >>> filters['foo'].add('ignore_this') + >>> filters['foo'] + + + Create a filter for a submodule: + + >>> filters['bar.baz'] = [ + ... (EXCLUDE, r'\w+\d+'), + ... (INCLUDE, ['ERROR403', 'ERROR404']) + ... ] + >>> # set specific rules for the submodule 'bar.baz': + >>> filters['bar.baz'] + + >>> # note that the default rules still apply to the module 'bar' + >>> filters['bar'] + + + Module-specific filter rules may be accessed using different syntaxes: + + >>> filters['bar.baz'] is filters['bar']['baz'] + True + >>> filters.bar.baz is filters['bar']['baz'] + True + + Note, however, that using the attribute syntax to directly set rules for + a submodule will fail if its parent module doesn't have an entry yet: + + >>> filters.parent.child = [] # filters.parent doesn't exist + AttributeError: 'ModuleFilters' object has no attribute 'parent' + >>> filters['parent.child'] = [] # use this syntax instead + >>> filters.parent.child.grandchild = [(EXCLUDE, str)] # works fine + """ + __slots__ = '_module', '_parent', '__dict__' + + def __init__(self, + rules: Union[Iterable[Rule], FilterRules, None] = None, + module: str = 'DEFAULT', + parent: ModuleFilters = None, + ): + if rules is not None: + super().__init__(rules) + # else: don't initialize FilterSets. + if parent is not None and parent._module != 'DEFAULT': + module = '%s.%s' % (parent._module, module) + # Bypass self.__setattr__() + super().__setattr__('_module', module) + super().__setattr__('_parent', parent) + + def __repr__(self) -> str: + desc = "DEFAULT" if self._module == 'DEFAULT' else "for %r" % self._module + return " bool: + if isinstance(other, ModuleFilters): + return super().__eq__(other) and self._module == other._module + elif isinstance(other, FilterRules): + return super().__eq__(other) + else: + return NotImplemented + + def __setattr__(self, name: str, value: Any) -> None: + if name in FilterRules.__slots__: + # Don't interfere with superclass attributes. + super().__setattr__(name, value) + elif name in ('exclude', 'include'): + if not (hasattr(self, 'exclude') or hasattr(self, 'include')): + # This was a placeholder node. Initialize 'other'. + other = 'include' if name == 'exclude' else 'exclude' + super().__setattr__(other, ()) + super().__setattr__(name, value) + else: + # Create a child node for submodule 'name'. + mod_filters = ModuleFilters(rules=value, module=name, parent=self) + super().__setattr__(name, mod_filters) + # Proxy __setitem__ and __getitem__ to self.__dict__ through attributes. + def __setitem__(self, name: str, value: Union[Iterable[Rule], FilterRules, None]) -> None: + if '.' not in name: + setattr(self, name, value) + else: + module, _, submodules = name.partition('.') + if module not in self.__dict__: + # Create a placeholder node, like logging.PlaceHolder. + setattr(self, module, None) + mod_filters = getattr(self, module) + mod_filters[submodules] = value + def __getitem__(self, name: str) -> ModuleFilters: + module, _, submodules = name.partition('.') + mod_filters = getattr(self, module) + if not submodules: + return mod_filters else: - main.__builtins__ = __builtin__ - sys.modules[main_name] = main - load_module(file, **kwds) - finally: - if not hasattr(filename, 'read'): # if newly opened file - file.close() + return mod_filters[submodules] + + def keys(self) -> List[str]: + values = self.__dict__.values() + # Don't include placeholder nodes. + keys = [x._module for x in values if hasattr(x, 'exclude') or hasattr(x, 'include')] + for mod_filters in values: + keys += mod_filters.keys() + keys.sort() + return keys + def get_rules(self, name: str) -> ModuleFilters: + while name: + try: + return self[name] + except AttributeError: + name = name.rpartition('.')[0] + return self + def get_filters(self, rule_type: RuleType) -> FilterSet: + """Get exclude/include filters. If not set, fall back to parent module's or default filters.""" + if not isinstance(rule_type, RuleType): + raise ValueError("invalid rule type: %r (must be one of %r)" % (rule_type, list(RuleType))) try: - if old_main is None: - del sys.modules[main_name] - else: - sys.modules[main_name] = old_main - except NameError: # failed before setting old_main - pass - main.__session__ = str(filename) - return main.__dict__ + return getattr(self, rule_type.name.lower()) + except AttributeError: + if self._parent is None: + raise + return self._parent.get_filters(rule_type) + + +## Session settings ## + +settings = { + 'refimported': False, + 'refonfail': True, + 'filters': ModuleFilters(rules=()), +} +## Session filter factories ## + +def ipython_filter(*, keep_history: str = 'input') -> FilterFunction: + """Filter factory to exclude IPython hidden variables. + + When saving the session with :func:`dump_module` from an IPython + interpreter, hidden variables (i.e. variables listed by ``dir()`` but + not listed by the ``%who`` magic command) are saved unless they are excluded + by filters. This function generates a filter that will exclude these hidden + variables from the list of saved variables, with the optional exception of + command history variables. + + Parameters: + keep_history: whether to keep (i.e. not exclude) the input and output + history of the IPython interactive session. Accepted values: + + - `"input"`: the input history contained in the hidden variables + ``In``, ``_ih``, ``_i``, ``_i1``, ``_i2``, etc. will be saved. + - `"output"`, the output history contained in the hidden variables + ``Out``, ``_oh``, ``_``, ``_1``, ``_2``, etc. will be saved. + - `"both"`: both the input and output history will be saved. + - `"none"`: all the hidden history variables will be excluded. + + Returns: + A variable exclusion filter function to be used with :func:`dump_module`. + + Important: + A filter of this kind should be created just before the call to + :func:`dump_module` where it's used, as it doesn't update the list of + hidden variables after its creation for performance reasons. + + Example: + + >>> import dill + >>> from dill.session import ipython_filter + >>> dill.dump_module(exclude=ipython_filter(keep_history='none')) + """ + HISTORY_OPTIONS = {'input', 'output', 'both', 'none'} + if keep_history not in HISTORY_OPTIONS: #pramga: no cover + raise ValueError( + "invalid 'keep_history' argument: %r (must be one of %r)" % + (keep_history, HISTORY_OPTIONS) + ) + if not _dill.IS_IPYTHON: #pragma: no cover + # Return no-op filter if not in IPython. + return (lambda x: False) + + from IPython import get_ipython + ipython_shell = get_ipython() + + # Code snippet adapted from IPython.core.magics.namespace.who_ls() + user_ns = ipython_shell.user_ns + user_ns_hidden = ipython_shell.user_ns_hidden + NONMATCHING = object() # This can never be in user_ns + interactive_vars = {x for x in user_ns if user_ns[x] is not user_ns_hidden.get(x, NONMATCHING)} + + # Input and output history hidden variables. + history_regex = [] + if keep_history in {'input', 'both'}: + interactive_vars |= {'_ih', 'In', '_i', '_ii', '_iii'} + history_regex.append(re.compile(r'_i\d+')) + if keep_history in {'output', 'both'}: + interactive_vars |= {'_oh', 'Out', '_', '__', '___'} + history_regex.append(re.compile(r'_\d+')) + + def not_interactive_var(obj: NamedObject) -> bool: + if any(regex.fullmatch(obj.name) for regex in history_regex): + return False + return obj.name not in interactive_vars + + return not_interactive_var + + +## Variables set in this module to avoid circular import problems ## + # Internal exports for backward compatibility with dill v0.3.5.1 -# Can't be placed in dill._dill because of circular import problems. for name in ( - '_lookup_module', '_module_map', '_restore_modules', '_stash_modules', + '_restore_modules', '_stash_modules', 'dump_session', 'load_session' # backward compatibility functions ): setattr(_dill, name, globals()[name]) + del name diff --git a/dill/tests/test_filtering.py b/dill/tests/test_filtering.py new file mode 100644 index 00000000..3bcc0c9c --- /dev/null +++ b/dill/tests/test_filtering.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python + +# Author: Leonardo Gama (@leogama) +# Copyright (c) 2022 The Uncertainty Quantification Foundation. +# License: 3-clause BSD. The full license text is available at: +# - https://github.com/uqfoundation/dill/blob/master/LICENSE + +import sys +from types import ModuleType + +import dill +from dill import _dill +from dill.session import ( + EXCLUDE, INCLUDE, FilterRules, FilterSet, RuleType, ipython_filter, size_filter, settings +) + +def test_filterset(): + import re + + name = 'test' + regex1 = re.compile(r'\w+\d+') + regex2 = r'_\w+' + id_ = id(FilterSet) + type1 = FilterSet + type2 = 'type:List' + func = lambda obj: obj.name == 'Arthur' + + empty_filters = FilterSet() + assert bool(empty_filters) is False + assert len(empty_filters) == 0 + assert len([*empty_filters]) == 0 + + # also tests add() and __ior__() for non-FilterSet other + filters = FilterSet._from_iterable([name, regex1, regex2, id_, type1, type2, func]) + assert filters.names == {name} + assert filters.regexes == {regex1, re.compile(regex2)} + assert filters.ids == {id_} + assert filters.types == {type1, list} + assert filters.funcs == {func} + + assert bool(filters) is True + assert len(filters) == 7 + assert all(x in filters for x in [name, regex1, id_, type1, func]) + + try: + filters.add(re.compile(b'an 8-bit string regex')) + except ValueError: + pass + else: + raise AssertionError("adding invalid filter should raise error") + + filters_copy = filters.copy() + for field in FilterSet._fields: + original, copy = getattr(filters, field), getattr(filters_copy, field) + assert copy is not original + assert copy == original + + filters.remove(re.compile(regex2)) + assert filters.regexes == {regex1} + filters.discard(list) + filters.discard(list) # should not raise error + assert filters.types == {type1} + assert [*filters] == [name, regex1, id_, type1, func] + + # also tests __ior__() for FilterSet other + filters.update(filters_copy) + assert filters.types == {type1, list} + + filters.clear() + assert len(filters) == 0 + +NS = { + 'a': 1, + 'aa': 2, + 'aaa': 3, + 'b': 42, + 'bazaar': 'cathedral', + 'has_spam': True, + 'Integer': int, +} + +def did_exclude(namespace, rules, excluded_subset): + rules = FilterRules(rules) + filtered = rules.apply_filters(namespace) + return set(namespace).difference(filtered) == excluded_subset + +def test_basic_filtering(): + filter_names = [(EXCLUDE, ['a', 'c'])] # not 'aa', etc. + assert did_exclude(NS, filter_names, excluded_subset={'a'}) + + filter_regexes = [(EXCLUDE, [r'aa+', r'bb+'])] # not 'a', 'b', 'bazaar' + assert did_exclude(NS, filter_regexes, excluded_subset={'aa', 'aaa'}) + + # Should exclude 'b' and 'd', and not 'b_id'. + NS_copy = NS.copy() + NS_copy['d'] = NS['b'] + NS_copy['b_id'] = id(NS['b']) + filter_ids = [(EXCLUDE, id(NS['b']))] + assert did_exclude(NS_copy, filter_ids, excluded_subset={'b', 'd'}) + + # Should also exclude bool 'has_spam' (int subclass). + filter_types = [(EXCLUDE, [int, frozenset])] + assert did_exclude(NS, filter_types, excluded_subset={'a', 'aa', 'aaa', 'b', 'has_spam'}) + + # Match substring (regexes use fullmatch()). + filter_funcs_name = [(EXCLUDE, lambda obj: 'aa' in obj.name)] + assert did_exclude(NS, filter_funcs_name, excluded_subset={'aa', 'aaa', 'bazaar'}) + + # Don't exclude subclasses. + filter_funcs_value = [(EXCLUDE, lambda obj: type(obj.value) == int)] + assert did_exclude(NS, filter_funcs_value, excluded_subset={'a', 'aa', 'aaa', 'b'}) + +def test_exclude_include(): + # Include rules must apply after exclude rules. + filter_include = [(EXCLUDE, r'a+'), (INCLUDE, 'aa')] # not 'aa' + assert did_exclude(NS, filter_include, excluded_subset={'a', 'aaa'}) + + # If no exclude rules, behave as an allowlist. + filter_allowlist = [(INCLUDE, lambda obj: 'a' in obj.name)] + assert did_exclude(NS, filter_allowlist, excluded_subset={'b', 'Integer'}) + +def test_add_type(): + type_rules = FilterRules() # Formats accepted (actually case insensitive): + type_rules.exclude.add('type: function') # 1. typename + type_rules.exclude.add('type: Type ') # 2. Typename + type_rules.exclude.add('type:ModuleType') # 2. TypenameType + NS_copy = NS.copy() + NS_copy.update(F=test_basic_filtering, T=FilterRules, M=_dill) + assert did_exclude(NS_copy, type_rules, excluded_subset={'F', 'T', 'M', 'Integer'}) + +def test_module_filters(): + R"""Test filters specific for a module and fallback to parent module or default. + + The settings['filers'] single-branched tree structure in these tests: + + exclude: {r'_.*[^_]'} None None + / / / + *-------------* *-------------* *-------------* *~~~~~~~~~~~~~* + module: | DEFAULT |-| foo* |-| foo.bar | | foo.bar.baz | + *-------------* *-------------* *-------------* *~~~~~~~~~~~~~* + \ \ \ \_____ _____/ + include: {'_keep'} None {} (empty) V + missing + (*) 'foo' is a placeholder node + """ + import io + foo = sys.modules['foo'] = ModuleType('foo') + foo.bar = sys.modules['foo.bar'] = ModuleType('foo.bar') + foo.bar.baz = sys.modules['foo.bar.baz'] = ModuleType('foo.bar.baz') + NS = {'_filter': 1, '_keep': 2} + + def _dump_load_dict(module): + module.__dict__.update(NS) + buf = io.BytesIO() + dill.dump_module(buf, module) + for var in NS: + delattr(module, var) + buf.seek(0) + return dill.load_module_asdict(buf) + + # Empty default filters + filters = settings['filters'] + saved = _dump_load_dict(foo) + assert '_filter' in saved + assert '_keep' in saved + + # Default filters + filters.exclude.add(r'_.*[^_]') + filters.include.add('_keep') + assert filters.get_rules('foo') is filters + saved = _dump_load_dict(foo) + assert '_filter' not in saved + assert '_keep' in saved + + # Add filters to 'foo.bar' and placeholder node for 'foo' + filters['foo.bar'] = () + del filters.foo.bar.exclude # remove empty exclude filters, fall back to default + assert not hasattr(filters.foo, 'exclude') and not hasattr(filters.foo, 'include') + assert not hasattr(filters.foo.bar, 'exclude') and hasattr(filters.foo.bar, 'include') + + # foo: placeholder node falling back to default + assert filters.foo.get_filters(EXCLUDE) is filters.exclude + saved = _dump_load_dict(foo) + assert '_filter' not in saved + assert '_keep' in saved + + # foo.bar: without exclude rules, with (empty) include rules + assert filters.foo.bar.get_filters(EXCLUDE) is filters.exclude + assert filters.foo.bar.get_filters(INCLUDE) is filters.foo.bar.include + saved = _dump_load_dict(foo.bar) + assert '_filter' not in saved + assert '_keep' not in saved + + # foo.bar.baz: without specific filters, falling back to foo.bar + assert filters.get_rules('foo.bar.baz') is filters.foo.bar + saved = _dump_load_dict(foo.bar.baz) + assert '_filter' not in saved + assert '_keep' not in saved + +def test_ipython_filter(): + from itertools import filterfalse + from types import SimpleNamespace + _dill.IS_IPYTHON = True # trick ipython_filter + sys.modules['IPython'] = MockIPython = ModuleType('IPython') + + # Mimic the behavior of IPython namespaces at __main__. + user_ns_actual = {'user_var': 1, 'x': 2} + user_ns_hidden = {'x': 3, '_i1': '1 / 2', '_1': 0.5, 'hidden': 4} + user_ns = user_ns_hidden.copy() # user_ns == vars(__main__) + user_ns.update(user_ns_actual) + assert user_ns['x'] == user_ns_actual['x'] # user_ns.x masks user_ns_hidden.x + MockIPython.get_ipython = lambda: SimpleNamespace(user_ns=user_ns, user_ns_hidden=user_ns_hidden) + + # Test variations of keeping or dropping the interpreter history. + user_vars = set(user_ns_actual) + def namespace_matches(keep_history, should_keep_vars): + rules = FilterRules([(EXCLUDE, ipython_filter(keep_history=keep_history))]) + return set(rules.apply_filters(user_ns)) == user_vars | should_keep_vars + assert namespace_matches(keep_history='input', should_keep_vars={'_i1'}) + assert namespace_matches(keep_history='output', should_keep_vars={'_1'}) + assert namespace_matches(keep_history='both', should_keep_vars={'_i1', '_1'}) + assert namespace_matches(keep_history='none', should_keep_vars=set()) + +def test_size_filter(): + from sys import getsizeof + estimate = size_filter.estimate_size + + small = list(range(100)) + large = list(range(1000)) + reflarge = 10*[small] + small_size = getsizeof(small) + 100*getsizeof(0) + large_size = getsizeof(large) + 1000*getsizeof(0) + assert small_size <= estimate(small) < estimate(reflarge) < large_size <= estimate(large) + + NS_copy = NS.copy() # all base objects are small and should not be excluded + reflarge.append(reflarge) # recursive reference + NS_copy.update(small=small, large=large, reflarge=reflarge) + filter_size = [(EXCLUDE, size_filter(limit=5*small_size))] + assert did_exclude(NS_copy, filter_size, excluded_subset={'large'}) + +if __name__ == '__main__': + test_filterset() + test_basic_filtering() + test_exclude_include() + test_add_type() + test_module_filters() + test_ipython_filter() + if not _dill.IS_PYPY: + test_size_filter() diff --git a/dill/tests/test_logger.py b/dill/tests/test_logging.py similarity index 97% rename from dill/tests/test_logger.py rename to dill/tests/test_logging.py index b4e4881a..ed33e6c4 100644 --- a/dill/tests/test_logger.py +++ b/dill/tests/test_logging.py @@ -11,7 +11,7 @@ import dill from dill import detect -from dill.logger import stderr_handler, adapter as logger +from dill.logging import stderr_handler, adapter as logger try: from StringIO import StringIO diff --git a/dill/tests/test_session.py b/dill/tests/test_session.py index 51128916..e5341b25 100644 --- a/dill/tests/test_session.py +++ b/dill/tests/test_session.py @@ -11,8 +11,11 @@ import __main__ from contextlib import suppress from io import BytesIO +from types import ModuleType import dill +from dill import _dill +from dill.session import ipython_filter, EXCLUDE, INCLUDE session_file = os.path.join(os.path.dirname(__file__), 'session-refimported-%s.pkl') @@ -20,7 +23,7 @@ # Child process # ################### -def _error_line(error, obj, refimported): +def _error_line(obj, refimported): import traceback line = traceback.format_exc().splitlines()[-2].replace('[obj]', '['+repr(obj)+']') return "while testing (with refimported=%s): %s" % (refimported, line.lstrip()) @@ -52,7 +55,7 @@ def test_modules(refimported): assert __main__.complex_log is cmath.log except AssertionError as error: - error.args = (_error_line(error, obj, refimported),) + error.args = (_error_line(obj, refimported),) raise test_modules(refimported) @@ -91,6 +94,7 @@ def weekdays(self): return [day_name[i] for i in self.iterweekdays()] cal = CalendarSubclass() selfref = __main__ +self_dict = __main__.__dict__ # Setup global namespace for session saving tests. class TestNamespace: @@ -120,7 +124,7 @@ def _clean_up_cache(module): def _test_objects(main, globals_copy, refimported): try: main_dict = __main__.__dict__ - global Person, person, Calendar, CalendarSubclass, cal, selfref + global Person, person, Calendar, CalendarSubclass, cal, selfref, self_dict for obj in ('json', 'url', 'local_mod', 'sax', 'dom'): assert globals()[obj].__name__ == globals_copy[obj].__name__ @@ -141,9 +145,10 @@ def _test_objects(main, globals_copy, refimported): assert cal.weekdays() == globals_copy['cal'].weekdays() assert selfref is __main__ + assert self_dict is __main__.__dict__ except AssertionError as error: - error.args = (_error_line(error, obj, refimported),) + error.args = (_error_line(obj, refimported),) raise def test_session_main(refimported): @@ -192,13 +197,12 @@ def test_session_other(): assert module.selfref is module def test_runtime_module(): - from types import ModuleType - modname = '__runtime__' - runtime = ModuleType(modname) - runtime.x = 42 + modname = 'runtime' + runtime_mod = ModuleType(modname) + runtime_mod.x = 42 - mod = dill.session._stash_modules(runtime) - if mod is not runtime: + mod, _ = dill.session._stash_modules(runtime_mod, runtime_mod) + if mod is not runtime_mod: print("There are objects to save by referenece that shouldn't be:", mod.__dill_imported, mod.__dill_imported_as, mod.__dill_imported_top_level, file=sys.stderr) @@ -207,46 +211,23 @@ def test_runtime_module(): # without imported objects in the namespace. It's a contrived example because # even dill can't be in it. This should work after fixing #462. session_buffer = BytesIO() - dill.dump_module(session_buffer, module=runtime, refimported=True) + dill.dump_module(session_buffer, module=runtime_mod, refimported=True) session_dump = session_buffer.getvalue() # Pass a new runtime created module with the same name. - runtime = ModuleType(modname) # empty - return_val = dill.load_module(BytesIO(session_dump), module=runtime) + runtime_mod = ModuleType(modname) # empty + return_val = dill.load_module(BytesIO(session_dump), module=runtime_mod) assert return_val is None - assert runtime.__name__ == modname - assert runtime.x == 42 - assert runtime not in sys.modules.values() + assert runtime_mod.__name__ == modname + assert runtime_mod.x == 42 + assert runtime_mod not in sys.modules.values() # Pass nothing as main. load_module() must create it. session_buffer.seek(0) - runtime = dill.load_module(BytesIO(session_dump)) - assert runtime.__name__ == modname - assert runtime.x == 42 - assert runtime not in sys.modules.values() - -def test_refimported_imported_as(): - import collections - import concurrent.futures - import types - import typing - mod = sys.modules['__test__'] = types.ModuleType('__test__') - dill.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) - mod.Dict = collections.UserDict # select by type - mod.AsyncCM = typing.AsyncContextManager # select by __module__ - mod.thread_exec = dill.executor # select by __module__ with regex - - session_buffer = BytesIO() - dill.dump_module(session_buffer, mod, refimported=True) - session_buffer.seek(0) - mod = dill.load(session_buffer) - del sys.modules['__test__'] - - assert set(mod.__dill_imported_as) == { - ('collections', 'UserDict', 'Dict'), - ('typing', 'AsyncContextManager', 'AsyncCM'), - ('dill', 'executor', 'thread_exec'), - } + runtime_mod = dill.load_module(BytesIO(session_dump)) + assert runtime_mod.__name__ == modname + assert runtime_mod.x == 42 + assert runtime_mod not in sys.modules.values() def test_load_module_asdict(): with TestNamespace(): @@ -268,13 +249,198 @@ def test_load_module_asdict(): assert main_vars['names'] == names assert main_vars['names'] is not names assert main_vars['x'] != x - assert 'y' not in main_vars + assert 'y' in main_vars assert 'empty' in main_vars + # Test a submodule. + import html + from html import entities + entitydefs = entities.entitydefs + + session_buffer = BytesIO() + dill.dump_module(session_buffer, entities) + session_buffer.seek(0) + entities_vars = dill.load_module_asdict(session_buffer) + + assert entities is html.entities # restored + assert entities is sys.modules['html.entities'] # restored + assert entitydefs is entities.entitydefs # unchanged + assert entitydefs is not entities_vars['entitydefs'] # saved by value + assert entitydefs == entities_vars['entitydefs'] + +def test_lookup_module(): + assert not _dill._is_builtin_module(local_mod) and local_mod.__package__ == '' + + def lookup(mod, name, obj, lookup_by_name=True): + from dill._dill import _lookup_module, _module_map + return _lookup_module(_module_map(mod), name, obj, lookup_by_name) + + name = '__unpickleable' + obj = object() + setattr(dill, name, obj) + assert lookup(dill, name, obj) == (None, None, None) + + # 4th level: non-installed module + setattr(local_mod, name, obj) + sys.modules[local_mod.__name__] = sys.modules.pop(local_mod.__name__) # put at the end + assert lookup(dill, name, obj) == (local_mod.__name__, name, False) # not installed + try: + import pox + # 3rd level: installed third-party module + setattr(pox, name, obj) + sys.modules['pox'] = sys.modules.pop('pox') + assert lookup(dill, name, obj) == ('pox', name, True) + except ModuleNotFoundError: + pass + # 2nd level: module of same package + setattr(dill.session, name, obj) + sys.modules['dill.session'] = sys.modules.pop('dill.session') + assert lookup(dill, name, obj) == ('dill.session', name, True) + # 1st level: stdlib module + setattr(os, name, obj) + sys.modules['os'] = sys.modules.pop('os') + assert lookup(dill, name, obj) == ('os', name, True) + + # Lookup by id. + name2 = name + '2' + setattr(dill, name2, obj) + assert lookup(dill, name2, obj) == ('os', name, True) + assert lookup(dill, name2, obj, lookup_by_name=False) == (None, None, None) + setattr(local_mod, name2, obj) + assert lookup(dill, name2, obj) == (local_mod.__name__, name2, False) + +def test_refimported(): + import collections + import concurrent.futures + import types + import typing + + mod = sys.modules['__test__'] = ModuleType('__test__') + mod.builtin_module_names = sys.builtin_module_names + dill.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + mod.Dict = collections.UserDict # select by type + mod.AsyncCM = typing.AsyncContextManager # select by __module__ + mod.thread_exec = dill.executor # select by __module__ with regex + mod.local_mod = local_mod + + session_buffer = BytesIO() + dill.dump_module(session_buffer, mod, refimported=True) + session_buffer.seek(0) + mod = dill.load(session_buffer) + + assert mod.__dill_imported == [('sys', 'builtin_module_names')] + assert set(mod.__dill_imported_as) == { + ('collections', 'UserDict', 'Dict'), + ('typing', 'AsyncContextManager', 'AsyncCM'), + ('dill', 'executor', 'thread_exec'), + } + assert mod.__dill_imported_top_level == [(local_mod.__name__, 'local_mod')] + + session_buffer.seek(0) + dill.load_module(session_buffer, mod) + del sys.modules['__test__'] + assert mod.builtin_module_names is sys.builtin_module_names + assert mod.Dict is collections.UserDict + assert mod.AsyncCM is typing.AsyncContextManager + assert mod.thread_exec is dill.executor + assert mod.local_mod is local_mod + +def test_unpickleable_var(): + global local_mod + import keyword as builtin_mod + from dill._dill import _global_string + refonfail_default = dill.session.settings['refonfail'] + dill.session.settings['refonfail'] = True + name = '__unpickleable' + obj = memoryview(b'') + assert _dill._is_builtin_module(builtin_mod) + assert not _dill._is_builtin_module(local_mod) + # assert not dill.pickles(obj) + try: + dill.dumps(obj) + except _dill.UNPICKLEABLE_ERRORS: + pass + else: + raise Exception("test object should be unpickleable") + + def dump_with_ref(mod, other_mod): + setattr(other_mod, name, obj) + buf = BytesIO() + dill.dump_module(buf, mod) + return buf.getvalue() + + # "user" modules + _local_mod = local_mod + del local_mod # remove from __main__'s namespace + try: + dump_with_ref(__main__, __main__) + except dill.PicklingError: + pass # success + else: + raise Exception("saving with a reference to the module itself should fail for '__main__'") + assert _global_string(_local_mod.__name__, name) in dump_with_ref(__main__, _local_mod) + assert _global_string('os', name) in dump_with_ref(__main__, os) + local_mod = _local_mod + del _local_mod, __main__.__unpickleable, local_mod.__unpickleable, os.__unpickleable + + # "builtin" or "installed" modules + assert _global_string(builtin_mod.__name__, name) in dump_with_ref(builtin_mod, builtin_mod) + assert _global_string(builtin_mod.__name__, name) in dump_with_ref(builtin_mod, local_mod) + assert _global_string('os', name) in dump_with_ref(builtin_mod, os) + del builtin_mod.__unpickleable, local_mod.__unpickleable, os.__unpickleable + + dill.session.settings['refonfail'] = refonfail_default + +def test_is_pickled_module(): + import tempfile + import warnings + + # Module saved with dump(). + pickle_file = tempfile.NamedTemporaryFile(mode='wb') + dill.dump(os, pickle_file) + pickle_file.flush() + assert not dill.is_pickled_module(pickle_file.name) + assert not dill.is_pickled_module(pickle_file.name, importable=False) + pickle_file.close() + + # Importable module saved with dump_module(). + pickle_file = tempfile.NamedTemporaryFile(mode='wb') + dill.dump_module(pickle_file, local_mod) + pickle_file.flush() + assert dill.is_pickled_module(pickle_file.name) + assert not dill.is_pickled_module(pickle_file.name, importable=False) + assert dill.is_pickled_module(pickle_file.name, identify=True) == local_mod.__name__ + pickle_file.close() + + # Module-type object saved with dump_module(). + pickle_file = tempfile.NamedTemporaryFile(mode='wb') + dill.dump_module(pickle_file, ModuleType('runtime')) + pickle_file.flush() + assert not dill.is_pickled_module(pickle_file.name) + assert dill.is_pickled_module(pickle_file.name, importable=False) + assert dill.is_pickled_module(pickle_file.name, importable=False, identify=True) == 'runtime' + pickle_file.close() + + # Importable module saved by reference due to unpickleable object. + pickle_file = tempfile.NamedTemporaryFile(mode='wb') + local_mod.__unpickleable = memoryview(b'') + warnings.filterwarnings('ignore') + dill.dump_module(pickle_file, local_mod) + warnings.resetwarnings() + del local_mod.__unpickleable + pickle_file.flush() + assert dill.is_pickled_module(pickle_file.name) + assert not dill.is_pickled_module(pickle_file.name, importable=False) + pickle_file.close() + if __name__ == '__main__': - test_session_main(refimported=False) - test_session_main(refimported=True) + if os.getenv('COVERAGE') != 'true': + test_session_main(refimported=False) + test_session_main(refimported=True) test_session_other() test_runtime_module() - test_refimported_imported_as() test_load_module_asdict() + test_lookup_module() + test_refimported() + test_unpickleable_var() + test_is_pickled_module() diff --git a/dill/tests/test_stdlib_modules.py b/dill/tests/test_stdlib_modules.py new file mode 100644 index 00000000..15cb0767 --- /dev/null +++ b/dill/tests/test_stdlib_modules.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python + +# Author: Leonardo Gama (@leogama) +# Copyright (c) 2022 The Uncertainty Quantification Foundation. +# License: 3-clause BSD. The full license text is available at: +# - https://github.com/uqfoundation/dill/blob/master/LICENSE + +import io +import itertools +import logging +import multiprocessing +import os +import sys +import warnings + +import dill + +if not dill._dill.OLD310: + STDLIB_MODULES = list(sys.stdlib_module_names) + STDLIB_MODULES += [ + # From https://docs.python.org/3.11/library/ + 'collections.abc', 'concurrent.futures', 'curses.ascii', 'curses.panel', 'curses.textpad', + 'html.entities', 'html.parser', 'http.client', 'http.cookiejar', 'http.cookies', 'http.server', + 'importlib.metadata', 'importlib.resources', 'importlib.resources.abc', 'logging.config', + 'logging.handlers', 'multiprocessing.shared_memory', 'os.path', 'test.support', + 'test.support.bytecode_helper', 'test.support.import_helper', 'test.support.os_helper', + 'test.support.script_helper', 'test.support.socket_helper', 'test.support.threading_helper', + 'test.support.warnings_helper', 'tkinter.colorchooser', 'tkinter.dnd', 'tkinter.font', + 'tkinter.messagebox', 'tkinter.scrolledtext', 'tkinter.tix', 'tkinter.ttk', 'unittest.mock', + 'urllib.error', 'urllib.parse', 'urllib.request', 'urllib.response', 'urllib.robotparser', + 'xml.dom', 'xml.dom.minidom', 'xml.dom.pulldom', 'xml.etree.ElementTree', 'xml.parsers.expat', + 'xml.sax', 'xml.sax.handler', 'xml.sax.saxutils', 'xml.sax.xmlreader', 'xmlrpc.client', + 'xmlrpc.server', + ] + STDLIB_MODULES.sort() +else: + STDLIB_MODULES = [ + # From https://docs.python.org/3.9/library/ + '__future__', '_thread', 'abc', 'aifc', 'argparse', 'array', 'ast', 'asynchat', 'asyncio', + 'asyncore', 'atexit', 'audioop', 'base64', 'bdb', 'binascii', 'binhex', 'bisect', 'builtins', + 'bz2', 'calendar', 'cgi', 'cgitb', 'chunk', 'cmath', 'cmd', 'code', 'codecs', 'codeop', + 'collections', 'collections.abc', 'colorsys', 'compileall', 'concurrent', 'concurrent.futures', + 'configparser', 'contextlib', 'contextvars', 'copy', 'copyreg', 'crypt', 'csv', 'ctypes', + 'curses', 'curses.ascii', 'curses.panel', 'curses.textpad', 'dataclasses', 'datetime', 'dbm', + 'decimal', 'difflib', 'dis', 'distutils', 'doctest', 'email', 'ensurepip', 'enum', 'errno', + 'faulthandler', 'fcntl', 'filecmp', 'fileinput', 'fnmatch', 'formatter', 'fractions', 'ftplib', + 'functools', 'gc', 'getopt', 'getpass', 'gettext', 'glob', 'graphlib', 'grp', 'gzip', 'hashlib', + 'heapq', 'hmac', 'html', 'html.entities', 'html.parser', 'http', 'http.client', + 'http.cookiejar', 'http.cookies', 'http.server', 'imaplib', 'imghdr', 'imp', 'importlib', + 'importlib.metadata', 'inspect', 'io', 'ipaddress', 'itertools', 'json', 'keyword', 'linecache', + 'locale', 'logging', 'logging.config', 'logging.handlers', 'lzma', 'mailbox', 'mailcap', + 'marshal', 'math', 'mimetypes', 'mmap', 'modulefinder', 'msilib', 'msvcrt', 'multiprocessing', + 'multiprocessing.shared_memory', 'netrc', 'nis', 'nntplib', 'numbers', 'operator', 'optparse', + 'os', 'os.path', 'ossaudiodev', 'parser', 'pathlib', 'pdb', 'pickle', 'pickletools', 'pipes', + 'pkgutil', 'platform', 'plistlib', 'poplib', 'posix', 'pprint', 'pty', 'pwd', 'py_compile', + 'pyclbr', 'pydoc', 'queue', 'quopri', 'random', 're', 'readline', 'reprlib', 'resource', + 'rlcompleter', 'runpy', 'sched', 'secrets', 'select', 'selectors', 'shelve', 'shlex', 'shutil', + 'signal', 'site', 'site', 'smtpd', 'smtplib', 'sndhdr', 'socket', 'socketserver', 'spwd', + 'sqlite3', 'ssl', 'stat', 'statistics', 'string', 'stringprep', 'struct', 'subprocess', 'sunau', + 'symbol', 'symtable', 'sys', 'sysconfig', 'syslog', 'tabnanny', 'tarfile', 'telnetlib', + 'tempfile', 'termios', 'test', 'test.support', 'test.support.bytecode_helper', + 'test.support.script_helper', 'test.support.socket_helper', 'textwrap', 'threading', 'time', + 'timeit', 'tkinter', 'tkinter.colorchooser', 'tkinter.dnd', 'tkinter.font', + 'tkinter.messagebox', 'tkinter.scrolledtext', 'tkinter.tix', 'tkinter.ttk', 'token', 'tokenize', + 'trace', 'traceback', 'tracemalloc', 'tty', 'turtle', 'types', 'typing', 'unicodedata', + 'unittest', 'unittest.mock', 'urllib', 'urllib.error', 'urllib.parse', 'urllib.request', + 'urllib.response', 'urllib.robotparser', 'uu', 'uuid', 'venv', 'warnings', 'wave', 'weakref', + 'webbrowser', 'winreg', 'winsound', 'wsgiref', 'xdrlib', 'xml.dom', 'xml.dom.minidom', + 'xml.dom.pulldom', 'xml.etree.ElementTree', 'xml.parsers.expat', 'xml.sax', 'xml.sax.handler', + 'xml.sax.saxutils', 'xml.sax.xmlreader', 'xmlrpc', 'xmlrpc.client', 'xmlrpc.server', 'zipapp', + 'zipfile', 'zipimport', 'zlib', 'zoneinfo', +] + +def _dump_load_module(module_name, refonfail): + try: + __import__(module_name) + except ImportError: + return None, None + success_load = None + buf = io.BytesIO() + try: + dill.dump_module(buf, module_name, refonfail=refonfail) + except Exception: + print("F", end="") + success_dump = False + return success_dump, success_load + print(":", end="") + success_dump = True + buf.seek(0) + try: + module = dill.load_module(buf) + except Exception: + success_load = False + return success_dump, success_load + success_load = True + return success_dump, success_load + +def test_stdlib_modules(): + modules = [x for x in STDLIB_MODULES if + not x.startswith('_') + and not x.startswith('test') + and x not in ('antigravity', 'this')] + + + print("\nTesting pickling and unpickling of Standard Library modules...") + message = "Success rate (%s_module, refonfail=%s): %.1f%% [%d/%d]" + with multiprocessing.Pool(maxtasksperchild=1) as pool: + for refonfail in (False, True): + args = zip(modules, itertools.repeat(refonfail)) + result = pool.starmap(_dump_load_module, args, chunksize=1) + dump_successes = sum(dumped for dumped, loaded in result if dumped is not None) + load_successes = sum(loaded for dumped, loaded in result if loaded is not None) + dump_failures = sum(not dumped for dumped, loaded in result if dumped is not None) + load_failures = sum(not loaded for dumped, loaded in result if loaded is not None) + dump_total = dump_successes + dump_failures + load_total = load_successes + load_failures + dump_percent = 100 * dump_successes / dump_total + load_percent = 100 * load_successes / load_total + if logging.getLogger().isEnabledFor(logging.INFO): print() + logging.info(message, "dump", refonfail, dump_percent, dump_successes, dump_total) + logging.info(message, "load", refonfail, load_percent, load_successes, load_total) + if refonfail: + failed_dump = [mod for mod, (dumped, _) in zip(modules, result) if dumped is False] + failed_load = [mod for mod, (_, loaded) in zip(modules, result) if loaded is False] + if failed_dump: + logging.info("dump_module() FAILURES: %s", str(failed_dump).replace("'", "")[1:-1]) + if failed_load: + logging.info("load_module() FAILURES: %s", str(failed_load).replace("'", "")[1:-1]) + assert dump_percent > 99 + assert load_percent > 85 #FIXME: many important modules fail to unpickle + print() + +if __name__ == '__main__': + logging.basicConfig(level=os.environ.get('PYTHONLOGLEVEL', 'WARNING')) + warnings.simplefilter('ignore') + test_stdlib_modules() diff --git a/dill/tests/test_utils.py b/dill/tests/test_utils.py new file mode 100644 index 00000000..32757773 --- /dev/null +++ b/dill/tests/test_utils.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python + +# Author: Leonardo Gama (@leogama) +# Copyright (c) 2022 The Uncertainty Quantification Foundation. +# License: 3-clause BSD. The full license text is available at: +# - https://github.com/uqfoundation/dill/blob/master/LICENSE + +"""test general utilities in _utils.py (for filters, see test_filtering.py)""" + +import io +import os +import sys + +from dill import _utils + +def test_format_bytes(): + formatb = _utils._format_bytes_size + assert formatb(1000) == (1000, 'B') + assert formatb(1024) == (1, 'KiB') + assert formatb(1024 + 511) == (1, 'KiB') + assert formatb(1024 + 512) == (2, 'KiB') + assert formatb(10**9) == (954, 'MiB') + +def test_open(): + file_unpeekable = open(__file__, 'rb', buffering=0) + assert not hasattr(file_unpeekable, 'peek') + + content = file_unpeekable.read() + peeked_chars = content[:10] + first_line = content[:100].partition(b'\n')[0] + b'\n' + file_unpeekable.seek(0) + + # Test _PeekableReader for seekable stream + with _utils._open(file_unpeekable, 'r', peekable=True) as file: + assert isinstance(file, _utils._PeekableReader) + assert file.peek(10)[:10] == peeked_chars + assert file.readline() == first_line + assert not file_unpeekable.closed + file_unpeekable.close() + + _pipe_r, _pipe_w = os.pipe() + pipe_r = io.FileIO(_pipe_r, closefd=False) + pipe_w = io.FileIO(_pipe_w, mode='w') + assert not hasattr(pipe_r, 'peek') + assert not pipe_r.seekable() + assert not pipe_w.seekable() + + # Test io.BufferedReader for unseekable stream + with _utils._open(pipe_r, 'r', peekable=True) as file: + assert isinstance(file, io.BufferedReader) + pipe_w.write(content[:100]) + assert file.peek(10)[:10] == peeked_chars + assert file.readline() == first_line + assert not pipe_r.closed + + # Test _SeekableWriter for unseekable stream + with _utils._open(pipe_w, 'w', seekable=True) as file: + # pipe_r is closed here for some reason... + assert isinstance(file, _utils._SeekableWriter) + file.write(content) + file.flush() + file.seek(0) + file.truncate() + file.write(b'a line of text\n') + assert not pipe_w.closed + pipe_r = io.FileIO(_pipe_r) + assert pipe_r.readline() == b'a line of text\n' + pipe_r.close() + pipe_w.close() + +if __name__ == '__main__': + test_format_bytes() + test_open() diff --git a/docs/source/conf.py b/docs/source/conf.py index ead9ed06..ebb91f57 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -66,8 +66,25 @@ # extension config github_project_url = "https://github.com/uqfoundation/dill" autoclass_content = 'both' +autodoc_default_options = { + 'members': True, + 'undoc-members': True, + 'private-members': True, + 'special-members': True, + 'show-inheritance': True, + 'exclude-members': ( #NOTE: this is a single string concatenation + '__dict__,' # implementation detail (may be verbose) + '__slots__,' # implementation detail + '__weakref__,' # implementation detail + '__module__,' # implementation detail + '_abc_impl,' # implementation detail of abstract classes + '__init__,' # repeated in class docstring by "autoclass_content=both" + '__annotations__,' # redundant with signature documentation + '__dataclass_fields__,' # dataclass automatic attribute, redundant + ) +} autodoc_typehints = 'description' -napoleon_include_init_with_doc = True +autodoc_typehints_format = 'short' napoleon_include_private_with_doc = False napoleon_include_special_with_doc = True napoleon_use_ivar = True diff --git a/docs/source/dill.rst b/docs/source/dill.rst index 2770af2a..67839523 100644 --- a/docs/source/dill.rst +++ b/docs/source/dill.rst @@ -5,107 +5,52 @@ dill module ----------- .. automodule:: dill._dill - :members: - :undoc-members: - :private-members: - :special-members: - :show-inheritance: - :imported-members: -.. :exclude-members: +.. :exclude-members: + detect module ------------- .. automodule:: dill.detect - :members: - :undoc-members: - :private-members: - :special-members: - :show-inheritance: - :imported-members: -.. :exclude-members: ismethod, isfunction, istraceback, isframe, iscode, parent, reference, at, parents, children - -logger module -------------- +.. :exclude-members: +ismethod, isfunction, istraceback, isframe, iscode, parent, reference, at, parents, children + +logging module +-------------- -.. automodule:: dill.logger - :members: - :undoc-members: - :private-members: - :special-members: - :show-inheritance: - :imported-members: -.. :exclude-members: +.. automodule:: dill.logging + :exclude-members: +trace objtypes module --------------- .. automodule:: dill.objtypes - :members: - :undoc-members: - :private-members: - :special-members: - :show-inheritance: - :imported-members: -.. :exclude-members: +.. :exclude-members: + pointers module --------------- .. automodule:: dill.pointers - :members: - :undoc-members: - :private-members: - :special-members: - :show-inheritance: - :imported-members: -.. :exclude-members: +.. :exclude-members: + session module ---------------- +-------------- .. automodule:: dill.session - :members: - :undoc-members: - :private-members: - :special-members: - :show-inheritance: - :imported-members: - :exclude-members: dump_session, load_session + :exclude-members: +dump_session, load_session settings module --------------- .. automodule:: dill.settings - :members: - :undoc-members: - :private-members: - :special-members: - :show-inheritance: - :imported-members: -.. :exclude-members: +.. :exclude-members: + source module ------------- .. automodule:: dill.source - :members: - :undoc-members: - :private-members: - :special-members: - :show-inheritance: - :imported-members: -.. :exclude-members: +.. :exclude-members: + temp module ----------- .. automodule:: dill.temp - :members: - :undoc-members: - :private-members: - :special-members: - :show-inheritance: - :imported-members: -.. :exclude-members: - +.. :exclude-members: +