Skip to content

Commit

Permalink
Session: check id against module being saved
Browse files Browse the repository at this point in the history
  • Loading branch information
leogama committed May 12, 2022
1 parent 5b77a06 commit 26c7b72
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
14 changes: 7 additions & 7 deletions dill/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import dill
from dill import Pickler, Unpickler
from ._dill import ModuleType, _import_module, _is_builtin_module
from ._dill import ModuleType, _import_module, _is_builtin_module, _main_module
from .utils import AttrDict, CheckerSet, TransSet
from .settings import settings

Expand Down Expand Up @@ -117,7 +117,7 @@ def _exclude_objs(main, exclude_extra, filters_extra, settings):
for item in exclude_extra:
for category, klass in categories.items():
if isinstance(item, klass):
exclude[category].add(item)
exclude[category].add(item, main=main)
break
else:
raise ValueError("bad value type for 'exclude' parameter: %r" % item)
Expand Down Expand Up @@ -220,21 +220,21 @@ def load_session(filename: Union[os.PathLike, io.BytesIO] = '/tmp/session.pkl',
# Settings #
##############

def _as_id(item):
def _as_id(item, *, main=_main_module):
if isinstance(item, int):
import warnings, __main__
if not any(id(obj) == item for obj in __main__.__dict__.values()):
import warnings
if not any(id(obj) == item for obj in main.__dict__.values()):
warnings.warn("%d isn't the id of any object in __main__ namespace. "
"Did you mean 'id(%d)?'" % (item, item))
return item
return id(item)

def _as_regex(item):
def _as_regex(item, **kwargs):
if isinstance(item, re.Pattern):
return item
return re.compile(item)

def _as_type(item):
def _as_type(item, **kwargs):
if isinstance(item, str):
import types
if hasattr(types, item + 'Type'):
Expand Down
4 changes: 2 additions & 2 deletions dill/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class TransSet(set):
def __init__(self, func: Callable, *args):
self.constructor = func
super().__init__(*args)
def add(self, item):
super().add(self.constructor(item))
def add(self, item, *, **kwags):
super().add(self.constructor(item, **kwargs))
def discard(self, item):
super().discard(self.constructor(item))
def remove(self, item):
Expand Down

0 comments on commit 26c7b72

Please sign in to comment.