diff --git a/tests/test_issues.py b/tests/test_issues.py index e7547b9b..19d7a839 100644 --- a/tests/test_issues.py +++ b/tests/test_issues.py @@ -404,3 +404,13 @@ def test_issue434(self): uproot.astable( uproot.asdtype(fromdtype, todtype)), skipbytes=6)) assert 486480 == hits['tdc'][0][0] + + def test_issue438_accessing_memory_mapped_objects_outside_of_context_raises(self): + with uproot.open("tests/samples/issue434.root") as f: + a = f['KM3NET_EVENT']['KM3NET_EVENT']['KM3NETDAQ::JDAQPreamble'].array() + b = f['KM3NET_EVENT']['KM3NET_EVENT']['KM3NETDAQ::JDAQPreamble'].lazyarray() + assert 4 == len(a[0]) + with pytest.raises(IOError): + len(b[0]) + + diff --git a/uproot/rootio.py b/uproot/rootio.py index f5168fc3..ae1cb41e 100644 --- a/uproot/rootio.py +++ b/uproot/rootio.py @@ -362,6 +362,9 @@ def get(self, name, cycle=None): else: raise _KeyError("not found: {0} with cycle {1}\n in file: {2}".format(repr(name), cycle, self._context.sourcepath)) + def close(self): + self._context.source.close() + def __contains__(self, name): try: self.get(name) @@ -374,7 +377,7 @@ def __enter__(self, *args, **kwds): return self def __exit__(self, *args, **kwds): - pass + self.close() class _KeyError(KeyError): def __str__(self): diff --git a/uproot/source/memmap.py b/uproot/source/memmap.py index 55c21d89..379f1248 100644 --- a/uproot/source/memmap.py +++ b/uproot/source/memmap.py @@ -19,12 +19,19 @@ class MemmapSource(uproot.source.source.Source): def __init__(self, path): self.path = os.path.expanduser(path) self._source = numpy.memmap(self.path, dtype=numpy.uint8, mode="r") + self.closed = False + + @property + def source(self): + if self.closed: + raise IOError("The file handler has already been closed.") + return self._source def parent(self): return self def size(self): - return len(self._source) + return len(self.source) def threadlocal(self): return self @@ -33,17 +40,18 @@ def dismiss(self): pass def close(self): - self._source._mmap.close() + self.source._mmap.close() + self.closed = True def data(self, start, stop, dtype=None): # assert start >= 0 # assert stop >= 0 # assert stop >= start - if stop > len(self._source): - raise IndexError("indexes {0}:{1} are beyond the end of data source {2}".format(len(self._source), stop, repr(self.path))) + if stop > len(self.source): + raise IndexError("indexes {0}:{1} are beyond the end of data source {2}".format(len(self.source), stop, repr(self.path))) if dtype is None: - return self._source[start:stop] + return self.source[start:stop] else: - return self._source[start:stop].view(dtype) + return self.source[start:stop].view(dtype)