Skip to content

Commit

Permalink
Final PyTorch 1.8 release - freeze requirements - remove class_dict.k…
Browse files Browse the repository at this point in the history
…eys() from kws20 (#324)

* Final PyTorch 1.8 release - freeze requirements, update README
* Remove .keys() from kws20.py
* Remove noise_type argument from kws20 get_datasets()
  • Loading branch information
rotx-eva authored Jul 5, 2024
1 parent ba6c02b commit be7d422
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 26 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# ADI MAX78000/MAX78002 Model Training and Synthesis

June 25, 2024
July 2, 2024

**Note: This branch is compatible with PyTorch 1.8. Please go to the “pytorch-2” branch for PyTorch 2.3 compatibility.**
**Note: This archived branch is compatible with PyTorch 1.8. Please go to the “develop” branch for PyTorch 2 compatibility.**

ADI’s MAX78000/MAX78002 project is comprised of five repositories:

Expand Down
Binary file modified README.pdf
Binary file not shown.
12 changes: 4 additions & 8 deletions datasets/kws20.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,15 +487,13 @@ def __filter_dtype(self):
def __filter_classes(self):
initial_new_class_label = len(self.class_dict)
new_class_label = initial_new_class_label
self.new_class_dict = {}
for c in self.classes:
if c not in self.class_dict:
if c == '_unknown_':
continue
raise ValueError(f'Class {c} not found in data')
num_elems = (self.targets == self.class_dict[c]).cpu().sum()
print(f'Class {c} (# {self.class_dict[c]}): {num_elems} elements')
self.new_class_dict[c] = new_class_label
self.targets[(self.targets == self.class_dict[c])] = new_class_label
new_class_label += 1

Expand All @@ -504,10 +502,6 @@ def __filter_classes(self):
self.targets[(self.targets < initial_new_class_label)] = new_class_label
self.targets -= initial_new_class_label

self.new_class_dict = {c: self.new_class_dict[c] - initial_new_class_label
for c in self.new_class_dict.keys()}
self.new_class_dict['_unknown_'] = len(self.new_class_dict)

def __filter_librispeech(self):

print('Filtering out librispeech elements...')
Expand Down Expand Up @@ -1002,7 +996,6 @@ def KWS_35_get_unquantized_datasets(data, load_train=True, load_test=True):

def KWS_20_msnoise_mixed_get_datasets(data, load_train=True, load_test=True,
apply_prob=0.8, snr_range=(-5, 10),
noise_type=MSnoise.class_dict.keys(),
desired_probs=None):
"""
Returns the KWS dataset mixed with MSnoise dataset. Only training set will be mixed
Expand All @@ -1016,6 +1009,8 @@ def KWS_20_msnoise_mixed_get_datasets(data, load_train=True, load_test=True,
noise_type --> All noise types in the noise dataset.
"""

noise_type = MSnoise.class_dict

if len(snr_range) > 1:
snr_range = range(snr_range[0], snr_range[1])
else:
Expand Down Expand Up @@ -1058,14 +1053,15 @@ def KWS_12_benchmark_get_datasets(data, load_train=True, load_test=True):

def MixedKWS_20_get_datasets_10dB(data, load_train=True, load_test=True,
apply_prob=1, snr_range=tuple([10]),
noise_type=MSnoise.class_dict.keys(),
desired_probs=None):
"""
Returns the mixed KWS dataset with MSnoise dataset under 10 dB SNR using signalmixer
data loader. All of the training and test data will be augmented with
additional noise.
"""

noise_type = MSnoise.class_dict

if len(snr_range) > 1:
snr_range = range(snr_range[0], snr_range[1])
else:
Expand Down
32 changes: 16 additions & 16 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
torch==1.8.1
torchaudio==0.8.1
torchvision==0.9.1
GitPython>=3.1.18
Pillow>=7
PyYAML>=5.1.1
albumentations>=1.3.0
GitPython==3.1.43
Pillow==10.4.0
PyYAML==6.0.1
albumentations==1.3.1
faiss-cpu==1.7.4
batch-face>=1.4.0
h5py>=3.7.0
batch-face==1.4.0
h5py==3.11.0
kornia==0.6.8
librosa>=0.7.2
numba<0.50.0
numpy>=1.22,<1.23
opencv-python>=4.4.0
protobuf>=3.20.1,<4.0
librosa==0.9.2
numba==0.49.1
numpy==1.22.4
opencv-python==4.10.0.84
protobuf==3.20.3
pycocotools==2.0.7
pyffmpeg==2.0
pytorch-metric-learning==2.3.0
pytube>=12.1.3
scipy>=1.3.0
shap>=0.34.0
tensorboard>=2.9.0,<2.10.0
tk>=0.1.0
pytube==15.0.0
scipy==1.10.1
shap==0.44.1
tensorboard==2.9.0
tk==0.1.0
torchmetrics==0.6.0
-e distiller

0 comments on commit be7d422

Please sign in to comment.