diff --git a/CHANGELOG.md b/CHANGELOG.md index e5ce8a2..f610dde 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,17 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## 0.7.12 + +- Fixed `num_training_steps` for lightning 1.7. + +- Changed all static methods `add_*_args` to standard form `add_argparse_args`. + +- Deprecated strategies based on DataParallel as in `pytorch-lightning` and added MPS accelerator. + +- Fixed deprecated classes in lightning 1.7. + + ## 0.7.10 - Moved `pre_trained_dir` hyperparameter from `Defaults` to `TransformersModelCheckpointCallback`. @@ -67,7 +78,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `SwappingLanguageModeling` technique and tests. -- Added `add_adapter_specific_args` method to `SuperAdapter` to allow adding parameters to the CLI. +- Added `add_argparse_args` method to `SuperAdapter` to allow adding parameters to the CLI. - Fixed typo with which `AdapterDataModule` was not receiving `collate_fn` argument. diff --git a/README.md b/README.md index cc862ee..9a8eb94 100644 --- a/README.md +++ b/README.md @@ -108,14 +108,14 @@ if __name__ == '__main__': parser.add_argument('--name', type=str, required=True, help='Name of the experiment, well be used to correctly retrieve checkpoints and logs') # I/O folders - DefaultConfig.add_defaults_args(parser) + DefaultConfig.add_argparse_args(parser) # add model specific cli arguments - TransformerModel.add_model_specific_args(parser) - YourDataModule.add_datamodule_specific_args(parser) + TransformerModel.add_argparse_args(parser) + YourDataModule.add_argparse_args(parser) # add callback / logger specific cli arguments - callbacks.TransformersModelCheckpointCallback.add_callback_specific_args(parser) + callbacks.TransformersModelCheckpointCallback.add_argparse_args(parser) # add all the available trainer options to argparse # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli diff --git a/tests/datamodules/test_datamodule_dp.py b/tests/datamodules/test_datamodule_mps.py similarity index 52% rename from tests/datamodules/test_datamodule_dp.py rename to tests/datamodules/test_datamodule_mps.py index cd7bb98..8c3f3c4 100644 --- a/tests/datamodules/test_datamodule_dp.py +++ b/tests/datamodules/test_datamodule_mps.py @@ -10,21 +10,15 @@ @pytest.mark.parametrize("batch_size", [1, 3]) @pytest.mark.parametrize("accumulate_grad_batches", [1, 11]) @pytest.mark.parametrize("iterable", [False, True]) -@pytest.mark.parametrize("devices", [1, 2, 4, 8]) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skipping GPU tests because this machine has not GPUs") -def test_datamodule_gpu_dp(num_workers, batch_size, accumulate_grad_batches, devices, iterable): - - # cannot do GPU training without enough devices - if torch.cuda.device_count() < devices: - pytest.skip() +@pytest.mark.skipif(not torch.has_mps, reason="Skipping MPS tests because this machine has no MPS device") +def test_datamodule_mps(num_workers, batch_size, accumulate_grad_batches, iterable): do_test_datamodule( num_workers, batch_size, accumulate_grad_batches, iterable, - strategy='dp', - accelerator='gpu', - devices=devices, + accelerator='mps', + devices=1, num_sanity_val_steps=0, ) diff --git a/tests/helpers.py b/tests/helpers.py index 8fed355..495138b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -29,6 +29,7 @@ padding='max_length', max_length=128, pin_memory=False, + prefetch_factor=2, ) diff --git a/tests/models/test_models_cpu.py b/tests/models/test_models_cpu.py index 0e0b9e7..b6c2346 100644 --- a/tests/models/test_models_cpu.py +++ b/tests/models/test_models_cpu.py @@ -14,3 +14,12 @@ def test_fix_max_steps_cpu(max_epochs, accumulate_grad_batches, batch_size): batch_size, accelerator="cpu", ) + + do_test_fix_max_steps( + max_epochs, + accumulate_grad_batches, + batch_size, + accelerator="cpu", + strategy="ddp", + devices=10, + ) diff --git a/tests/models/test_models_dp.py b/tests/models/test_models_dp.py deleted file mode 100644 index d8696bc..0000000 --- a/tests/models/test_models_dp.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest -import torch - -from tests.models.helpers import do_test_fix_max_steps - - -@pytest.mark.parametrize("max_epochs", (1, 2)) -@pytest.mark.parametrize("accumulate_grad_batches", (1, 3)) -@pytest.mark.parametrize("batch_size" , (1, 2, 3, 8, 11)) -@pytest.mark.parametrize("devices", (1, 2, 3)) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -def test_fix_max_steps_dp(max_epochs, accumulate_grad_batches, batch_size, devices): - - if torch.cuda.device_count() < devices: - pytest.skip() - - do_test_fix_max_steps( - max_epochs, - accumulate_grad_batches, - batch_size, - devices=devices, - strategy='dp', - accelerator='gpu', - ) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 033b289..6a32e38 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -47,7 +47,7 @@ def test_optimizers(optimizer_class, batch_size): ) parser = ArgumentParser() - optimizer_class.add_optimizer_specific_args(parser) + optimizer_class.add_argparse_args(parser) hyperparameters = Namespace(**vars(hyperparameters), **vars(parser.parse_args(""))) hyperparameters.optimizer_class = optimizer_class diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py index fbef671..7e9f87a 100644 --- a/tests/test_schedulers.py +++ b/tests/test_schedulers.py @@ -3,7 +3,8 @@ import pytest import pytorch_lightning as pl import torch -from transformers import AdamW, BertTokenizer +from torch.optim import AdamW +from transformers import BertTokenizer from tests.helpers import DummyDataModule, DummyTransformerModel, standard_args from transformers_lightning.schedulers import ( @@ -87,7 +88,7 @@ def test_schedulers(scheduler_class, parameters, expected_lrs): **parameters, ) - scheduler_class.add_scheduler_specific_args(ArgumentParser()) + scheduler_class.add_argparse_args(ArgumentParser()) class SchedulerModel(DummyTransformerModel): diff --git a/transformers_lightning/__init__.py b/transformers_lightning/__init__.py index 0980a41..2df7add 100644 --- a/transformers_lightning/__init__.py +++ b/transformers_lightning/__init__.py @@ -3,6 +3,7 @@ import transformers_lightning.datamodules # noqa: F401 import transformers_lightning.datasets # noqa: F401 import transformers_lightning.defaults # noqa: F401 +import transformers_lightning.info # noqa: F401 import transformers_lightning.language_modeling # noqa: F401 import transformers_lightning.loggers # noqa: F401 import transformers_lightning.models # noqa: F401 diff --git a/transformers_lightning/adapters/super_adapter.py b/transformers_lightning/adapters/super_adapter.py index c85c1fb..fa67e3d 100644 --- a/transformers_lightning/adapters/super_adapter.py +++ b/transformers_lightning/adapters/super_adapter.py @@ -48,5 +48,5 @@ def preprocess_line(self, line: list) -> list: return line @staticmethod - def add_adapter_specific_args(parser: ArgumentParser) -> ArgumentParser: + def add_argparse_args(parser: ArgumentParser) -> ArgumentParser: r""" Add here arguments that will be available from the command line. """ diff --git a/transformers_lightning/callbacks/README.md b/transformers_lightning/callbacks/README.md index 7271731..c011d3a 100644 --- a/transformers_lightning/callbacks/README.md +++ b/transformers_lightning/callbacks/README.md @@ -13,7 +13,7 @@ This callback can be used to save a checkpoint after every `k` steps, after ever >>> parser = ArgumentParser() >>> ... >>> # add callback / logger specific parameters ->>> callbacks.TransformersModelCheckpointCallback.add_callback_specific_args(parser) +>>> callbacks.TransformersModelCheckpointCallback.add_argparse_args(parser) >>> ... >>> hyperparameters = parser.parse_args() ``` diff --git a/transformers_lightning/callbacks/transformers_model_checkpoint.py b/transformers_lightning/callbacks/transformers_model_checkpoint.py index 5eea2ac..bf1009c 100644 --- a/transformers_lightning/callbacks/transformers_model_checkpoint.py +++ b/transformers_lightning/callbacks/transformers_model_checkpoint.py @@ -2,7 +2,7 @@ import shutil from argparse import ArgumentParser -from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.callbacks.callback import Callback from pytorch_lightning.utilities.rank_zero import rank_zero_warn from transformers_lightning.utils import dump_json, is_simple @@ -157,7 +157,7 @@ def on_validation_end(self, trainer, pl_module): self.save_model(pl_module, epoch=trainer.current_epoch, step=trainer.global_step) @staticmethod - def add_callback_specific_args(parser: ArgumentParser): + def add_argparse_args(parser: ArgumentParser): r""" Add callback_specific arguments to parser. """ parser.add_argument( '--checkpoint_interval', diff --git a/transformers_lightning/datamodules/super_datamodule.py b/transformers_lightning/datamodules/super_datamodule.py index 650147e..47bf331 100644 --- a/transformers_lightning/datamodules/super_datamodule.py +++ b/transformers_lightning/datamodules/super_datamodule.py @@ -118,7 +118,7 @@ def predict_dataloader(self): return None @staticmethod - def add_datamodule_specific_args(parser: ArgumentParser): + def add_argparse_args(parser: ArgumentParser): parser.add_argument( '--num_workers', required=False, diff --git a/transformers_lightning/defaults/__init__.py b/transformers_lightning/defaults/__init__.py index 11f6e8c..6bc122b 100644 --- a/transformers_lightning/defaults/__init__.py +++ b/transformers_lightning/defaults/__init__.py @@ -9,7 +9,7 @@ class DefaultConfig: """ @staticmethod - def add_defaults_args(parser: ArgumentParser): + def add_argparse_args(parser: ArgumentParser): parser.add_argument( '--output_dir', type=str, diff --git a/transformers_lightning/info.py b/transformers_lightning/info.py index a2fb4f7..491ef41 100644 --- a/transformers_lightning/info.py +++ b/transformers_lightning/info.py @@ -1,7 +1,7 @@ -__version__ = '0.7.11' +__version__ = '0.7.12' __author__ = 'Luca Di Liello and Matteo Gabburo' __author_email__ = 'luca.diliello@unitn.it' __license__ = 'GNU GENERAL PUBLIC LICENSE v2' -__copyright__ = f'Copyright (c) 2020-2021, {__author__}.' +__copyright__ = f'Copyright (c) 2020-2022, {__author__}.' __homepage__ = 'https://github.com/iKernels/transformers-lightning' __docs__ = "Utilities to use Transformers with Pytorch-Lightning" diff --git a/transformers_lightning/loggers/jsonboard_logger.py b/transformers_lightning/loggers/jsonboard_logger.py index 870f020..481afa7 100644 --- a/transformers_lightning/loggers/jsonboard_logger.py +++ b/transformers_lightning/loggers/jsonboard_logger.py @@ -7,16 +7,16 @@ import numpy as np import torch -from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment -from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.logger import _convert_params, _flatten_dict from pytorch_lightning.utilities.logger import _sanitize_params as _utils_sanitize_params +from pytorch_lightning.utilities.rank_zero import rank_zero_only logger = logging.getLogger(__name__) -class JsonBoardLogger(LightningLoggerBase): +class JsonBoardLogger(Logger): r""" Log to local file system in `JsonBoard `_ format. @@ -247,7 +247,7 @@ def __setstate__(self, state: Dict[Any, Any]): self.__dict__.update(state) @staticmethod - def add_logger_specific_args(parser: ArgumentParser): + def add_argparse_args(parser: ArgumentParser): r""" Add callback_specific arguments to parser. """ parser.add_argument( '--jsonboard_dir', type=str, required=False, default='jsonboard', help="Where to save logs." diff --git a/transformers_lightning/models/README.md b/transformers_lightning/models/README.md index c371c13..498427f 100644 --- a/transformers_lightning/models/README.md +++ b/transformers_lightning/models/README.md @@ -5,12 +5,12 @@ This package containts two high-level models that can be used to inherit some us ## TransfomersModel -`TransformersModel` only overrides `configure_optimizers` by returning a better optimizer and the relative scheduler and finally provides a `add_model_specific_args` to automatically add the parameters of the optimizer to the global parser. +`TransformersModel` only overrides `configure_optimizers` by returning a better optimizer and the relative scheduler and finally provides a `add_argparse_args` to automatically add the parameters of the optimizer to the global parser. Example: ```python >>> parser = ArgumentParser() ->>> TransformerModel.add_model_specific_args(parser) +>>> TransformerModel.add_argparse_args(parser) >>> save_transformers_callback = callbacks.TransformersModelCheckpointCallback(hyperparameters) >>> hyperparameters = parser.parse_args() ``` diff --git a/transformers_lightning/models/transformers_model.py b/transformers_lightning/models/transformers_model.py index 788d7ca..c08279a 100644 --- a/transformers_lightning/models/transformers_model.py +++ b/transformers_lightning/models/transformers_model.py @@ -2,7 +2,6 @@ from argparse import ArgumentParser, Namespace from pytorch_lightning import LightningModule -from pytorch_lightning.strategies import DataParallelStrategy, DDP2Strategy from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.rank_zero import rank_zero_warn from transformers.configuration_utils import PretrainedConfig @@ -19,9 +18,7 @@ class TransformersModel(LightningModule): - r""" - `TransformersModel` add a ready-to-be-used optimizer and scheduler functions. - """ + r""" `TransformersModel` adds ready-to-be-used optimizer and scheduler functions. """ model: PreTrainedModel tokenizer: PreTrainedTokenizerBase @@ -31,12 +28,10 @@ class TransformersModel(LightningModule): def __init__(self, hyperparameters): super().__init__() self.hyperparameters = hyperparameters - self.save_hyperparameters() + self.save_hyperparameters(hyperparameters) def forward(self, *args, **kwargs): - r""" - Simply call the `model` attribute with the given args and kwargs - """ + r""" Simply call the `model` attribute with the given args and kwargs """ return self.model(*args, **kwargs) def get_optimizer(self) -> SuperOptimizer: @@ -50,7 +45,7 @@ def get_scheduler(self, optimizer) -> SuperScheduler: return sched_class(self.hyperparameters, optimizer) def num_training_steps(self) -> int: - r""" Total training steps inferred from datasets length, nodes and devices. """ + r""" Total training steps inferred from datasets length, number of nodes and devices. """ if self.trainer.max_steps is not None and self.trainer.max_steps >= 0: return self.trainer.max_steps @@ -62,11 +57,7 @@ def num_training_steps(self) -> int: train_samples = len(self.trainer.datamodule.train_dataset) # number of training devices - if isinstance(self.trainer.strategy, (DataParallelStrategy, DDP2Strategy)): - total_devices = self.trainer.num_nodes - else: - total_devices = self.trainer.num_devices * self.trainer.num_nodes - + total_devices = self.trainer.num_devices * self.trainer.num_nodes rank_zero_warn(f"Number of training devices is {total_devices}") # the number of training samples may be modified in distributed training @@ -108,14 +99,14 @@ def configure_optimizers(self): 'optimizer': optimizer, 'lr_scheduler': { - 'scheduler': scheduler, # The LR schduler - 'interval': self.hyperparameters.scheduler_interval, # The unit of the scheduler's step size - 'frequency': self.hyperparameters.scheduler_frequency, # The frequency of the scheduler + 'scheduler': scheduler, # The LR schduler + 'interval': self.hyperparameters.scheduler_interval, # The unit of the scheduler's step size + 'frequency': self.hyperparameters.scheduler_frequency, # The frequency of the scheduler } } @staticmethod - def add_model_specific_args(parser: ArgumentParser): + def add_argparse_args(parser: ArgumentParser): parser.add_argument('--optimizer_class', type=str, default='AdamWOptimizer', choices=all_optimizers.keys()) parser.add_argument( '--scheduler_class', type=str, default='LinearSchedulerWithWarmup', choices=all_schedulers.keys() @@ -131,5 +122,5 @@ def add_model_specific_args(parser: ArgumentParser): sched_class = all_schedulers[tmp_params.scheduler_class] # add optimizer and scheduler specific args - optim_class.add_optimizer_specific_args(parser) - sched_class.add_scheduler_specific_args(parser) + optim_class.add_argparse_args(parser) + sched_class.add_argparse_args(parser) diff --git a/transformers_lightning/optimizers/adamw.py b/transformers_lightning/optimizers/adamw.py index 4850ee9..8e1c8b2 100644 --- a/transformers_lightning/optimizers/adamw.py +++ b/transformers_lightning/optimizers/adamw.py @@ -22,8 +22,8 @@ def __init__(self, hyperparameters: Namespace, named_parameters: Generator): betas=hyperparameters.adam_betas ) - def add_optimizer_specific_args(parser: ArgumentParser): - super(AdamWOptimizer, AdamWOptimizer).add_optimizer_specific_args(parser) + def add_argparse_args(parser: ArgumentParser): + super(AdamWOptimizer, AdamWOptimizer).add_argparse_args(parser) parser.add_argument('--learning_rate', type=float, default=1e-4) parser.add_argument('--weight_decay', type=float, default=0.0) parser.add_argument('--adam_epsilon', type=float, default=1e-8) diff --git a/transformers_lightning/optimizers/adamw_electra.py b/transformers_lightning/optimizers/adamw_electra.py index beca21f..45d7835 100644 --- a/transformers_lightning/optimizers/adamw_electra.py +++ b/transformers_lightning/optimizers/adamw_electra.py @@ -103,8 +103,8 @@ def __init__(self, hyperparameters: Namespace, named_parameters: Generator): amsgrad=hyperparameters.amsgrad ) - def add_optimizer_specific_args(parser: ArgumentParser): - super(ElectraAdamWOptimizer, ElectraAdamWOptimizer).add_optimizer_specific_args(parser) + def add_argparse_args(parser: ArgumentParser): + super(ElectraAdamWOptimizer, ElectraAdamWOptimizer).add_argparse_args(parser) parser.add_argument('--learning_rate', type=float, default=1e-3) parser.add_argument('--weight_decay', type=float, default=0.01) parser.add_argument('--adam_epsilon', type=float, default=1e-6) diff --git a/transformers_lightning/optimizers/super_optimizer.py b/transformers_lightning/optimizers/super_optimizer.py index 0ee89d8..9b1a42d 100644 --- a/transformers_lightning/optimizers/super_optimizer.py +++ b/transformers_lightning/optimizers/super_optimizer.py @@ -12,5 +12,5 @@ def __init__(self, hyperparameters: Namespace, *args, **kwargs) -> None: self.hyperparameters = hyperparameters @staticmethod - def add_optimizer_specific_args(parser: ArgumentParser): + def add_argparse_args(parser: ArgumentParser): r""" Add here the hyperparameters used by your optimizer. """ diff --git a/transformers_lightning/schedulers/constant_scheduler_with_warmup.py b/transformers_lightning/schedulers/constant_scheduler_with_warmup.py index fdb6f13..baee0d3 100644 --- a/transformers_lightning/schedulers/constant_scheduler_with_warmup.py +++ b/transformers_lightning/schedulers/constant_scheduler_with_warmup.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser, Namespace import torch -from pytorch_lightning.utilities.warnings import rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_warn from transformers_lightning.schedulers.super_scheduler import SuperScheduler @@ -50,7 +50,7 @@ def get_lr(self): return [base_lr * self.lr_lambda(self.last_epoch) for base_lr in self.base_lrs] @staticmethod - def add_scheduler_specific_args(parser: ArgumentParser): + def add_argparse_args(parser: ArgumentParser): r""" Add here the hyperparameters specific of the scheduler like the number of warmup steps. """ - super(ConstantSchedulerWithWarmup, ConstantSchedulerWithWarmup).add_scheduler_specific_args(parser) + super(ConstantSchedulerWithWarmup, ConstantSchedulerWithWarmup).add_argparse_args(parser) parser.add_argument('--num_warmup_steps', type=int, default=0) diff --git a/transformers_lightning/schedulers/cosine_scheduler_with_warmup.py b/transformers_lightning/schedulers/cosine_scheduler_with_warmup.py index 1256bdd..61dc146 100644 --- a/transformers_lightning/schedulers/cosine_scheduler_with_warmup.py +++ b/transformers_lightning/schedulers/cosine_scheduler_with_warmup.py @@ -55,8 +55,8 @@ def get_lr(self): return [base_lr * self.lr_lambda(self.last_epoch) for base_lr in self.base_lrs] @staticmethod - def add_scheduler_specific_args(parser: ArgumentParser): + def add_argparse_args(parser: ArgumentParser): r""" Add here the hyperparameters specific of the scheduler like the number of warmup steps. """ - super(CosineSchedulerWithWarmup, CosineSchedulerWithWarmup).add_scheduler_specific_args(parser) + super(CosineSchedulerWithWarmup, CosineSchedulerWithWarmup).add_argparse_args(parser) parser.add_argument('--num_warmup_steps', type=int, default=0) parser.add_argument('--num_cycles', type=float, default=1.0) diff --git a/transformers_lightning/schedulers/cosine_scheduler_with_warmup_and_hard_restart.py b/transformers_lightning/schedulers/cosine_scheduler_with_warmup_and_hard_restart.py index aa6ff5b..e48454f 100644 --- a/transformers_lightning/schedulers/cosine_scheduler_with_warmup_and_hard_restart.py +++ b/transformers_lightning/schedulers/cosine_scheduler_with_warmup_and_hard_restart.py @@ -57,10 +57,10 @@ def get_lr(self): return [base_lr * self.lr_lambda(self.last_epoch) for base_lr in self.base_lrs] @staticmethod - def add_scheduler_specific_args(parser: ArgumentParser): + def add_argparse_args(parser: ArgumentParser): r""" Add here the hyperparameters specific of the scheduler like the number of warmup steps. """ super( CosineSchedulerWithWarmupAndHardRestart, CosineSchedulerWithWarmupAndHardRestart - ).add_scheduler_specific_args(parser) + ).add_argparse_args(parser) parser.add_argument('--num_warmup_steps', type=int, default=0) parser.add_argument('--num_cycles', type=float, default=1.0) diff --git a/transformers_lightning/schedulers/layerwise_decay_scheduler.py b/transformers_lightning/schedulers/layerwise_decay_scheduler.py index 3e21ebc..defa949 100644 --- a/transformers_lightning/schedulers/layerwise_decay_scheduler.py +++ b/transformers_lightning/schedulers/layerwise_decay_scheduler.py @@ -122,11 +122,11 @@ def get_lr(self): return lrs @staticmethod - def add_scheduler_specific_args(parser: ArgumentError): + def add_argparse_args(parser: ArgumentError): r""" Add here the hyperparameters specific of the scheduler like the number of warmup steps. """ super( PolynomialLayerwiseDecaySchedulerWithWarmup, PolynomialLayerwiseDecaySchedulerWithWarmup - ).add_scheduler_specific_args(parser) + ).add_argparse_args(parser) parser.add_argument('--num_warmup_steps', type=int, default=0) parser.add_argument('--end_learning_rate', type=float, default=0.0001) parser.add_argument('--lr_decay_power', type=float, default=1.0) diff --git a/transformers_lightning/schedulers/linear_scheduler_with_warmup.py b/transformers_lightning/schedulers/linear_scheduler_with_warmup.py index 1c6d409..0f8e775 100644 --- a/transformers_lightning/schedulers/linear_scheduler_with_warmup.py +++ b/transformers_lightning/schedulers/linear_scheduler_with_warmup.py @@ -57,7 +57,7 @@ def get_lr(self): return [base_lr * self.lr_lambda(self.last_epoch) for base_lr in self.base_lrs] @staticmethod - def add_scheduler_specific_args(parser: ArgumentParser): + def add_argparse_args(parser: ArgumentParser): r""" Add here the hyperparameters specific of the scheduler like the number of warmup steps. """ - super(LinearSchedulerWithWarmup, LinearSchedulerWithWarmup).add_scheduler_specific_args(parser) + super(LinearSchedulerWithWarmup, LinearSchedulerWithWarmup).add_argparse_args(parser) parser.add_argument('--num_warmup_steps', type=int, default=0) diff --git a/transformers_lightning/schedulers/super_scheduler.py b/transformers_lightning/schedulers/super_scheduler.py index f1b8a2f..d236fc6 100644 --- a/transformers_lightning/schedulers/super_scheduler.py +++ b/transformers_lightning/schedulers/super_scheduler.py @@ -21,7 +21,7 @@ def num_training_steps(self): raise ValueError(f'scheduler {self.__class__.__name__} needs `max_steps` to be defined') @staticmethod - def add_scheduler_specific_args(parser: ArgumentParser): + def add_argparse_args(parser: ArgumentParser): r""" Add here the hyperparameters specific of the scheduler like the number of warmup steps. """ parser.add_argument('--scheduler_last_epoch', type=int, default=-1) parser.add_argument('--scheduler_verbose', action='store_true')