From 6d3ac156ae2a1ea9c9203002af41900267d61b40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Dudziak?= Date: Wed, 21 Apr 2021 19:50:27 +0100 Subject: [PATCH] First release of the code --- .gitignore | 10 + README.md | 200 ++++++- nasbench_asr/__init__.py | 77 +++ nasbench_asr/dataset.py | 532 ++++++++++++++++++ nasbench_asr/graph_utils.py | 383 +++++++++++++ nasbench_asr/model/__init__.py | 24 + nasbench_asr/model/tf/__init__.py | 51 ++ .../model/tf/mean_variance_normalization.py | 47 ++ nasbench_asr/model/tf/model.py | 96 ++++ nasbench_asr/model/tf/ops.py | 91 +++ nasbench_asr/model/torch/__init__.py | 47 ++ nasbench_asr/model/torch/model.py | 135 +++++ nasbench_asr/model/torch/ops.py | 84 +++ nasbench_asr/quiet_tensorflow.py | 35 ++ nasbench_asr/search_space.py | 93 +++ nasbench_asr/training/__init__.py | 90 +++ nasbench_asr/training/tf/__init__.py | 143 +++++ .../training/tf/callbacks/__init__.py | 0 .../training/tf/callbacks/lrscheduler.py | 63 +++ .../training/tf/callbacks/reset_states.py | 16 + .../training/tf/callbacks/tensorboard.py | 28 + .../training/tf/datasets/audio_feature.py | 475 ++++++++++++++++ .../training/tf/datasets/audio_featurizer.py | 69 +++ .../tf/datasets/audio_sentence_timit.py | 116 ++++ .../tf/datasets/cache_shard_shuffle_batch.py | 94 ++++ .../training/tf/datasets/phoneme_encoder.py | 40 ++ .../training/tf/datasets/preprocess.py | 63 +++ .../training/tf/datasets/text_encoder.py | 66 +++ .../training/tf/datasets/timit_foldings.py | 95 ++++ nasbench_asr/training/tf/main.py | 172 ++++++ nasbench_asr/training/tf/metrics/__init__.py | 0 nasbench_asr/training/tf/metrics/ctc.py | 105 ++++ nasbench_asr/training/tf/metrics/ler.py | 34 ++ nasbench_asr/training/tf/metrics/ratio.py | 41 ++ nasbench_asr/training/tf/metrics/roll.py | 6 + nasbench_asr/training/tf/metrics/wer.py | 41 ++ nasbench_asr/training/tf/trainer.py | 521 +++++++++++++++++ nasbench_asr/training/timit_folding.txt | 61 ++ nasbench_asr/training/timit_train_stats.npz | Bin 0 -> 1178 bytes nasbench_asr/training/torch/__init__.py | 24 + nasbench_asr/training/torch/encoder.py | 84 +++ nasbench_asr/training/torch/timit.py | 131 +++++ nasbench_asr/training/torch/trainer.py | 268 +++++++++ nasbench_asr/utils.py | 175 ++++++ nasbench_asr/version.py | 36 ++ setup.py | 60 ++ train.py | 58 ++ 47 files changed, 5079 insertions(+), 1 deletion(-) create mode 100644 .gitignore create mode 100644 nasbench_asr/__init__.py create mode 100644 nasbench_asr/dataset.py create mode 100644 nasbench_asr/graph_utils.py create mode 100644 nasbench_asr/model/__init__.py create mode 100644 nasbench_asr/model/tf/__init__.py create mode 100644 nasbench_asr/model/tf/mean_variance_normalization.py create mode 100644 nasbench_asr/model/tf/model.py create mode 100644 nasbench_asr/model/tf/ops.py create mode 100644 nasbench_asr/model/torch/__init__.py create mode 100644 nasbench_asr/model/torch/model.py create mode 100644 nasbench_asr/model/torch/ops.py create mode 100644 nasbench_asr/quiet_tensorflow.py create mode 100644 nasbench_asr/search_space.py create mode 100644 nasbench_asr/training/__init__.py create mode 100644 nasbench_asr/training/tf/__init__.py create mode 100644 nasbench_asr/training/tf/callbacks/__init__.py create mode 100644 nasbench_asr/training/tf/callbacks/lrscheduler.py create mode 100644 nasbench_asr/training/tf/callbacks/reset_states.py create mode 100644 nasbench_asr/training/tf/callbacks/tensorboard.py create mode 100644 nasbench_asr/training/tf/datasets/audio_feature.py create mode 100644 nasbench_asr/training/tf/datasets/audio_featurizer.py create mode 100644 nasbench_asr/training/tf/datasets/audio_sentence_timit.py create mode 100644 nasbench_asr/training/tf/datasets/cache_shard_shuffle_batch.py create mode 100644 nasbench_asr/training/tf/datasets/phoneme_encoder.py create mode 100644 nasbench_asr/training/tf/datasets/preprocess.py create mode 100644 nasbench_asr/training/tf/datasets/text_encoder.py create mode 100644 nasbench_asr/training/tf/datasets/timit_foldings.py create mode 100644 nasbench_asr/training/tf/main.py create mode 100644 nasbench_asr/training/tf/metrics/__init__.py create mode 100644 nasbench_asr/training/tf/metrics/ctc.py create mode 100644 nasbench_asr/training/tf/metrics/ler.py create mode 100644 nasbench_asr/training/tf/metrics/ratio.py create mode 100644 nasbench_asr/training/tf/metrics/roll.py create mode 100644 nasbench_asr/training/tf/metrics/wer.py create mode 100644 nasbench_asr/training/tf/trainer.py create mode 100644 nasbench_asr/training/timit_folding.txt create mode 100644 nasbench_asr/training/timit_train_stats.npz create mode 100644 nasbench_asr/training/torch/__init__.py create mode 100644 nasbench_asr/training/torch/encoder.py create mode 100644 nasbench_asr/training/torch/timit.py create mode 100644 nasbench_asr/training/torch/trainer.py create mode 100644 nasbench_asr/utils.py create mode 100644 nasbench_asr/version.py create mode 100644 setup.py create mode 100644 train.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d8bac4a --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +__pycache__ +graphs/ +*.pyc +autocaml/exps +autocaml/data +autocaml/plots +*.egg-info +tools/plots +results/ +.vscode diff --git a/README.md b/README.md index f6d236e..e032e04 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,205 @@ -# NAS-Bench-ASR +# NASBench-ASR + Code for the "NAS-Bench-ASR: Reproducible Neural Architecture Search for Speech Recognition" paper published at ICLR'21: https://openreview.net/forum?id=CU0APx9LMaL +Content: + * [Installing](#Installing) + * [Dataset format](#Dataset-format) + * [Using the dataset](#using-the-dataset) + * [Creating and training models](#Creating-and-training-models) + +## Installing + +You can type `pip install ` to install our tool (optionally with the -e argument for in-place installation) and its dependencies. +We recommend using pipenv (or some other virtual environment management tool) to avoid problems with TF/Pytorch versions. + +> **Note:** please let us know if you run into any problems related to missing dependencies + + +## Dataset format + +We split out dataset into multiple pickle files containing information about the models from the search space in different settings. +Each dataset contains two parts: a header and the actual data. Those are serialized using pickle one after another, so when reading a similar sequence +needs to be used: +```python +with open('nb-asr-e40-1234.pickle', 'rb') as f: + header = pickle.load(f) + data = pickle.load(f) +``` + +The header contains usual metainformation about the dataset file, including things like: search space used when generating the dataset, column names, dataset type and version. +The following chunk of data is a python list with raw values - the order of values follows the order of column names in the header. + +We distinguish three main types of datasets: + * training datasets, which contain information about accuracy of models during/after training + * benchmarking datasets, which contain information about on-device performance of models + * static information datasets, which constain static information about models, such as number of parameters + +Please see the following section to see a brief summary of how to use the dataset conveniently. + +## Using the dataset + +All relevant files can be downaloded from the `releases` page in this repo. + +We provide a `Dataset` class as the top-level interface to querying the datasets (although the user is free to read the pickle files on their own). +The `Dataset` class is primarily used to deal with training datasets but it has an option to also piggy-back benchmarking and static datasets for convenience. +If needed the user can also use `BenchmarkingDataset` and `StaticInfoDataset` classes to access only benchmarking and/or static information about models without loading the data realted to training. + +Assuming all NB-ASR files are in the same directory and no special behaviour is needed, the user can also use a high-level `nasbench_asr.from_folder` function which automatically searches for files in the given directory and creates a `Dataset` object from them. + +The rest of the section presents some typical use cases. + +Creating the dataset: +```python +>>> import nasbench_asr as nbasr +>>> d = nbasr.from_folder('~/data/nasbench_asr/', include_static_info=True) +``` + +Querying all information, returned as dict or list: +```python +>>> d.full_info([[1,0], [1,0,0], [1,0,0,0]], include_static_info=True) +{'val_per': [0.47851497, 0.32516438, 0.27674836, 0.25589427, 0.24639702, 0.23125456, 0.22919573, 0.228598, 0.22308561, 0.21856944, 0.22109318, 0.2183702, 0.21451816, 0.21498306, 0.21458457, 0.21239291, 0.21431892, 0.21418609, 0.21584645, 0.21584645, 0.21578003, 0.21704191, 0.21664342, 0.21843661, 0.2188351, 0.22003055, 0.22109318, 0.22149166, 0.23816165, 0.23643488, 0.22886366, 0.22082752, 0.2207611, 0.22142525, 0.22169091, 0.22056186, 0.22149166, 0.22182374, 0.22142525, 0.22202298], 'test_per': 0.242688849568367, 'arch_vec': [(1, 0), (1, 0, 0), (1, 0, 0, 0)], 'model_hash': '36855332a5778e0df5114305bc3ce238', 'seed': 1235, 'gtx-1080ti-fp32': {'latency': 0.04320073127746582}, 'jetson-nano-fp32': {'latency': 0.5421140193939209}, 'info': {'params': 26338848}} +>>> d.full_info([[1,0], [1,0,0], [1, 0,0,0]], include_static_info=True, return_dict=False) +['36855332a5778e0df5114305bc3ce238', [0.4840938, 0.31912068, 0.27867436, 0.25908217, 0.24433818, 0.23291492, 0.22713688, 0.22979344, 0.22288637, 0.22036262, 0.22056186, 0.21637776, 0.21823737, 0.21637776, 0.21272498, 0.21245933, 0.21318989, 0.21458457, 0.21591286, 0.2169755, 0.21797171, 0.21863586, 0.22036262, 0.22129242, 0.22129242, 0.2216245, 0.23152022, 0.24480309, 0.23450887, 0.22554293, 0.22268713, 0.221226, 0.22175732, 0.2216245, 0.22202298, 0.22182374, 0.22149166, 0.22222222, 0.22242147, 0.22228864], 0.23728343844413757, [(1, 0), (1, 0, 0), (1, 0, 0, 0)], 1236, [0.04320073127746582], [0.5421140193939209], [26338848]] +``` + +Removing static information: +```python +>>> d.full_info([[1,0], [1,0,0], [1, 0,0,0]], include_static_info=False) +{'val_per': [0.47851497, 0.32516438, 0.27674836, 0.25589427, 0.24639702, 0.23125456, 0.22919573, 0.228598, 0.22308561, 0.21856944, 0.22109318, 0.2183702, 0.21451816, 0.21498306, 0.21458457, 0.21239291, 0.21431892, 0.21418609, 0.21584645, 0.21584645, 0.21578003, 0.21704191, 0.21664342, 0.21843661, 0.2188351, 0.22003055, 0.22109318, 0.22149166, 0.23816165, 0.23643488, 0.22886366, 0.22082752, 0.2207611, 0.22142525, 0.22169091, 0.22056186, 0.22149166, 0.22182374, 0.22142525, 0.22202298], 'test_per': 0.242688849568367, 'arch_vec': [(1, 0), (1, 0, 0), (1, 0, 0, 0)], 'model_hash': '36855332a5778e0df5114305bc3ce238', 'seed': 1235, 'gtx-1080ti-fp32': {'latency': 0.04320073127746582}, 'jetson-nano-fp32': {'latency': 0.5421140193939209}} +``` + +Asking for a particular device performance only: +```python +>>> d.full_info([[1,0], [1,0,0], [1, 0,0,0]], devices='jetson-nano-fp32') +{'val_per': [0.4840938, 0.31912068, 0.27867436, 0.25908217, 0.24433818, 0.23291492, 0.22713688, 0.22979344, 0.22288637, 0.22036262, 0.22056186, 0.21637776, 0.21823737, 0.21637776, 0.21272498, 0.21245933, 0.21318989, 0.21458457, 0.21591286, 0.2169755, 0.21797171, 0.21863586, 0.22036262, 0.22129242, 0.22129242, 0.2216245, 0.23152022, 0.24480309, 0.23450887, 0.22554293, 0.22268713, 0.221226, 0.22175732, 0.2216245, 0.22202298, 0.22182374, 0.22149166, 0.22222222, 0.22242147, 0.22228864], 'test_per': 0.23728343844413757, 'arch_vec': [(1, 0), (1, 0, 0), (1, 0, 0, 0)], 'model_hash': '36855332a5778e0df5114305bc3ce238', 'seed': 1236, 'jetson-nano-fp32': {'latency': 0.5421140193939209}, 'info': {'params': 26338848}} +``` + +Do not include any benchmarking results: +```python +>>> d.full_info([[1,0], [1,0,0], [1, 0,0,0]], devices=False, include_static_info=False) +{'val_per': [0.48522282, 0.32031614, 0.28338978, 0.25430033, 0.24128312, 0.23942353, 0.22547652, 0.22733612, 0.22527727, 0.22109318, 0.21670984, 0.21929999, 0.21551438, 0.21458457, 0.21226008, 0.21305706, 0.2137876, 0.21352194, 0.2127914, 0.21491665, 0.21597928, 0.21777247, 0.21996413, 0.2249452, 0.2412167, 0.23484094, 0.23152022, 0.22281995, 0.21890152, 0.21870227, 0.21896791, 0.21896791, 0.21810454, 0.21863586, 0.21923357, 0.21896791, 0.21923357, 0.2198313, 0.21996413, 0.22056186], 'test_per': 0.23395703732967377, 'arch_vec': [(1, 0), (1, 0, 0), (1, 0, 0, 0)], 'model_hash': '36855332a5778e0df5114305bc3ce238', 'seed': 1234} +``` + +Querying test accuracy: +```python +>>> d.test_acc([[1,0], [1,0,0], [1,0,0,0]]) +0.242688849568367 +``` + +Querying validation accuracy: +```python +>>> d.val_acc([[1,0], [1,0,0], [1,0,0,0]]) +0.21245933 +>>> d.val_acc([[1,0], [1,0,0], [1,0,0,0]], best=False) +0.22056186 +>>> d.val_acc([[1,0], [1,0,0], [1,0,0,0]], epoch=8) +0.22547652 +>>> d.val_acc([[1,0], [1,0,0], [1,0,0,0]], epoch=8, best=False) +0.22979344 +``` + +Querying latency alone: +```python +>>> d.latency([[1,0], [1,0,0], [1,0,0,0]], devices='gtx-1080ti-fp32') +[[0.04320073127746582]] +>>> d.latency([[1,0], [1,0,0], [1,0,0,0]], devices='gtx-1080ti-fp32', return_dict=True) +{'gtx-1080ti-fp32': {'latency': 0.04320073127746582}} +>>> d.bench_info.latency([[1,0], [1,0,0], [1,0,0,0]], devices='gtx-1080ti-fp32', return_dict=True) +{'gtx-1080ti-fp32': {'latency': 0.04320073127746582}} +``` + +Asking for missing information will result in an error: +```python +>>> d = nbasr.from_folder('~/data/nasbench_asr/', include_static_info=False) +>>> d.full_info([[1,0], [1,0,0], [1, 0,0,0]], include_static_info=True) +Traceback (most recent call last): + File "", line 1, in + File "/home/SERILOCAL/l.dudziak/dev/asr/pytorch-asr/nasbench_asr/dataset.py", line 345, in full_info + return self._query(model_hash, seed, devices, include_static_info, return_dict) + File "/home/SERILOCAL/l.dudziak/dev/asr/pytorch-asr/nasbench_asr/dataset.py", line 304, in _query + raise ValueError('No static information attached') +ValueError: No static information attached +``` + +Default values will always include data only if available: +```python +>>> d = nbasr.from_folder('~/data/nasbench_asr/', max_epochs=5, devices=False) +>>> d.full_info([[1,0], [1,0,0], [1, 0,0,0]]) +{'val_per': [0.4846915, 0.3614266, 0.32323837, 0.31241283, 0.3053065], 'test_per': 0.3227997124195099, 'arch_vec': [(1, 0), (1, 0, 0), (1, 0, 0, 0)], 'model_hash': '36855332a5778e0df5114305bc3ce238', 'seed': 1234} +>>> d = nbasr.from_folder('~/data/nasbench_asr/', max_epochs=5, devices=None) +>>> d.full_info([[1,0], [1,0,0], [1, 0,0,0]]) +{'val_per': [0.4846915, 0.3614266, 0.32323837, 0.31241283, 0.3053065], 'test_per': 0.3227997124195099, 'arch_vec': [(1, 0), (1, 0, 0), (1, 0, 0, 0)], 'model_hash': '36855332a5778e0df5114305bc3ce238', 'seed': 1234, 'gtx-1080ti-fp32': {'latency': 0.04320073127746582}, 'jetson-nano-fp32': {'latency': 0.5421140193939209}} +``` + +`nasbench_asr.dataset.from_folder` silently fails to include requested data if it doesn't exist, if this is undesired please consider using `Dataset` directly. + +```python +>>> d = nbasr.from_folder('~/data/nasbench_asr/', max_epochs=5, devices='non-existing-device') +>>> d.full_info([[1,0], [1,0,0], [1, 0,0,0]]) +{'val_per': [0.4846915, 0.3614266, 0.32323837, 0.31241283, 0.3053065], 'test_per': 0.3227997124195099, 'arch_vec': [(1, 0), (1, 0, 0), (1, 0, 0, 0)], 'model_hash': '36855332a5778e0df5114305bc3ce238', 'seed': 1234} +>>> d.full_info([[1,0], [1,0,0], [1, 0,0,0]], devices='non-existing-device') +Traceback (most recent call last): + File "", line 1, in + File "/home/SERILOCAL/l.dudziak/dev/asr/pytorch-asr/nasbench_asr/dataset.py", line 346, in full_info + return self._query(model_hash, seed, devices, include_static_info, return_dict) + File "/home/SERILOCAL/l.dudziak/dev/asr/pytorch-asr/nasbench_asr/dataset.py", line 295, in _query + raise ValueError('No benchmarking information attached') +ValueError: No benchmarking information attached +``` + +## Creating and training models +Alongside the dataset, we provide code to create models and training environment to train them in a reproducible way. +We anticipate that this might be especially useful for people working on differentiable NAS. + +We support two training backends - Tensorflow or PyTorch. **However, please bear in mind that only the TF backend is meant to reproduce results from the dataset as we were using TF during our experiments. +PyTorch implementation is provided as a courtesy and comes with no guarantees about achievable results** + +The following is a short summary of top-level function exposed by the package, most of those functions take an extra argument called `backend` which can be used to explicitly specify which implementation should be called. +If unused, the functions will try to use the default backend - which can be either set explicitly or it can be deduced. +The default backend is deduced based on available packages and prefers to use TF, falling back to PyTorch if TF is unavailable. + * `set_default_backend(name)` sets the default backend + * `get_backend_name()` gets the name of the backend in use (unless overwritten by a function-specific argument) + * `set_seed(seed)` sets random seed(s) to the specific value + * `prepare_devices(devices)` prepared the specified GPUs for training (only relevant for TF backend, this e.g. turns on dynamic memory growth) + * `get_model(arch)` return an implementation of a model with the specified architecture (`arch` should come from the search space, see e.g. `nasbench_asr.search_space.get_all_architectures`) + * `get_dataloader(timti_root, batch_size)` returns a tuple of 5 values, in order: encoder object (used to encode phonemes), iterable yielding training examples, iterable yielding validation examples, iterable yielding testing examples, backend-specific data + * `get_loss()` returns a callable objective, the signature is `(logits, logits_size, targets, targets_size)` + * `get_trainer(dataloaders, gpus, save_dir, verbose)` returns a trainer class which can be used to train models. + + For more information on about to use those functions, please see for example `train.py` which can be used to run trainings of models, using those functions. + You can also take a look at the `Trainer` abstract class defined in `nasbench_asr/training/__init__.py`. + Briefly speaking, a sequence of functions like the following should do the trick: + + ```python + set_default_backend(args.backend) + set_seed(args.seed) + prepare_devices(args.gpus) + + print(f'Using backend: {get_backend_name()}') + print(f' Model vec: {args.model}') + print(f' Training for {args.epochs} epochs') + print(f' Batch size: {args.batch_size}') + print(f' Learning rate: {args.lr}') + print(f' Dropout: {args.dropout}') + print(f' GPUs: {args.gpus}') + + results_folder = pathlib.Path(args.exp_folder) / args.backend + + first_gpu = None + if args.gpus: + first_gpu = args.gpus[0] + + dataloaders = get_dataloaders(args.data, batch_size=args.batch_size) + model = get_model(args.model, use_rnn=args.rnn, dropout_rate=args.dropout, gpu=first_gpu) + trainer = get_trainer(dataloaders, gpus=args.gpus, save_dir=results_folder, verbose=True) + trainer.train(model, epochs=args.epochs, lr=args.lr, reset=args.reset, model_name=args.exp_name) + ``` + + ## Citation + +Please consider citing NB-ASR if you find our work useful! ``` @inproceedings{ mehrotra2021nasbenchasr, diff --git a/nasbench_asr/__init__.py b/nasbench_asr/__init__.py new file mode 100644 index 0000000..7e7cc2f --- /dev/null +++ b/nasbench_asr/__init__.py @@ -0,0 +1,77 @@ +import functools + +from . import dataset +from . import model +from . import training + +from .utils import add_module_properties +from .utils import staticproperty + + +Dataset = dataset.Dataset +BenchmarkingDataset = dataset.BenchmarkingDataset +StaticInfoDataset = dataset.StaticInfoDataset + +from_folder = dataset.from_folder + + +def set_default_backend(backend): + from .model import set_default_backend as impl1 + from .training import set_default_backend as impl2 + name1 = impl1(backend) + name2 = impl2(backend) + return name1, name2 + +def get_backend_name(): + from .model import get_backend_name as impl1 + from .training import get_backend_name as impl2 + return impl1(), impl2() + +@functools.wraps(training.set_seed) +def set_seed(*args, **kwargs): + return training.set_seed(*args, **kwargs) + +@functools.wraps(training.prepare_devices) +def prepare_devices(*args, **kwargs): + return training.prepare_devices(*args, **kwargs) + +@functools.wraps(model.get_model) +def get_model(*args, **kwargs): + return model.get_model(*args, **kwargs) + +@functools.wraps(training.get_dataloaders) +def get_dataloaders(*args, **kwargs): + return training.get_dataloaders(*args, **kwargs) + +@functools.wraps(training.get_loss) +def get_loss(*args, **kwargs): + return training.get_loss(*args, **kwargs) + +@functools.wraps(training.get_trainer) +def get_trainer(*args, **kwargs): + return training.get_trainer(*args, **kwargs) + + +def _get_version(): + from . import version + return version.version + +def _get_has_repo(): + from . import version + return version.has_repo + +def _get_repo(): + from . import version + return version.repo + +def _get_commit(): + from . import version + return version.commit + + +add_module_properties(__name__, { + '__version__': staticproperty(staticmethod(_get_version)), + '__has_repo__': staticproperty(staticmethod(_get_has_repo)), + '__repo__': staticproperty(staticmethod(_get_repo)), + '__commit__': staticproperty(staticmethod(_get_commit)) +}) diff --git a/nasbench_asr/dataset.py b/nasbench_asr/dataset.py new file mode 100644 index 0000000..bb119b3 --- /dev/null +++ b/nasbench_asr/dataset.py @@ -0,0 +1,532 @@ +from multiprocessing import Value +import re +import pickle +import random +import pathlib +import functools +import collections.abc as cabc + +from . import graph_utils +from . import search_space + + +class _Dataset(): + def __init__(self, dataset_files, validate_data, db_type): + if isinstance(dataset_files, str): + dataset_files = [dataset_files] + + self.dbs = [] + self.header = None + if db_type == 'training': + self.seeds = [] + elif db_type == 'benchmarking': + self.devices = [] + elif db_type == 'static': + if len(dataset_files) != 1: + raise ValueError('Expected exactly one dataste file') + + for db_file in dataset_files: + with open(db_file, 'rb') as f: + header = pickle.load(f) + if header['dataset_type'] != db_type: + raise ValueError(f'Expected a dataset file with {db_type} information') + + if db_type == 'training': + seed = header.pop('seed') + elif db_type == 'benchmarking': + device = header.pop('device') + + if self.header is None: + self.header = header + if self.header != header: + raise ValueError('Different dataset files contain data for different settings') + + # TODO: we could relax this if needed + if db_type == 'training': + if header['columns'][:3] != ['model_hash', 'val_per', 'test_per']: + raise ValueError('In the current implementation we expect the dataset to contain information in order: model hash, val PER, test PER') + elif db_type == 'benchmarking': + if header['columns'][:2] != ['model_hash', 'latency']: + raise ValueError('In the current implementation we expect the dataset to contain information in order: model hash, latency') + elif db_type == 'static': + if header['columns'][:2] != ['model_hash', 'params']: + print(header['columns']) + raise ValueError('In the current implementation we expect the dataset to contain information in order: model hash, number of parameters') + + if db_type == 'training': + self.seeds.append(seed) + elif db_type == 'benchmarking': + self.devices.append(device) + data = pickle.load(f) + data_dict = { model_hash: rest for model_hash, *rest in data } + self.dbs.append(data_dict) + + if not self.dbs: + raise ValueError('At least one dataset should be read') + + if validate_data and len(self.dbs) > 1: + #if db_type == 'training': + models = { model_hash: model_pt for model_hash, (*_, model_pt) in self.dbs[0].items() } + for fidx, db in enumerate(self.dbs[1:]): + if len(db) != len(models): + raise ValueError(f'Dataset file at position {fidx+1} has {len(db)} entries but the one at position 0 has {len(models)}') + for model_hash, (*_, model_pt) in db.items(): + if model_hash not in models: + raise ValueError(f'{model_hash} is present in dataset file {fidx+1} but no in 0') + if db_type == 'training': + # even if this is not true, the same model hash should guarantee that the architectures are the same + # however, internally we'd expect the points to be the same + assert model_pt == models[model_hash] + + @property + def version(self): + ''' Version of the dataset. + ''' + return self.header['version'] + + @property + def search_space(self): + ''' Search space shape. A (potentially nested) list of integers identifying + different choices and their related number of options. + ''' + return self.header['search_space']['shape'] + + @property + def ops(self): + ''' List of the operations which were considered when creating the dataset. + ''' + return self.header['search_space']['ops'] + + @property + def nodes(self): + ''' Number of nodes which was considered when creating the dataset. + ''' + return self.header['search_space']['nodes'] + + @property + def columns(self): + ''' Names of values stored in the dataset, in-order. + Can be used to identify specific information from values returned by + functions which do not convert their results to dictionaries. + See the remaining API for more information. + ''' + return self.header['columns'] + + def __contains__(self, arch): + h = search_space.get_model_hash(arch, ops=self.ops) + return h in self.dbs[0] + + +class StaticInfoDataset(_Dataset): + def __init__(self, dataset_file): + super().__init__([dataset_file], False, 'static') + + def _get(self, model_hash, return_dict): + r = self.dbs[0].get(model_hash) + if return_dict and r is not None: + return dict(zip(self.columns[1:], r)) + return r + + def params(self, arch): + ''' Return the number of parameters in a specific architecture. + + Arguments: + arch - a point from the search space identifying a model + return_dict - (optional) determinates if the returned values will be provided + as a ``dict`` or a scalar value. A ``dict`` contains the same values as + the ``list`` but allows the user to extract them by their names, whereas + a list can be thought of as a single row in a table containing values only. + The user can map particular elements of the returned ``list`` by considering + the values in provided ``devices`` argument. Default: ``False``. + + Returns: + ``None`` if information about a given ``arch`` cannot be found in the dataset, + otherwise a ``dict`` or a ``list`` containing information about the model. + ''' + model_hash = search_space.get_model_hash(arch, ops=self.ops) + ret = self._get(model_hash, False) + return ret[0] + + +class BenchmarkingDataset(_Dataset): + ''' An object representing a queryable dataset containing benchmarking information + of Nasbench-ASR models. + + The dataset is constructed by reading a set of pickle files containing + information about models benchmarked on different devices. + + All the files used to create a single ``BenchmarkingDataset`` object should contian information + about models coming from the same search space and can only differ by the type of device used. + If you want to compare performance of models from different search spaces you'd need to create + different objects for each case. + ''' + def __init__(self, dataset_files, validate_data=True): + ''' Create a new dataset by loading data from the provided list of files. + + If multiple files are given, they should contain information about models + from the same search space, benchmarked on different devices. + + If ``validate_data`` is set to ``True``, the data from the files will be validated + to check if it's consistent. If the files are known to be ok, the argument can be + set to ``False`` to speed up loading time a little bit (or to hack the code if you know + what you are doing). + ''' + super().__init__(dataset_files, validate_data, 'benchmarking') + + def _get(self, model_hash, devices, ret_dict): + if devices is None: + devices = self.devices + indices = list(range(len(self.devices))) + else: + if isinstance(devices, str): + devices = [devices] + indices = [self.devices.index(d) for d in devices] + + raw = [] if not ret_dict else {} + for didx, device_name in zip(indices, devices): + value = self.dbs[didx].get(model_hash) + if value is None: + return None + if not ret_dict: + raw.append(value) + else: + value = dict(zip(self.columns[1:], value)) + raw[device_name] = value + + return raw + + def latency(self, arch, devices=None, return_dict=False): + ''' Return benchmarking information about a specific architecture on the provided + devices from the dataset. + + Arguments: + arch - a point from the search space identifying a model + device - (optional) if provided, the returned will be information about + the model's performance when run on the device with the given name(s), + otherwise latency on all devices will be returned; accepted values are: + Str, List[Str] and None + return_dict - (optional) determinates if the returned values will be provided + as a ``dict`` or a simple ``list``. A ``dict`` contains the same values as + the ``list`` but allows the user to extract them by their names, whereas + a list can be thought of as a single row in a table containing values only. + The user can map particular elements of the returned ``list`` by considering + the values in provided ``devices`` argument. Default: ``False``. + + Returns: + ``None`` if information about a given ``arch`` cannot be found in the dataset, + otherwise a ``dict`` or a ``list`` containing information about the model. + + Raises: + ValueError - if invalid ``device`` is given + ''' + model_hash = search_space.get_model_hash(arch, ops=self.ops) + return self._get(model_hash, devices, return_dict) + + +class Dataset(_Dataset): + ''' An object representing a queryable NasBench-ASR dataset. + + The dataset is constructed by reading a set of pickle files containing training + information about models using different configurations (different initialization + seed and/or total number of epochs). + + The training information can be optionally extended with benchmarking and static + (e.g. number of parameters) information. + + All the files used to create a single ``Dataset`` object should contian information + about models trained in the same setting and can only differ by the initialization seed. + If you want to compare performance of models in different settings, e.g. using full training + or reduced training of 10 epochs, you'd need to create different objects for each case. + ''' + def __init__(self, dataset_files, devices_files=None, static_info=None, validate_data=True): + ''' Create a new dataset by loading data from the provided list of files. + + If multiple files are given, they should contain information about models + trained in the same setting, differing only by their initialization seed. + + If ``validate_data`` is set to ``True``, the data from the files will be validated + to check if it's consistent. If the files are known to be ok, the argument can be + set to ``False`` to speed up loading time a little bit (or to hack the code if you know + what you are doing). + ''' + super().__init__(dataset_files, validate_data, 'training') + self.bench_info = None + self.static_info = None + if devices_files: + self.bench_info = BenchmarkingDataset(devices_files, validate_data=validate_data) + if static_info: + self.static_info = StaticInfoDataset(static_info) + + @property + def epochs(self): + ''' Total number of epochs for which the models were trained when creating the dataset. + ''' + return self.header['epochs'] + + def _get_raw_info(self, seed_idx, model_hash): + raw = self.dbs[seed_idx].get(model_hash) + if raw is None: + return None + return [model_hash] + list(raw) + [self.seeds[seed_idx]] + + def _get_info_dict(self, seed_idx, model_hash): + raw = self.dbs[seed_idx].get(model_hash) + if raw is not None: + raw = dict(zip(self.columns[1:], raw)) + raw[self.columns[0]] = model_hash + raw['seed'] = self.seeds[seed_idx] + return raw + + def _get_info(self, seed_idx, model_hash, return_dict): + if return_dict: + return self._get_info_dict(seed_idx, model_hash) + else: + return self._get_raw_info(seed_idx, model_hash) + + def _query(self, model_hash, seed, devices, include_static_info, return_dict): + if seed is None: + seed_idx = random.randrange(len(self.seeds)) + else: + seed_idx = self.seeds.index(seed) + + ret = self._get_info(seed_idx, model_hash, return_dict) + if devices != False and (devices is not None or self.bench_info): + if not self.bench_info: + raise ValueError('No benchmarking information attached') + lat = self.bench_info._get(model_hash, devices, return_dict) + if return_dict: + ret.update(lat) + else: + ret.extend(lat) + + if include_static_info: + if not self.static_info: + raise ValueError('No static information attached') + info = self.static_info._get(model_hash, return_dict) + if return_dict: + ret['info'] = info + else: + ret.append(info) + + return ret + + def full_info(self, arch, seed=None, devices=None, include_static_info=None, return_dict=True): + ''' Return all information about a specific architecture from the dataset. + If multiple seeds are available, the can either return information about + a specific one or a random one. + + Arguments: + arch - a point from the search space identifying a model + seed - (optional) if provided, the returned will be information about + the model's performance when initialized with this particular seed, + otherwise information related to a randomly chosen seed from the list + if available ones will be used. Default: random seed + devices - (optional) add information about benchmarking on the provided devices, + if ``None`` all available devices are included, otherwise should be a name of + the device or a list of names, can also be exactly ``False`` to avoid including + benchmarking information even when they are available + include_static_info - (optional) include static information about the model, + such as number of parameters, if set to ``None`` static information will be + added only if available + return_dict - (optional) determinates if the returned values will be provided + as a ``dict`` or a simple ``list``. A ``dict`` contains the same values as + the ``list`` but alolws the user to extract them by their names, whereas + a list can be thought of as a single row in a table containing values only. + The user can map particular elements of the returned ``list`` by considering + the values in ``columns``. Default: ``True``. + + Returns: + ``None`` if information about a given ``arch`` cannot be found in the dataset, + otherwise a ``dict`` or a ``list`` containing information about the model. + + Raises: + ValueError - if invalid ``seed`` is given + ''' + model_hash = search_space.get_model_hash(arch, ops=self.ops) + return self._query(model_hash, seed, devices, include_static_info, return_dict) + + def full_info_by_graph(self, graph, seed=None, devices=None, include_static_info=None, return_dict=True): + ''' Return all information about an architecture identified by the provided model + graph. + If multiple seeds are available, the can either return information about + a specific one or a random one. + + Arguments: + graph - a graph of a model from the search space, obtained by calling + ``nasbench_asr.graph_utils.get_model_graph(arch)`` + seed - (optional) if provided, the returned will be information about + the model's performance when initialized with this particular seed, + otherwise information related to a randomly chosen seed from the list + if available ones will be used. Default: random seed + devices - (optional) add information about benchmarking on the provided devices, + if ``None`` all available devices are included, otherwise should be a name of + the device or a list of names, can also be exactly ``False`` to avoid including + benchmarking information even when they are available + include_static_info - (optional) include static information about the model, + such as number of parameters, if set to ``None`` static information will be + added only if available + return_dict - (optional) determinates if the returned values will be provided + as a ``dict`` or a simple ``list``. A ``dict`` contains the same values as + the ``list`` but allows the user to extract them by their names, whereas + a list can be thought of as a single row in a table containing values only. + The user can map particular elements of the returned ``list`` by considering + the values in ``columns``. Default: ``True``. + + Returns: + ``None`` if information about a given ``arch`` cannot be found in the dataset, + otherwise a ``dict`` or a ``list`` containing information about the model. + + Raises: + ValueError - if invalid ``seed`` is given + ''' + model_hash = graph_utils.graph_hash(graph) + return self._query(model_hash, seed, devices, include_static_info, return_dict) + + def test_acc(self, arch, seed=None): + ''' Return test PER of a model. + + Test PER is currently defined as the test PER of the model at epoch + with the lowest validation PER. + + Arguments: + arch - a point from the search space identifying a model + seed - (optional) an initialization seed to use, if not provided information + will be queried for a random seed (default: ``None``) + + Returns: + ``None`` if the dataset does not contain information about a model ``arch``, + otherwise a scalar ``float``. + ''' + info = self.full_info(arch, seed=seed, devices=False, include_static_info=False, return_dict=False) + if info is None: + return None + return info[2] + + def val_acc(self, arch, epoch=None, best=True, seed=None): + ''' Return validation PER of a model. + + The returned PER can be either the best PER or the PER at the last epoch. + The maximum number of epochs to consider can be controlled by ``epoch``. + + If ``vals`` is a list of validation PERs, the returned value can be + defined as: + + epoch = epoch if epoch is not None else len(vals) + return min(vals[:epoch]) if best else vals[epoch-1] + + Arguments: + arch - a point from the search space identifying a model + epoch - (optional) number of epochs to consider, if not provided + all epochs will be considered (default: ``None``) + best - (optional) return best validation PER from epoch 1 to the + maximum considered epochs, otherwise return PER at the last + considered epoch (default: ``True``) + seed - (optional) an initialization seed to use, if not provided information + will be queried for a random seed (default: ``None``) + + ''' + info = self.full_info(arch, seed=seed, devices=False, include_static_info=False, return_dict=False) + if info is None: + return None + if epoch is None: + epoch = len(info[1]) + if best: + return min(info[1][:epoch]) + else: + return info[1][epoch-1] + + @functools.wraps(BenchmarkingDataset.latency) + def latency(self, *args, **kwargs): + if not self.bench_info: + raise ValueError('No benchmarking information attached') + + return self.bench_info.latency(*args, **kwargs) + + @functools.wraps(StaticInfoDataset.params) + def params(self, *args, **kwargs): + if not self.static_info: + raise ValueError('No static information attached') + + return self.static_info.params(*args, **kwargs) + + + +def from_folder(folder, max_epochs=None, seeds=None, devices=None, include_static_info=False, validate_data=True): + ''' Create a ``Dataset`` object from files in a given directory. + Arguments control what subset of the files will be used. + + Recognizable files should have names following the pattern:: + + - nb-asr-e{max_epochs}-{seed}.pickle for training datasets + - nb-asr-bench-{device}.pickle for benchmarking datasets + - nb-asr-info.pickle for static information dataset + + Arguments: + max_epochs - load dataset files related to accuracy of models + when trained with at most ``max_epochs`` of training. + The related files should have a 'e{max_epochs}' component + in their name. If the argument is ``None``, load the dataset + related to full training. + seeds - if not provided the created dataset will use all available + seeds (each file should hold information about one seed only). + Otherwise it can be a single value or a list seeds to use. + The function will not check if the file(s) for the provided seed(s) + exist(s) and will fail silently (i.e., the resulting + dataset simply won't include results for the provided seed) + devices - (optional) add information about benchmarking on the provided devices, + if ``None`` all available devices are included, otherwise should be a name of + the device or a list of names, can also be exactly ``False`` to avoid including + benchmarking information even when they are available + include_static_info - (optional) include static information about the model, + such as number of parameters + validate_data - passed to ``Dataset`` constructor, if ``True`` the dataset + will be validated to check consistency of the data. Can be set to ``False`` + to speed up loading if the data is known to be valid. + + Raises: + ValueError - if ``folder`` is not a directory or does not exist + ValueError - if any of the loaded dataset files contain + ''' + f = pathlib.Path(folder).expanduser() + if not f.exists() or not f.is_dir(): + raise ValueError(f'{folder} is not a directory') + + if max_epochs is None: + max_epochs = 40 + + max_epochs = f'e{max_epochs}-' + + if seeds is not None: + if isinstance(seeds, cabc.Sequence) and not isinstance(seeds, str): + seeds = '(' + '|'.join(map(str, seeds)) + ')' + else: + seeds = str(seeds) + else: + seeds = '[0-9]+' + + if devices != False: + if devices is not None: + if isinstance(devices, cabc.Sequence) and not isinstance(devices, str): + devices = '(' + '|'.join(map(str, devices)) + ')' + else: + devices = str(devices) + else: + devices = '[a-zA-Z0-9-]+' + + datasets = [] + bench_info = [] + static_info = None + + regex = re.compile(f'nb-asr-{max_epochs}{seeds}.pickle') + regex2 = re.compile(f'nb-asr-bench-{devices}.pickle') if devices else None + for ff in f.iterdir(): + if ff.is_file(): + if regex.fullmatch(ff.name): + datasets.append(str(ff)) + if devices and regex2.fullmatch(ff.name): + bench_info.append(str(ff)) + if include_static_info and ff.name == 'nb-asr-info.pickle': + static_info = str(ff) + + + return Dataset(datasets, bench_info, static_info, validate_data=validate_data) diff --git a/nasbench_asr/graph_utils.py b/nasbench_asr/graph_utils.py new file mode 100644 index 0000000..ab7fdb8 --- /dev/null +++ b/nasbench_asr/graph_utils.py @@ -0,0 +1,383 @@ +import copy +import random +import hashlib +import tempfile +import subprocess +import collections.abc + +import tqdm +import numpy as np +import networkx as nx + +from .utils import flatten, count, get_first_n + +_use_np = True + + +def get_model_graph_np(arch_vec, ops=None, minimize=True, keep_dims=False): + if ops is None: + from . import search_space as ss + ops = ss.all_ops + num_nodes = len(arch_vec) + mat = np.zeros((num_nodes+2, num_nodes+2)) + labels = ['input'] + prev_skips = [] + for nidx, node in enumerate(arch_vec): + op = node[0] + labels.append(ops[op]) + mat[nidx, nidx+1] = 1 + for i, sc in enumerate(prev_skips): + if sc: + mat[i, nidx+1] = 1 + prev_skips = node[1:] + labels.append('output') + mat[num_nodes, num_nodes+1] = 1 + for i, sc in enumerate(prev_skips): + if sc: + mat[i, num_nodes+1] = 1 + orig = None + if minimize: + orig = copy.copy(mat), copy.copy(labels) + for n in range(len(mat)): + if labels[n] == 'zero': + for n2 in range(len(mat)): + if mat[n,n2]: + mat[n,n2] = 0 + if mat[n2,n]: + mat[n2,n] = 0 + def bfs(src, mat, backward): + visited = np.zeros(len(mat)) + q = [src] + visited[src] = 1 + while q: + n = q.pop() + for n2 in range(len(mat)): + if visited[n2]: + continue + if (backward and mat[n2,n]) or (not backward and mat[n,n2]): + q.append(n2) + visited[n2] = 1 + return visited + vfw = bfs(0, mat, False) + vbw = bfs(len(mat)-1, mat, True) + v = vfw + vbw + dangling = (v < 2).nonzero()[0] + if dangling.size: + if keep_dims: + mat[dangling, :] = 0 + mat[:, dangling] = 0 + for i in dangling: + labels[i] = None + else: + mat = np.delete(mat, dangling, axis=0) + mat = np.delete(mat, dangling, axis=1) + for i in sorted(dangling, reverse=True): + del labels[i] + return (mat, labels), orig + +def get_model_graph_nx(arch_vector, ops=None, minimize=True, keep_dims=False): + ''' Get :class:`netwworkx.DiGraph` object from an arch vector. + If ``minimize`` is ``True``, the graph will be minimized by removing + "zero" operations and consequently any dangling nodes. + ''' + if ops is None: + from . import search_space as ss + ops = ss.all_ops + num_nodes = len(arch_vector) + g = nx.DiGraph() + g.add_node(0, label='input') + prev_skips = [] + for nidx, node in enumerate(arch_vector): + op = node[0] + g.add_node(nidx+1, label=ops[op]) + g.add_edge(nidx, nidx+1) + for i, sc in enumerate(prev_skips): + if sc: + g.add_edge(i, nidx+1) + prev_skips = node[1:] + g.add_node(num_nodes+1, label='output') + g.add_edge(num_nodes, num_nodes+1) + for i, sc in enumerate(prev_skips): + if sc: + g.add_edge(i, num_nodes+1) + orig = None + if minimize: + orig = copy.deepcopy(g) + for n in dict(g.nodes): + if g.nodes[n]['label'] == 'zero': + g.remove_node(n) + for _i in range(2): + if 0 in g.nodes: + from_source = nx.descendants(g, 0) + else: + from_source = [] + for n in dict(g.nodes): + keep = True + desc = nx.descendants(g, n) + if n != num_nodes+1: + if num_nodes+1 not in desc: + keep = False + if n > 0: + if n not in from_source: + keep = False + if not keep: + if not _i: + if keep_dims: + edges = list(g.in_edges(n)) + list(g.out_edges(n)) + g.remove_edges_from(edges) + g.nodes[n]['label'] = None + else: + g.remove_node(n) + else: + print(_i, n, desc) + show_graph(g) + show_graph(orig) + assert False + return g, orig + +def get_model_graph(arch_vector, ops=None, minimize=True, keep_dims=False): + if _use_np: + return get_model_graph_np(arch_vector, ops, minimize, keep_dims) + else: + return get_model_graph_nx(arch_vector, ops, minimize, keep_dims) + + +def graph_hash_np(g): + from . import search_space as ss + m, l = g + def hash_module(matrix, labelling): + """Computes a graph-invariance MD5 hash of the matrix and label pair. + Args: + matrix: np.ndarray square upper-triangular adjacency matrix. + labelling: list of int labels of length equal to both dimensions of + matrix. + Returns: + MD5 hash of the matrix and labelling. + """ + vertices = np.shape(matrix)[0] + in_edges = np.sum(matrix, axis=0).tolist() + out_edges = np.sum(matrix, axis=1).tolist() + assert len(in_edges) == len(out_edges) == len(labelling), f'{labelling} {matrix}' + hashes = list(zip(out_edges, in_edges, labelling)) + hashes = [hashlib.md5(str(h).encode('utf-8')).hexdigest() for h in hashes] + # Computing this up to the diameter is probably sufficient but since the + # operation is fast, it is okay to repeat more times. + for _ in range(vertices): + new_hashes = [] + for v in range(vertices): + in_neighbours = [hashes[w] for w in range(vertices) if matrix[w, v]] + out_neighbours = [hashes[w] for w in range(vertices) if matrix[v, w]] + new_hashes.append(hashlib.md5( + (''.join(sorted(in_neighbours)) + '|' + + ''.join(sorted(out_neighbours)) + '|' + + hashes[v]).encode('utf-8')).hexdigest()) + hashes = new_hashes + fingerprint = hashlib.md5(str(sorted(hashes)).encode('utf-8')).hexdigest() + return fingerprint + labels = [] + if l: + labels = [-1] + [ss.all_ops.index(op) for op in l[1:-1]] + [-2] + return hash_module(m, labels) + +def graph_hash_nx(g): + return nx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash(g, node_attr='label') + +def graph_hash(g): + if _use_np: + return graph_hash_np(g) + else: + return graph_hash_nx(g) + + +_op_to_node_color = { + 'linear': 'tomato', + 'conv5': 'cadetblue1', + 'conv5d2': 'deepskyblue1', + 'conv7': 'olivedrab2', + 'conv7d2': 'seagreen4' +} + +_op_to_label = { + 'linear': 'Linear', + 'conv5': 'Conv(5)', + 'conv5d2': 'Conv(5,d=2)', + 'conv7': 'Conv(7)', + 'conv7d2': 'Conv(7,d=2)', + 'input': 'Input', + 'output': 'Output', + 'zero': 'Zero' +} + + +def _make_nice(agraph): + positions = {} + agraph.node_attr['shape'] = 'rectangle' + agraph.node_attr['style'] = 'rounded' + agraph.graph_attr['splines'] = 'true' + agraph.graph_attr['esep'] = 0.17 + #agraph.graph_attr['overlap'] = 'false' + for node in agraph.nodes(): + op = node.attr['label'] + node.attr['label'] = _op_to_label.get(op, op) + node.attr['width'] = 1.2 + node.attr['height'] = 0.3 + if op in _op_to_node_color: + node.attr['fillcolor'] = _op_to_node_color[op] + node.attr['style'] = 'filled,rounded' + + positions[2*int(node)] = node + + + outputs = {} + removed = set() + for e in agraph.edges(): + if e in removed: + continue + if int(e[0]) + 1 != int(e[1]): + e.attr['group'] = 'branches' + e.attr['style'] = 'dashed' + d = int(e[1]) + prev = str(d-1) + if prev not in outputs: + onode = outputs.setdefault(prev, f'o{prev}') + agraph.add_node(onode, label='+', shape='circle', width=0.3, height=0.3, fixedsize=True, fontsize=16) + positions[2*int(prev)+1] = onode + for e2 in agraph.edges(): + if e2[0] == prev and e2[1] == e[1]: + agraph.remove_edge(e2) + removed.add(e2) + agraph.add_edge(prev, onode, group='main', arrowsize=0.5) + agraph.add_edge(onode, e[1], group='main', arrowsize=0.5) + else: + onode = outputs[prev] + + agraph.add_edge(outputs.get(e[0], e[0]), onode, group='branches', style='dashed', arrowsize=0.5) + agraph.remove_edge(e) + removed.add(e) + else: + e.attr['group'] = 'main' + e.attr['arrowsize'] = 0.5 + + _pos = sorted(positions.keys()) + p = 0 + next_half = False + is_next_sc = [_pos[i+1] % 2 != 0 for i in range(len(_pos)-1)] + [False] + is_prev_sc = [False] + is_next_sc[:-1] + for pos, nsc, psc in zip(_pos, is_next_sc, is_prev_sc): + node = agraph.get_node(positions[pos]) + node.attr['pos'] = f'0,{p}!' + if not nsc and not psc: + p -= 0.47 + else: + p -= 0.47 + + +def show_graph(g, aid=None, show=True, out_dir=None): + ''' Renders graph ``g`` using graphiviz. + ``aid`` is an optional architecture id, if provided, + the rendered graph will be stored under "{out_dir}/nb_graph.{aid}.png". + (If ``out_dir`` is ``None``, it will default to ``graphs``). + Otherwise, it will be saved in a temporary file. + If ``show`` is ``True``, the rendered file will be opened with "xdg-open". + ''' + if _use_np: + a, l = g + g = nx.from_numpy_array(a, create_using=nx.DiGraph) + for idx, label in enumerate(l): + g.nodes[idx]['label'] = label + a = nx.nx_agraph.to_agraph(g) + _make_nice(a) + a.layout('dot', '-Kfdp') + if aid is None: + fname = tempfile.mktemp('.png', 'nb_graph.') + else: + dname = out_dir if out_dir is not None else "graphs" + fname = f'{dname}/nb_graph.{aid}.png' + a.draw(fname) + if show: + subprocess.run(['xdg-open', fname], check=True) + + +def show_model(arch_vec, aid=None, show=True, inc_full=True, out_dir=None): + ''' Renders graphs constructed from arch vector (both minimal and full). + Full graph is only rendered if different from minimal. + ``aid`` is an architecture id which will be used when saving rendered graphs, + if not provided it will be derived from ``arch_vec``. + ''' + g, full = get_model_graph(arch_vec) + if aid is None: + aid = '_'.join(map(str, flatten(arch_vec))) + show_graph(g, aid=aid, show=show, out_dir=out_dir) + if full is not None: + if graph_hash(g) != graph_hash(full): + assert 5 in flatten(arch_vec) + show_graph(full, aid=f'{aid}_full', show=show, out_dir=out_dir) + else: + assert 5 not in flatten(arch_vec) + + +def compare_nx_and_np(): + from .search_space import get_all_architectures, all_ops, default_nodes + global _use_np + all_count = count(get_all_architectures(all_ops, default_nodes)) + _use_np = False + all_hashes = set() + without_zero = set() + unique_graphs = [] + conflicts = {} + for m in tqdm.tqdm(get_all_architectures(all_ops, default_nodes), total=all_count): + has_zero = 5 in flatten(m) + g, _ = get_model_graph(m) + h = graph_hash(g) + if h not in all_hashes: + unique_graphs.append(m) + else: + conflicts[h] = m + all_hashes.add(h) + if not has_zero: + without_zero.add(h) + _use_np = True + np_hashes = set() + invalid = [] + for m in tqdm.tqdm(get_all_architectures(all_ops, default_nodes), total=all_count): + has_zero = 5 in flatten(m) + g, _ = get_model_graph(m) + h = graph_hash(g) + if h not in np_hashes: + if m not in unique_graphs: + invalid.append(m) + np_hashes.add(h) + print('Core:', len(without_zero)) + print('With zeros:', len(all_hashes)) + print('Unique:', len(unique_graphs)) + print('Np unique:', len(np_hashes)) + print('Invalid:', len(invalid)) + _use_np = False + if invalid: + inv = invalid[0] + g, _ = get_model_graph(inv) + h = graph_hash(g) + conflicting = conflicts[h] + show_model(invalid[0]) + show_model(conflicting) + + +def main(): + from .search_space import get_all_architectures, all_ops, default_nodes + all_count = count(get_all_architectures(all_ops, default_nodes)) + all_hashes = set() + without_zero = set() + for m in tqdm.tqdm(get_all_architectures(all_ops, default_nodes), total=all_count): + has_zero = 5 in flatten(m) + g, _ = get_model_graph(m) + h = graph_hash(g) + all_hashes.add(h) + if not has_zero: + without_zero.add(h) + # show_model([[0,1], [5,1,0], [3,1,1,1]]) + print('Core:', len(without_zero)) + print('With zeros:', len(all_hashes)) + + +if __name__ == '__main__': + main() diff --git a/nasbench_asr/model/__init__.py b/nasbench_asr/model/__init__.py new file mode 100644 index 0000000..b9e542d --- /dev/null +++ b/nasbench_asr/model/__init__.py @@ -0,0 +1,24 @@ +from .. import utils + + +_backends = utils.BackendsAccessor(__file__, __name__) + + +def get_available_backends(): + return list(_backends.available_backends) + + +def set_default_backend(backend): + return _backends.get_backend(backend, set_default=True).__name__.rsplit('.')[-1] + + +def get_backend_name(backend=None): + return _backends.get_backend(backend).__name__.rsplit('.')[-1] + + +def get_model(*args, backend=None, **kwargs): + return _backends.get_backend(backend).get_model(*args, **kwargs) + + +def print_model_summary(model): + return _backends.get_backend(model.backend).print_model_summary(model) diff --git a/nasbench_asr/model/tf/__init__.py b/nasbench_asr/model/tf/__init__.py new file mode 100644 index 0000000..c14b89a --- /dev/null +++ b/nasbench_asr/model/tf/__init__.py @@ -0,0 +1,51 @@ +import pathlib +import contextlib + +import numpy as np + +from nasbench_asr.quiet_tensorflow import tensorflow as tf + + +def _call_once(func): + _called = False + _cache = None + def impl(*args, **kwargs): + nonlocal _called, _cache + if _called: + return _cache + _cache = func(*args, **kwargs) + _called = True + return _cache + return impl + + + +@_call_once +def _get_data_norm(): + stats_file = pathlib.Path(__file__).parents[2].joinpath('training', 'timit_train_stats.npz') + norm_stats = np.load(stats_file) + mean = norm_stats['moving_mean'] + variance = norm_stats['moving_variance'] + return mean, variance + + +def get_model(arch_vec, use_rnn, dropout_rate, gpu=None): + from .model import ASRModel + + with contextlib.ExitStack() as stack: + if gpu is not None: + stack.enter_context(tf.device(f'/GPU:{gpu}')) + + model = ASRModel(arch_vec, + num_classes=48, + use_rnn=use_rnn, + dropout_rate=dropout_rate, + input_shape=[None, 80], + data_norm=_get_data_norm(), + epsilon=0.001) + + return model + + +def print_model_summary(model): + model._model.summary() diff --git a/nasbench_asr/model/tf/mean_variance_normalization.py b/nasbench_asr/model/tf/mean_variance_normalization.py new file mode 100644 index 0000000..e5a31e7 --- /dev/null +++ b/nasbench_asr/model/tf/mean_variance_normalization.py @@ -0,0 +1,47 @@ +# pylint: skip-file +from nasbench_asr.quiet_tensorflow import tensorflow as tf + + +class MeanVarianceNormalization(tf.keras.layers.Layer): + def __init__(self, epsilon, mean_initializer, variance_initializer, + **kwargs): + super().__init__(**kwargs) + self.epsilon = epsilon + self.mean_initializer = mean_initializer + self.variance_initializer = variance_initializer + + def build(self, input_shape): + self.mean = self.add_weight( + name="mean", + shape=(input_shape[-1]), + initializer=self.mean_initializer, + trainable=False, + ) + self.variance = self.add_weight( + name="variance", + shape=(input_shape[-1]), + initializer=self.variance_initializer, + trainable=False, + ) + + def call(self, inputs, mask=None): + outputs = (inputs - self.mean) / tf.math.sqrt(self.variance + + self.epsilon) + + if mask is not None: + outputs = tf.where(tf.expand_dims(mask, axis=-1), outputs, 0) + + return outputs + + def compute_mask(self, inputs, mask=None): + return mask + + def get_config(self): + config = super().get_config() + config.update({ + "epsilon": self.epsilon, + "mean_initializer": self.mean_initializer, + "variance_initializer": self.variance_initializer, + }) + + return config diff --git a/nasbench_asr/model/tf/model.py b/nasbench_asr/model/tf/model.py new file mode 100644 index 0000000..e17885c --- /dev/null +++ b/nasbench_asr/model/tf/model.py @@ -0,0 +1,96 @@ +import random +import itertools as itr + +from nasbench_asr.quiet_tensorflow import tensorflow as tf + +from .ops import OPS_LIST, BRANCH_OPS_LIST, norm_op, PadConvRelu +from .mean_variance_normalization import MeanVarianceNormalization + + +class Node(tf.keras.Model): + def __init__(self, filters, op_idx, branch_op_idx_list): + super().__init__() + self._op = OPS_LIST[op_idx](filters) + self.branch_ops = [BRANCH_OPS_LIST[i] for i in branch_op_idx_list] + + def call(self, input_list, training=None): + assert len(input_list) == len(self.branch_ops), 'Branch op and input list have different lenghts' + + output = self._op(input_list[-1], training=training) + edges = [output] + for i in range(len(self.branch_ops)): + x = self.branch_ops[i](input_list[i]) + edges.append(x) + + return tf.math.add_n(edges) + + +class SearchCell(tf.keras.Model): + def __init__(self, filters, config, num_nodes=3): + super().__init__() + + self._nodes = list() + for n_config in config: + node = Node(filters=filters, op_idx=n_config[0], branch_op_idx_list=n_config[1:]) + self._nodes.append(node) + + self.norm_layer = norm_op() + + def call(self, input, training=None): + outputs = [input] # input is the output coming from node 0 + for node in self._nodes: + n_out = node(outputs, training=training) + outputs.append(n_out) + + output = self.norm_layer(outputs[-1]) #use layer norm at the end of a search cell + return output + + +class ASRModel(tf.keras.Model): + def __init__(self, arch_desc, num_classes=48, use_rnn=False, use_norm=True, dropout_rate=0.0, input_shape=None, data_norm=None, epsilon=0.001): + super().__init__() + + self.arch_desc = list(arch_desc) + self.num_classes = num_classes + self.use_rnn = use_rnn + self.use_norm = use_norm + self.dropout_rate = dropout_rate + + cnn_time_reduction_kernels = [8, 8, 8, 8] + cnn_time_reduction_strides = [1, 1, 2, 2] + filters = [600, 800, 1000, 1200] + scells_per_block = [3, 4, 5, 6] + + zipped_params = zip(cnn_time_reduction_kernels, + cnn_time_reduction_strides, + filters, + scells_per_block) + + layers = [] + + if input_shape is not None: + layers.append(tf.keras.layers.Masking(input_shape=input_shape)) + else: + layers.append(tf.keras.layers.Masking()) + + if data_norm is not None: + mean, variance = data_norm + layers.append(MeanVarianceNormalization(epsilon, tf.keras.initializers.Constant(mean), tf.keras.initializers.Constant(variance))) + + for i, (kernel, stride, filters, cells) in enumerate(zipped_params): + layers.append(PadConvRelu(kernel_size=kernel, strides=stride, filters=filters, dialation=1, name=f'conv_{i}')) + layers.append(norm_op()) + + for j in range(cells): + layers.append(SearchCell(filters=filters, config=arch_desc)) + + if use_rnn: + layers.append(tf.keras.layers.LSTM(units=500, dropout=self.dropout_rate, time_major=False, return_sequences=True)) + + layers.append(tf.keras.layers.Dense(self.num_classes+1)) + + self._model = tf.keras.Sequential(layers) + + def call(self, input, training=None): + return self._model(input, training=training) + diff --git a/nasbench_asr/model/tf/ops.py b/nasbench_asr/model/tf/ops.py new file mode 100644 index 0000000..ab6c4fe --- /dev/null +++ b/nasbench_asr/model/tf/ops.py @@ -0,0 +1,91 @@ +from nasbench_asr.quiet_tensorflow import tensorflow as tf + +FUTURE_CONTEXT = 4 # 40ms look ahead + +def get_activation(params_activation): + def activation(inputs): + return getattr(tf.keras.activations, params_activation["name"])( + inputs, **params_activation["kwargs"]) + + return activation + +class PadConvRelu(tf.keras.Model): + def __init__(self, kernel_size, dialation, filters, strides, groups=1, dropout_rate=0, name='PadConvRelu'): + super(PadConvRelu, self).__init__(name=name) + + if int(FUTURE_CONTEXT / strides) >= (kernel_size-strides): + rpad = kernel_size-strides + lpad = 0 + else: + rpad = int(FUTURE_CONTEXT / strides) + lpad = int(kernel_size - 1 - rpad) + + padding = tf.keras.layers.ZeroPadding1D(padding=(lpad, rpad)) + conv1d = tf.keras.layers.Conv1D(filters=filters, kernel_size=kernel_size, strides=strides, groups=groups, kernel_regularizer=tf.keras.regularizers.L2()) + #activation = tf.keras.layers.Activation('relu') + activation = tf.keras.layers.Activation(get_activation({"name": "relu", "kwargs": {"max_value": 20}})) + dropout = tf.keras.layers.Dropout(rate=dropout_rate) + self.layer = tf.keras.Sequential([padding, conv1d, activation, dropout]) + + def call(self, x, training=None): + return self.layer(x, training=training) + + +class Linear(tf.keras.Model): + def __init__(self, units, dropout_rate=0, name='Linear'): + super(Linear, self).__init__(name=name) + dense = tf.keras.layers.Dense(units=units) + activation = tf.keras.layers.Activation(get_activation({"name": "relu", "kwargs": {"max_value": 20}})) + dropout = tf.keras.layers.Dropout(rate=dropout_rate) + self.layer = tf.keras.Sequential([dense, activation, dropout]) + + def call(self, x, training=None): + return self.layer(x, training=training) + + +class Identity(tf.keras.Model): + def __init__(self, name='Identity'): + super(Identity, self).__init__(name=name) + + def call(self, x): + return x + +class Zero(tf.keras.Model): + def __init__(self, name='Zero'): + super(Zero, self).__init__(name=name) + + def call(self, x, training=None): + return x*0.0 + +DROPOUT_RATE=0.2 +# OPS_old = { +# 'linear' : lambda filters: Linear(units=filters, dropout_rate=DROPOUT_RATE), +# 'conv3' : lambda filters: PadConvRelu(kernel_size=3, dialation=1, filters=filters, strides=1, groups=100, dropout_rate=DROPOUT_RATE, name='conv3'), +# 'conv3d2' : lambda filters: PadConvRelu(kernel_size=3, dialation=2, filters=filters, strides=1, groups=100, dropout_rate=DROPOUT_RATE, name='conv3d2'), +# 'conv5' : lambda filters: PadConvRelu(kernel_size=5, dialation=1, filters=filters, strides=1, groups=100, dropout_rate=DROPOUT_RATE, name='conv5'), +# 'conv5d2' : lambda filters: PadConvRelu(kernel_size=5, dialation=2, filters=filters, strides=1, groups=100, dropout_rate=DROPOUT_RATE, name='conv5d2'), +# } + +OPS = { + 'linear' : lambda filters: Linear(units=filters, dropout_rate=DROPOUT_RATE), + 'conv5' : lambda filters: PadConvRelu(kernel_size=5, dialation=1, filters=filters, strides=1, groups=100, dropout_rate=DROPOUT_RATE, name='conv3'), + 'conv5d2' : lambda filters: PadConvRelu(kernel_size=5, dialation=2, filters=filters, strides=1, groups=100, dropout_rate=DROPOUT_RATE, name='conv3d2'), + 'conv7' : lambda filters: PadConvRelu(kernel_size=7, dialation=1, filters=filters, strides=1, groups=100, dropout_rate=DROPOUT_RATE, name='conv5'), + 'conv7d2' : lambda filters: PadConvRelu(kernel_size=7, dialation=2, filters=filters, strides=1, groups=100, dropout_rate=DROPOUT_RATE, name='conv5d2'), + 'none': lambda filters: Zero(name='none') + } + +BRANCH_OPS = { + 'none' : Zero(name='none'), + 'skip_connect' : Identity(name='skip_connect') +} + +def norm_op(): + return tf.keras.layers.LayerNormalization() + +OPS_LIST = [OPS['linear'], OPS['conv5'], OPS['conv5d2'], OPS['conv7'], OPS['conv7d2'], OPS['none']] +BRANCH_OPS_LIST = [BRANCH_OPS['skip_connect'], BRANCH_OPS['none']] + + + + diff --git a/nasbench_asr/model/torch/__init__.py b/nasbench_asr/model/torch/__init__.py new file mode 100644 index 0000000..84fff44 --- /dev/null +++ b/nasbench_asr/model/torch/__init__.py @@ -0,0 +1,47 @@ +import torch.nn +import torch.nn.init + +from .. import utils + + +def get_model(arch_vec, use_rnn, dropout_rate, gpu=None): + from . import model + from ... import search_space as ss + arch_desc = ss.arch_vec_to_names(arch_vec) + model = model.ASRModel(arch_desc, use_rnn=use_rnn, dropout_rate=dropout_rate) + + def init_weights(m): + if isinstance(m, torch.nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + torch.nn.init.zeros_(m.bias) + elif isinstance(m, torch.nn.Conv1d): + torch.nn.init.xavier_uniform_(m.weight) + torch.nn.init.zeros_(m.bias) + elif isinstance(m, torch.nn.LSTM): + for l in range(m.num_layers): + wi = getattr(m, f'weight_ih_l{l}') + wh = getattr(m, f'weight_hh_l{l}') + bi = getattr(m, f'bias_ih_l{l}') + bh = getattr(m, f'bias_hh_l{l}') + torch.nn.init.xavier_uniform_(wi) + torch.nn.init.xavier_uniform_(wh) + torch.nn.init.zeros_(bi) + torch.nn.init.zeros_(bh) + + model.apply(init_weights) + if gpu is not None: + model.to(device=f'cuda:{gpu}') + + return model + + +def print_model_summary(model): + print(model) + print('======================') + def _print(m, level=0): + for n, child in m.named_children(): + print(' '*level + type(child).__name__, ' ', n, ' ', sum(p.numel() for p in child.parameters())) + _print(child, level+1) + _print(model.model) + print('======================') + print('Trainable parameters:', utils.make_nice_number(sum(p.numel() for p in model.parameters()))) diff --git a/nasbench_asr/model/torch/model.py b/nasbench_asr/model/torch/model.py new file mode 100644 index 0000000..6ec9c78 --- /dev/null +++ b/nasbench_asr/model/torch/model.py @@ -0,0 +1,135 @@ +import torch +import torch.nn as nn + +from .ops import PadConvRelu, _ops, _branch_ops + + +class Node(nn.Module): + def __init__(self, filters, op_ctor, branch_op_ctors, dropout_rate=0.0): + super().__init__() + self.op = op_ctor(filters, filters, dropout_rate=dropout_rate) + self.branch_ops = [ctor() for ctor in branch_op_ctors] + + def forward(self, input_list): + assert len(input_list) == len(self.branch_ops), 'Branch op and input list have different lenghts' + + output = self.op(input_list[-1]) + edges = [output] + for i in range(len(self.branch_ops)): + x = self.branch_ops[i](input_list[i]) + edges.append(x) + + return sum(edges) + + +class SearchCell(nn.Module): + def __init__(self, filters, node_configs, dropout_rate=0.0, use_norm=True): + super().__init__() + + self.nodes = nn.ModuleList() + for node_config in node_configs: + node_op_name, *node_branch_ops = node_config + try: + node_op_ctor = _ops[node_op_name] + except KeyError: + raise ValueError(f'Operation "{node_op_name}" is not implemented') + + try: + node_branch_ctors = [_branch_ops[branch_op] for branch_op in node_branch_ops] + except KeyError: + raise ValueError(f'Invalid branch operations: {node_branch_ops}, expected is a vector of 0 (no skip-con.) and 1 (skip-con. present)') + + node = Node(filters=filters, op_ctor=node_op_ctor, branch_op_ctors=node_branch_ctors, dropout_rate=dropout_rate) + self.nodes.append(node) + + self.use_norm = use_norm + if self.use_norm: + self.norm_layer = nn.LayerNorm(filters, eps=0.001) + + def forward(self, input): + outputs = [input] # input is the output coming from node 0 + for node in self.nodes: + n_out = node(outputs) + outputs.append(n_out) + output = outputs[-1] #last node is the output + if self.use_norm: + output = output.permute(0,2,1) + output = self.norm_layer(output) + output = output.permute(0,2,1) + return output + + +class ASRModel(nn.Module): + def __init__(self, arch_desc, num_classes=48, use_rnn=False, use_norm=True, dropout_rate=0.0, **kwargs): + super().__init__() + + self.arch_desc = arch_desc + self.num_classes = num_classes + self.use_rnn = use_rnn + self.use_norm = use_norm + self.dropout_rate = dropout_rate + + num_blocks = 4 + features = 80 + filters = [600, 800, 1000, 1200] + cnn_time_reduction_kernels = [8, 8, 8, 8] + cnn_time_reduction_strides = [1, 1, 2, 2] + scells_per_block = [3, 4, 5, 6] + + layers = nn.ModuleList() + + for i in range(num_blocks): + layers.append(PadConvRelu( + in_channels= features if i==0 else filters[i-1], + out_channels=filters[i], + kernel_size=cnn_time_reduction_kernels[i], + dilation=1, + strides=cnn_time_reduction_strides[i], + groups=1, + name=f'conv_{i}')) + + # TODO: normalize axis=1 + layers.append(nn.LayerNorm(filters[i], eps=0.001)) + + for j in range(scells_per_block[i]): + cell = SearchCell(filters=filters[i], node_configs=arch_desc, use_norm=use_norm, dropout_rate=dropout_rate) + layers.append(cell) + + if use_rnn: + layers.append(nn.Dropout(dropout_rate)) + layers.append(nn.LSTM(input_size=filters[num_blocks-1], hidden_size=500, batch_first=True, dropout=0.0)) + layers.append(nn.Linear(in_features=500, out_features=num_classes+1)) + else: + layers.append(nn.Linear(in_features=filters[num_blocks-1], out_features=num_classes+1)) + + # self._model = nn.Sequential(*layers) + self.model = layers + + def get_prunable_copy(self, bn=False, masks=None): + # bn, masks are not used in this func. + # Keeping them to make the code work with predictive.py + model_new = ASRModel(arch_desc=self.arch_desc, num_classes=self.num_classes, use_rnn=self.use_rnn, use_norm=bn, dropout_rate=self.dropout_rate) + model_new.load_state_dict(self.state_dict(), strict=False) + model_new.train() + return model_new + + def forward(self, input): # input is (B, F, T) + for xx in self.model: + if isinstance(xx, nn.LSTM): + input = input.permute(0,2,1) + input = xx(input)[0] + input = input.permute(0,2,1) + elif isinstance(xx, nn.Linear): + input = input.permute(0,2,1) + input = xx(input) + elif isinstance(xx, nn.LayerNorm): + input = input.permute(0,2,1) + input = xx(input) + input = input.permute(0,2,1) + else: + input = xx(input) + return input + + @property + def backend(self): + return 'torch' diff --git a/nasbench_asr/model/torch/ops.py b/nasbench_asr/model/torch/ops.py new file mode 100644 index 0000000..9870b46 --- /dev/null +++ b/nasbench_asr/model/torch/ops.py @@ -0,0 +1,84 @@ +import functools + +import torch +import torch.nn as nn + + +class PadConvRelu(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, dilation, strides, groups=1, dropout_rate=0, context=4, name='PadConvRelu'): + super().__init__() + self.name = name + + if int(context / strides) >= (kernel_size*dilation-strides): + rpad = kernel_size*dilation-strides + lpad = 0 + else: + rpad = int(context / strides) + lpad = int((kernel_size - 1)*dilation - rpad) + + self.pad = nn.ZeroPad2d((lpad, rpad, 0, 0)) + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=strides, dilation=dilation, groups=groups) + self.relu = nn.ReLU(inplace=False) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward(self, x): + x = self.pad(x) + x = self.conv(x) + x = self.relu(x) + x = torch.clamp_max_(x, 20) + x = self.dropout(x) + return x + + +class Linear(nn.Module): + def __init__(self, in_features, out_features, dropout_rate=0, name='Linear'): + super().__init__() + self.name = name + + self.linear = nn.Linear(in_features, out_features) + self.relu = nn.ReLU(inplace=False) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward(self, x): + shape = x.shape + x = x.permute(0,2,1) + x = self.linear(x) + x = self.relu(x) + x = torch.clamp_max_(x, 20) + x = self.dropout(x) + x = x.permute(0,2,1) + return x + + +class Identity(nn.Module): + def __init__(self, name='Identity'): + super().__init__() + self.name = name + + def forward(self, x): + return x + + +class Zero(nn.Module): + def __init__(self, name='Zero'): + super(Zero, self).__init__() + self.name = name + + def forward(self, x): + return torch.zeros_like(x) + + +_ops = { + 'linear': Linear, + 'conv5': functools.partial(PadConvRelu, kernel_size=5, dilation=1, strides=1, groups=100, name='conv5'), + 'conv5d2': functools.partial(PadConvRelu, kernel_size=5, dilation=2, strides=1, groups=100, name='conv52d'), + 'conv7': functools.partial(PadConvRelu, kernel_size=7, dilation=1, strides=1, groups=100, name='conv7'), + 'conv7d2': functools.partial(PadConvRelu, kernel_size=7, dilation=2, strides=1, groups=100, name='conv52d'), + 'zero': lambda *args, **kwargs: Zero(name='zero') +} + +_branch_ops = { + 0: Zero, # branch not present + 1: Identity # branch present +} + diff --git a/nasbench_asr/quiet_tensorflow.py b/nasbench_asr/quiet_tensorflow.py new file mode 100644 index 0000000..e4203f6 --- /dev/null +++ b/nasbench_asr/quiet_tensorflow.py @@ -0,0 +1,35 @@ +def disable_warnings(): + import os + import logging + import warnings + warnings.filterwarnings('ignore',category=FutureWarning) + warnings.filterwarnings('ignore', category=DeprecationWarning) + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + os.environ['TF_DETERMINISTIC_OPS'] = '1' + logging.getLogger('tensorflow').setLevel(logging.ERROR) + + try: + from tensorflow.python.util import module_wrapper as deprecation + except ImportError: + try: + from tensorflow.python.util import deprecation_wrapper as deprecation + except ImportError: + from tensorflow.python.util import deprecation + try: + deprecation._PRINT_DEPRECATION_WARNINGS = False + except: + pass + + import tensorflow as tf + try: + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + except: + pass + + try: + tf.get_logger().setLevel('ERROR') + except: + pass + +disable_warnings() +import tensorflow diff --git a/nasbench_asr/search_space.py b/nasbench_asr/search_space.py new file mode 100644 index 0000000..eea48bd --- /dev/null +++ b/nasbench_asr/search_space.py @@ -0,0 +1,93 @@ +import random + +from .utils import recursive_iter, flatten, copy_structure + + +all_ops = ['linear', 'conv5', 'conv5d2', 'conv7', 'conv7d2', 'zero'] +ops_no_zero = all_ops[:-1] +default_nodes = 3 + + +def get_search_space(ops=None, nodes=None): + ''' Return boundaries of the search space for the given list + of available operations and number of nodes. + ''' + ops = ops if ops is not None else all_ops + nodes = nodes if nodes is not None else default_nodes + search_space = [[len(ops)] + [2]*(idx+1) for idx in range(nodes)] + return search_space + + +def get_model_hash(arch_vec, ops=None, minimize=True): + ''' Get hash of the architecture specified by arch_vec. + Architecture hash can be used to determine if two + configurations from the search space are in fact the + same (graph isomorphism). + ''' + from .graph_utils import get_model_graph, graph_hash + g, _ = get_model_graph(arch_vec, ops=ops, minimize=minimize) + return graph_hash(g) + + +def get_all_architectures(ops=None, nodes=None): + ''' Yields all architecture configurations in the search space + ''' + search_space = get_search_space(ops, nodes) + flat = flatten(search_space) + cfg = [0 for _ in range(len(flat))] + end = False + while not end: + yield copy_structure(cfg, search_space) + for dim in range(len(flat)): + cfg[dim] += 1 + if cfg[dim] != flat[dim]: + break + cfg[dim] = 0 + if dim+1 >= len(flat): + end = True + + +def get_random_architectures(num, ops=None, nodes=None, seed=None): + ''' Get random architecture configurations from the search space + ''' + ops = ops if ops is not None else all_ops + nodes = nodes if nodes is not None else default_nodes + if seed is not None: + random.seed(seed) + search_space = [[len(ops)] + [2]*(idx+1) for idx in range(nodes)] + flat = flatten(search_space) + models = [] + while len(models) < num: + m = [random.randrange(opts) for opts in flat] + m = copy_structure(m, search_space) + models.append(m) + return models + + +def get_archs_with_zero(): + models_with_zero = {} + for m in get_all_architectures(all_ops, default_nodes): + if 5 in flatten(m): + h = get_model_hash(m) + models_with_zero[h] = m + new_model_archs = [models_with_zero[k] for k in sorted(models_with_zero.keys())] + return new_model_archs + + +def arch_vec_to_names(arch_vec, ops=None): + ''' Translates identifiers of operations in ``arch_vec`` to their names. + ``ops`` can be provided externally to avoid relying on the current definition + of available ops. Otherwise canonical ``all_ops`` will be used. + ''' + + if ops is None: + ops = all_ops + + # current approach is to have an arch vector contain sub-vectors for node in a cell, + # each subvector has a form of: + # [op_idx, branch_op_idx...] + # where op_idx points to an operation from ``all_ops`` and ``branch_op_idx`` is + # either 0 (no skip connection) or 1 (identity skip connection) + # since skip connects are already quite self-explanatory we leave them as they are + # and only change numbers of the main operations to their respective names + return [[all_ops[op_idx]] + branches for op_idx, *branches in arch_vec] diff --git a/nasbench_asr/training/__init__.py b/nasbench_asr/training/__init__.py new file mode 100644 index 0000000..61f32fd --- /dev/null +++ b/nasbench_asr/training/__init__.py @@ -0,0 +1,90 @@ +from .. import utils + + +_backends = utils.BackendsAccessor(__file__, __name__) + + +class Trainer(): + def __init__(self, dataloaders, gpus=None, save_dir=None, verbose=True): + raise NotImplementedError() + + def train(self, model, epochs=40, lr=0.0001, reset=False, model_name=None): + raise NotImplementedError() + + def step(self, input, training=True): + raise NotImplementedError() + + def save(self, checkpoint): + raise NotImplementedError() + + def load(self, checkpoint): + raise NotImplementedError() + + def remember(self): + raise NotImplementedError() + + def recall(self): + raise NotImplementedError() + + +def get_available_backends(): + return list(_backends.available_backends) + + +def set_default_backend(backend): + return _backends.get_backend(backend, set_default=True).__name__.rsplit('.')[-1] + + +def get_backend_name(backend=None): + return _backends.get_backend(backend).__name__.rsplit('.')[-1] + + +def set_seed(*args, backend=None, **kwargs): + return _backends.get_backend(backend).set_seed(*args, **kwargs) + + +def prepare_devices(devices, backend=None): + return _backends.get_backend(backend).prepare_devices(devices) + + +def get_dataloaders(timit_root, batch_size=64, backend=None): + ''' Prepare dataset. + + Arguments: + timit_root (os.PathLike) - root folder holding TIMIT dataset + batch_size (int) - batch size to use when training + backend (str) - optional, specifies backend to use + + Returns: + tuple: a tuple of 5 values, in order: + + - encoder object (used to encode phonemes) + - iterable yielding training examples + - iterable yielding validation examples + - iterable yielding testing examples + - backend-specific data + ''' + return _backends.get_backend(backend).get_dataloaders(timit_root, batch_size=batch_size) + + +def get_loss(backend=None): + return _backends.get_backend(backend).get_loss() + + +def get_trainer(dataloaders, loss, gpus=None, save_dir=None, verbose=False, backend=None) -> Trainer: + ''' Return a :py:class:`Trainer` object which implements training functionality. + + Arguments: + dataloaders - a set of data-related objects obtained by calling :py:func:`nasbench_asr.get_dataloaders` + epochs (int) - number of epochs to train, default: 40 + lr (float) - learning rate to use, default: 0.0001 + gpus (list[int]) - a list of GPUs to use when training a model, by default use CPU only + save_dir (str) - an optional directory name where the model will be save, if the directory exists + when the training begins and ``reset`` is ``False``, the trainer will try to continue training + from a checkpoint stored in the directory + reset (bool) - specifies if training script should ignore existing checkpoints in ``save_dir`` at + the beginning of training + verbose (bool) - whether training should print to standard output + backend (str) - optional, specifies backend to use + ''' + return _backends.get_backend(backend).get_trainer(dataloaders, loss, gpus=gpus, save_dir=save_dir, verbose=verbose) diff --git a/nasbench_asr/training/tf/__init__.py b/nasbench_asr/training/tf/__init__.py new file mode 100644 index 0000000..84d2fec --- /dev/null +++ b/nasbench_asr/training/tf/__init__.py @@ -0,0 +1,143 @@ +import os +import pathlib +import random + +import numpy as np +from nasbench_asr.quiet_tensorflow import tensorflow as tf +from attrdict import AttrDict + +from . import trainer +from .datasets.audio_featurizer import AudioFeaturizer +from .datasets.audio_sentence_timit import get_timit_audio_sentence_ds +from .datasets.preprocess import preprocess +from .datasets.text_encoder import TextEncoder +from .datasets.cache_shard_shuffle_batch import cache_shard_shuffle_batch + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + tf.random.set_seed(seed) + + +def prepare_devices(devices): + physical_devices = [gpu for idx, gpu in enumerate(tf.config.list_physical_devices('GPU')) if idx in devices] + if len(physical_devices) != len(devices): + raise ValueError('Could not find all devices!') + + try: + tf.config.experimental.set_visible_devices(physical_devices, 'GPU') + except RuntimeError: + return + for pd in physical_devices: + tf.config.experimental.set_memory_growth(pd, True) + + +def get_dataloaders(timit_root, batch_size): + # hidden arguments + encoder_class = 'phoneme' + num_parallel_calls = tf.data.experimental.AUTOTUNE + deterministic = True + + curriculum_learnings = [[16000, 2], [32000, 2]] + + splits = ['TRAIN', 'VAL', 'TEST'] + + + # Create common objects + + featurizer = AudioFeaturizer( + sample_rate=16000, + feature_type='lmel', + normalize_full_scale=False, + window_len_in_sec=0.025, + step_len_in_sec=0.010, + num_feature_filters=80, + mel_weight_mat=None, + verbose=False + ) + + encoder = TextEncoder(encoder_class=encoder_class) + + # helper function to apply common transformations for different + # parts of timit + def get_timit_ds(split_name, max_len): + ds = get_timit_audio_sentence_ds(timit_root, + split_name, + remove_sa=True, + encoder_class=encoder_class, + num_parallel_calls=num_parallel_calls, + deterministic=deterministic, + max_audio_size=max_len) + + # stats_file = str(pathlib.Path(__file__).parents[1].joinpath(f'timit_train_stats.npz')) + stats_file = None + + ds = preprocess(ds=ds, + encoder=encoder, + featurizer=featurizer, + norm_stats=stats_file, + epsilon=0.001, + num_parallel_calls=num_parallel_calls, + deterministic=deterministic, + max_feature_size=0) + + ds = cache_shard_shuffle_batch(ds=ds, + ds_cache_in_disk=False, + path_ds_cache=None, + ds_cache_in_memory=False, + shard_num_shards=None, + shard_index=None, + shuffle=(split_name == 'TRAIN'), + shuffle_buffer_size=2048, + num_feature_filters=80, + pad_strategy='bucket_by_sequence_length', + batch_size=batch_size, + padded_shapes=([None, 80], [], [None], []), + drop_remainder=False, + bucket_boundaries=[300], + bucket_batch_sizes=[min(batch_size, 64), min(batch_size, 48)], + device=None, + prefetch_buffer_size=1) + + steps = 0 + for _ in ds: + steps += 1 + + #print('!!!!!', split_name, steps, batch_size) + + ds = AttrDict({ + 'ds': ds, + 'encoder': encoder, + 'featurizer': featurizer, + 'steps': steps + }) + + return ds + + all_ds = [] + for split in splits: + curriculum = [] + if split == 'TRAIN': + for max_len, epochs in curriculum_learnings: + c_ds = get_timit_ds(split, max_len) + c_ds.ds = c_ds.ds.repeat(epochs) + curriculum.append(c_ds) + + ds = get_timit_ds(split, 0) + ds.ds = ds.ds.repeat() + + if curriculum: + t = curriculum[0].ds + for c in curriculum[1:]: + t = t.concatenate(c.ds) + t = t.concatenate(ds.ds) + ds.ds = t + + all_ds.append(ds) + + return (encoder, *all_ds) + + +get_trainer = trainer.get_trainer +get_loss = trainer.get_loss \ No newline at end of file diff --git a/nasbench_asr/training/tf/callbacks/__init__.py b/nasbench_asr/training/tf/callbacks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nasbench_asr/training/tf/callbacks/lrscheduler.py b/nasbench_asr/training/tf/callbacks/lrscheduler.py new file mode 100644 index 0000000..00a3b1a --- /dev/null +++ b/nasbench_asr/training/tf/callbacks/lrscheduler.py @@ -0,0 +1,63 @@ +# coding=utf-8 +import os +import pickle + +from nasbench_asr.quiet_tensorflow import tensorflow as tf + + +class ExponentialDecay(tf.keras.callbacks.Callback): + """ + Exponential decay-base learning rate scheduler. + """ + def __init__(self, + decay_factor, + start_epoch=None, + min_lr= None, + verbose=0): + """ + Args: + decay_factor : A float value (< 1.0) indicating the factor by which the LR should be reduced + start_epoch : At which epoch to start applying the decay, default None => start from first epoch + min_lr : What's the lowest value for the LR allowed, default None => 0 + """ + + super().__init__() + self.decay_factor = decay_factor + self.start_epoch = start_epoch + if self.start_epoch is None: + self.start_epoch = 1 + + self.min_lr = min_lr + if self.min_lr is None: + self.min_lr = 0.0 + + self.verbose = verbose + self.epoch = 0 + + def _schedule(self): + + """ + Allows the learning rate to cycle linearly within a range. + """ + lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) + if self.epoch < self.start_epoch: + return tf.math.reduce_max((self.min_lr, lr)) + else: + return tf.math.reduce_max((self.min_lr, lr * self.decay_factor)) + + def on_epoch_begin(self, epoch, logs=None): + if not hasattr(self.model.optimizer, 'lr'): + raise ValueError('Optimizer must have a "lr" attribute.') + self.epoch += 1 + + def on_epoch_end(self, epoch, logs=None): + logs = logs or {} + + # Call schedule function to get the scheduled learning rate. + scheduled_lr = self._schedule() + + # Set the value back to the optimizer + tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr) + + # Adding the current LR to the logs so that it will show up on tensorboard + logs['lr'] = float(tf.keras.backend.get_value(self.model.optimizer.lr)) diff --git a/nasbench_asr/training/tf/callbacks/reset_states.py b/nasbench_asr/training/tf/callbacks/reset_states.py new file mode 100644 index 0000000..aa8ec6d --- /dev/null +++ b/nasbench_asr/training/tf/callbacks/reset_states.py @@ -0,0 +1,16 @@ +from nasbench_asr.quiet_tensorflow import tensorflow as tf + + +class ResetStatesCallback(tf.keras.callbacks.Callback): + def __init__(self, trackers={}): + super().__init__() + self.trackers = trackers + + def on_epoch_begin(self, epoch, logs=None): + for metric in self.trackers["train"]: + self.trackers["train"][metric].reset_states() + + def on_test_begin(self, logs=None): + for metric in self.trackers["test"]: + self.trackers["test"][metric].reset_states() + diff --git a/nasbench_asr/training/tf/callbacks/tensorboard.py b/nasbench_asr/training/tf/callbacks/tensorboard.py new file mode 100644 index 0000000..9b01e58 --- /dev/null +++ b/nasbench_asr/training/tf/callbacks/tensorboard.py @@ -0,0 +1,28 @@ +from nasbench_asr.quiet_tensorflow import tensorflow as tf + + +class Tensorboard(tf.keras.callbacks.Callback): + """ + A simple TensorBoard callback + """ + def __init__(self, log_dir, update_freq=10): + super().__init__() + self.log_dir = log_dir + self.update_freq = update_freq + self.file_writer_train = tf.summary.create_file_writer(str(log_dir / "train")) + self.file_writer_val = tf.summary.create_file_writer(str(log_dir / "val")) + self.step = 0 + + def on_train_batch_end(self, batch, logs=None): + logs = logs or {} + + if self.step % self.update_freq == 0: + with self.file_writer_train.as_default(): + for k, val in logs.items(): + tf.summary.scalar("batch/" + k, data=val, step=self.step) + self.step += 1 + + def on_epoch_end(self, epoch, logs=None): + with self.file_writer_val.as_default(): + for k, val in logs.items(): + tf.summary.scalar("epoch/" + k, data=val, step=epoch+1) diff --git a/nasbench_asr/training/tf/datasets/audio_feature.py b/nasbench_asr/training/tf/datasets/audio_feature.py new file mode 100644 index 0000000..726b63f --- /dev/null +++ b/nasbench_asr/training/tf/datasets/audio_feature.py @@ -0,0 +1,475 @@ +# coding=utf-8 + +""" +Collection of function to extract features from from audio signal in TensorFlow +Author: SpeechX team, Cambrigde +""" + +import numpy as np +from nasbench_asr.quiet_tensorflow import tensorflow as tf + + + +def normalize_audio_full_scale(audio_pcm): + """ + Normalizes audio data to the full scale + """ + max_abs_val = tf.math.reduce_max(tf.math.abs(audio_pcm)) + + # Estimating the scale factor + gain = tf.constant(1.0) / (max_abs_val + tf.constant(1e-5)) + + return audio_pcm * gain + + +def normalize_spectrogram(magnitude_spectrogram): + """ + Performs mean and variance normalization of a spectrogram + """ + + magnitude_spectrogram -= tf.math.reduce_mean(magnitude_spectrogram) + magnitude_spectrogram /= tf.math.reduce_std(magnitude_spectrogram) + + return magnitude_spectrogram + + +def convert_to_dB(magnitude_spectrogram, normalize=False, ref_level_db=20.0, min_level_db=-100.0): + """ + Converts spectrograms to dB + + Args: + magnitude_spectrogram : A tensor containing magnitude spectrogram + normalize : A bool indicating if spectrogram normalization is needed + ref_level_db: Ref db level required [default 20]. Pass None if this not to be used. + min_level_db: Minimum db level required [default -100]. Pass None if this not to be used. + """ + + # Removing small values before taking log + magnitude_spectrogram = tf.clip_by_value(magnitude_spectrogram, + clip_value_min=1e-30, + clip_value_max=tf.math.reduce_max(magnitude_spectrogram)) + + magnitude_spectrogram = 10.0 * tf.math.log(magnitude_spectrogram) / tf.math.log(10.0) + + if normalize: + magnitude_spectrogram = normalize_spectrogram(magnitude_spectrogram) + + # this is from amp-to-db function in librosa + # Source: https://librosa.org/librosa/master/_modules/librosa/core/spectrum.html#amplitude_to_db + # Source: https://github.com/mindslab-ai/voicefilter/blob/master/utils/audio.py + if ref_level_db is not None: + magnitude_spectrogram -= ref_level_db + if min_level_db is not None: + magnitude_spectrogram /= -min_level_db + magnitude_spectrogram = tf.clip_by_value(magnitude_spectrogram, clip_value_min=-1.0, clip_value_max=0.0) + 1.0 + + return magnitude_spectrogram + + +def get_stft(audio_pcm, + normalize=False, + fft_length=512, + window_len=None, + step_len=None, + center=True, + verbose=0): + + """ + Performs short time fourier transformation of a time domain audio signal + + Parameters + ---------- + audio_pcm : A 1D tensor (float32) holding the input audio + fft_length : (int in samples) length of the windowed signal after padding, + which will be used to extract FFT + window_len : (int > 0 and <= fft_length) length of each audio frame in samples [default: fft_length] + step_len : (int > 0) length of hop / stride in samples [default: window_length // 4] + center : (Bool) Type of padding to be used to match librosa + verbose : Verbosity level, 0 = no ouput, > 0 debug prints + + This function returns a complex-valued matrix stfts + """ + + # Checking the input type and perform casting if necessary + if audio_pcm.dtype != 'float32': + audio_pcm = tf.cast(audio_pcm, tf.float32) + + # Performing audio normalization + if normalize: + audio_pcm = normalize_audio_full_scale(audio_pcm) + + if window_len is None: + window_len = fft_length + + if step_len is None: + step_len = int(window_len // 4) + + # Perform padding of the original signal + if center: + pad_amount = int(window_len // 2) # As used by Librosa + + if verbose > 0: + print(f'[INFO] (audio_feature.get_stft)] pad_amount = {pad_amount}') + + audio_pcm = tf.pad(audio_pcm, [[pad_amount, pad_amount]], 'REFLECT') + + # Extracting frames from sudio signal + frames = tf.signal.frame(audio_pcm, window_len, step_len, pad_end=False) + + if verbose > 0: + print(f'[INFO] (audio_feature.get_stft)] frames.shape = {frames.shape}') + + # Generating hanning window + fft_window = tf.signal.hann_window(window_len, periodic=True) + + # Computing the spectrogram, the output is an array of complex number + stfts = tf.signal.rfft(frames * fft_window, fft_length=[fft_length]) + + return stfts + + +def get_magnitude_spectrogram(audio_pcm, + sample_rate, + window_len_in_sec=0.025, + step_len_in_sec=0.010, + exponent=2.0, + nfft=None, + normalize_full_scale=False, + compute_phase=False, + verbose=0): + + """ + Computes the magnitude spectrogram of an audio signal + + Parameters: + ----------- + audio_pcm: A 1D tensor (float32) holding the input audio + sample_rate: Samling frequency of the recorded audio + window_len_in_sec: float, in seconds + step_len_in_sec: float, in seconds + exponent: Int, 1 for energy and 2 for power [default 2] + normalize_full_scale: If full scale power normalization to be + performed, default is False + compute_phase Compute and return phase for all frames [default False] + verbose: Verbosity level, 0 = no ouput, > 0 debug prints + """ + + # Full-scale normalization of audio + if normalize_full_scale: + audio_pcm = normalize_audio_full_scale(audio_pcm) + + # Estimating parameters for STFT + frame_length_in_sample = int(window_len_in_sec * sample_rate) + frame_step_in_sample = int(step_len_in_sec * sample_rate) + + if nfft is None: + nfft = frame_length_in_sample + + stfts = tf.signal.stft( + signals=audio_pcm, + frame_length=frame_length_in_sample, + frame_step=frame_step_in_sample, + fft_length=nfft, + window_fn=tf.signal.hann_window, + pad_end=False) + + magnitude_spectrograms = tf.abs(stfts) + + if exponent != 1.0: + magnitude_spectrograms = tf.math.pow(magnitude_spectrograms, exponent) + + phases = None + if compute_phase: + phases = tf.math.angle(stfts) + + return magnitude_spectrograms, phases + + +def get_magnitude_spectrogram_dB(audio_pcm, + sample_rate, + normalize_spec=False, + ref_level_db=20.0, + min_level_db=-100.0, + **kwargs): + """ + Computes the magnitude spectrogram in dB + + Parameters: + ----------- + audio_pcm: A 1D tensor (float32) holding the input audio + sample_rate: Samling frequency of the recorded audio + window_len_in_sec: float, in seconds + step_len_in_sec: float, in seconds + exponent: Int, 1 for energy and 2 for power [default 2] + normalize_full_scale: If full scale power normalization to be + performed, default is False + normalize_spec Is magnitude spectogram to be normalized [default False] + ref_level_db: Ref db level required [default 20] + min_level_db: Minimum db level required [default -100] + compute_phase Compute and return phase for all frames [default False] + verbose: Verbosity level, 0 = no ouput, > 0 debug prints + """ + + + magnitude_spectrogram, phase = get_magnitude_spectrogram(audio_pcm, sample_rate, **kwargs) + magnitude_spectrogram_dB = convert_to_dB(magnitude_spectrogram, normalize=normalize_spec, ref_level_db=ref_level_db, min_level_db=min_level_db) + + return magnitude_spectrogram_dB, phase + + +def wav2spec(audio_pcm, + sample_rate, + ref_level_db=20.0, + min_level_db=-100.0, + **kwargs): + """ + Computes the magnitude spectrogram in dB and phase + + Parameters: + ----------- + audio_pcm: A 1D tensor (float32) holding the input audio + sample_rate: Samling frequency of the recorded audio + ref_level_db: Ref db level required [default 20] + min_level_db: Minimum db level required [default -100] + window_len_in_sec: float, in seconds + step_len_in_sec: float, in seconds + exponent: Int, 1 for energy and 2 for power [default 2] + normalize_full_scale: If full scale power normalization to be + performed, default is False + verbose: Verbosity level, 0 = no ouput, > 0 debug prints + """ + + magnitude_spectrogram, phase = get_magnitude_spectrogram(audio_pcm, sample_rate, **dict(kwargs, compute_phase=True)) + magnitude_spectrogram_dB = convert_to_dB(magnitude_spectrogram, normalize=False, ref_level_db=ref_level_db, min_level_db=min_level_db) + return magnitude_spectrogram_dB, phase + + +def spec2wav(magnitude_spectrogram, + phase, + sample_rate, + nfft=None, + ref_level_db=20, + min_level_db=-100, + window_len_in_sec=0.025, + step_len_in_sec=0.010, + exponent=2.0): + """ + Computes the audio pcm from magnitude spectrogram and phase + + Parameters: + ----------- + magnitude_spectrogram: Magnitude spectogram of audio pcm + phase: Phase obtained from stfts of audio pcm + sample_rate: Samling frequency of the recorded audio + ref_level_db: Ref db level required [defaul 20] + min_level_db: Minimum db level required [default -100] + window_len_in_sec: float, in seconds + step_len_in_sec: float, in seconds + exponent: Int, 1 for energy and 2 for power [default 2] + """ + + magnitude_spectrogram = tf.clip_by_value(magnitude_spectrogram, clip_value_min=0.0, clip_value_max=1.0) + magnitude_spectrogram = (magnitude_spectrogram - 1.0) * - min_level_db + magnitude_spectrogram += ref_level_db + magnitude_spectrogram = tf.math.pow(tf.constant(10.0), magnitude_spectrogram / (exponent*10)) + magnitude_spectrogram = tf.cast(magnitude_spectrogram, dtype=tf.complex64) + + phase = tf.complex(tf.zeros(tf.shape(phase)), phase) + phase = tf.math.exp(phase) + stfts = magnitude_spectrogram * phase + + # Estimating parameters for STFT + frame_length_in_sample = int(window_len_in_sec * sample_rate) + frame_step_in_sample = int(step_len_in_sec * sample_rate) + if nfft is None: + nfft = frame_length_in_sample + + W = tf.signal.inverse_stft( + stfts=stfts, + frame_length=frame_length_in_sample, + frame_step=frame_step_in_sample, + fft_length=nfft, + window_fn=tf.signal.inverse_stft_window_fn( + frame_step=frame_step_in_sample, + forward_window_fn=tf.signal.hann_window + ) + ) + return W + +def get_mel_filterbank(audio_pcm, + sample_rate, + window_len_in_sec=0.025, + step_len_in_sec=0.010, + nfft=None, + normalize_full_scale=False, + num_feature_filters=40, + lower_edge_hertz=0.0, + upper_edge_hertz=8000.0, + exponent=2.0, + mel_weight_mat=None, + verbose=0): + """ + Computes Mel-filterbank features from an audio signal using TF operations. + + Parameters + ---------- + audio_pcm: A 1D tensor (float32) holding the input audio + sample_rate: Samling frequency of the recorded audio + window_len_in_sec: float, in seconds + step_len_in_sec: float, in seconds + num_feature_filters: int, e.g., 40 + lower_edge_hertz: Lower frequency to consider in the mel scale + upper_edge_hertz: Upper frequency to consider in the mel scale + exponent: Int, 1 for energy and 2 for power [default 2] + mel_weight_mat: Accepts a mel_weight_matrix [defult None and generates using the HTK algo] + verbose: Verbosity level, 0 = no ouput, > 0 debug prints + + Returns: + Tensor (audio_len // int(step_len * sample_rate), num_feature_filters), float32 + """ + + magnitude_spectrograms, _ = \ + get_magnitude_spectrogram( + audio_pcm=audio_pcm, + sample_rate=sample_rate, + window_len_in_sec=window_len_in_sec, + step_len_in_sec=step_len_in_sec, + nfft=nfft, + compute_phase=False, + exponent=exponent, + normalize_full_scale=normalize_full_scale, + verbose=verbose) + + if verbose: + print('[INFO] (audio_feature.get_mel_filterbank) magnitude_spectrograms.shape', magnitude_spectrograms.shape) + + num_spectrogram_bins = tf.shape(magnitude_spectrograms)[-1] + + if mel_weight_mat is None: + if verbose: + print('[INFO] (audio_feature.get_mel_filterbank) mel_weight_mat not provided and generating using TF2.') + mel_weight_mat = tf.signal.linear_to_mel_weight_matrix(num_feature_filters, + num_spectrogram_bins, + sample_rate, + lower_edge_hertz, + upper_edge_hertz) + + if verbose: + print('[INFO] (audio_feature.get_mel_filterbank) linear_to_mel_weight_matrix.shape', mel_weight_mat.shape) + + + mel_spectrograms = tf.tensordot(magnitude_spectrograms, mel_weight_mat, 1) + mel_spectrograms.set_shape( + magnitude_spectrograms.shape[:-1].concatenate(mel_weight_mat.shape[-1:]) + ) + + if verbose: + print('[INFO] (audio_feature.get_mel_filterbank) mel_spectrograms.shape', mel_spectrograms.shape) + + return mel_spectrograms + + +def get_log_mel_filterbank(audio_pcm, + sample_rate, + **kwargs): + """ + Computes Log-Mel-filterbank features from an audio signal using TF operations. + + Parameters + ---------- + audio_pcm: A 1D tensor (float32) holding the input audio + sample_rate: Samling frequency of the recorded audio + window_len_in_sec: float, in seconds + step_len_in_sec: float, in seconds + num_feature_filters: Int, e.g., 40 + lower_edge_hertz: Lower frequency to consider in the mel scale + upper_edge_hertz: Upper frequency to consider in the mel scale + exponent: Int, 1 for energy and 2 for exponent [default 2] + mel_weight_mat: Accepts a mel_weight_matrix [defult None and generates using the HTK algo] + verbose: Verbosity level, 0 = no ouput, > 0 debug prints + + Returns: + numpy.ndarray, (audio_len // int(step_len * sample_rate), num_feature_filters), float32 + """ + return tf.math.log(get_mel_filterbank(audio_pcm, sample_rate, **kwargs) + 1e-10) + +def get_mfcc(audio_pcm, + sample_rate, + **kwargs): + """ + Extracts mfcc feature from audio_pcm measurements + + Parameters + ---------- + audio_pcm: A 1D tensor (float32) holding the input audio + sample_rate: Samling frequency of the recorded audio + window_len_in_sec: float, in seconds + step_len_in_sec: float, in seconds + num_feature_filters: int, e.g., 40 + lower_edge_hertz: Lower frequency to consider in the mel scale + upper_edge_hertz: Upper frequency to consider in the mel scale + exponent: Int, 1 for energy and 2 for power [default 2] + mel_weight_mat: Accepts a mel_weight_matrix [defult None and generates using the HTK algo] + verbose: Verbosity level, 0 = no ouput, > 0 debug prints + + """ + + log_mel_spectrograms = get_log_mel_filterbank(audio_pcm, sample_rate, **kwargs) + + mfccs = tf.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrograms) + + return mfccs + + +def get_power_mel_filterbank(audio_pcm, + sample_rate, + power_coeff=1.0 / 15.0, + **kwargs): + """ + Computes power Mel-filterbank features (PNCC) from an audio signal. + + References: + https://github.sec.samsung.net/STAR/speech/blob/master/speech/trainer/returnn_based_end_to_end_trainer/ver0p2/GeneratingDataset.py + + Parameters + ---------- + audio_pcm: A 1D tensor (float32) holding the input audio + sample_rate: Samling frequency of the recorded audio + window_len_in_sec: float, in seconds + step_len_in_sec: float, in seconds + num_feature_filters: Int, e.g., 40 + lower_edge_hertz: Lower frequency to consider in the mel scale + upper_edge_hertz: Upper frequency to consider in the mel scale + exponent: Int, 1 for energy and 2 for power [default 2] + mel_weight_mat: Accepts a mel_weight_matrix [defult None and generates using the HTK algo] + verbose: Verbosity level, 0 = no ouput, > 0 debug prints + + Returns: + Tensor, (audio_len // int(step_len * sample_rate), num_feature_filters), float32 + """ + assert power_coeff > 0.0 and power_coeff < 1.0, 'Invalid power_coeff!!!' + + mel_filterbank = get_mel_filterbank(audio_pcm, sample_rate, **kwargs) + + feature_vector = mel_filterbank ** power_coeff + + return feature_vector + +def get_feature(audio, sample_rate, feature_type='pmel', **kwargs): + """ + A wrapper function for audio features + """ + if feature_type == 'spec': + return get_magnitude_spectrogram(audio, sample_rate, **kwargs) + elif feature_type == 'spec_dB': + return get_magnitude_spectrogram_dB(audio, sample_rate, **kwargs) + elif feature_type == 'pmel': + return get_power_mel_filterbank(audio, sample_rate, **kwargs) + elif feature_type == 'lmel': + return get_log_mel_filterbank(audio, sample_rate, **kwargs) + elif feature_type == 'mel': + return get_mel_filterbank(audio, sample_rate, **kwargs) + elif feature_type == 'mfcc': + return get_mfcc(audio, sample_rate, **kwargs) + else: + raise NotImplementedError(f'Unsupported audio feature type {feature_type}') diff --git a/nasbench_asr/training/tf/datasets/audio_featurizer.py b/nasbench_asr/training/tf/datasets/audio_featurizer.py new file mode 100644 index 0000000..71fca19 --- /dev/null +++ b/nasbench_asr/training/tf/datasets/audio_featurizer.py @@ -0,0 +1,69 @@ +# pylint: skip-file +import os +import sys + +from .audio_feature import get_feature + + +class AudioFeaturizer: + def __init__(self, sample_rate, feature_type, normalize_full_scale, + window_len_in_sec, step_len_in_sec, num_feature_filters, + mel_weight_mat, verbose): + self.sample_rate = sample_rate + self.feature_type = feature_type + self.normalize_full_scale = normalize_full_scale + self.window_len_in_sec = window_len_in_sec + self.step_len_in_sec = step_len_in_sec + self.num_feature_filters = num_feature_filters + self.mel_weight_mat = mel_weight_mat + self.verbose = verbose + + def namespace(self): + """ + This function returns a list of the hyper-parameters related to + transformation of audio into features, which impacts the creation of + the caches of the datasets that we support (grep also for "def + get_path_ds_cache" to find where this is used toward that end). + """ + output = "" + + params = [ + "sample_rate", + "feature_type", + "normalize_full_scale", + "window_len_in_sec", + "step_len_in_sec", + "num_feature_filters", + "mel_weight_mat", + ] + + for param in params: + output += param + "_" + \ + str(getattr(self, param)).replace( + "/", "_").replace("*", "_") + "/" + + return output + + def __call__(self, audio): + if self.feature_type in ["spec", "spec_dB"]: + audio_feature, _ = get_feature( + audio, + sample_rate=self.sample_rate, + feature_type=self.feature_type, + normalize_full_scale=self.normalize_full_scale, + window_len_in_sec=self.window_len_in_sec, + step_len_in_sec=self.step_len_in_sec, + verbose=self.verbose) + else: + audio_feature = get_feature( + audio, + sample_rate=self.sample_rate, + feature_type=self.feature_type, + normalize_full_scale=self.normalize_full_scale, + window_len_in_sec=self.window_len_in_sec, + step_len_in_sec=self.step_len_in_sec, + num_feature_filters=self.num_feature_filters, + mel_weight_mat=self.mel_weight_mat, + verbose=self.verbose) + + return audio_feature diff --git a/nasbench_asr/training/tf/datasets/audio_sentence_timit.py b/nasbench_asr/training/tf/datasets/audio_sentence_timit.py new file mode 100644 index 0000000..f047824 --- /dev/null +++ b/nasbench_asr/training/tf/datasets/audio_sentence_timit.py @@ -0,0 +1,116 @@ +# pylint: skip-file +import os +import glob + +from nasbench_asr.quiet_tensorflow import tensorflow as tf + + +def get_paths_to_wav(folder): + """ + Args: + folder: string with location of TIMIT dataset, should end in "TRAIN" or + "TEST" + Returns: + paths_to_wav: sorted list "*.RIFF.WAV" files in folder or in any of its + children + """ + folder = os.path.expanduser(folder) + assert os.path.exists(folder) + pattern = folder + "/**/*.RIFF.WAV" + paths_to_wav = glob.glob(pattern, recursive=True) + # sort() ensures different calls to this function always return the same + # list + paths_to_wav.sort() + + return paths_to_wav + + +def get_audio_and_sentence_fn(encoder_class): + @tf.function + def get_audio_and_sentence(path_to_wav): + """ + Args: + path_to_wav: tf.string with path to a wav file + Returns: + (audio, sentence): + - audio is tf.float32 of shape [None] + - sentence is a tf.string of shape [], without '\n', without '.', upper + case + """ + # Original TIMIT audio files have Header NIST, but tf.audio.decode_wav + # expects RIFF. Use sox to fix this issue, namely run: + # apt-get install -y parallel + # find TIMIT -name '*.WAV' | parallel -P20 sox {} '{.}.RIFF.WAV' + audio, _ = tf.audio.decode_wav(tf.io.read_file(path_to_wav), + desired_channels=1, + desired_samples=-1) + audio = tf.squeeze(audio) + + if encoder_class == "phoneme": + path_to_phn = tf.strings.join( + [tf.strings.split(path_to_wav, sep=".RIFF.WAV")[0], ".PHN"]) + + def get_last_column(sentence): + return tf.strings.split(sentence, sep=" ")[-1] + + phonemes = tf.io.read_file(path_to_phn) + phonemes = tf.strings.strip(phonemes) + phonemes = tf.strings.split(phonemes, sep="\n") + phonemes = tf.map_fn(get_last_column, phonemes) + + return audio, phonemes + else: + path_to_txt = tf.strings.join( + [tf.strings.split(path_to_wav, sep=".RIFF.WAV")[0], ".TXT"]) + sentence = tf.strings.reduce_join( + tf.strings.split(tf.io.read_file(path_to_txt), sep=" ")[2:], + separator=" ", + ) + # Remove '\n' from end of sentence + sentence = tf.strings.strip(sentence) + # Replace '.' with '' + sentence = tf.strings.regex_replace(sentence, "\.+", "") + # Change sentence to upper + sentence = tf.strings.upper(sentence) + + return audio, sentence + + return get_audio_and_sentence + + +def get_timit_audio_sentence_ds(folder, + ds_name, + remove_sa=True, + encoder_class='phoneme', + num_parallel_calls=-1, + deterministic=True, + max_audio_size=0): + """ + Returns: + - ds: yields (audio, sentence) tuples + where + - audio has shape [None] and is of type tf.float32 + - sentence has shape [] and is of type tf.string + from the TIMIT dataset indicated in the params. + """ + paths_to_wav = get_paths_to_wav(folder=os.path.join(folder, ds_name)) + if remove_sa: + paths_to_wav = [ + path_to_wav for path_to_wav in paths_to_wav + if path_to_wav.split("/")[-1][:2] != "SA" + ] + ds = tf.data.Dataset.from_tensor_slices(paths_to_wav) + get_audio_and_sentence = get_audio_and_sentence_fn(encoder_class=encoder_class) + ds = ds.map( + get_audio_and_sentence, + num_parallel_calls=num_parallel_calls, + deterministic=deterministic, + ) + + if max_audio_size > 0: + def filter_fn(audio, sentence): + return tf.size(audio) < tf.saturate_cast(max_audio_size, tf.int32) + + ds = ds.filter(filter_fn) + + return ds diff --git a/nasbench_asr/training/tf/datasets/cache_shard_shuffle_batch.py b/nasbench_asr/training/tf/datasets/cache_shard_shuffle_batch.py new file mode 100644 index 0000000..564758d --- /dev/null +++ b/nasbench_asr/training/tf/datasets/cache_shard_shuffle_batch.py @@ -0,0 +1,94 @@ +# pylint: skip-file +import os +import sys + +import numpy as np +from nasbench_asr.quiet_tensorflow import tensorflow as tf + + +def cache_shard_shuffle_batch(*, + ds, + ds_cache_in_disk=False, + path_ds_cache="", + ds_cache_in_memory=False, + shard_num_shards=None, + shard_index=None, + shuffle=False, + shuffle_buffer_size=1, + num_feature_filters=None, + pad_strategy="padded_batch", + batch_size=2, + padded_shapes=([None, None], [], [None], []), + drop_remainder=True, + bucket_boundaries=[sys.maxsize], + bucket_batch_sizes=[2, 1], + device=None, + prefetch_buffer_size=1): + """ + Args: + - ds: yields (feature, feature_size, encoded, encoded_size) tuples where + - feature has shape [time, channels], and is of type tf.float32 + - feature_size has shape [], and is of type tf.int32, and represents + the number of time frames + - encoded has shape [None], and is of type tf.int32, and represents a + text encoded version of the original sentence; it contains values in + the range [1, encoder.vocab_size) + - encoded_size has shape [], and is of type tf.int32, and represents + the number of tokens in each text encoded version of the original + sentence + Returns: + - ds: yields (features, features_size, encodeds, encodeds_size) tuples + where + - features has shape [batch_size, time, channels] + - features_size has shape [batch_size] + - encodeds has shape [batch_size, None] + - encodeds_size has shape [batch_size] + """ + + if ds_cache_in_disk: + # cache to disk + ds = ds.cache(path_ds_cache) + + if shard_num_shards is not None and shard_index is not None: + ds = ds.shard(num_shards=shard_num_shards, index=shard_index) + + if ds_cache_in_memory: + # cache to memory + ds = ds.cache() + + if shuffle: + ds = ds.shuffle(shuffle_buffer_size) + + if pad_strategy == "padded_batch": + ds = ds.padded_batch( + batch_size=batch_size, + padded_shapes=padded_shapes, + drop_remainder=drop_remainder, + ) + elif pad_strategy == "bucket_by_sequence_length": + + def element_length_func(feature, feature_size, encoded, encoded_size): + return feature_size + + transformation_func = tf.data.experimental.bucket_by_sequence_length( + element_length_func=element_length_func, + bucket_boundaries=bucket_boundaries, + bucket_batch_sizes=bucket_batch_sizes, + padded_shapes=padded_shapes, + padding_values=None, + pad_to_bucket_boundary=False, + no_padding=False, + drop_remainder=drop_remainder, + ) + ds = ds.apply(transformation_func) + else: + assert False + + if device is not None: + ds = ds.apply( + tf.data.experimental.prefetch_to_device( + device, buffer_size=prefetch_buffer_size)) + else: + ds = ds.prefetch(prefetch_buffer_size) + + return ds diff --git a/nasbench_asr/training/tf/datasets/phoneme_encoder.py b/nasbench_asr/training/tf/datasets/phoneme_encoder.py new file mode 100644 index 0000000..7e21ae0 --- /dev/null +++ b/nasbench_asr/training/tf/datasets/phoneme_encoder.py @@ -0,0 +1,40 @@ +# pylint: skip-file + +from .timit_foldings import * + + +class PhonemeEncoder: + def __init__(self): + source_phonemes, source_encodes, dest_phonemes, dest_encodes, _ = get_phoneme_mapping(source_enc_name='p61', dest_enc_name='p48') + + self.source_phonemes = source_phonemes + self.dest_phonemes = dest_phonemes + + self.phoneme_to_index = {} + for phoneme, encoding in zip(source_phonemes, dest_encodes): + self.phoneme_to_index[phoneme] = encoding + + # As in tfds.features.text.SubwordTextEncoder, we assume Size of the + # vocabulary. Decode produces ints [1, vocab_size). Hence the addition + # of one to the len of the phonemes + self.vocab_size = len(dest_phonemes) + 1 + + def encode(self, sentence): + indices = [] + for phoneme in sentence: + phoneme = phoneme.decode("utf-8") + assert (phoneme in self.phoneme_to_index), f"{phoneme} not present in the encoder list" + if phoneme == "q": + # there is no "q" in p48 nor in p39 and # + # self.phoneme_to_index["q"] is 0 which is not a valid index + continue + indices.append(self.phoneme_to_index[phoneme]) + + return indices + + def decode(self, indices): + if len(indices) != 0: + assert max(indices) < self.vocab_size + sentence = "" + + return sentence diff --git a/nasbench_asr/training/tf/datasets/preprocess.py b/nasbench_asr/training/tf/datasets/preprocess.py new file mode 100644 index 0000000..f987ea1 --- /dev/null +++ b/nasbench_asr/training/tf/datasets/preprocess.py @@ -0,0 +1,63 @@ +# pylint: skip-file +import numpy as np +from nasbench_asr.quiet_tensorflow import tensorflow as tf + + +def preprocess( + *, + ds, + encoder, + featurizer, + norm_stats=None, + epsilon=0.001, + num_parallel_calls=tf.data.experimental.AUTOTUNE, + deterministic=True, + max_feature_size=0 +): + """ + Args: + - ds: yields (audio, sentence) tuples where + - audio has shape [None], and is of type tf.float32 + - sentence has shape [], and is of type tf.string + Returns: + - ds: yields (feature, feature_size, encoded, encoded_size) tuples where + - feature has shape [time, channels], and is of type tf.float32 + - feature_size has shape [], and is of type tf.int32, and represents + the number of time frames + - encoded has shape [None], and is of type tf.int32, and represents a + text encoded version of the original sentence; it contains values in + the range [1, encoder.vocab_size) + - encoded_size has shape [], and is of type tf.int32, and represents + the number of tokens in each text encoded version of the original + sentence + - featurizer + - encoder + """ + if norm_stats: + norm_stats = np.load(norm_stats) + mean = norm_stats['moving_mean'] + variance = norm_stats['moving_variance'] + norm_stats = True + + def preprocess_map_func(audio, sentence): + feature = featurizer(audio) + feature_size = tf.shape(feature)[0] + encoded = encoder.get_encoded_from_sentence(sentence) + encoded_size = tf.shape(encoded)[0] + + if norm_stats: + feature = (feature - mean) / tf.math.sqrt(variance + epsilon) + + return feature, feature_size, encoded, encoded_size + + ds = ds.map(preprocess_map_func, + num_parallel_calls=num_parallel_calls, + deterministic=deterministic) + + if max_feature_size > 0: + def filter_fn(feature, feature_size, encoded, encoded_size): + return feature_size < tf.saturate_cast(max_feature_size, tf.int32) + + ds = ds.filter(filter_fn) + + return ds diff --git a/nasbench_asr/training/tf/datasets/text_encoder.py b/nasbench_asr/training/tf/datasets/text_encoder.py new file mode 100644 index 0000000..98262f2 --- /dev/null +++ b/nasbench_asr/training/tf/datasets/text_encoder.py @@ -0,0 +1,66 @@ +# pylint: skip-file + +from nasbench_asr.quiet_tensorflow import tensorflow as tf + +from .phoneme_encoder import PhonemeEncoder + + +def get_utf8_valid_sentence(sentence): + return sentence.numpy() + + +def get_corpus_generator(ds): + + for _, sentence in ds: + yield get_utf8_valid_sentence(sentence) + + +def get_encoded_from_sentence_fn(encoder): + def get_encoded_from_sentence_helper(sentence): + # the following [] are essential! + encoded = [encoder.encode(get_utf8_valid_sentence(sentence))] + + return encoded + + def get_encoded_from_sentence(sentence): + # the following [] are essential! + encoded = tf.py_function(get_encoded_from_sentence_helper, [sentence], + tf.int32) + + return encoded + + return get_encoded_from_sentence + + +def get_decoded_from_encoded_fn(encoder): + def get_decoded_from_encoded_helper(encoded): + # the following [] are essential! + decoded = [ + get_utf8_valid_sentence( + tf.constant(encoder.decode(encoded.numpy().tolist()))) + ] + + return decoded + + def get_decoded_from_encoded(encoded): + # the following [] are essential! + decoded = tf.py_function(get_decoded_from_encoded_helper, [encoded], + tf.string) + + return decoded + + return get_decoded_from_encoded + + +class TextEncoder: + def __init__( + self, + encoder_class, + ): + if encoder_class != 'phoneme': + raise ValueError('Unsupported encoder type {!r}'.format(encoder_class)) + + self.encoder_class = encoder_class + self.encoder = PhonemeEncoder() + self.get_encoded_from_sentence = get_encoded_from_sentence_fn(self.encoder) + self.get_decoded_from_encoded = get_decoded_from_encoded_fn(self.encoder) diff --git a/nasbench_asr/training/tf/datasets/timit_foldings.py b/nasbench_asr/training/tf/datasets/timit_foldings.py new file mode 100644 index 0000000..70fa479 --- /dev/null +++ b/nasbench_asr/training/tf/datasets/timit_foldings.py @@ -0,0 +1,95 @@ + +import pathlib + +from nasbench_asr.quiet_tensorflow import tensorflow as tf + + +def old_to_new_indices(table, old_indices): + """ + This function uses lookup table to convert given indexices to corresponding ones + + The input to this function has shape [batch_size, max_num_indices] and is given by + old_indices = = [old_indices[0], ..., old_indices[batch_size - 1]] + and the output to this function has the same shape as is given by + tf.map_fn(fn=fn, elems=old_indices) = [fn(old_indices[0]), ..., fn(old_indices[batch_size - 1])] + As an example, let's say that table was constructed from + - keys: [1, 2, 3] + - vals: [3, 5, 2] + - missing elements replaced with value 0 + and that the input is + old_indices = [[1, 2, 44, 2, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1]] + Then for the first row of the input, we'll have + x = old_indices[0] = [1, 2, 44, 2, 1, 0, 0] + y = [3, 5, 0, 5, 3, 0, 0] + tf.boolean_mask(y, y > 0) = [3, 5, 5, 3] + tf.zeros(tf.reduce_sum(tf.cast(y <= 0, dtype=tf.int32)) = [0, 0, 0] + z = [3, 5, 5, 3, 0, 0, 0] + whereas for the second row of the input we'll have + x = old_indices[1] = [1, 1, 1, 1, 1, 1, 1] + y = [3, 3, 3, 3, 3, 3, 3] + tf.boolean_mask(y, y > 0) = [3, 3, 3, 3, 3, 3, 3] + tf.zeros(tf.reduce_sum(tf.cast(y <= 0, dtype=tf.int32)) = [] + z = [3, 3, 3, 3, 3, 3, 3] + Therefore the output will be + [[3, 5, 5, 3, 0, 0, 0], [3, 3, 3, 3, 3, 3, 3]] + """ + def fn(x): + y = table.lookup(x) + z = tf.concat( + [ + tf.boolean_mask(y, y > 0), + tf.zeros(tf.reduce_sum(tf.cast(y <= 0, dtype=tf.int32)), + dtype=y.dtype), + ], + axis=0, + ) + + return z + + return tf.map_fn(fn=fn, elems=old_indices) + +def get_phoneme_mapping(source_enc_name='p61', dest_enc_name='p48'): + #creates a mapping [defined in timit_foldings.txt] for a given source to destination folding + #also returns a lookup table for mapping source to dest indices + file_path = pathlib.Path(__file__).parents[2].joinpath('timit_folding.txt') + assert file_path.exists(), f'Timit mapping file not found: {file_path}' + + with file_path.open('r') as f: + mapping = f.readlines() + mapping = [m.strip().split('\t') for m in mapping] + + #remove phonemes with no mapping + no_map_phonemes = [m[0] for m in mapping if len(m)<2] + mapping = [m for m in mapping if m[0] not in no_map_phonemes] + + foldings = ['p61', 'p48', 'p39'] #don't change the order. It is same as the order of mapping. + assert source_enc_name in foldings and dest_enc_name in foldings, 'Encoding name is incorrect' + ph61 = sorted(list(set([m[0] for m in mapping] + no_map_phonemes))) + ph48 = sorted(list(set([m[1] for m in mapping]))) + ph39 = sorted(list(set([m[2] for m in mapping]))) + phonemes = [ph61, ph48, ph39] + + source_idx = foldings.index(source_enc_name) + dest_idx = foldings.index(dest_enc_name) + source_phonemes = phonemes[source_idx] + dest_phonemes = phonemes[dest_idx] + + source_encodes = [] + dest_encodes = [] + for idx, ph in enumerate(source_phonemes): + source_encodes.append(idx + 1) + ph_idx_in_map = [i for i, _map in enumerate(mapping) if _map[source_idx] == ph] + + if len(ph_idx_in_map) == 0: #phoneme not resent in mapping [special case for q] + dest_encodes.append(0) + else: + dest_ph = mapping[ph_idx_in_map[0]][dest_idx] + dest_encodes.append(dest_phonemes.index(dest_ph) + 1) + + # create the hash table (must have 0 to 0 mapping) + source_encodes = tf.constant(source_encodes, dtype=tf.int32) + dest_encodes = tf.constant(dest_encodes, dtype=tf.int32) + table = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer(source_encodes, dest_encodes), 0) + + return source_phonemes, source_encodes, dest_phonemes, dest_encodes, table diff --git a/nasbench_asr/training/tf/main.py b/nasbench_asr/training/tf/main.py new file mode 100644 index 0000000..7a635d8 --- /dev/null +++ b/nasbench_asr/training/tf/main.py @@ -0,0 +1,172 @@ +# pylint: skip-file +# import argparse +import logging +import os +import sys +import random +import numpy as np +from attrdict import AttrDict +import pickle + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # FATAL +os.environ['TF_DETERMINISTIC_OPS'] = '1' + + +from nasbench_asr.quiet_tensorflow import tensorflow as tf +#tf.config.experimental_run_functions_eagerly(True) + +logging.getLogger("tensorflow").setLevel(logging.FATAL) +from .core.trainer.callbacks.model_checkpoint_max_to_keep import ModelCheckpointMaxToKeep +from .core.trainer.callbacks import lrscheduler +from .core.trainer.callbacks.nni_report import NNIReport +from .core.trainer.callbacks.tensorboard_image import Tensorboard +from .core.utils import (expand_path_if_nni, read_yaml, + update_config_with_nni_tuned_params) + +from .datasets.datasets import get_data +from .core.learner.asr_ctc import ASRCTCLearner + +from ...model.tf.model import * +from ...model.tf.ops import * + +try: + import nni +except ImportError: + print("Microsoft NNI is not installed, and will therefore not be used") + +def main(config): + random.seed(config.seed) + np.random.seed(config.seed) + tf.random.set_seed(config.seed) + + # hvd_util.init_gpus() + physical_devices = tf.config.list_physical_devices('GPU') + try: + tf.config.experimental.set_memory_growth(physical_devices[0], True) + except: + # Invalid device or cannot modify virtual devices once initialized. + pass + + verbose = config.verbose #if hvd.rank() == 0 else 0 + + data_train, data_validate, data_test = get_data(config.dataset) + + model_configs = gen_model_configs(shuffle=False) + model_config = model_configs[config.counter] + #model_config = [[3,1], [4,1,1], [0,0,0,1]] + model = ASRModel(model_config=config, sc_config=model_config, num_classes=data_train.encoder.encoder.vocab_size, use_rnn=config.use_rnn, + **{ + "input_shape": [None, config.dataset.featurizer.num_feature_filters], + "epsilon": 0.001, + "stats": data_train.stats, + "mask_time_fraction": 0.0, + "mask_channels_fraction": 0.0, + "seed": config.seed + } + ) + + learning_rate = config.train.lr #* hvd.size() + optimizer = getattr(tf.keras.optimizers, config.train.optimizer)(learning_rate) + + # Adding learning rate scheduler callback + lr_scheduler = getattr(lrscheduler, config.lrscheduler.name).from_dict(config) + callbacks = [lr_scheduler] + + ckpt_dir = expand_path_if_nni(config.callbacks.model_checkpoint.log_dir) + # if hvd.rank() == 0: + if config.callbacks.model_checkpoint.max_to_keep >= 1: + # use cb ModelCheckpoint before cb NNIReport so that the weights saved + # by the former get copied by the latter at the end of each epoch + callbacks.append( + ModelCheckpointMaxToKeep( + folder=ckpt_dir, + monitor=config.callbacks.monitor, + mode=config.callbacks.mode, + max_to_keep=config.callbacks.model_checkpoint.max_to_keep, + )) + callbacks.append( + NNIReport(nni_chosen_metric=config.callbacks.nni.nni_chosen_metric, + report_final_result=False)) + tensorboard_dir = expand_path_if_nni(config.callbacks.tb.log_dir) + callbacks.append(Tensorboard(log_dir=tensorboard_dir, update_freq=10)) + # profile_batch=0 required on mlp to show train on TensorBoard + # https://stackoverflow.com/questions/58300465/tensorboard-2-0-0-not-updating-train-scalars + # callbacks.append( + # tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir, + # update_freq=10, + # profile_batch=0)) + + asr_ctc_learner = ASRCTCLearner( + model=model, + optimizer=optimizer, + get_decoded_from_encoded=data_train.encoder.get_decoded_from_encoded, + fp16_allreduce=config.train.fp16_allreduce, + greedy_decoder=config.test.greedy_decoder, + beam_width=config.test.beam_width, + encoder_class=config.dataset.encoder.encoder_class) + + asr_ctc_learner.compile() + asr_ctc_learner.model._model.summary() + + history_fit = asr_ctc_learner.fit( + data_train.ds, + epochs=config.train.epochs, + steps_per_epoch=data_train.steps, + callbacks=callbacks, + validation_data=data_validate.ds, + validation_steps=data_validate.steps, + verbose=verbose, + ) + + # if hvd.rank() == 0: + tf.print(history_fit.history) + with open("./history_fit.pickle", "wb") as fp: + pickle.dump(history_fit.history, fp) + + if config.callbacks.model_checkpoint.max_to_keep >= 1: + # load best model seen so far for the test eval + lers = history_fit.history[config.callbacks.monitor] + best_model_epoch = lers.index(min(lers)) + checkpoint_path = ckpt_dir + '/cp-' + str(best_model_epoch).zfill(4) + '.ckpt' + asr_ctc_learner.model.load_weights(checkpoint_path) + + tmp = asr_ctc_learner.evaluate(data_test.ds, + verbose=verbose, + steps=data_test.steps, + return_dict=True) + + # if hvd.rank() == 0: + history_evaluate = {} + for key, val in tmp.items(): + history_evaluate['val_'+key] = val + if key == config.callbacks.nni.nni_chosen_metric: + history_evaluate["default"] = val + + if "nni" in sys.modules: + nni.report_intermediate_result(history_evaluate) + nni.report_final_result(history_evaluate['default']) + tf.print(history_evaluate) + with open("./history_evaluate.pickle", "wb") as fp: + pickle.dump(history_evaluate, fp) + + os.removedirs(ckpt_dir) + +if __name__ == "__main__": + # parser = argparse.ArgumentParser() + # parser.add_argument( + # '-f', '--config_file', + # type=str, + # required=True, + # help='Input configuration file name.' + # ) + # cmd_data, _ = parser.parse_known_args() + # config = read_yaml(cmd_data.config_file) + curr_dir = os.path.dirname(os.path.realpath(__file__)) + config = read_yaml(os.path.join(curr_dir, 'config.yaml')) + config.dataset = read_yaml(os.path.join(curr_dir, config.dataset.config)) + + if "nni" in sys.modules: + tuned_params = nni.get_next_parameter() + config = update_config_with_nni_tuned_params(config, tuned_params) + + main(config) diff --git a/nasbench_asr/training/tf/metrics/__init__.py b/nasbench_asr/training/tf/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nasbench_asr/training/tf/metrics/ctc.py b/nasbench_asr/training/tf/metrics/ctc.py new file mode 100644 index 0000000..c4ff8d6 --- /dev/null +++ b/nasbench_asr/training/tf/metrics/ctc.py @@ -0,0 +1,105 @@ +# pylint: skip-file +import os +import sys + +from nasbench_asr.quiet_tensorflow import tensorflow as tf + +from .roll import roll + + +def get_normalized_ctc_loss_without_reduce(*, logits_transposed, logits_size, + encodeds, encodeds_size): + ctc_loss_without_reduce = tf.nn.ctc_loss( + labels=encodeds, + logits=logits_transposed, + label_length=encodeds_size, + logit_length=logits_size, + logits_time_major=True, + blank_index=0, + ) + + # tf.nn.ctc_loss returns a tensor of shape [batch_size] with negative log + # probabilities, but each probability may have been computed with an + # argument with different length (which turn into sums, each with different + # number of summands in the case of independence). For this reason we + # divide each negative log probability by the logits_size + # replacing "logits_size" with "logits_size + 1" to avoid division by zero + ctc_loss_without_reduce /= tf.cast(logits_size + 1, + ctc_loss_without_reduce.dtype) + + ctc_loss_without_reduce = tf.debugging.check_numerics( + tensor=ctc_loss_without_reduce, + message="nan or inf in ctc_loss", + name="ctc_loss_without_reduce", + ) + + return ctc_loss_without_reduce + + +def get_normalized_ctc_loss(*, logits_transposed, logits_size, encodeds, + encodeds_size): + ctc_loss_without_reduce = get_normalized_ctc_loss_without_reduce( + logits_transposed=logits_transposed, + logits_size=logits_size, + encodeds=encodeds, + encodeds_size=encodeds_size, + ) + + # Finally, average across the samples of the batch + ctc_loss = tf.reduce_mean(ctc_loss_without_reduce) + + return ctc_loss + + +def get_logits_encodeds( + *, + logits_transposed, + logits_size, + greedy_decoder, + beam_width, +): + # Unlike tf.nn.ctc_loss, the functions + # tf.nn.ctc_greedy_decoder and tf.nn.ctc_beam_search_decoder don't have + # a parameter to signal which is the blank_index. In fact, in the + # tf.nn.ctc_greedy_decoder the documentation mentions that blank index + # (num_classes - 1) + + # To account for the fact that the text encoder + # https://www.tensorflow.org/datasets/api_docs/python/tfds/features/text/TextEncoder + # encodes to the range [1, + # vocab_size), and we took advantage of that by setting blank_index=0 + # in the get_normalized_ctc_loss, we now roll the logits_transposed + # with shift=-1, axis=-1, so that the blank_index is moved from the + # 0-th position to the last + logits_transposed = roll(logits_transposed) + + if greedy_decoder: + logits_encodeds, _ = tf.nn.ctc_greedy_decoder( + inputs=logits_transposed, + sequence_length=logits_size, + merge_repeated=True, + ) + else: + logits_encodeds, _ = tf.nn.ctc_beam_search_decoder( + inputs=logits_transposed, + sequence_length=logits_size, + beam_width=beam_width, + top_paths=1, + ) + logits_encodeds = logits_encodeds[0] + + # Given that the text encoder + # https://www.tensorflow.org/datasets/api_docs/python/tfds/features/text/TextEncoder + # encodes to and decodes from the range [1, vocab_size), we shift the + # output of the ctc decoder which is in the range [0, vocab_size - 1) + # to the correct range [1, vocab_size) by adding one each index + logits_encodeds = tf.sparse.SparseTensor( + indices=logits_encodeds.indices, + values=logits_encodeds.values + 1, + dense_shape=logits_encodeds.dense_shape, + ) + + logits_encodeds = tf.sparse.to_dense(logits_encodeds) + logits_encodeds = tf.cast(logits_encodeds, tf.int32) + + return logits_encodeds diff --git a/nasbench_asr/training/tf/metrics/ler.py b/nasbench_asr/training/tf/metrics/ler.py new file mode 100644 index 0000000..59f6cdb --- /dev/null +++ b/nasbench_asr/training/tf/metrics/ler.py @@ -0,0 +1,34 @@ +# pylint: skip-file +from nasbench_asr.quiet_tensorflow import tensorflow as tf + + +def count_non_zero_indices(x): + return tf.cast(tf.math.reduce_sum(tf.cast(x != 0, tf.int64), axis=1), + dtype=tf.float32) + + +def get_ler_numerator_denominator(*, encodeds, logits_encodeds): + """ + encodeds: vector of shape [batch_size, time_1] where each row corresponds + to a sample, whose zero elements come from padded_batch and whose non-zero + elements are in the range [1, vocab_size) which correspond to valid indices + from the tfds.features.text.SubwordTextEncoder + + logits_encodeds: vector of shape [batch_size, time_2] where each row + corresponds to a sample, whose zero elements come from padded_batch and + whose non-zero elements are in the range [1, vocab_size) which correspond + to valid indices from the tfds.features.text.SubwordTextEncoder + """ + ler_numerator = tf.edit_distance( + hypothesis=tf.sparse.from_dense(logits_encodeds), + truth=tf.sparse.from_dense(encodeds), + normalize=False, + name="ler_numerator", + ) + + encodeds_len = count_non_zero_indices(encodeds) + # logits_encodeds_len = count_non_zero_indices(logits_encodeds) + + ler_denominator = tf.identity(encodeds_len, name="ler_denominator") + + return ler_numerator, ler_denominator diff --git a/nasbench_asr/training/tf/metrics/ratio.py b/nasbench_asr/training/tf/metrics/ratio.py new file mode 100644 index 0000000..38256dd --- /dev/null +++ b/nasbench_asr/training/tf/metrics/ratio.py @@ -0,0 +1,41 @@ +# pylint: skip-file +# coding=utf-8 +import os +import sys +from nasbench_asr.quiet_tensorflow import tensorflow as tf + + +class Ratio(tf.keras.metrics.Metric): + def __init__(self, name="ratio", **kwargs): + super().__init__(name=name, **kwargs) + self.numerator = self.add_weight(name="numerator", + initializer="zeros", + dtype=tf.float32) + self.denominator = self.add_weight(name="denominator", + initializer="zeros", + dtype=tf.float32) + + def update_state(self, numerator_denominator): + numerator = numerator_denominator[0] + denominator = numerator_denominator[1] + self.numerator.assign_add( + tf.reduce_sum(tf.cast(numerator, dtype=tf.float32))) + self.denominator.assign_add( + tf.reduce_sum(tf.cast(denominator, dtype=tf.float32))) + + def result(self): + # use of / rather than tf.math.divide_no_nan is intentional + + return self.numerator / self.denominator + + # def reset_states(self): + # # reset across hvd works at the same time + # # hvd_util.apply_hvd_allreduce_np2np(0) + + # super().reset_states() + + # def sync_across_hvd_workers(self): + # # for var in self.variables: + # # tmp = hvd.size() * hvd_util.apply_hvd_allreduce_np2np(var.numpy()) + # # var.assign(tmp) + # pass diff --git a/nasbench_asr/training/tf/metrics/roll.py b/nasbench_asr/training/tf/metrics/roll.py new file mode 100644 index 0000000..3ee28c8 --- /dev/null +++ b/nasbench_asr/training/tf/metrics/roll.py @@ -0,0 +1,6 @@ +# pylint: skip-file +from nasbench_asr.quiet_tensorflow import tensorflow as tf + + +def roll(logits_transposed): + return tf.roll(logits_transposed, shift=-1, axis=-1) diff --git a/nasbench_asr/training/tf/metrics/wer.py b/nasbench_asr/training/tf/metrics/wer.py new file mode 100644 index 0000000..173350d --- /dev/null +++ b/nasbench_asr/training/tf/metrics/wer.py @@ -0,0 +1,41 @@ +# pylint: skip-file +from nasbench_asr.quiet_tensorflow import tensorflow as tf + + +def separate_sentences_into_words_fn(x): + return tf.strings.split(x, sep=" ") + + +def count_words_in_separated_sentences(x): + return tf.cast( + tf.sparse.reduce_sum( + tf.ragged.map_flat_values( + lambda y: tf.cast(y != "", dtype=tf.int64), x).to_sparse(), + axis=1, + ), + dtype=tf.float32, + ) + + +def get_wer_numerator_denominator(*, sentences, logits_sentences): + """ + sentences: vector of shape [batch_size] of type tf.string + + logits_sentences: vector of shape [batch_size] of type tf.string + """ + words = separate_sentences_into_words_fn(sentences) + logits_words = separate_sentences_into_words_fn(logits_sentences) + + wer_numerator = tf.edit_distance( + hypothesis=logits_words.to_sparse(), + truth=words.to_sparse(), + normalize=False, + name="wer_numerator", + ) + + words_len = count_words_in_separated_sentences(words) + # logits_words_len = count_words_in_separated_sentences(logits_words) + + wer_denominator = tf.identity(words_len, name="wer_denominator") + + return wer_numerator, wer_denominator diff --git a/nasbench_asr/training/tf/trainer.py b/nasbench_asr/training/tf/trainer.py new file mode 100644 index 0000000..42e1389 --- /dev/null +++ b/nasbench_asr/training/tf/trainer.py @@ -0,0 +1,521 @@ +import pickle +import pathlib +import functools +import collections.abc as cabc + +import numpy as np +from nasbench_asr.quiet_tensorflow import tensorflow as tf + +from .callbacks.tensorboard import Tensorboard +from .callbacks.lrscheduler import ExponentialDecay +from .callbacks.reset_states import ResetStatesCallback +from .metrics.ratio import Ratio +from .metrics.ler import get_ler_numerator_denominator +from .metrics.wer import get_wer_numerator_denominator +from .metrics.ctc import get_logits_encodeds, get_normalized_ctc_loss_without_reduce +from .datasets.timit_foldings import old_to_new_indices, get_phoneme_mapping + + +def get_logits_size(features, features_size, logits): + time_reduction = tf.cast(tf.shape(features)[1], + dtype=tf.float32) / tf.cast(tf.shape(logits)[1], + dtype=tf.float32) + logits_size = tf.cast(tf.cast(features_size, dtype=tf.float32) / + time_reduction, + dtype=features_size.dtype) + + return logits_size + + +def get_loss(): + def loss(logits, logits_size, encodeds, encodeds_size, metrics=None): + logits_transposed = tf.transpose(logits, [1, 0, 2]) + ctc_loss_without_reduce = get_normalized_ctc_loss_without_reduce( + logits_transposed=logits_transposed, + logits_size=logits_size, + encodeds=encodeds, + encodeds_size=encodeds_size, + ) + + ctc_loss_without_reduce_numerator = ctc_loss_without_reduce + ctc_loss_without_reduce_denominator = tf.ones_like(ctc_loss_without_reduce) + + if metrics is not None: + metrics.update({ + "ctc_loss": ( + ctc_loss_without_reduce_numerator, + ctc_loss_without_reduce_denominator, + ) + }) + + return tf.reduce_mean(ctc_loss_without_reduce) + + return loss + + +class Trainer(): + class RememberBestCallback(tf.keras.callbacks.Callback): + def __init__(self, trainer, checkpoint_name): + self.trainer = trainer + self.best_so_far = None + self.checkpoint_name = checkpoint_name + + def on_epoch_end(self, epoch, logs=None): + if 'val_ler' in logs: + value = logs['val_ler'] + if self.best_so_far is None or value <= self.best_so_far: + self.best_so_far = value + self.trainer.remember_best() + self.trainer.save(self.checkpoint_name) + else: + print('Missing validation LER') + + class SaveLatestCallback(tf.keras.callbacks.Callback): + def __init__(self, trainer, checkpoint_name): + self.trainer = trainer + self.checkpoint_name = checkpoint_name + + def on_epoch_end(self, epoch, logs=None): + self.trainer.save(self.checkpoint_name) + + def __init__(self, dataloaders, loss, gpus=None, save_dir=None, verbose=True): + encoder, data_train, data_validate, data_test = dataloaders + + self.encoder = encoder + self.data_train = data_train + self.data_validate = data_validate + self.data_test = data_test + + self.save_dir = save_dir + if self.save_dir: + pathlib.Path(self.save_dir).mkdir(exist_ok=True) + self.verbose = verbose + + self.model = None + self.optimizer = None + self.trackers = {} + self.loss = loss + + self.get_decoded_from_encoded = self.data_train.encoder.get_decoded_from_encoded + self.fp16_allreduce = True + self.greedy_decoder = False + self.beam_width = 12 + + self._best_weights = None + + if gpus is not None and (not isinstance(gpus, cabc.Sequence) or bool(gpus)): + if not isinstance(gpus, cabc.Sequence): + gpus = [gpus] + else: + gpus = [] + + if len(gpus) != 1: + raise ValueError('TF implementation only supports running on a single GPU') + + # + # API + # + + def train(self, model, epochs=40, lr=0.0001, reset=False, model_name=None): + metrics = { + "ctc_loss": Ratio, + "wer": Ratio, + "ler": Ratio + } + self.init_trackers(metrics) + + self.model = model + self.optimizer = tf.keras.optimizers.Adam(lr) + + # Adding learning rate scheduler callback + self.lr_scheduler = ExponentialDecay(0.9, start_epoch=5, min_lr=0.0, verbose=False) + + callbacks = [self.lr_scheduler] + + this_model_save_dir = None + if self.save_dir is not None: + this_model_save_dir = self.save_dir + if model_name is not None: + this_model_save_dir = this_model_save_dir / model_name + + latest_ckpt = this_model_save_dir / 'latest.ckpt' + best_ckpt = this_model_save_dir / 'best.ckpt' + tensorboard_dir = this_model_save_dir / 'tensorboard' + callbacks.append(Trainer.SaveLatestCallback(self, latest_ckpt)) + callbacks.append(Trainer.RememberBestCallback(self, best_ckpt)) + callbacks.append(Tensorboard(log_dir=tensorboard_dir, update_freq=10)) + + # TODO: is that enough to restore state? or maybe we should always start from the beginnin + if best_ckpt.exists(): + if reset: + best_ckpt.unlink() + else: + self.load(best_ckpt) + self.remember_best() + if latest_ckpt.exists(): + if reset: + latest_ckpt.unlink() + else: + self.load(latest_ckpt) + + self.compile() + if self.verbose: + self.model._model.summary() + + history_fit = self.fit( + self.data_train.ds, + epochs=epochs, + steps_per_epoch=self.data_train.steps, + callbacks=callbacks, + validation_data=self.data_validate.ds, + validation_steps=self.data_validate.steps, + verbose=self.verbose + ) + if self.verbose: + tf.print(history_fit.history) + + self.recall_best() + + test_res = self.evaluate(self.data_test.ds, + verbose=self.verbose, + steps=self.data_test.steps, + return_dict=True) + + history_evaluate = {} + for key, val in test_res.items(): + history_evaluate['val_'+key] = val + if self.verbose: + tf.print(history_evaluate) + + if self.save_dir: + with open(this_model_save_dir / 'scores.pickle', "wb") as fp: + pickle.dump(history_fit.history, fp) + with open(this_model_save_dir / 'test_scores.pickle', "wb") as fp: + pickle.dump(history_evaluate, fp) + + self.model = None + self.optimizer = None + self.trackers = {} + + def step(self, input, training=True): + if training: + return self._train_step(input) + else: + return self._test_step(input) + + def save(self, checkpoint): + self.model.save_weights(filepath=checkpoint, overwrite=True, save_format='tf') + + def load(self, checkpoint): + self.model.load_weights(checkpoint) + + def remember_best(self): + self._best_weights = self.model.get_weights() + + def recall_best(self): + self.model.set_weights(self._best_weights) + + # + # Implementation + # + + def init_trackers(self, metrics, track_modes=["train", "test"]): + """ + Initializing metric-trackers for e.g., training and validation datasets + Args: + metrics : A dictionary containing the tracker callable classes as values + track_mode : A list containing the dataset partitions, e.g., ["train", "val", "test"] + """ + self.trackers.clear() + + if metrics is None: + return + + assert isinstance(metrics, dict), "Metrics should be a dictionary with callable values" + assert isinstance(track_modes, list), "Expecting a list of tracking modes, e.g., [train, test]" + + for mode in track_modes: + self.trackers[mode] = {} + + for metric_name, metric_fn in metrics.items(): + if not callable(metric_fn): continue + + name = mode + "_" + metric_name + self.trackers[mode][metric_name] = metric_fn(name=name) + self.trackers[mode][metric_name].reset_states() + + def get_tracker_results(self): + """ + Prints all tracker result on screen + """ + results = {} + + # Looping over all trackers + for key in self.trackers.keys(): + metrics = self.trackers.get(key) + + for met in metrics.keys(): + val = metrics.get(met).result() + + # Only printing tracker results if not NaN + if not tf.math.is_nan(val): + results[f"{key}-{met}"] = val.numpy() + + return results + + def _train_step(self, data): + """ + A wrapper to update the train trackers + """ + logs = self.train_step(data) + + for key in logs.keys() & self.trackers["train"].keys(): + self.trackers["train"][key].update_state(logs[key]) + logs[key] = self.trackers["train"][key].result() + + # the returned logs will appear in the arguments to the methods of + # the classes inheriting from tf.keras.callbacks.Callback + + return logs + + def _test_step(self, data): + """ + A wrapper to update the test trackers + """ + logs = self.test_step(data) + + for key in logs.keys() & self.trackers["test"].keys(): + self.trackers["test"][key].update_state(logs[key]) + logs[key] = self.trackers["test"][key].result() + + # the returned logs will appear in the arguments to the methods of + # the classes inheriting from tf.keras.callbacks.Callback + + return logs + + def compile(self): + """ + Overrides the tf.keras.Model train_step/test_step functions and + compiles the model + """ + self.model.train_step = functools.partial(self.step, training=True) + self.model.test_step = functools.partial(self.step, training=False) + + self.model.compile(optimizer=self.optimizer) + + def fit(self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose=1, + callbacks=None, + validation_split=0.0, + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_batch_size=None, + validation_freq=1, + max_queue_size=10, + workers=1, + use_multiprocessing=False): + + """ + Wrapper for + https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit + method of self.model, which ensures that HorovodCallback is prepended + to callbacks before calling built-in fit. Parameters and default + values are the same as those in built-in fit. + """ + + # Adding ResetStateCallback as default + if callbacks is None: + callbacks = [] + callbacks.insert(0, ResetStatesCallback(self.trackers)) + + return self.model.fit(x=x, + y=y, + batch_size=batch_size, + epochs=epochs, + verbose=self.verbose, + callbacks=callbacks, + validation_split=validation_split, + validation_data=validation_data, + shuffle=shuffle, + class_weight=class_weight, + sample_weight=sample_weight, + initial_epoch=initial_epoch, + steps_per_epoch=steps_per_epoch, + validation_steps=validation_steps, + validation_batch_size=validation_batch_size, + validation_freq=validation_freq, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing) + + + def evaluate(self, + x=None, + y=None, + batch_size=None, + verbose=1, + sample_weight=None, + steps=None, + callbacks=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False, + return_dict=False): + + """ + Wrapper for + https://www.tensorflow.org/api_docs/python/tf/keras/Model#evaluate + method of self.model, which ensures that HorovodCallback is prepended + to callbacks before calling built-in evaluate. Parameters and default + values are the same as those in built-in evaluate. + """ + + # Adding ResetStateCallback as default + if callbacks is None: + callbacks = [] + callbacks.insert(0, ResetStatesCallback(self.trackers)) + + return self.model.evaluate(x=x, + y=y, + batch_size=batch_size, + verbose=self.verbose, + sample_weight=sample_weight, + steps=steps, + callbacks=callbacks, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + return_dict=return_dict) + + def train_step(self, data): + """ + The argument data represents what is yielded from tf.data.Dataset. It + is expected to be a tuple with four elements, namely: + + features, features_size, encodeds, encodeds_size = data + + where + + - features has shape [batch_size, time, channels], and is of type + tf.float32 + - features_size has shape [batch_size], and is of type tf.int32, and + represents the number of time frames per example in the batch + - encodeds has shape [batch_size, None], and is of type tf.int32, and + represents a text encoded version of the original sentence per + example in the batch; it contains values in the range [1, + encoder.vocab_size) + - encodeds_size has shape [batch_size], and is of type tf.int32, and + represents the number of tokens in each text encoded version of the + original sentence + + In all above batch_size and time and determined at run time, whereas + channels is defined at compile time + """ + + metrics = {} + + features, features_size, encodeds, encodeds_size = data + with tf.GradientTape() as tape: + logits = self.model(features, training=True) + logits_size = get_logits_size(features, features_size, logits) + ctc_loss = self.loss(logits, logits_size, encodeds, encodeds_size, metrics=metrics) + total_loss = tf.math.add_n([ctc_loss] + self.model.losses) + + # Horovod: (optional) compression algorithm. + # compression = hvd.Compression.fp16 if self.fp16_allreduce else hvd.Compression.none + # # Horovod: add Horovod Distributed GradientTape. + # tape = hvd.DistributedGradientTape(tape, compression=compression) + grads = tape.gradient(total_loss, self.model.trainable_variables) + grads, _ = tf.clip_by_global_norm(grads, 5.0) + grads = [ + tf.debugging.check_numerics(tensor=grad, + message="nan or inf in grad", + name="grad") for grad in grads + ] + self.model.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) + unused_trainable_variables = [ + tf.debugging.check_numerics(tensor=var, + message="nan or inf in train_var", + name="train_var") + for var in self.model.trainable_variables + ] + + + return metrics + + def test_step(self, data): + """ + The argument data represents what is yielded from tf.data.Dataset. It + is expected to be a tuple with four elements, namely: + + features, features_size, encodeds, encodeds_size = data + + where + + - features has shape [batch_size, time, channels], and is of type + tf.float32 + - features_size has shape [batch_size], and is of type tf.int32, and + represents the number of time frames per example in the batch + - encodeds has shape [batch_size, None], and is of type tf.int32, and + represents a text encoded version of the original sentence per + example in the batch; it contains values in the range [1, + encoder.vocab_size) + - encodeds_size has shape [batch_size], and is of type tf.int32, and + represents the number of tokens in each text encoded version of the + original sentence + + In all above batch_size and time and determined at run time, whereas + channels is defined at compile time + """ + + metrics = {} + + features, features_size, encodeds, encodeds_size = data + logits = self.model(features, training=False) + logits_size = get_logits_size(features, features_size, logits) + _ = self.loss(logits, logits_size, encodeds, encodeds_size, metrics=metrics) + logits_transposed = tf.transpose(logits, [1, 0, 2]) + logits_encodeds = get_logits_encodeds( + logits_transposed=logits_transposed, + logits_size=logits_size, + greedy_decoder=self.greedy_decoder, + beam_width=self.beam_width, + ) + # tfds.features.text.SubwordTextEncoder can only run on CPU + with tf.device("/CPU:0"): + sentences = tf.map_fn(self.encoder.get_decoded_from_encoded, + encodeds, + dtype=tf.string) + logits_sentences = tf.map_fn(self.encoder.get_decoded_from_encoded, + logits_encodeds, + dtype=tf.string) + + _, _, _, _, hash_table = get_phoneme_mapping(source_enc_name='p48', dest_enc_name='p39') + encodeds = old_to_new_indices(hash_table, encodeds) + logits_encodeds = old_to_new_indices(hash_table, logits_encodeds) + + wer_numerator, wer_denominator = get_wer_numerator_denominator( + sentences=sentences, logits_sentences=logits_sentences) + + ler_numerator, ler_denominator = get_ler_numerator_denominator( + encodeds=encodeds, logits_encodeds=logits_encodeds) + + metrics.update({ + "wer": (wer_numerator, wer_denominator), + "ler": (ler_numerator, ler_denominator), + }) + + return metrics + + +def get_trainer(*args, **kwargs): + return Trainer(*args, **kwargs) diff --git a/nasbench_asr/training/timit_folding.txt b/nasbench_asr/training/timit_folding.txt new file mode 100644 index 0000000..946f3be --- /dev/null +++ b/nasbench_asr/training/timit_folding.txt @@ -0,0 +1,61 @@ +aa aa aa +ae ae ae +ah ah ah +ao ao aa +aw aw aw +ax ax ah +ax-h ax ah +axr er er +ay ay ay +b b b +bcl vcl sil +ch ch ch +d d d +dcl vcl sil +dh dh dh +dx dx dx +eh eh eh +el el l +em m m +en en n +eng ng ng +epi epi sil +er er er +ey ey ey +f f f +g g g +gcl vcl sil +h# sil sil +hh hh hh +hv hh hh +ih ih ih +ix ix ih +iy iy iy +jh jh jh +k k k +kcl cl sil +l l l +m m m +n n n +ng ng ng +nx n n +ow ow ow +oy oy oy +p p p +pau sil sil +pcl cl sil +q +r r r +s s s +sh sh sh +t t t +tcl cl sil +th th th +uh uh uh +uw uw uw +ux uw uw +v v v +w w w +y y y +z z z +zh zh sh \ No newline at end of file diff --git a/nasbench_asr/training/timit_train_stats.npz b/nasbench_asr/training/timit_train_stats.npz new file mode 100644 index 0000000000000000000000000000000000000000..7a1d0f28e68aab282dc617bc9d2012bc42bb3362 GIT binary patch literal 1178 zcmd5+YfO`882x&I00Eh`W@<|9MN^bUQBgVu`W>@uQAFxB0>N6e3=j)z85!XcjL4=l z=;A$iC5U(Yp(I|gc^=1RVq{#DW$K(V66??vR?&^B(^g}=B>vg2oj330dERsWo#&X7 zWpbtP5?UczEsZ9LAOv=#pcGaW6zAGMx2&>R?Gx;UB@*F-7heH-fD*vwxr&E7SkeA*qx2RGiO%9UQacGyGtaqV>TW(ReB*-CRASE)_=GYNmSQEqw* zr7B&t&=f^*y;hP@QqJe*?`S=r>);mal>U?yrbC8CH`)KX9Cp7B(Giur*VMU{q zH|5KCdasPHtK=MVTgs|#88<3q9N#JDTNi>@Ru;lJy`g;PyD%>A59O$A@y>lyKhqMp8tv&oPtq zy#6O0f6=AoirG3IFzWf_YCS*tTgP*bY1#ai0W_(ts#gXW_Syi&R!6SYzTEbIH9^nj z4AFZiL#$jMg`?}D;I8`wJ6v<1_%RKocUL1Kwh9}*UW2~KYTSvg#-8{Zj8N``GpP<& z^bJ`2?*W8dXh2%gejIJ8g|DI(k_TI%oKb^BTQzPEDTh1OiIFY&xRjEI8=gEId76Vm z?p!SOWMRpWEGWlXa52e(?(j_1m!xCz?!`!LUx+31=A&<43S#bk2GcTN9VNni)`Vt* z2~R6Ng7;n$+GskqWG7;lYbv5jCgQ8LF_^t~92y#9U~HO%mAm3GIPqhgxNJgMQXFc^ zCScEf1HPG}NBMRw8cvOdu||dO6(6E+-cYDpB9JMB;fN&|qC$#3g@A=ckHuYu55-wy z`o&UPuUM7f78`SVM9UA|qB&V2Rm$F%11|?Xa8|w;`gc;{UGdu= self.num_classes: + return encodeds + if num_classes not in PhonemeEncoder.all_encodings: + raise ValueError(num_classes) + + new_class_idx = PhonemeEncoder.all_encodings.index(num_classes) + for old_idx, new_idx in self.idx_mappings[self.class_idx][new_class_idx].items(): + encodeds[encodeds == old_idx] = new_idx + + return encodeds + + + def encode(self, phonemes): + phonemes_folded = self._fold(phonemes) + enc = [self.encodeds[self.class_idx].index(p)+1 if p else 0 for p in phonemes_folded] #start from 1, 0 is used for blank + return enc + + def decode(self, encodeds): + return [self.encodeds[self.class_idx][idx-1] if idx else '' for idx in encodeds] + diff --git a/nasbench_asr/training/torch/timit.py b/nasbench_asr/training/torch/timit.py new file mode 100644 index 0000000..38f8feb --- /dev/null +++ b/nasbench_asr/training/torch/timit.py @@ -0,0 +1,131 @@ +import pathlib + +import torch +import torchaudio +import torchvision +import numpy as np + +from .encoder import PhonemeEncoder + + +torchaudio.set_audio_backend('sox_io') + + +class TimitDataset(torch.utils.data.Dataset): + def __init__(self, root_folder, encoder, subset='TRAIN', ignore_sa=True, transforms=None): + root = pathlib.Path(root_folder).expanduser() + wavs = list(root.rglob(f'{subset}/**/*.RIFF.WAV')) + wavs = sorted(wavs) + if ignore_sa: + wavs = [w for w in wavs if not w.name.startswith('SA')] + phonemes = [(f.parent / f.stem).with_suffix('.PHN') for f in wavs] + + self.audio = [] + self.audio_len = [] + for wav in wavs: + tensor, sample_rate = torchaudio.load(str(wav)) + self.audio.append(tensor) + self.audio_len.append(tensor.shape[1] / sample_rate) + + def load_sentence(f): + lines = f.read_text().strip().split('\n') + last = [l.rsplit(' ', maxsplit=1)[-1] for l in lines] + last = encoder.encode(last) + return last + + self.root_folder = root_folder + self.encoder = encoder + self.sentences = [load_sentence(f) for f in phonemes] + self.transforms = transforms + + assert len(self.audio) == len(self.sentences) + + def __len__(self): + return len(self.audio) + + def __getitem__(self, idx): + audio = self.audio[idx] + sentence = self.sentences[idx] + if self.transforms is not None: + audio = self.transforms(audio) + return audio, sentence + + def get_indices_shorter_than(self, time_limit): + return [i for i, audio_len in enumerate(self.audio_len) if time_limit is None or audio_len < time_limit] + + +def pad_sequence_bft(sequences, extra=0, padding_value=0.0): + batch_size = len(sequences) + leading_dims = sequences[0].shape[:-1] + max_t = max([s.shape[-1]+extra for s in sequences]) + + out_dims = (batch_size, ) + leading_dims + (max_t, ) + + out_tensor = sequences[0].new_full(out_dims, padding_value) + for i, tensor in enumerate(sequences): + length = tensor.shape[-1] + out_tensor[i, ..., :length] = tensor + + return out_tensor + + +def pad_sentences(sequences, padding_value=0.0): + max_t = max([len(s) for s in sequences]) + sequences = [s+[0]*(max_t-len(s)) for s in sequences] + return sequences + + +def get_normalize_fn(part_name, eps=0.001): + stats = np.load(pathlib.Path(__file__).parents[1].joinpath(f'timit_train_stats.npz')) + mean = stats['moving_mean'][None,:,None] + variance = stats['moving_variance'][None,:,None] + def normalize(audio): + return (audio - mean) / (variance + eps) + return normalize + + +def get_dataloaders(timit_root, batch_size): + encoder = PhonemeEncoder(48) + + def get_transforms(part_name): + transforms = torchvision.transforms.Compose([ + torchaudio.transforms.MelSpectrogram(sample_rate=16000, win_length=400, hop_length=160, n_mels=80), + torch.log, + get_normalize_fn(part_name) + ]) + + return transforms + + def collate_fn(batch): + audio = [b[0][0] for b in batch] + audio_lengths = [a.shape[-1] for a in audio] + sentence = [b[1] for b in batch] + sentence_lengths = [len(s) for s in sentence] + audio = pad_sequence_bft(audio, extra=0, padding_value=0.0) + sentence = pad_sentences(sentence, padding_value=0.0) + return (audio, torch.tensor(audio_lengths)), (torch.tensor(sentence, dtype=torch.int32), torch.tensor(sentence_lengths)) + + + subsets = ['TRAIN', 'VAL', 'TEST'] + datasets = [TimitDataset(timit_root, encoder, subset=s, ignore_sa=True, transforms=get_transforms(s)) for s in subsets] + train_sampler = torch.utils.data.SubsetRandomSampler(datasets[0].get_indices_shorter_than(None)) + loaders = [torch.utils.data.DataLoader(d, batch_size=batch_size, sampler=train_sampler if not i else None, pin_memory=True, collate_fn=collate_fn) for i, d in enumerate(datasets)] + return (encoder, *loaders) + + +def set_time_limit(loader, time_limit): + db = loader.dataset + sampler = loader.sampler + sampler.indices = db.get_indices_shorter_than(time_limit) + + +if __name__ == '__main__': + import pprint + train_load, val_load, test_load = get_dataloaders('TIMIT', 3) + for (audio, lengths), sentence in train_load: + print(audio.shape, audio) + print() + print(lengths.shape, lengths) + print() + print(sentence) + break diff --git a/nasbench_asr/training/torch/trainer.py b/nasbench_asr/training/torch/trainer.py new file mode 100644 index 0000000..e7ffa5a --- /dev/null +++ b/nasbench_asr/training/torch/trainer.py @@ -0,0 +1,268 @@ +import pathlib +import collections.abc as cabc + +import tqdm +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_edit_distance as ed +from ctcdecode import CTCBeamDecoder + +from ...model.torch.ops import PadConvRelu +from ...model import print_model_summary +from .timit import set_time_limit + + +class AvgMeter(): + def __init__(self): + self.reset() + + def reset(self): + self.avg = 0 + self.n = 0 + + def update(self, a): + if not self.n: + self.avg = a + self.n = 1 + else: + self.avg = self.avg * (self.n / (self.n + 1)) + (a / (self.n + 1)) + self.n += 1 + + def get(self): + return self.avg + + +def get_loss(): + def loss(output, output_len, targets, targets_len): + output_trans = output.permute(1, 0, 2) # needed by the CTCLoss + loss = F.ctc_loss(output_trans, targets, output_len, targets_len, reduction='none', zero_infinity=True) + loss /= output_len + loss = loss.mean() + return loss + + return loss + + +class Trainer(): + def __init__(self, dataloaders, loss, gpus=None, save_dir=None, verbose=True): + #we don't use config param, it is just to have consistent api + encoder, train_load, valid_load, test_load = dataloaders + + self.encoder = encoder + self.train_load = train_load + self.valid_load = valid_load + self.test_load = test_load + self.gpus = gpus + self.save_dir = pathlib.Path(save_dir) if save_dir else save_dir + if self.save_dir: + self.save_dir.mkdir(exist_ok=True) + self.verbose = verbose + + self.loss = loss + if self.gpus is not None and (not isinstance(self.gpus, cabc.Sequence) or bool(self.gpus)): + if not isinstance(gpus, cabc.Sequence): + self.gpus = [self.gpus] + + self.device = torch.device(f'cuda:{gpus[0]}') + else: + self.device = torch.device('cpu') + + self.decoder = CTCBeamDecoder(encoder.get_vocab(inc_blank=True), beam_width=12, log_probs_input=True) + + self.model = None + self._model = None + self.lr = None + self.optimizer = None + self.scheduler = None + self._best_weights = None + + def train(self, model, epochs=40, lr=0.0001, reset=False, model_name=None): + self.model = model + self._model = model + self.lr = lr + self.optimizer = torch.optim.Adam(model.parameters(), lr=lr, eps=1e-07) + self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, 0.9) + + if self.verbose: + print_model_summary(self.model) + + self.model.to(device=self.device) + if len(self.gpus) > 1: + self.model = torch.nn.DataParallel(self.model, self.gpus) + + epoch = 0 + best_val = None + val_scores = [] + + this_model_save_dir = None + if self.save_dir: + this_model_save_dir = pathlib.Path(self.save_dir) + if model_name is not None: + this_model_save_dir = this_model_save_dir / str(model_name) + + this_model_save_dir.mkdir(exist_ok=True) + + latest_ckpt = this_model_save_dir / 'latest.ckpt' + best_ckpt = this_model_save_dir / 'best.ckpt' + + # TODO: is that enough to restore state? or maybe we should always start from the beginnin + if best_ckpt.exists(): + if reset: + best_ckpt.unlink() + else: + self.load(best_ckpt) + self.remember_best() + if latest_ckpt.exists(): + if reset: + latest_ckpt.unlink() + else: + self.load(latest_ckpt) + + loss_tracker = AvgMeter() + per_tracker = AvgMeter() + + warmup_limits = [1.0, 1.0, 2.0, 2.0] + warmup = 0 + + while epoch < epochs: + if warmup < len(warmup_limits): + set_time_limit(self.train_load, warmup_limits[warmup]) + else: + set_time_limit(self.train_load, None) + + loss_tracker.reset() + self.model.train() + with tqdm.tqdm(self.train_load) as pbar: + for train_input in pbar: + pbar.set_description(f'Avg. loss: {loss_tracker.get():.4f}') + loss, *_ = self.step(train_input, training=True) + loss_tracker.update(loss.item()) + + if self.verbose: + print(f'{"Warmup e" if warmup < len(warmup_limits) else "E"}poch {warmup+1 if warmup < len(warmup_limits) else epoch+1}: average loss: {loss_tracker.get():.4f}') + + if warmup < len(warmup_limits): + warmup += 1 + else: + loss_tracker.reset() + per_tracker.reset() + self.model.eval() + for val_input in self.valid_load: + loss, logits, logits_len = self.step(val_input, training=False) + per = self.decode(logits, logits_len, val_input) + loss_tracker.update(loss.item()) + per_tracker.update(per.item()) + + val_loss = loss_tracker.get() + val_per = per_tracker.get() + val_scores.append((val_loss, val_per)) + + if self.verbose: + print(f'Epoch {epoch+1}: average val loss: {val_loss:.4f}, average val per: {val_per:.4f}') + + is_best = False + epoch += 1 + if best_val is None or val_per < best_val: + is_best = True + best_val = val_per + + if is_best: + if self.verbose: + print(f' Best model, saving...') + self.remember_best() + + if epoch >= 5: # ignore epochs with time limits + self.scheduler.step() + + if self.save_dir: + self.save(latest_ckpt) + if is_best: + self.save(best_ckpt) + + if self.verbose: + print('Performing final test') + + self.recall_best() + loss_tracker.reset() + per_tracker.reset() + self.model.eval() + for test_input in self.test_load: + loss, logits, logits_len = self.step(test_input, training=False) + per = self.decode(logits, logits_len, test_input) + loss_tracker.update(loss.item()) + per_tracker.update(per.item()) + + test_loss = loss_tracker.get() + test_per = per_tracker.get() + + self.model = None + self._model = None + self.lr = None + self.optimizer = None + self.scheduler = None + self._best_weights = None + + return val_scores, test_loss, test_per + + def step(self, inputs, training): + (audio, audio_len), (targets, targets_len) = inputs + audio = audio.to(device=self.device) + audio_len = audio_len.to(device=self.device) + targets = targets.to(device=self.device) + targets_len = targets_len.to(device=self.device) + + if training: + self.optimizer.zero_grad() + output = self.model(audio) + output = F.log_softmax(output, dim=2) + output_len = audio_len // 4 + loss = self.loss(output, output_len, targets, targets_len) + _regu_loss = loss + 0.01 * sum(torch.norm(l.conv.weight) for l in self._model.modules() if isinstance(l, PadConvRelu)) + if training: + _regu_loss.backward() + nn.utils.clip_grad_norm_(self.model.parameters(), 5) + self.optimizer.step() + + return loss.detach(), output.detach(), output_len.detach() + + def decode(self, output, output_len, val_inputs): + _, (targets, targets_len) = val_inputs + targets = targets.to(device=self.device, dtype=torch.int) + targets_len = targets_len.to(device=self.device, dtype=torch.int) + + targets = self.encoder.fold_encoded(targets, 39) + + beams, _, _, beams_len = self.decoder.decode(output, output_len) + top_beams = beams[:,0].to(device=self.device, dtype=torch.int) + top_beams_len = beams_len[:,0].to(device=self.device, dtype=torch.int) + + top_beams = self.encoder.fold_encoded(top_beams, 39) + + blank = torch.Tensor([0]).to(device=self.device, dtype=torch.int) + sep = torch.Tensor([]).to(device=self.device, dtype=torch.int) + + per = ed.compute_wer(top_beams, targets, top_beams_len, targets_len, blank, sep) + per = per.mean() + return per + + def save(self, ckpt_name): + torch.save({ + 'model': self._model.state_dict(), + 'optim': self.optimizer.state_dict() + }, str(ckpt_name)) + + def load(self, ckpt_name): + state = torch.load(str(ckpt_name), map_location=self.device) + self._model.load_state_dict(state['model']) + self.optimizer.load_state_dict(state['optim']) + + def remember_best(self): + self._best_weights = self._model.state_dict() + + def recall_best(self): + self._model.load_state_dict(self._best_weights) + + +def get_trainer(*args, **kwargs): + return Trainer(*args, **kwargs) diff --git a/nasbench_asr/utils.py b/nasbench_asr/utils.py new file mode 100644 index 0000000..3177f63 --- /dev/null +++ b/nasbench_asr/utils.py @@ -0,0 +1,175 @@ +import sys +import pathlib +import importlib +import collections + + +class LazyModule(): + def __init__(self, module): + self.module = module + + def __repr__(self): + return repr(self.module) + + def __getattr__(self, name): + return getattr(self.module, name) + + +def add_module_properties(module_name, properties): + module = sys.modules[module_name] + replace = False + if isinstance(module, LazyModule): + lazy_type = type(module) + else: + lazy_type = type('LazyModule({})'.format(module_name), (LazyModule,), {}) + replace = True + + for name, prop in properties.items(): + setattr(lazy_type, name, prop) + + if replace: + sys.modules[module_name] = lazy_type(module) + + +class staticproperty(property): + def __init__(self, fget=None, fset=None, fdel=None, doc=None): + if fget is not None and not isinstance(fget, staticmethod): + raise ValueError('fget should be a staticmethod') + if fset is not None and not isinstance(fset, staticmethod): + raise ValueError('fset should be a staticmethod') + if fdel is not None and not isinstance(fdel, staticmethod): + raise ValueError('fdel should be a staticmethod') + super().__init__(fget, fset, fdel, doc) + + def __get__(self, inst, cls=None): + if inst is None: + return self + if self.fget is None: + raise AttributeError("unreadable attribute") + return self.fget.__get__(inst, cls)() # pylint: disable=no-member + + def __set__(self, inst, val): + if self.fset is None: + raise AttributeError("can't set attribute") + return self.fset.__get__(inst)(val) # pylint: disable=no-member + + def __delete__(self, inst): + if self.fdel is None: + raise AttributeError("can't delete attribute") + return self.fdel.__get__(inst)() # pylint: disable=no-member + + +# utils to work with nested collections +def recursive_iter(seq): + ''' Iterate over elements in seq recursively (returns only non-sequences) + ''' + if isinstance(seq, collections.abc.Sequence): + for e in seq: + for v in recursive_iter(e): + yield v + else: + yield seq + + +def flatten(seq): + ''' Flatten all nested sequences, returned type is type of ``seq`` + ''' + return list(recursive_iter(seq)) + + +def copy_structure(data, shape): + ''' Put data from ``data`` into nested containers like in ``shape``. + This can be seen as "unflatten" operation, i.e.: + seq == copy_structure(flatten(seq), seq) + ''' + d_it = recursive_iter(data) + + def copy_level(s): + if isinstance(s, collections.abc.Sequence): + return type(s)(copy_level(ss) for ss in s) + else: + return next(d_it) + return copy_level(shape) + + +def count(seq): + ''' Count elements in ``seq`` in a streaming manner. + ''' + ret = 0 + for _ in seq: + ret += 1 + return ret + + +def get_first_n(seq, n): + ''' Get first ``n`` elements of ``seq`` in a streaming manner. + ''' + c = 0 + i = iter(seq) + while c < n: + yield next(i) + c += 1 + + +class BackendsAccessor(): + def __init__(self, parent_module_init, parent_module_name): + self.parent_module_path = pathlib.Path(parent_module_init).parent + self.parent_module_name = parent_module_name + self.backends = {} + self.available_backends = [d.name for d in self.parent_module_path.iterdir() if d.is_dir()] + + def _check_backend(self, backend): + if backend == 'tf': + try: + from nasbench_asr.quiet_tensorflow import tensorflow as _ + except ImportError as e: + raise ImportError('Tensorflow backend not available') from e + elif backend == 'torch': + try: + import torch as _ + except ImportError as e: + raise ImportError('PyTorch backend not available') from e + else: + raise ValueError(f'Unknown backend: {backend}') + + def _deduce_backend(self): + try: + self._check_backend('tf') + return 'tf' + except ImportError: + pass + + try: + self._check_backend('torch') + return 'torch' + except ImportError: + pass + + raise ImportError('Neither tensorflow nor torch package could not be imported - at least one should be available to train/create models') + + def get_backend(self, backend, set_default=False): + if backend in self.backends: + return self.backends[backend] + + is_none = False + if backend is None: + backend = self._deduce_backend() + is_none = True + else: + self._check_backend(backend) + + backend_impl = importlib.import_module(f'.{backend}', self.parent_module_name) + self.backends[backend] = backend_impl + if is_none or set_default: + self.backends[None] = backend_impl + return backend_impl + + +def make_nice_number(num): + n = str(num) + parts = (len(n)-1)//3 + 1 + if parts == 1: + return n + offset = len(n)%3 or 3 + breaks = [0] + [offset + i*3 for i in range(parts)] + [len(n)] + return ','.join(n[breaks[i]:breaks[i+1]] for i in range(parts)) diff --git a/nasbench_asr/version.py b/nasbench_asr/version.py new file mode 100644 index 0000000..85dbe46 --- /dev/null +++ b/nasbench_asr/version.py @@ -0,0 +1,36 @@ +version = '0.1.0.dev0' +repo = 'unknown' +commit = 'unknown' +has_repo = False + +try: + import git + from pathlib import Path + + try: + r = git.Repo(Path(__file__).parents[1]) + has_repo = True + + if not r.remotes: + repo = 'local' + else: + repo = r.remotes.origin.url + + commit = r.head.commit.hexsha + if r.is_dirty(): + commit += ' (dirty)' + except git.InvalidGitRepositoryError: + raise ImportError() +except ImportError: + pass + +try: + from . import _dist_info as info + assert not has_repo, '_dist_info should not exist when repo is in place' + assert version == info.version + repo = info.repo + commit = info.commit +except ImportError: + pass + +__all__ = ['version', 'repo', 'commit', 'has_repo'] diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..5f01e6c --- /dev/null +++ b/setup.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python + +from setuptools import setup, find_packages +from setuptools.command.build_py import build_py + +import sys +import importlib +import importlib.util +from pathlib import Path + +package_name = 'nasbench_asr' + +version_file = Path(__file__).parent.joinpath(package_name, 'version.py') +spec = importlib.util.spec_from_file_location('{}.version'.format(package_name), version_file) +package_version = importlib.util.module_from_spec(spec) +spec.loader.exec_module(package_version) +sys.modules[spec.name] = package_version + + +class build_maybe_inplace(build_py): + def run(self): + global package_version + package_version = importlib.reload(package_version) + _dist_file = version_file.parent.joinpath('_dist_info.py') + assert not _dist_file.exists() + _dist_file.write_text('\n'.join(map(lambda attr_name: attr_name+' = '+repr(getattr(package_version, attr_name)), package_version.__all__)) + '\n') + return super().run() + + +setup(name='NasbenchASR', + version=package_version.version, + description='Library for the NasbenchASR dataset', + author='SAIC-Cambridge, On-Device Team', + author_email='on.device@samsung.com', + url='https://github.sec.samsung.net/a-mehrotra1/pytorch-asr', + download_url='https://github.sec.samsung.net/a-mehrotra1/pytorch-asr', + python_requires='>=3.6.0', + setup_requires=[ + 'git-python' + ], + install_requires=[ + 'tqdm', + 'numpy', + 'tensorflow', + 'torch==1.7.0', + 'torchaudio==0.7.0', + 'git-python', + 'networkx>=2.5', + 'ctcdecode @ git+https://github.com/parlance/ctcdecode@9a20e00f34d8f605f4a8501cc42b1a53231f1597', + 'torch-edit-distance' + ], + dependency_links=[ + ], + packages=find_packages(where='.', include=[ 'nasbench_asr', 'nasbench_asr.*' ]), + package_dir={ '': '.' }, + data_files=[], + cmdclass={ + 'build_py': build_maybe_inplace + } +) diff --git a/train.py b/train.py new file mode 100644 index 0000000..1c19fcb --- /dev/null +++ b/train.py @@ -0,0 +1,58 @@ +import pathlib +import argparse + +from nasbench_asr import set_default_backend, get_backend_name, set_seed, prepare_devices, get_model, get_dataloaders, get_trainer, get_loss + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('model', type=int, nargs=9) + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--epochs', type=int, default=40) + parser.add_argument('--data', type=str, default='TIMIT') + parser.add_argument('--rnn', type=bool, default=True) + parser.add_argument('--exp_folder', type=str, default='results') + parser.add_argument('--exp_name', type=str, default=None) + parser.add_argument('--backend', type=str, default=None) + parser.add_argument('--lr', type=float, default=0.0001) + parser.add_argument('--dropout', type=float, default=0.2) + parser.add_argument('--gpus', type=int, nargs='+', default=[0]) + parser.add_argument('--reset', action='store_true') + parser.add_argument('--seed', type=int, default=1235) + args = parser.parse_args() + + flat_model = tuple(map(str, args.model)) + args.model = [args.model[0:2], args.model[2:5], args.model[5:9]] + + if not args.exp_name: + args.exp_name = '_'.join(flat_model) + f'_b{args.batch_size}_rnn{int(args.rnn)}' + + set_default_backend(args.backend) + set_seed(args.seed) + prepare_devices(args.gpus) + + args.backend = get_backend_name()[0] + + print(f'Using backend: {get_backend_name()}') + print(f' Model vec: {args.model}') + print(f' Training for {args.epochs} epochs') + print(f' Batch size: {args.batch_size}') + print(f' Learning rate: {args.lr}') + print(f' Dropout: {args.dropout}') + print(f' GPUs: {args.gpus}') + + results_folder = pathlib.Path(args.exp_folder) / args.backend + + first_gpu = None + if args.gpus: + first_gpu = args.gpus[0] + + dataloaders = get_dataloaders(args.data, batch_size=args.batch_size) + loss = get_loss() + model = get_model(args.model, use_rnn=args.rnn, dropout_rate=args.dropout, gpu=first_gpu) + trainer = get_trainer(dataloaders, loss, gpus=args.gpus, save_dir=results_folder, verbose=True) + trainer.train(model, epochs=args.epochs, lr=args.lr, reset=args.reset, model_name=args.exp_name) + + +if __name__ == "__main__": + main()