Skip to content

Commit

Permalink
Revision of SimCLR transforms (Lightning-Universe#857)
Browse files Browse the repository at this point in the history
Co-authored-by: otaj <ota@lightning.ai>
Co-authored-by: arnol <fokammanuel1@students.wits.ac.za>
  • Loading branch information
3 people authored and Jungwon-Lee committed Sep 18, 2022
1 parent f5be6c9 commit 2059bb0
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 73 deletions.
105 changes: 32 additions & 73 deletions pl_bolts/models/self_supervised/simclr/transforms.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,22 @@
import numpy as np

from pl_bolts.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision import transforms
else: # pragma: no cover
warn_missing_pkg("torchvision")

if _OPENCV_AVAILABLE:
import cv2
else: # pragma: no cover
warn_missing_pkg("cv2", pypi_name="opencv-python")


@under_review()
class SimCLRTrainDataTransform:
"""Transforms for SimCLR.
"""Transforms for SimCLR during training step of the pre-training stage.
Transform::
RandomResizedCrop(size=self.input_height)
RandomHorizontalFlip()
RandomApply([color_jitter], p=0.8)
RandomGrayscale(p=0.2)
GaussianBlur(kernel_size=int(0.1 * self.input_height))
RandomApply([GaussianBlur(kernel_size=int(0.1 * self.input_height))], p=0.5)
transforms.ToTensor()
Example::
Expand All @@ -34,7 +25,7 @@ class SimCLRTrainDataTransform:
transform = SimCLRTrainDataTransform(input_height=32)
x = sample()
(xi, xj) = transform(x)
(xi, xj, xk) = transform(x) # xk is only for the online evaluator if used
"""

def __init__(
Expand Down Expand Up @@ -68,16 +59,16 @@ def __init__(
if kernel_size % 2 == 0:
kernel_size += 1

data_transforms.append(GaussianBlur(kernel_size=kernel_size, p=0.5))
data_transforms.append(transforms.RandomApply([transforms.GaussianBlur(kernel_size=kernel_size)], p=0.5))

data_transforms = transforms.Compose(data_transforms)
self.data_transforms = transforms.Compose(data_transforms)

if normalize is None:
self.final_transform = transforms.ToTensor()
else:
self.final_transform = transforms.Compose([transforms.ToTensor(), normalize])

self.train_transform = transforms.Compose([data_transforms, self.final_transform])
self.train_transform = transforms.Compose([self.data_transforms, self.final_transform])

# add online train transform of the size of global view
self.online_transform = transforms.Compose(
Expand All @@ -93,9 +84,8 @@ def __call__(self, sample):
return xi, xj, self.online_transform(sample)


@under_review()
class SimCLREvalDataTransform(SimCLRTrainDataTransform):
"""Transforms for SimCLR.
"""Transforms for SimCLR during the validation step of the pre-training stage.
Transform::
Expand All @@ -109,7 +99,7 @@ class SimCLREvalDataTransform(SimCLRTrainDataTransform):
transform = SimCLREvalDataTransform(input_height=32)
x = sample()
(xi, xj) = transform(x)
(xi, xj, xk) = transform(x) # xk is only for the online evaluator if used
"""

def __init__(
Expand All @@ -129,70 +119,39 @@ def __init__(
)


@under_review()
class SimCLRFinetuneTransform:
class SimCLRFinetuneTransform(SimCLRTrainDataTransform):
"""Transforms for SimCLR during the fine-tuning stage.
Transform::
Resize(input_height + 10, interpolation=3)
transforms.CenterCrop(input_height),
transforms.ToTensor()
Example::
from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform
transform = SimCLREvalDataTransform(input_height=32)
x = sample()
xk = transform(x)
"""

def __init__(
self, input_height: int = 224, jitter_strength: float = 1.0, normalize=None, eval_transform: bool = False
) -> None:

self.jitter_strength = jitter_strength
self.input_height = input_height
self.normalize = normalize

self.color_jitter = transforms.ColorJitter(
0.8 * self.jitter_strength,
0.8 * self.jitter_strength,
0.8 * self.jitter_strength,
0.2 * self.jitter_strength,
super().__init__(
normalize=normalize, input_height=input_height, gaussian_blur=None, jitter_strength=jitter_strength
)

if not eval_transform:
data_transforms = [
transforms.RandomResizedCrop(size=self.input_height),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([self.color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
]
else:
data_transforms = [
if eval_transform:
self.data_transforms = [
transforms.Resize(int(self.input_height + 0.1 * self.input_height)),
transforms.CenterCrop(self.input_height),
]

if normalize is None:
final_transform = transforms.ToTensor()
else:
final_transform = transforms.Compose([transforms.ToTensor(), normalize])

data_transforms.append(final_transform)
self.transform = transforms.Compose(data_transforms)
self.transform = transforms.Compose([self.data_transforms, self.final_transform])

def __call__(self, sample):
return self.transform(sample)


@under_review()
class GaussianBlur:
# Implements Gaussian blur as described in the SimCLR paper
def __init__(self, kernel_size, p=0.5, min=0.1, max=2.0):
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError("You want to use `GaussianBlur` from `cv2` which is not installed yet.")

self.min = min
self.max = max

# kernel size is set to be 10% of the image height/width
self.kernel_size = kernel_size
self.p = p

def __call__(self, sample):
sample = np.array(sample)

# blur the image with a 50% chance
prob = np.random.random_sample()

if prob < self.p:
sigma = (self.max - self.min) * np.random.random_sample() + self.min
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)

return sample
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.utilities.imports import _IS_WINDOWS

from pl_bolts.utils import _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_13
from pl_bolts.utils.stability import UnderReviewWarning

# GitHub Actions use this path to cache datasets.
Expand All @@ -27,6 +28,8 @@ def catch_warnings():
with warnings.catch_warnings():
warnings.simplefilter("error")
warnings.simplefilter("ignore", UnderReviewWarning)
if _TORCHVISION_AVAILABLE and _TORCHVISION_LESS_THAN_0_13:
warnings.filterwarnings("ignore", "FLIP_LEFT_RIGHT is deprecated", DeprecationWarning)
yield


Expand Down
Empty file.
55 changes: 55 additions & 0 deletions tests/models/self_supervised/unit/test_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np
import pytest
import torch
from PIL import Image

from pl_bolts.models.self_supervised.simclr.transforms import (
SimCLREvalDataTransform,
SimCLRFinetuneTransform,
SimCLRTrainDataTransform,
)


@pytest.mark.parametrize(
"transform_cls",
[pytest.param(SimCLRTrainDataTransform, id="train-data"), pytest.param(SimCLREvalDataTransform, id="eval-data")],
)
def test_simclr_train_data_transform(catch_warnings, transform_cls):
# dummy image
img = np.random.randint(low=0, high=255, size=(32, 32, 3), dtype=np.uint8)
img = Image.fromarray(img)

# size of the generated views
input_height = 96
transform = transform_cls(input_height=input_height)
views = transform(img)

# the transform must output a list or a tuple of images
assert isinstance(views, (list, tuple))

# the transform must output three images
# (1st view, 2nd view, online evaluation view)
assert len(views) == 3

# all views are tensors
assert all(torch.is_tensor(v) for v in views)

# all views have expected sizes
assert all(v.size(1) == v.size(2) == input_height for v in views)


def test_simclr_finetune_transform(catch_warnings):
# dummy image
img = np.random.randint(low=0, high=255, size=(32, 32, 3), dtype=np.uint8)
img = Image.fromarray(img)

# size of the generated views
input_height = 96
transform = SimCLRFinetuneTransform(input_height=input_height)
view = transform(img)

# the view generator is a tensor
assert torch.is_tensor(view)

# view has expected size
assert view.size(1) == view.size(2) == input_height

0 comments on commit 2059bb0

Please sign in to comment.