diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index bed6679be2e6..968c78760af9 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -23,8 +23,10 @@ import threading import copy import warnings -import re +import weakref from collections import OrderedDict, defaultdict + +import re import numpy as np from ..base import mx_real_t, MXNetError @@ -46,7 +48,7 @@ class _BlockScope(object): _current = threading.local() def __init__(self, block): - self._block = block + self._block = weakref.ref(block) if block is not None else None self._counter = {} self._old_scope = None self._name_scope = None @@ -55,7 +57,8 @@ def __init__(self, block): def create(prefix, params, hint): """Creates prefix and params for new `Block`.""" current = getattr(_BlockScope._current, "value", None) - if current is None: + block = current._block() if current is not None else None + if current is None or block is None: if prefix is None: if not hasattr(_name.NameManager._current, "value"): _name.NameManager._current.value = _name.NameManager() @@ -71,23 +74,25 @@ def create(prefix, params, hint): prefix = '%s%d_'%(hint, count) current._counter[hint] = count + 1 if params is None: - parent = current._block.params + parent = block.params params = ParameterDict(parent.prefix+prefix, parent._shared) else: params = ParameterDict(params.prefix, params) - return current._block.prefix+prefix, params + return block.prefix + prefix, params def __enter__(self): - if self._block._empty_prefix: + block = self._block() + if block is None or block._empty_prefix: return self self._old_scope = getattr(_BlockScope._current, "value", None) _BlockScope._current.value = self - self._name_scope = _name.Prefix(self._block.prefix) + self._name_scope = _name.Prefix(block.prefix) self._name_scope.__enter__() return self def __exit__(self, ptype, value, trace): - if self._block._empty_prefix: + block = self._block() + if block is None or block._empty_prefix: return self._name_scope.__exit__(ptype, value, trace) self._name_scope = None diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index a02682557954..cf6bc362eb47 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -17,6 +17,7 @@ import os import tempfile +import gc import mxnet as mx from mxnet import gluon @@ -3212,6 +3213,44 @@ def hybrid_forward(self, F, x): mx.test_utils.assert_almost_equal(grad1, grad2) +def test_no_memory_leak_in_gluon(): + # Collect all other garbage prior to this test. Otherwise the test may fail + # due to unrelated memory leaks. + gc.collect() + + gc_flags = gc.get_debug() + gc.set_debug(gc.DEBUG_SAVEALL) + net = mx.gluon.nn.Dense(10, in_units=10) + net.initialize() + del net + gc.collect() + gc.set_debug(gc_flags) # reset gc flags + + # Check for leaked NDArrays + seen = set() + def has_array(element): + try: + if element in seen: + return False + seen.add(element) + except TypeError: # unhashable + pass + + if isinstance(element, mx.nd._internal.NDArrayBase): + return True + elif hasattr(element, '__dict__'): + return any(has_array(x) for x in vars(element)) + elif isinstance(element, dict): + return any(has_array(x) for x in element.items()) + else: + try: + return any(has_array(x) for x in element) + except (TypeError, KeyError): + return False + + assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles' + del gc.garbage[:] + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_thread_local.py b/tests/python/unittest/test_thread_local.py index f0e3c660181b..50ecb064b04e 100644 --- a/tests/python/unittest/test_thread_local.py +++ b/tests/python/unittest/test_thread_local.py @@ -124,8 +124,9 @@ def __init__(self, prefix): status = [False] event = threading.Event() def f(): - with block._BlockScope(dummy_block("spawned_")): - x= NameManager.current.get(None, "hello") + net = dummy_block("spawned_") # BlockScope only keeps a weakref to the Block + with block._BlockScope(net): + x = NameManager.current.get(None, "hello") event.wait() if x == "spawned_hello0": status[0] = True