Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting the performance of (CLIP w/ RN-50) and (CLIP w/ EN-B5) #13

Closed
yuzhimanhua opened this issue Aug 29, 2024 · 7 comments
Closed

Getting the performance of (CLIP w/ RN-50) and (CLIP w/ EN-B5) #13

yuzhimanhua opened this issue Aug 29, 2024 · 7 comments

Comments

@yuzhimanhua
Copy link

Hello, thank you for the great work and for publishing the code!

I wonder if there is a fast solution to using your released codebase to reproduce the result of (CLIP w/ RN-50) and (CLIP w/ EN-B5). For example, for image classification, can I directly take an RN-50 or EN-B5 checkpoint (e.g., pre-trained on ImageNet) and fine-tune it with VinDr/RSNA training sets to get the performance? Thanks!

@shantanu-ai
Copy link
Member

shantanu-ai commented Aug 29, 2024

Hi,
Thanks for taking interest in our work. The baselines are is a messy state. However, what u said will be good to work. Couple of points:

  1. Finetuning RSNA and VINDR with imagenet-pretrained models will perform extremely poorly as they did not learn breast specific representations.
  2. In the paper, we use the image+text contrastive loss to create the baselines using imagenet-pretrained vision encoders. We align the vision and text representations with our Inhouse UPMC dataset. You need this alignment to have better results. For pretraining the baselines, we only use image+text dataset only, no image+label dataset was used.
  3. Also use the same augmentations and the csv file mentioned in the codebase.

This is the contrastive loss we optimize for the baselines:

import torch
import torch.nn as nn
from breastclip import util
from torch.nn import functional as F

all_gather_func = util.DistAutogradAllGatherFunction(partial=False)


def all_gather(tensor):
    world_size = util.GlobalEnv.get().world_size
    if world_size > 1:
        tensor_list = all_gather_func.apply(tensor)
        all_tensor = torch.cat(tensor_list, 0)
    else:
        all_tensor = tensor
    return all_tensor


class MammoClipBaseline_contrastive(nn.Module):
    def __init__(self, label_smoothing=0.0, i2i_weight=0.0, t2t_weight=0.0, loss_ratio=1.0):
        super(BreastClip_contrastive, self).__init__()
        self.name = "contrastive"
        self.label_smoothing = label_smoothing
        self.loss_ratio = loss_ratio
        self.i2i_weight = i2i_weight
        self.t2t_weight = t2t_weight

    def forward(self, image_embeddings, text_embeddings, labels, logit_scale, is_train, **kwargs):
        world_rank = util.GlobalEnv.get().world_rank
        batch_size = labels.size(0)

        all_image_embeddings = all_gather(image_embeddings)
        all_text_embeddings = all_gather(text_embeddings)

        with torch.no_grad():
            labels = labels + (world_rank * batch_size)

        loss_i2t = 0
        loss_t2i = 0

        # I1 - T1
        logits_per_image = logit_scale * image_embeddings @ all_text_embeddings.T
        logits_per_text = logit_scale * text_embeddings @ all_image_embeddings.T

        label_smoothing = self.label_smoothing if is_train else 0.0
        loss_i2t += F.cross_entropy(logits_per_image, labels, label_smoothing=label_smoothing)
        loss_t2i += F.cross_entropy(logits_per_text, labels, label_smoothing=label_smoothing)

        if is_train:
            util.GlobalEnv.get().summary_writer.train.add_scalar(
                "loss/contrastive/steps_i2t", loss_i2t, util.GlobalEnv.get().summary_writer.global_step
            )
            util.GlobalEnv.get().summary_writer.train.add_scalar(
                "loss/contrastive/steps_t2i", loss_t2i, util.GlobalEnv.get().summary_writer.global_step
            )

        # contrastive loss
        loss = (0.75 * loss_i2t + 0.25 * loss_t2i)  # shape: (batch_size,)
        return loss.mean()

@shantanu-ai
Copy link
Member

Also this is the baseline model:

