Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix memory leaks in Gluon #18328

Merged
merged 2 commits into from
May 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import threading
import copy
import warnings
import re
import weakref
from collections import OrderedDict, defaultdict

import re
import numpy as np

from ..base import mx_real_t, MXNetError, NDArrayHandle, py_str
Expand All @@ -48,7 +50,7 @@ class _BlockScope(object):
_current = threading.local()

def __init__(self, block):
self._block = block
self._block = weakref.ref(block) if block is not None else None
self._counter = {}
self._old_scope = None
self._name_scope = None
Expand All @@ -60,7 +62,8 @@ def create(prefix, params, hint):
The profiler scope is to support the GPU memory profiler.
"""
current = getattr(_BlockScope._current, "value", None)
if current is None:
block = current._block() if current is not None else None
if current is None or block is None:
if prefix is None:
if not hasattr(_name.NameManager._current, "value"):
_name.NameManager._current.value = _name.NameManager()
Expand All @@ -79,29 +82,31 @@ def create(prefix, params, hint):
prefix = '%s%d_'%(hint, count)
current._counter[hint] = count + 1
if params is None:
parent = current._block.params
parent = block.params
params = ParameterDict(parent.prefix+prefix, parent._shared)
else:
params = ParameterDict(params.prefix, params)
# replace the trailing underscore with colon
profiler_scope_name = (prefix[:-1] if prefix.endswith('_') \
else prefix) + ":"
return current._block.prefix + prefix, params, \
current._block._profiler_scope_name + profiler_scope_name
return block.prefix + prefix, params, \
block._profiler_scope_name + profiler_scope_name

def __enter__(self):
if self._block._empty_prefix:
block = self._block()
if block is None or block._empty_prefix:
return self
self._old_scope = getattr(_BlockScope._current, "value", None)
_BlockScope._current.value = self
self._name_scope = _name.Prefix(self._block.prefix)
self._name_scope = _name.Prefix(block.prefix)
self._name_scope.__enter__()
self._profiler_scope = _profiler.Scope(self._block._profiler_scope_name)
self._profiler_scope = _profiler.Scope(block._profiler_scope_name)
self._profiler_scope.__enter__()
return self

def __exit__(self, ptype, value, trace):
if self._block._empty_prefix:
block = self._block()
if block is None or block._empty_prefix:
return
self._name_scope.__exit__(ptype, value, trace)
self._name_scope = None
Expand Down
38 changes: 38 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import os
import gc

import mxnet as mx
from mxnet import gluon
Expand Down Expand Up @@ -3229,3 +3230,40 @@ def hybrid_forward(self, F, x):

mx.test_utils.assert_almost_equal(grad1, grad2)

def test_no_memory_leak_in_gluon():
# Collect all other garbage prior to this test. Otherwise the test may fail
# due to unrelated memory leaks.
gc.collect()

gc_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)
net = mx.gluon.nn.Dense(10, in_units=10)
net.initialize()
del net
gc.collect()
gc.set_debug(gc_flags) # reset gc flags

# Check for leaked NDArrays
seen = set()
def has_array(element):
try:
if element in seen:
return False
seen.add(element)
except TypeError: # unhashable
pass

if isinstance(element, mx.nd._internal.NDArrayBase):
return True
elif hasattr(element, '__dict__'):
return any(has_array(x) for x in vars(element))
elif isinstance(element, dict):
return any(has_array(x) for x in element.items())
else:
try:
return any(has_array(x) for x in element)
except (TypeError, KeyError):
return False

assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles'
del gc.garbage[:]
5 changes: 3 additions & 2 deletions tests/python/unittest/test_thread_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@ def __init__(self, prefix):
status = [False]
event = threading.Event()
def f():
with block._BlockScope(dummy_block("spawned_")):
x= NameManager.current.get(None, "hello")
net = dummy_block("spawned_") # BlockScope only keeps a weakref to the Block
with block._BlockScope(net):
x = NameManager.current.get(None, "hello")
event.wait()
if x == "spawned_hello0":
status[0] = True
Expand Down