forked from Lightning-Universe/lightning-bolts
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revision of SimCLR transforms (Lightning-Universe#857)
Co-authored-by: otaj <ota@lightning.ai> Co-authored-by: arnol <fokammanuel1@students.wits.ac.za>
- Loading branch information
1 parent
f5be6c9
commit 2059bb0
Showing
4 changed files
with
90 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import numpy as np | ||
import pytest | ||
import torch | ||
from PIL import Image | ||
|
||
from pl_bolts.models.self_supervised.simclr.transforms import ( | ||
SimCLREvalDataTransform, | ||
SimCLRFinetuneTransform, | ||
SimCLRTrainDataTransform, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"transform_cls", | ||
[pytest.param(SimCLRTrainDataTransform, id="train-data"), pytest.param(SimCLREvalDataTransform, id="eval-data")], | ||
) | ||
def test_simclr_train_data_transform(catch_warnings, transform_cls): | ||
# dummy image | ||
img = np.random.randint(low=0, high=255, size=(32, 32, 3), dtype=np.uint8) | ||
img = Image.fromarray(img) | ||
|
||
# size of the generated views | ||
input_height = 96 | ||
transform = transform_cls(input_height=input_height) | ||
views = transform(img) | ||
|
||
# the transform must output a list or a tuple of images | ||
assert isinstance(views, (list, tuple)) | ||
|
||
# the transform must output three images | ||
# (1st view, 2nd view, online evaluation view) | ||
assert len(views) == 3 | ||
|
||
# all views are tensors | ||
assert all(torch.is_tensor(v) for v in views) | ||
|
||
# all views have expected sizes | ||
assert all(v.size(1) == v.size(2) == input_height for v in views) | ||
|
||
|
||
def test_simclr_finetune_transform(catch_warnings): | ||
# dummy image | ||
img = np.random.randint(low=0, high=255, size=(32, 32, 3), dtype=np.uint8) | ||
img = Image.fromarray(img) | ||
|
||
# size of the generated views | ||
input_height = 96 | ||
transform = SimCLRFinetuneTransform(input_height=input_height) | ||
view = transform(img) | ||
|
||
# the view generator is a tensor | ||
assert torch.is_tensor(view) | ||
|
||
# view has expected size | ||
assert view.size(1) == view.size(2) == input_height |