lass BreastClipBaseline(nn.Module):
    def __init__(self, model_config: Dict, all_loss_config: Dict, tokenizer: PreTrainedTokenizer = None):
        super().__init__()
        self.tokenizer = tokenizer
        self.image_encoder = load_image_encoder(model_config["image_encoder"])
        self.text_encoder = load_text_encoder(model_config["text_encoder"], vocab_size=tokenizer.vocab_size)
        self.text_pooling = model_config["text_encoder"]["pooling"]

        self.model_config = model_config
        self.loss_config = {k: v for k, v in all_loss_config.items()}

        self.projection = "projection_head" in model_config

        if self.projection:
            self.image_projection = load_projection_head(
                embedding_dim=self.image_encoder.out_dim, config_projection_head=model_config["projection_head"]
            )
            self.text_projection = load_projection_head(
                embedding_dim=self.text_encoder.out_dim, config_projection_head=model_config["projection_head"]
            )
        else:
            assert (
                    self.image_encoder.out_dim == self.text_encoder.out_dim
            ), "Without 'projection_head', embedding_dim of the image and text encoder must be the same."

        self.temperature = model_config["temperature"] if "temperature" in model_config else None
        if self.temperature:
            self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / self.temperature))
        else:
            self.logit_scale = torch.tensor(1, dtype=torch.float32)
            log.warning("[Breast-clip] missing temperature scaling factor")

    def encode_image(self, image):
        image_features = self.image_encoder(image)

        if self.model_config["image_encoder"]["model_type"].lower() == "cnn":
            return image_features
        else:
            # get [CLS] token for global representation (only for vision transformer)
            global_features = image_features[:, 0]
            return global_features

    def encode_image_normalized(self, image):
        img_emb = self.encode_image(image)
        img_emb = self.image_projection(img_emb) if self.projection else img_emb
        img_emb = img_emb / torch.norm(img_emb, dim=1, keepdim=True)
        return img_emb

    def encode_text(self, text_tokens):
        text_features = self.text_encoder(text_tokens)

        if self.text_pooling == "eos":
            # take features from the eot embedding (eos_token is the highest number in each sequence)
            eos_token_indices = text_tokens["attention_mask"].sum(dim=-1) - 1
            text_features = text_features[torch.arange(text_features.shape[0]), eos_token_indices]
        elif self.text_pooling == "bos":
            text_features = text_features[:, 0]
        elif self.text_pooling == "mean":
            input_mask_expanded = text_tokens["attention_mask"].unsqueeze(axis=-1).expand(text_features.size()).float()
            text_features = torch.sum(text_features * input_mask_expanded, axis=1) / torch.clamp(
                input_mask_expanded.sum(axis=1), min=1e-9)
        else:
            raise NotImplementedError("Not supported pooling method : %s", self.text_pooling)

        return text_features

    def forward(self, batch, device=None):
        device = batch["images"].device if device is None else device
        # get image and text features
        image_features_g = self.encode_image(batch["images"].to(device))
        text_features_g = self.encode_text(batch["text_tokens"].to(device))

        image_embeddings = self.image_projection(image_features_g) if self.projection else image_features_g
        text_embeddings = self.text_projection(text_features_g) if self.projection else text_features_g

        # normalize features
        image_embeddings = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
        text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)

        # labels
        labels = torch.arange(image_embeddings.shape[0], device=device)

        out = {
            "image_embeddings": image_embeddings,
            "text_embeddings": text_embeddings,
            "labels": labels,
            "logit_scale": self.logit_scale.exp(),
        }
        return out

@shantanu-ai
Copy link
Member

And this is the dataset class:

import logging
import random
from pathlib import Path
from typing import Dict, List

import cv2
import numpy as np
import pandas as pd
import torch
from PIL import Image
from breastclip.data.data_utils import load_transform
from torch.utils.data.dataset import Dataset

log = logging.getLogger(__name__)


