-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding efficientnet spectrogram models
new file: efficientnet/__init__.py new file: efficientnet/efficientnet2d.py
- Loading branch information
Showing
2 changed files
with
46 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .efficientnet2d import EfficientNet2d |
45 changes: 45 additions & 0 deletions
45
torchsig/models/spectrogram_models/efficientnet/efficientnet2d.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import timm | ||
from torch.nn import Linear | ||
|
||
from torchsig.models.model_utils.model_utils_1d.conversions_to_1d import convert_2d_model_to_1d | ||
|
||
__all__ = ["EfficientNet1d"] | ||
|
||
def EfficientNet2d( | ||
input_channels: int, | ||
n_features: int, | ||
efficientnet_version: str = "b0", | ||
drop_path_rate: float = 0.2, | ||
drop_rate: float = 0.3, | ||
): | ||
"""Constructs and returns a 1d version of the EfficientNet model described in | ||
`"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" <https://arxiv.org/abs/1905.11946>`_. | ||
Args: | ||
input_channels (int): | ||
Number of 1d input channels; e.g., common practice is to split complex number time-series data into 2 channels, representing the real and imaginary parts respectively | ||
n_features (int): | ||
Number of output features; should be the number of classes when used directly for classification | ||
efficientnet_version (str): | ||
Specifies the version of efficientnet to use. See the timm efficientnet documentation for details. Examples are 'b0', 'b1', and 'b4' | ||
drop_path_rate (float): | ||
Drop path rate for training | ||
drop_rate (float): | ||
Dropout rate for training | ||
""" | ||
#mdl = #convert_2d_model_to_1d( | ||
mdl = timm.create_model( | ||
"efficientnet_" + efficientnet_version, | ||
in_chans=input_channels, | ||
drop_path_rate=drop_path_rate, | ||
drop_rate=drop_rate, | ||
) | ||
#) | ||
mdl.classifier = Linear(mdl.classifier.in_features, n_features) | ||
return mdl |