From f54b537f15b6e0e6fb9912fd14d692031ef91fc8 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Mon, 18 May 2020 23:52:46 +0000 Subject: [PATCH] Fix leak in gluon.Trainer --- conftest.py | 49 +++++++++++++++++++++++++++++ python/mxnet/gluon/parameter.py | 23 +++++++++----- tests/python/unittest/test_gluon.py | 37 ++-------------------- 3 files changed, 67 insertions(+), 42 deletions(-) diff --git a/conftest.py b/conftest.py index caabaf9d74b9..2db6a0155b99 100644 --- a/conftest.py +++ b/conftest.py @@ -24,9 +24,11 @@ """ import logging +import gc import os import random +import mxnet as mx import pytest @@ -229,3 +231,50 @@ def doctest(doctest_namespace): logging.warning('Unable to import numpy/mxnet. Skipping conftest.') import doctest doctest.ELLIPSIS_MARKER = '-etc-' + + +@pytest.fixture(scope='session') +def mxnet_module(): + import mxnet + return mxnet + + +@pytest.fixture() +# @pytest.fixture(autouse=True) # Fix all the bugs and mark this autouse=True +def check_leak_ndarray(mxnet_module): + # Collect garbage prior to running the next test + gc.collect() + # Enable gc debug mode to check if the test leaks any arrays + gc_flags = gc.get_debug() + gc.set_debug(gc.DEBUG_SAVEALL) + + # Run the test + yield + + # Check for leaked NDArrays + gc.collect() + gc.set_debug(gc_flags) # reset gc flags + + seen = set() + def has_array(element): + try: + if element in seen: + return False + seen.add(element) + except TypeError: # unhashable + pass + + if isinstance(element, mxnet_module.nd._internal.NDArrayBase): + return True + elif hasattr(element, '__dict__'): + return any(has_array(x) for x in vars(element)) + elif isinstance(element, dict): + return any(has_array(x) for x in element.items()) + else: + try: + return any(has_array(x) for x in element) + except (TypeError, KeyError): + return False + + assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles' + del gc.garbage[:] diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 06b615005158..763a488a41b5 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -25,6 +25,7 @@ from collections import OrderedDict, defaultdict import warnings +import weakref import numpy as np from ..base import mx_real_t, MXNetError @@ -201,12 +202,12 @@ def shape(self, new_shape): def _set_trainer(self, trainer): """ Set the trainer this parameter is associated with. """ # trainer cannot be replaced for sparse params - if self._stype != 'default' and self._trainer and trainer and self._trainer is not trainer: + if self._stype != 'default' and self._trainer and trainer and self._trainer() is not trainer: raise RuntimeError( "Failed to set the trainer for Parameter '%s' because it was already set. " \ "More than one trainers for a %s Parameter is not supported." \ %(self.name, self._stype)) - self._trainer = trainer + self._trainer = weakref.ref(trainer) def _check_and_get(self, arr_list, ctx): if arr_list is not None: @@ -245,13 +246,14 @@ def _get_row_sparse(self, arr_list, ctx, row_id): # get row sparse params based on row ids if not isinstance(row_id, ndarray.NDArray): raise TypeError("row_id must have NDArray type, but %s is given"%(type(row_id))) - if not self._trainer: + trainer = self._trainer() if self._trainer else None + if not trainer: raise RuntimeError("Cannot get row_sparse data for Parameter '%s' when no " \ "Trainer is created with it."%self.name) results = self._check_and_get(arr_list, ctx) # fetch row sparse params from the trainer - self._trainer._row_sparse_pull(self, results, row_id) + trainer._row_sparse_pull(self, results, row_id) return results def _load_init(self, data, ctx, cast_dtype=False, dtype_source='current'): @@ -397,7 +399,11 @@ def _reduce(self): # fetch all rows for 'row_sparse' param all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx) data = ndarray.zeros(self.shape, stype='row_sparse', ctx=ctx) - self._trainer._row_sparse_pull(self, data, all_row_ids, full_idx=True) + trainer = self._trainer() if self._trainer else None + if not trainer: + raise RuntimeError("Cannot reduce row_sparse data for Parameter '%s' when no " \ + "Trainer is created with it."%self.name) + trainer._row_sparse_pull(self, data, all_row_ids, full_idx=True) return data def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(), @@ -503,9 +509,10 @@ def set_data(self, data): return # if update_on_kvstore, we need to make sure the copy stored in kvstore is in sync - if self._trainer and self._trainer._kv_initialized and self._trainer._update_on_kvstore: - if self not in self._trainer._params_to_init: - self._trainer._reset_kvstore() + trainer = self._trainer() if self._trainer else None + if trainer and trainer._kv_initialized and trainer._update_on_kvstore: + if self not in trainer._params_to_init: + trainer._reset_kvstore() for arr in self._check_and_get(self._data, list): arr[:] = data diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 98773b238348..da13affd6955 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -99,6 +99,7 @@ def test_parameter_invalid_access(): assertRaises(RuntimeError, p1.list_row_sparse_data, row_id) @with_seed() +@pytest.mark.usefixtures("check_leak_ndarray") def test_parameter_dict(): ctx = mx.cpu(1) params0 = gluon.ParameterDict('net_') @@ -3226,40 +3227,8 @@ def hybrid_forward(self, F, x): mx.test_utils.assert_almost_equal(grad1, grad2) -def test_no_memory_leak_in_gluon(): - # Collect all other garbage prior to this test. Otherwise the test may fail - # due to unrelated memory leaks. - gc.collect() - gc_flags = gc.get_debug() - gc.set_debug(gc.DEBUG_SAVEALL) +@pytest.mark.usefixtures("check_leak_ndarray") +def test_no_memory_leak_in_gluon(): net = mx.gluon.nn.Dense(10, in_units=10) net.initialize() - del net - gc.collect() - gc.set_debug(gc_flags) # reset gc flags - - # Check for leaked NDArrays - seen = set() - def has_array(element): - try: - if element in seen: - return False - seen.add(element) - except TypeError: # unhashable - pass - - if isinstance(element, mx.nd._internal.NDArrayBase): - return True - elif hasattr(element, '__dict__'): - return any(has_array(x) for x in vars(element)) - elif isinstance(element, dict): - return any(has_array(x) for x in element.items()) - else: - try: - return any(has_array(x) for x in element) - except (TypeError, KeyError): - return False - - assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles' - del gc.garbage[:]