Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deprecated test case (NeuralFactory) + important bugfix! (v3) #298

Merged
merged 5 commits into from
Jan 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions nemo/core/neural_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
]

import random
import warnings
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Optional
Expand All @@ -21,6 +20,7 @@
from ..utils import ExpManager
from .callbacks import ActionCallback, EvaluatorCallback
from .neural_types import *
from nemo.utils.decorators import deprecated


class DeploymentFormat(Enum):
Expand Down Expand Up @@ -417,6 +417,7 @@ def __name_import(name):
mod = getattr(mod, comp)
return mod

@deprecated(version=0.11)
def __get_pytorch_module(self, name, params, collection, pretrained):
params["factory"] = self
if collection == "toys" or collection == "tutorials" or collection == "other":
Expand Down Expand Up @@ -493,6 +494,7 @@ def __get_pytorch_module(self, name, params, collection, pretrained):
instance = constructor(**params)
return instance

@deprecated(version=0.11)
def get_module(self, name, params, collection, pretrained=False):
"""
Creates NeuralModule instance
Expand Down Expand Up @@ -665,6 +667,7 @@ def clear_cache(self):
"""Helper function to clean inference cache."""
self._trainer.clear_cache()

@deprecated(version="future")
def _get_trainer(self, tb_writer=None):
if self._backend == Backend.PyTorch:
constructor = NeuralModuleFactory.__name_import("nemo.backends.pytorch.PtActions")
Expand All @@ -678,14 +681,12 @@ def _get_trainer(self, tb_writer=None):
else:
raise ValueError("Only PyTorch backend is currently supported.")

@deprecated(
version="future",
explanation="Please use .train(...), .eval(...), .infer(...) and "
f".create_optimizer(...) of the NeuralModuleFactory instance directly.",
)
def get_trainer(self, tb_writer=None):
nemo.logging.warning(
f"This function is deprecated and will be removed"
f"in future versions of NeMo."
f"Please use .train(...), .eval(...), .infer(...) and "
f".create_optimizer(...) directly methods from "
f"NeuralModuleFactory instance."
)
if self._trainer:
nemo.logging.warning(
"The trainer instance was created during initialization of "
Expand Down Expand Up @@ -742,9 +743,9 @@ def placement(self):
def optim_level(self):
return self._optim_level

@deprecated(version=0.11, explanation="Please use ``nemo.logging instead``")
@property
def logger(self):
warnings.warn("This will be deprecated in future releases. Please use " "nemo.logging instead")
return nemo.logging

@property
Expand Down
15 changes: 15 additions & 0 deletions nemo/utils/decorators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (C) NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .deprecated import deprecated
17 changes: 9 additions & 8 deletions nemo/utils/decorators/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = [
'deprecated',
]

import nemo


class deprecated(object):
""" Decorator class used for indicating that a function is
deprecated and going to be removed.
Tracks down which functions printed the warning and
will print it only once per function.
""" Decorator class used for indicating that a function is deprecated and going to be removed.
Tracks down which functions printed the warning and will print it only once per function.
"""

# Static variable - list of names of functions that we already printed
Expand All @@ -33,8 +34,7 @@ def __init__(self, version=None, explanation=None):

Args:
version: Version in which the function will be removed (optional)
explanation: Additional explanation (optional), e.g. use method
``blabla instead``.
explanation: Additional explanation (optional), e.g. use method ``blabla instead``.

"""
self.version = version
Expand All @@ -61,7 +61,8 @@ def wrapper(*args, **kwargs):

# Optionally, add version and alternative.
if self.version is not None:
msg = msg + " It is going to be removed in version {}.".format(self.version)
msg = msg + " It is going to be removed in "
msg = msg + "the {} version.".format(self.version)

if self.explanation is not None:
msg = msg + " " + self.explanation
Expand All @@ -70,6 +71,6 @@ def wrapper(*args, **kwargs):
nemo.logging.warning(msg)

# Call the function.
func(*args, **kwargs)
return func(*args, **kwargs)

return wrapper
9 changes: 3 additions & 6 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from unittest.mock import patch

from .common_setup import NeMoUnitTest
from nemo.utils.decorators.deprecated import deprecated
from nemo.utils.decorators import deprecated


class DeprecatedTestCase(NeMoUnitTest):
Expand Down Expand Up @@ -89,8 +89,7 @@ def say_whoopie():
# Check error output.
self.assertEqual(
std_err.getvalue().strip(),
'Function ``say_whoopie`` is deprecated. It is going \
to be removed in version 0.1.',
"Function ``say_whoopie`` is deprecated. It is going to be removed in the 0.1 version.",
)

def test_say_kowabunga_deprecated_explanation(self):
Expand All @@ -111,7 +110,5 @@ def say_kowabunga():

# Check error output.
self.assertEqual(
std_err.getvalue().strip(),
'Function ``say_kowabunga`` is deprecated. Please \
use ``print_ihaa`` instead.',
std_err.getvalue().strip(), 'Function ``say_kowabunga`` is deprecated. Please use ``print_ihaa`` instead.'
)