Skip to content

Commit

Permalink
Remove sklearn test
Browse files Browse the repository at this point in the history
  • Loading branch information
kisonho committed Apr 22, 2024
1 parent 28f83be commit df5ae18
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions tests/test_0102.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def test_dice_metric(self) -> None:

def test_f1_score(self) -> None:
from torchmanager import metrics
from sklearn.metrics import f1_score

# initialize
f1_score_fn = metrics.F1()
Expand All @@ -34,14 +33,18 @@ def test_f1_score(self) -> None:
# test f1 score
f1 = f1_score_fn(y, y_test)
self.assertGreaterEqual(float(f1_score_fn.result), 0, f"F1 Score value must be non-negative, got {f1}.")

'''
# Test with sklearn.metrics.f1_score
from sklearn.metrics import f1_score
y = y.argmax(1) == 1
y_test = y_test == 1
f1_sklearn = f1_score(y_test.flatten(), y.flatten(), average='binary')
self.assertAlmostEqual(float(f1_score_fn.result), f1_sklearn, places=2)
'''

def test_miou(self) -> None:
from torchmanager import metrics
from sklearn.metrics import jaccard_score

# initialize
miou_score_fn = metrics.MeanIoU()
Expand All @@ -51,8 +54,13 @@ def test_miou(self) -> None:
# test miou score
miou = miou_score_fn(y, y_test)
self.assertGreaterEqual(float(miou_score_fn.result), 0, f"Mean IoU value must be non-negative, got {miou}.")

'''
# Test with sklearn.metrics.jaccard_score
from sklearn.metrics import jaccard_score
miou_jaccard = jaccard_score(y_test.flatten(), (y > 0).flatten(), average='weighted')
self.assertAlmostEqual(float(miou_score_fn.result), miou_jaccard, places=2)
'''

def test_random(self) -> None:
from torchmanager_core.random import freeze_seed, unfreeze_seed
Expand Down

0 comments on commit df5ae18

Please sign in to comment.