class ImageTextDataset_contrastive(Dataset):
    def __init__(
            self,
            tokenizer,
            split: str,
            df: pd.DataFrame,
            dataset: str,
            data_dir: str,
            image_dir: str,
            text_max_length: int = 256,
            loss_config: Dict = None,
            transform_config: Dict = None,
            mean=0,
            std=0,
            image_encoder_type="swin",
            convirt_mode=True,
            **kwargs
    ):
        self.dataframe = df[["patient_id", "image_id", "FINDINGS", "IMPRESSION", "REPORT", "BIRADS_numeric", "fold"]]
        self.tokenizer = tokenizer
        self.root_dir = Path(data_dir)
        self.img_dir = image_dir
        self.dataset = dataset
        self.text_max_length = text_max_length
        self.loss_config = {k: v for k, v in loss_config.items()}
        self.tfms = load_transform(split=split, transform_config=transform_config)
        self.mean = mean
        self.std = std
        self.image_encoder_type = image_encoder_type
        self.convirt_mode = convirt_mode

        log.info(f"split: {split} transform")
        log.info(self.tfms)

    def __len__(self):
        return len(self.dataframe)

    def _get_img_path(self, study_id, image_id):
        if self.dataset.lower() == 'upmc':
            return self.root_dir / self.img_dir / f'Patient_{study_id}' / image_id
        else:
            return self.root_dir / self.img_dir / f'{str(study_id)}' / image_id

    def __getitem__(self, idx):
        study_id = str(self.dataframe.iloc[idx]['patient_id'])
        image_id = self.dataframe.iloc[idx]['image_id']
        img_path = self._get_img_path(study_id, image_id)

        if (
                self.image_encoder_type == "swin" or
                self.image_encoder_type == "resnet101" or
                self.image_encoder_type == "resnet152" or
                self.image_encoder_type == "tf_efficientnet_b5_ns-detect" or
                self.image_encoder_type == "tf_efficientnetv2-detect"
        ):
            img = Image.open(img_path).convert('RGB')
            img = np.array(img)
        else:
            img = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)

        if self.tfms:
            augmented = self.tfms(image=img)
            img = augmented['image']
        img = img.astype('float32')
        img -= img.min()
        img /= img.max()
        img = torch.tensor((img - self.mean) / self.std, dtype=torch.float32)
        img = img.unsqueeze(0)

        content = self.dataframe.iloc[idx]['REPORT']
        if self.convirt_mode:
            content = content.replace("\n", " ")
            ls_text = content.split(".")
            if '' in ls_text:
                ls_text.remove('')

            text = random.choice(ls_text)
        else:
            text = content
        label = torch.tensor(self.dataframe.iloc[idx]['BIRADS_numeric'], dtype=torch.long)
        results = {"image": img, "text": text, "label": label}
        return results

    def collate_fn(self, instances: List):
        images = torch.stack([ins["image"] for ins in instances], dim=0)
        texts = list([ins["text"] for ins in instances])
        text_tokens = self.tokenizer(texts, padding="max_length", truncation=True, return_tensors="pt",
                                     max_length=self.text_max_length)
        labels = torch.stack([ins["label"] for ins in instances], dim=0)

        batch = {
            "images": images,
            "texts": texts,
            "labels": labels,
            "text_tokens": text_tokens,
        }

        return batch

@shantanu-ai shantanu-ai pinned this issue Aug 29, 2024
@shantanu-ai
Copy link
Member

This is the model config yaml file:

name: "clip_custom"
temperature: 0.07

image_encoder:
  source: "cnn" # one of { "huggingface"}
  name: 'tf_efficientnet_b5_ns-detect'
  pretrained: true
  model_type: 'cnn'

text_encoder:
  source: "huggingface" # one of { "huggingface"}
  name: emilyalsentzer/Bio_ClinicalBERT
  pretrained: true
  gradient_checkpointing: false
  pooling: "eos" # one of { "eos" | "bos" | "mean" }
  cache_dir: "/restricted/projectnb/batmanlab/shawn24/PhD/Breast-CLIP/src/codebase/outputs/huggingface/"
  trust_remote_code: true
  mlm_head: true

projection_head: # optional
  name: "linear" # one of { "linear" | "mlp" }
  dropout: 0.1
  proj_dim: 512

@shantanu-ai
Copy link
Member

And this is the overall training config file:

defaults:
  - _self_
  - data_train:
      - upmc_convirt
  - data_zs:
      - upmc_zs
  - dataloader: dataloader_b5
  - tokenizer: clinical_bert
  - transform: clahe
  - model: clip_b5_det_clinical
  - optimizer: adamw
  - scheduler: cosine_epoch15_warmup1
  - loss: breast_clip_contrastive

base:
  period: "n"
  resume_training: False
  epoch_to_start: 0
  checkpoint_to_start: ""
  train_fast: False
  fold: 0
  seed: 10
  amp: True
  mean: 0.3089279
  std: 0.25053555408335154
  image_size_h: 1520
  image_size_w: 912
  text_max_length: 256
  loss_best: contrastive
  data_frac: 1.0
  zs_prompts:
    upmc:
      - "birads category 0"
      - "birads category 1"
      - "birads category 2"
    rsna:
      - "no malignancy"
      - "malignancy"
  output:
    args_path: ${hydra:run.dir}
    checkpoint: ${hydra:run.dir}/checkpoints/
    tensorboard: ${hydra:run.dir}/tensorboard/

hydra:
  run:
    dir: /restricted/projectnb/batmanlab/shawn24/PhD/Breast-CLIP/src/codebase/outputs/upmc_clip/b5_baseline_period_${base.period}

@shantanu-ai
Copy link
Member

Let me know if you have further issues. If not, let me know if i can close the issue?

@yuzhimanhua
Copy link
Author

Thank you so much for your quick and detailed response! I will try your code to build the baseline!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants