Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix leak in gluon.Trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed May 19, 2020
1 parent 7f5df07 commit f54b537
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 42 deletions.
49 changes: 49 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
"""

import logging
import gc
import os
import random

import mxnet as mx
import pytest


Expand Down Expand Up @@ -229,3 +231,50 @@ def doctest(doctest_namespace):
logging.warning('Unable to import numpy/mxnet. Skipping conftest.')
import doctest
doctest.ELLIPSIS_MARKER = '-etc-'


@pytest.fixture(scope='session')
def mxnet_module():
import mxnet
return mxnet


@pytest.fixture()
# @pytest.fixture(autouse=True) # Fix all the bugs and mark this autouse=True
def check_leak_ndarray(mxnet_module):
# Collect garbage prior to running the next test
gc.collect()
# Enable gc debug mode to check if the test leaks any arrays
gc_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)

# Run the test
yield

# Check for leaked NDArrays
gc.collect()
gc.set_debug(gc_flags) # reset gc flags

seen = set()
def has_array(element):
try:
if element in seen:
return False
seen.add(element)
except TypeError: # unhashable
pass

if isinstance(element, mxnet_module.nd._internal.NDArrayBase):
return True
elif hasattr(element, '__dict__'):
return any(has_array(x) for x in vars(element))
elif isinstance(element, dict):
return any(has_array(x) for x in element.items())
else:
try:
return any(has_array(x) for x in element)
except (TypeError, KeyError):
return False

assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles'
del gc.garbage[:]
23 changes: 15 additions & 8 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from collections import OrderedDict, defaultdict
import warnings
import weakref
import numpy as np

from ..base import mx_real_t, MXNetError
Expand Down Expand Up @@ -201,12 +202,12 @@ def shape(self, new_shape):
def _set_trainer(self, trainer):
""" Set the trainer this parameter is associated with. """
# trainer cannot be replaced for sparse params
if self._stype != 'default' and self._trainer and trainer and self._trainer is not trainer:
if self._stype != 'default' and self._trainer and trainer and self._trainer() is not trainer:
raise RuntimeError(
"Failed to set the trainer for Parameter '%s' because it was already set. " \
"More than one trainers for a %s Parameter is not supported." \
%(self.name, self._stype))
self._trainer = trainer
self._trainer = weakref.ref(trainer)

def _check_and_get(self, arr_list, ctx):
if arr_list is not None:
Expand Down Expand Up @@ -245,13 +246,14 @@ def _get_row_sparse(self, arr_list, ctx, row_id):
# get row sparse params based on row ids
if not isinstance(row_id, ndarray.NDArray):
raise TypeError("row_id must have NDArray type, but %s is given"%(type(row_id)))
if not self._trainer:
trainer = self._trainer() if self._trainer else None
if not trainer:
raise RuntimeError("Cannot get row_sparse data for Parameter '%s' when no " \
"Trainer is created with it."%self.name)
results = self._check_and_get(arr_list, ctx)

# fetch row sparse params from the trainer
self._trainer._row_sparse_pull(self, results, row_id)
trainer._row_sparse_pull(self, results, row_id)
return results

def _load_init(self, data, ctx, cast_dtype=False, dtype_source='current'):
Expand Down Expand Up @@ -397,7 +399,11 @@ def _reduce(self):
# fetch all rows for 'row_sparse' param
all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx)
data = ndarray.zeros(self.shape, stype='row_sparse', ctx=ctx)
self._trainer._row_sparse_pull(self, data, all_row_ids, full_idx=True)
trainer = self._trainer() if self._trainer else None
if not trainer:
raise RuntimeError("Cannot reduce row_sparse data for Parameter '%s' when no " \
"Trainer is created with it."%self.name)
trainer._row_sparse_pull(self, data, all_row_ids, full_idx=True)
return data

def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
Expand Down Expand Up @@ -503,9 +509,10 @@ def set_data(self, data):
return

# if update_on_kvstore, we need to make sure the copy stored in kvstore is in sync
if self._trainer and self._trainer._kv_initialized and self._trainer._update_on_kvstore:
if self not in self._trainer._params_to_init:
self._trainer._reset_kvstore()
trainer = self._trainer() if self._trainer else None
if trainer and trainer._kv_initialized and trainer._update_on_kvstore:
if self not in trainer._params_to_init:
trainer._reset_kvstore()

for arr in self._check_and_get(self._data, list):
arr[:] = data
Expand Down
37 changes: 3 additions & 34 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def test_parameter_invalid_access():
assertRaises(RuntimeError, p1.list_row_sparse_data, row_id)

@with_seed()
@pytest.mark.usefixtures("check_leak_ndarray")
def test_parameter_dict():
ctx = mx.cpu(1)
params0 = gluon.ParameterDict('net_')
Expand Down Expand Up @@ -3226,40 +3227,8 @@ def hybrid_forward(self, F, x):

mx.test_utils.assert_almost_equal(grad1, grad2)

def test_no_memory_leak_in_gluon():
# Collect all other garbage prior to this test. Otherwise the test may fail
# due to unrelated memory leaks.
gc.collect()

gc_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)
@pytest.mark.usefixtures("check_leak_ndarray")
def test_no_memory_leak_in_gluon():
net = mx.gluon.nn.Dense(10, in_units=10)
net.initialize()
del net
gc.collect()
gc.set_debug(gc_flags) # reset gc flags

# Check for leaked NDArrays
seen = set()
def has_array(element):
try:
if element in seen:
return False
seen.add(element)
except TypeError: # unhashable
pass

if isinstance(element, mx.nd._internal.NDArrayBase):
return True
elif hasattr(element, '__dict__'):
return any(has_array(x) for x in vars(element))
elif isinstance(element, dict):
return any(has_array(x) for x in element.items())
else:
try:
return any(has_array(x) for x in element)
except (TypeError, KeyError):
return False

assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles'
del gc.garbage[:]

0 comments on commit f54b537

Please sign in to comment.