-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added MetricConfusionMatrixBase for adding custom confusion matrix based metrics * Added ConfusionMatrixBasedMetric Enum to get specific metrics such as tp,fp,fn,tn,precision,sensitivity,specificity,recall,ppv,npv,accuracy,f1score * Added confusion matrix common metrics (TruePositives, TrueNegatives, FalsePositives, FalseNegatives) * Added MetricMethod enum to pass to MetricBase, now you can define whether your metric is based on MEAN, SUM or LAST of all batches * StatsPrint callback now support "print_confusion_matrix" and "print_confusion_matrix_normalized" arguments in case MetricConfusionMatrixBase metric is found * Added confusion matrix tests and example * Some custom layers renames (breaking changes in this part)
- Loading branch information
Showing
26 changed files
with
993 additions
and
170 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import os | ||
import torch.optim as optim | ||
import torch.nn as nn | ||
|
||
from lpd.trainer import Trainer | ||
from lpd.callbacks import SchedulerStep, StatsPrint, ModelCheckPoint, LossOptimizerHandler, CallbackMonitor | ||
from lpd.extensions.custom_schedulers import DoNothingToLR | ||
from lpd.enums import Phase, State, MonitorType, StatsType, MonitorMode | ||
from lpd.metrics import TruePositives, FalsePositives, TrueNegatives, FalseNegatives | ||
import lpd.utils.torch_utils as tu | ||
import lpd.utils.general_utils as gu | ||
import examples.utils as eu | ||
|
||
gu.seed_all(42) # BECAUSE ITS THE ANSWER TO LIFE AND THE UNIVERSE | ||
|
||
def get_parameters(): | ||
# N is batch size; D_in is input dimension; | ||
# H is hidden dimension; D_out is output dimension. | ||
N, D_in, H, D_out, num_classes = 128, 100, 100, 3,3 | ||
num_epochs = 5 | ||
data_loader = eu.examples_data_generator(N, D_in, D_out, category_out=True) | ||
data_loader_steps = 100 | ||
return N, D_in, H, D_out, num_epochs, num_classes, data_loader, data_loader_steps | ||
|
||
|
||
def get_trainer_base(D_in, H, D_out, num_classes): | ||
device = tu.get_gpu_device_if_available() | ||
|
||
model = eu.get_basic_model(D_in, H, D_out).to(device) | ||
|
||
loss_func = nn.CrossEntropyLoss().to(device) | ||
|
||
optimizer = optim.Adam(model.parameters(), lr=1e-4) | ||
|
||
scheduler = DoNothingToLR() #CAN ALSO USE scheduler=None, BUT DoNothingToLR IS MORE EXPLICIT | ||
|
||
labels = ['Cat', 'Dog', 'Bird'] | ||
metric_name_to_func = { | ||
"TP":TruePositives(num_classes, labels=labels, threshold = 0), | ||
"FP":FalsePositives(num_classes, labels=labels, threshold = 0), | ||
"TN":TrueNegatives(num_classes, labels=labels, threshold = 0), | ||
"FN":FalseNegatives(num_classes, labels=labels, threshold = 0) | ||
} | ||
|
||
return device, model, loss_func, optimizer, scheduler, metric_name_to_func | ||
|
||
|
||
def get_trainer(N, D_in, H, D_out, num_epochs, num_classes, data_loader, data_loader_steps): | ||
device, model, loss_func, optimizer, scheduler, metric_name_to_func = get_trainer_base(D_in, H, D_out, num_classes) | ||
|
||
callbacks = [ | ||
LossOptimizerHandler(), | ||
StatsPrint(print_confusion_matrix=True) | ||
] | ||
|
||
trainer = Trainer(model=model, | ||
device=device, | ||
loss_func=loss_func, | ||
optimizer=optimizer, | ||
scheduler=scheduler, | ||
metric_name_to_func=metric_name_to_func, | ||
train_data_loader=data_loader, | ||
val_data_loader=data_loader, | ||
train_steps=data_loader_steps, | ||
val_steps=data_loader_steps, | ||
callbacks=callbacks, | ||
name='Confusion-Matrix-Example') | ||
return trainer | ||
|
||
|
||
def run(): | ||
N, D_in, H, D_out, num_epochs, num_classes, data_loader, data_loader_steps = get_parameters() | ||
|
||
current_trainer = get_trainer(N, D_in, H, D_out, num_epochs, num_classes, data_loader, data_loader_steps) | ||
|
||
current_trainer.train(num_epochs) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.