-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
22 changed files
with
361 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from abc import ABC | ||
from typing import Optional, Union, Sequence, Dict | ||
|
||
from torch import Tensor | ||
from torch.nn import BCEWithLogitsLoss | ||
from torch.optim import Adam | ||
from torchmetrics import Accuracy, F1Score, AUROC | ||
|
||
from ..neko_model import NekoModel | ||
|
||
|
||
class BinaryClassifier(NekoModel, ABC): | ||
|
||
def __init__(self, model=None, learning_rate: float = 1e-4, distributed: bool = False): | ||
super().__init__() | ||
self.save_hyperparameters() | ||
self.model = model | ||
self.learning_rate = learning_rate | ||
self.distributed = distributed | ||
self.loss_fn = BCEWithLogitsLoss() | ||
self.acc_fn = Accuracy(task="binary") | ||
self.f1_fn = F1Score(task="binary") | ||
self.auc_fn = AUROC(task="binary") | ||
|
||
@classmethod | ||
def from_module(cls, model, learning_rate: float = 1e-4, distributed=False): | ||
return cls(model, learning_rate, distributed) | ||
|
||
def forward(self, x): | ||
return self.model(x) | ||
|
||
def step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]]) -> Dict[str, Tensor]: | ||
x, y = batch | ||
y_hat = self(x).squeeze(1) | ||
loss = self.loss_fn(y_hat, y) | ||
prob = y_hat.sigmoid() | ||
acc = self.acc_fn(prob, y) | ||
f1 = self.f1_fn(prob, y) | ||
auc = self.auc_fn(prob, y) | ||
return {"loss": loss, "acc": acc, "f1": f1, "auc": auc} | ||
|
||
def training_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None, | ||
optimizer_idx: Optional[int] = None, hiddens: Optional[Tensor] = None | ||
) -> Dict[str, Tensor]: | ||
return self.step(batch) | ||
|
||
def validation_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None, | ||
dataloader_idx: Optional[int] = None | ||
) -> Dict[str, Tensor]: | ||
return self.step(batch) | ||
|
||
def configure_optimizers(self): | ||
optimizer = Adam(self.parameters(), lr=self.learning_rate) | ||
return [optimizer] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,11 @@ | ||
from .round_robin_dataset import RoundRobinDataset | ||
from .nested_dataset import NestedDataset | ||
from .list_dataset import ListDataset | ||
from . import sampler | ||
|
||
__all__ = [ | ||
"RoundRobinDataset", | ||
"NestedDataset", | ||
"ListDataset", | ||
"sampler" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import List | ||
|
||
from torch.utils.data.dataset import Dataset, T_co | ||
|
||
|
||
class ListDataset(Dataset[T_co]): | ||
""" | ||
A dataset wrapping a list of data. | ||
""" | ||
|
||
def __init__(self, data: List[T_co]): | ||
super().__init__() | ||
self.data = data | ||
|
||
def __getitem__(self, index: int) -> T_co: | ||
return self.data[index] | ||
|
||
def __len__(self): | ||
return len(self.data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .sequential_iter_sampler import SequentialIterSampler | ||
|
||
__all__ = [ | ||
"SequentialIterSampler" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import Sized | ||
|
||
from torch.utils.data.sampler import Sampler, T_co | ||
|
||
|
||
class SequentialIterSampler(Sampler[T_co]): | ||
""" | ||
Use to split the large scale data into small subsets for each epochs | ||
For example, if the dataset size is 1M, and the num_samples = 1000, then each epoch will only use 1000 samples, and | ||
the next epoch will use the next 1000 samples. | ||
""" | ||
|
||
def __init__(self, data_source: Sized, num_samples: int): | ||
super().__init__(data_source) | ||
self.data_source = data_source | ||
self.num_samples = num_samples | ||
self.total_size = len(data_source) | ||
self.current_position = 0 | ||
|
||
def __iter__(self): | ||
yield from map(lambda x: x % self.total_size, | ||
range(self.current_position, self.current_position + self.num_samples)) | ||
self.current_position = (self.current_position + self.num_samples) % self.total_size | ||
|
||
def __len__(self): | ||
return self.num_samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from numpy import ndarray | ||
from torch import Tensor | ||
|
||
from tensorneko_util.util import dispatch, Eval | ||
|
||
from tensorneko_util.io import read | ||
|
||
|
||
@Eval.later | ||
def _secs_encoder(): | ||
from resemblyzer import VoiceEncoder | ||
return VoiceEncoder() | ||
|
||
|
||
@dispatch | ||
def secs(pred: str, real: str) -> float: | ||
from resemblyzer import VoiceEncoder, preprocess_wav | ||
pred_audio = preprocess_wav(read.audio(pred).audio[0].numpy()) | ||
real_audio = preprocess_wav(read.audio(real).audio[0].numpy()) | ||
return _secs_compute(pred_audio, real_audio) | ||
|
||
|
||
@dispatch | ||
def secs(pred: Tensor, real: Tensor) -> float: | ||
return secs(pred.numpy(), real.numpy()) | ||
|
||
|
||
@dispatch | ||
def secs(pred: ndarray, real: ndarray) -> float: | ||
from resemblyzer import VoiceEncoder, preprocess_wav | ||
if len(pred.shape) == 2: | ||
if pred.shape[0] == 1: | ||
pred = pred.squeeze(0) | ||
elif pred.shape[1] == 1: | ||
pred = pred.squeeze(1) | ||
else: | ||
raise ValueError("The input audio must be mono.") | ||
|
||
if len(real.shape) == 2: | ||
if real.shape[0] == 1: | ||
real = real.squeeze(0) | ||
elif real.shape[1] == 1: | ||
real = real.squeeze(1) | ||
else: | ||
raise ValueError("The input audio must be mono.") | ||
|
||
pred_audio = preprocess_wav(pred) | ||
real_audio = preprocess_wav(real) | ||
|
||
return _secs_compute(pred_audio, real_audio) | ||
|
||
|
||
def _secs_compute(pred_audio: ndarray, real_audio: ndarray) -> float: | ||
encoder = _secs_encoder.value | ||
real_embed = encoder.embed_utterance(real_audio) | ||
pred_embed = encoder.embed_utterance(pred_audio) | ||
|
||
return float((real_embed * pred_embed).sum()) |
Oops, something went wrong.