diff --git a/padertorch/contrib/cb/complex.py b/padertorch/contrib/cb/complex.py index 8c45eb63..475ab04f 100644 --- a/padertorch/contrib/cb/complex.py +++ b/padertorch/contrib/cb/complex.py @@ -1,10 +1,7 @@ import numpy as np import torch -import torch_complex -from torch_complex import ComplexTensor __all__ = { - 'ComplexTensor', } @@ -19,8 +16,11 @@ def is_torch(obj): >>> is_torch(ComplexTensor(np.zeros(3))) True """ - if torch.is_tensor(obj) or isinstance(obj, ComplexTensor): + if torch.is_tensor(obj): return True - else: - return False + if type(obj).__name__ == 'ComplexTensor': + from torch_complex import ComplexTensor + if isinstance(obj, ComplexTensor): + return True + return False diff --git a/padertorch/summary/tbx_utils.py b/padertorch/summary/tbx_utils.py index 09c26f33..8076d756 100644 --- a/padertorch/summary/tbx_utils.py +++ b/padertorch/summary/tbx_utils.py @@ -79,7 +79,7 @@ def mask_to_image( is assumed to be in the second position, i.e., `(frames, batch [optional], features)`. color: A color map name. The name is forwarded to - `matplotlib.pyplot.cm.get_cmap` to get the color map. If `None`, + `matplotlib.pyplot.get_cmap` to get the color map. If `None`, grayscale is used. origin: Origin of the plot. Can be `'upper'` or `'lower'`. @@ -119,7 +119,7 @@ def stft_to_image( signal: Shape (frames, batch [optional], features) batch_first: if true mask shape (batch [optional], frames, features] color: A color map name. The name is forwarded to - `matplotlib.pyplot.cm.get_cmap` to get the color map. If `None`, + `matplotlib.pyplot.get_cmap` to get the color map. If `None`, grayscale is used. origin: Origin of the plot. Can be `'upper'` or `'lower'`. visible_dB: How many dezibel are visible in the image. @@ -203,7 +203,7 @@ def __call__(self, image, color): except KeyError: try: import matplotlib.pyplot as plt - cmap = plt.cm.get_cmap(color) + cmap = plt.get_cmap(color) self.color_to_cmap[color] = cmap except ImportError: from warnings import warn @@ -243,7 +243,7 @@ def spectrogram_to_image( is assumed to be in the second position, i.e., `(frames, batch [optional], features)`. color: A color map name. The name is forwarded to - `matplotlib.pyplot.cm.get_cmap` to get the color map. + `matplotlib.pyplot.get_cmap` to get the color map. origin: Origin of the plot. Can be `'upper'` or `'lower'`. log: If `True`, the spectrogram is plotted in log domain and shows a 50dB range. The 50dB can be changed with the argument `visible_dB`. @@ -299,6 +299,9 @@ def audio( docs for further information on the return type. """ signal = to_numpy(signal, detach=True) + if signal.dtype.kind == 'c': + raise ValueError( + f'Complex datatype ({signal.dtype}) is not supported for audio.') signal = _remove_batch_axis(signal, batch_first=batch_first, ndim=1) @@ -454,3 +457,5 @@ def review_dict( assert operator.xor(loss is None, losses is None), (loss, losses) return review + + diff --git a/tests/test_train/test_trainer.py b/tests/test_train/test_trainer.py index fa69c9f5..a5982723 100644 --- a/tests/test_train/test_trainer.py +++ b/tests/test_train/test_trainer.py @@ -561,7 +561,11 @@ def test_released_tensors(): dt_dataset = dt_dataset[:2] class ReleaseTestHook(pt.train.hooks.Hook): - def get_all_tensors(self): + def __init__(self, global_tensors): + self.global_tensors = global_tensors + + @staticmethod + def get_all_tensors(): import gc tensors = [] for obj in gc.get_objects(): @@ -607,16 +611,22 @@ def show_referrers_type(cls, obj, depth, ignore=list()): ignore=ignore + [referrers, o, obj] ): l.append(textwrap.indent(s, ' '*4)) + else: + l.append('... cycle ...') + class c: + magenta = '\033[35m' + reset = '\033[0m' + cyan = '\033[36m' if inspect.isframe(obj): frame_info = inspect.getframeinfo(obj, context=1) if frame_info.function == 'show_referrers_type': pass else: - info = f' {frame_info.function}, {frame_info.filename}:{frame_info.lineno}' + info = f' {frame_info.function}, {c.magenta}{frame_info.filename}{c.reset}:{c.magenta}{frame_info.lineno}{c.reset}' l.append(f'Frame: {type(obj)} {info}') else: - l.append(str(type(obj)) + str(obj)[:80].replace('\n', ' ')) + l.append(str(type(obj)) + str(obj)[:160].replace('\n', ' ')) return l def pre_step(self, trainer: 'pt.Trainer'): @@ -645,17 +655,18 @@ def pre_step(self, trainer: 'pt.Trainer'): import textwrap print(len(all_tensors), len(parameters), len(optimizer_tensors)) - assert len(all_tensors) == len(parameters) + len(optimizer_tensors) + len(grads), ( + def format_(name, tensors): + s = textwrap.indent("\n".join(map(str, all_tensors)), " "*8) + return f'{name}: {len(tensors)}\n{s}\n' + + assert len(all_tensors) == len(parameters) + len(optimizer_tensors) + len(grads) + len(self.global_tensors), ( f'pre_step\n' f'{summary}\n' - f'all_tensors: {len(all_tensors)}\n' - + textwrap.indent("\n".join(map(str, all_tensors)), " "*8) + f'\n' - f'parameters: {len(parameters)}\n' - + textwrap.indent("\n".join(map(str, parameters)), " "*8) + f'\n' - f'parameters: {len(grads)}\n' - + textwrap.indent("\n".join(map(str, grads)), " "*8) + f'\n' - f'optimizer_tensors: {len(optimizer_tensors)}\n' - + textwrap.indent("\n".join(map(str, optimizer_tensors)), " "*8) + f'\n' + + format_('all_tensors', all_tensors) + + format_('parameters', parameters) + + format_('optimizer_tensors', optimizer_tensors) + + format_('grads', grads) + + format_('global_tensors', self.global_tensors) ) def post_step( @@ -665,12 +676,30 @@ def post_step( parameters = list(trainer.model.parameters()) assert len(all_tensors) > len(parameters), ('post_step', all_tensors, parameters) + + print('pre TemporaryDirectory', ReleaseTestHook.get_all_tensors()) + + try: + # Between Torch 2.1.2 and 2.3.1 someone created _nt_view_dummy, + # which is the only Tensor in torch, that is created with an import + # of torch code. + # For some unknown reason the Adam optimizer triggers this import + # with the __init__ call. + # Do it here manually to be able to find all "global" tensors. + from torch.nested._internal.nested_tensor import _nt_view_dummy + except Exception: + pass + + global_tensors = ReleaseTestHook.get_all_tensors() + with tempfile.TemporaryDirectory() as tmp_dir: tmp_dir = Path(tmp_dir) + model = Model() + optimizer = pt.optimizer.Adam() t = pt.Trainer( - Model(), - optimizer=pt.optimizer.Adam(), + model=model, + optimizer=optimizer, storage_dir=str(tmp_dir), stop_trigger=(1, 'epoch'), summary_trigger=(1, 'epoch'), @@ -679,7 +708,7 @@ def post_step( t.register_validation_hook( validation_iterator=dt_dataset, max_checkpoints=None ) - t.register_hook(ReleaseTestHook()) # This hook will do the tests + t.register_hook(ReleaseTestHook(global_tensors)) # This hook will do the tests t.train(tr_dataset)