Skip to content

Commit

Permalink
add test for MultiMacenkoNormalizer torch
Browse files Browse the repository at this point in the history
  • Loading branch information
carloalbertobarbano committed Jan 13, 2025
1 parent 6c88302 commit e621b86
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,40 @@ def test_macenko_torch():
# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(result_numpy.flatten(), result_torch.flatten(), decimal=2, verbose=True)

def test_multitarget_macenko_torch():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size))
to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size))

# setup preprocessing and preprocess image to be normalized
T = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x * 255)
])
target = T(target)
t_to_transform = T(to_transform)

# initialize normalizers for each backend and fit to target image
single_normalizer = torchstain.normalizers.MacenkoNormalizer(backend="torch")
single_normalizer.fit(target)

multi_normalizer = torchstain.normalizers.MultiMacenkoNormalizer(backend="torch", norm_mode="avg-post")
multi_normalizer.fit([target, target, target])


# transform
result_single, _, _ = single_normalizer.normalize(I=t_to_transform, stains=True)
result_multi, _, _ = multi_normalizer.normalize(I=t_to_transform, stains=True)

# convert to numpy and set dtype
result_single = result_single.numpy().astype("float32") / 255.
result_multi = result_multi.numpy().astype("float32") / 255.

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(result_single.flatten(), result_multi.flatten(), decimal=2, verbose=True)


def test_reinhard_torch():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
Expand Down

0 comments on commit e621b86

Please sign in to comment.