Skip to content

Experiments of Pytorch SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Notifications You must be signed in to change notification settings

goamegah/PyTorch-SimCLR

Repository files navigation

Pytorch simCLR experiments

SimCLR Illustration
credit: google SimCLR
An illustration of SimCLR (from google-research SimCLR repository).

Data Constraints

We assume that we have 100 labeled data samples to train a deep learning architecture for image classification tasks. Our dataset is MNIST handwritten digit. So we have 100 images and 100 labels.

How Run Models

In order to run models, try the following commands according to specific model.

LeNet5

$ python run_lenet.py --mode train --train-epochs 100 

Let's breaking down available flags

  • -m, --mode: which mode use during running model (train or eval)
  • -data: path to store or get dataset
  • -dn, --dataset-name: which dataset use (default MNIST)
  • -a, --arch: architecture use as base line
  • -b, --batch-size: train batch size
  • -eval-batch-size: eval batch size when eval mode
  • --lr, --learning-rate: learning rate

ResNet-18

$ python run_resnet.py --mode train --train-epochs 100 

Let's breaking down available flags

  • -m, --mode: which mode use during running model (train or eval)
  • -data: path to store or get dataset
  • -dn, --dataset-name: which dataset use (default MNIST)
  • -a, --arch: architecture use as base line
  • -b, --batch-size: train batch size
  • -eval-batch-size: eval batch size when eval mode
  • --lr, --learning-rate: learning rate

SimCLR-Resnet18

$ python run.py --mode train --train-mode finetune --train-epochs 10

Let's breaking down available flags

  • -m, --mode: which mode use during running model (train or eval)
  • -tm, --train-mode: type of training (finetune or pretrain)
    • pretrain for training contrastive layer
    • finetune for training classifier layer by freezing backbone(ResNet-18) pretrained layer.
  • -j, --workers: number of data loading workers
  • -te, --train-epochs: number of total epochs to run train'
  • -ee, --eval-epochs: number of total epochs to run test
  • -wd, --weight-decay: weight decay (default: 1e-4)
  • -s, --seed: seed
  • --out-dim: projection head out dimension
  • --temperature: temperature
  • --data: path to store or get dataset
  • -dn, --dataset-name: which dataset use (default MNIST)
  • -a, --arch: architecture use as base line
  • -b, --batch-size: train batch size
  • --eval-batch-size: eval batch size when eval mode
  • -lr, --learning-rate: learning rate

Feature Evaluation

Feature evaluation is done using a linear model protocol.

First, we learned features using SimCLR on the MNIST unsupervised set. Then, we train a linear classifier on top of the frozen features from SimCLR. The linear model is trained on features extracted from the MNIST train set and evaluated on the MNIST test set.

Check the Open In Colab notebook for reproducibility.

Method Architecture Accuracy
Supervised baseline LeNet-5 73.73
Supervised baseline ResNet-18 73.26
SimCLR ResNet-18 93.84

models are trained on 100 epochs.

Tools and Libraries Used

Workflow Illustration

  • numpy >= 1.24.3 (The fundamental package for scientific computing with Python)
  • scipy >= 1.10.1 (Additional functions for NumPy)
  • pandas >= 2.0.2 (A data frame library)
  • matplotlib >= 3.7.1 (A plotting library)
  • jupyterlab >= 4.0 (An application for running Jupyter notebooks)
  • ipywidgets >= 8.0.6 (Fixes progress bar issues in Jupyter Lab)
  • scikit-learn >= 1.2.2 (A general machine learning library)
  • watermark >= 2.4.2 (An IPython/Jupyter extension for printing package information)
  • torch >= 2.0.1 (The PyTorch deep learning library)
  • torchvision >= 0.15.2 (PyTorch utilities for computer vision)
  • torchmetrics >= 0.11.4 (Metrics for PyTorch)
  • wandb >= 0.17.9 (Web server for Model monitoring)

[OPTIONAL PACKAGES]

  • TensorboardX
  • wandb
  • boto3

To install these requirements most conveniently, you can use the requirements.txt file:

pip install -r requirements.txt

install-requirements

Then, after completing the installation, please check if all the packages are installed and are up to date using

python python_environment_check.py

check1

More installation (Optional)

Fast data loading feedback on Tensorboard (Source: tensorflow/tensorboard#4784)

$ pip uninstall -y tensorboard tb-nightly &&
$ pip install tb-nightly  # must have at least tb-nightly==2.5.0a20210316

About

Experiments of Pytorch SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages