-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Getting the performance of (CLIP w/ RN-50) and (CLIP w/ EN-B5) #13
Comments
Hi,
This is the contrastive loss we optimize for the baselines: import torch
import torch.nn as nn
from breastclip import util
from torch.nn import functional as F
all_gather_func = util.DistAutogradAllGatherFunction(partial=False)
def all_gather(tensor):
world_size = util.GlobalEnv.get().world_size
if world_size > 1:
tensor_list = all_gather_func.apply(tensor)
all_tensor = torch.cat(tensor_list, 0)
else:
all_tensor = tensor
return all_tensor
class MammoClipBaseline_contrastive(nn.Module):
def __init__(self, label_smoothing=0.0, i2i_weight=0.0, t2t_weight=0.0, loss_ratio=1.0):
super(BreastClip_contrastive, self).__init__()
self.name = "contrastive"
self.label_smoothing = label_smoothing
self.loss_ratio = loss_ratio
self.i2i_weight = i2i_weight
self.t2t_weight = t2t_weight
def forward(self, image_embeddings, text_embeddings, labels, logit_scale, is_train, **kwargs):
world_rank = util.GlobalEnv.get().world_rank
batch_size = labels.size(0)
all_image_embeddings = all_gather(image_embeddings)
all_text_embeddings = all_gather(text_embeddings)
with torch.no_grad():
labels = labels + (world_rank * batch_size)
loss_i2t = 0
loss_t2i = 0
# I1 - T1
logits_per_image = logit_scale * image_embeddings @ all_text_embeddings.T
logits_per_text = logit_scale * text_embeddings @ all_image_embeddings.T
label_smoothing = self.label_smoothing if is_train else 0.0
loss_i2t += F.cross_entropy(logits_per_image, labels, label_smoothing=label_smoothing)
loss_t2i += F.cross_entropy(logits_per_text, labels, label_smoothing=label_smoothing)
if is_train:
util.GlobalEnv.get().summary_writer.train.add_scalar(
"loss/contrastive/steps_i2t", loss_i2t, util.GlobalEnv.get().summary_writer.global_step
)
util.GlobalEnv.get().summary_writer.train.add_scalar(
"loss/contrastive/steps_t2i", loss_t2i, util.GlobalEnv.get().summary_writer.global_step
)
# contrastive loss
loss = (0.75 * loss_i2t + 0.25 * loss_t2i) # shape: (batch_size,)
return loss.mean() |
Also this is the baseline model: lass BreastClipBaseline(nn.Module):
def __init__(self, model_config: Dict, all_loss_config: Dict, tokenizer: PreTrainedTokenizer = None):
super().__init__()
self.tokenizer = tokenizer
self.image_encoder = load_image_encoder(model_config["image_encoder"])
self.text_encoder = load_text_encoder(model_config["text_encoder"], vocab_size=tokenizer.vocab_size)
self.text_pooling = model_config["text_encoder"]["pooling"]
self.model_config = model_config
self.loss_config = {k: v for k, v in all_loss_config.items()}
self.projection = "projection_head" in model_config
if self.projection:
self.image_projection = load_projection_head(
embedding_dim=self.image_encoder.out_dim, config_projection_head=model_config["projection_head"]
)
self.text_projection = load_projection_head(
embedding_dim=self.text_encoder.out_dim, config_projection_head=model_config["projection_head"]
)
else:
assert (
self.image_encoder.out_dim == self.text_encoder.out_dim
), "Without 'projection_head', embedding_dim of the image and text encoder must be the same."
self.temperature = model_config["temperature"] if "temperature" in model_config else None
if self.temperature:
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / self.temperature))
else:
self.logit_scale = torch.tensor(1, dtype=torch.float32)
log.warning("[Breast-clip] missing temperature scaling factor")
def encode_image(self, image):
image_features = self.image_encoder(image)
if self.model_config["image_encoder"]["model_type"].lower() == "cnn":
return image_features
else:
# get [CLS] token for global representation (only for vision transformer)
global_features = image_features[:, 0]
return global_features
def encode_image_normalized(self, image):
img_emb = self.encode_image(image)
img_emb = self.image_projection(img_emb) if self.projection else img_emb
img_emb = img_emb / torch.norm(img_emb, dim=1, keepdim=True)
return img_emb
def encode_text(self, text_tokens):
text_features = self.text_encoder(text_tokens)
if self.text_pooling == "eos":
# take features from the eot embedding (eos_token is the highest number in each sequence)
eos_token_indices = text_tokens["attention_mask"].sum(dim=-1) - 1
text_features = text_features[torch.arange(text_features.shape[0]), eos_token_indices]
elif self.text_pooling == "bos":
text_features = text_features[:, 0]
elif self.text_pooling == "mean":
input_mask_expanded = text_tokens["attention_mask"].unsqueeze(axis=-1).expand(text_features.size()).float()
text_features = torch.sum(text_features * input_mask_expanded, axis=1) / torch.clamp(
input_mask_expanded.sum(axis=1), min=1e-9)
else:
raise NotImplementedError("Not supported pooling method : %s", self.text_pooling)
return text_features
def forward(self, batch, device=None):
device = batch["images"].device if device is None else device
# get image and text features
image_features_g = self.encode_image(batch["images"].to(device))
text_features_g = self.encode_text(batch["text_tokens"].to(device))
image_embeddings = self.image_projection(image_features_g) if self.projection else image_features_g
text_embeddings = self.text_projection(text_features_g) if self.projection else text_features_g
# normalize features
image_embeddings = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
# labels
labels = torch.arange(image_embeddings.shape[0], device=device)
out = {
"image_embeddings": image_embeddings,
"text_embeddings": text_embeddings,
"labels": labels,
"logit_scale": self.logit_scale.exp(),
}
return out
|
And this is the dataset class: import logging
import random
from pathlib import Path
from typing import Dict, List
import cv2
import numpy as np
import pandas as pd
import torch
from PIL import Image
from breastclip.data.data_utils import load_transform
from torch.utils.data.dataset import Dataset
log = logging.getLogger(__name__)
class ImageTextDataset_contrastive(Dataset):
def __init__(
self,
tokenizer,
split: str,
df: pd.DataFrame,
dataset: str,
data_dir: str,
image_dir: str,
text_max_length: int = 256,
loss_config: Dict = None,
transform_config: Dict = None,
mean=0,
std=0,
image_encoder_type="swin",
convirt_mode=True,
**kwargs
):
self.dataframe = df[["patient_id", "image_id", "FINDINGS", "IMPRESSION", "REPORT", "BIRADS_numeric", "fold"]]
self.tokenizer = tokenizer
self.root_dir = Path(data_dir)
self.img_dir = image_dir
self.dataset = dataset
self.text_max_length = text_max_length
self.loss_config = {k: v for k, v in loss_config.items()}
self.tfms = load_transform(split=split, transform_config=transform_config)
self.mean = mean
self.std = std
self.image_encoder_type = image_encoder_type
self.convirt_mode = convirt_mode
log.info(f"split: {split} transform")
log.info(self.tfms)
def __len__(self):
return len(self.dataframe)
def _get_img_path(self, study_id, image_id):
if self.dataset.lower() == 'upmc':
return self.root_dir / self.img_dir / f'Patient_{study_id}' / image_id
else:
return self.root_dir / self.img_dir / f'{str(study_id)}' / image_id
def __getitem__(self, idx):
study_id = str(self.dataframe.iloc[idx]['patient_id'])
image_id = self.dataframe.iloc[idx]['image_id']
img_path = self._get_img_path(study_id, image_id)
if (
self.image_encoder_type == "swin" or
self.image_encoder_type == "resnet101" or
self.image_encoder_type == "resnet152" or
self.image_encoder_type == "tf_efficientnet_b5_ns-detect" or
self.image_encoder_type == "tf_efficientnetv2-detect"
):
img = Image.open(img_path).convert('RGB')
img = np.array(img)
else:
img = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
if self.tfms:
augmented = self.tfms(image=img)
img = augmented['image']
img = img.astype('float32')
img -= img.min()
img /= img.max()
img = torch.tensor((img - self.mean) / self.std, dtype=torch.float32)
img = img.unsqueeze(0)
content = self.dataframe.iloc[idx]['REPORT']
if self.convirt_mode:
content = content.replace("\n", " ")
ls_text = content.split(".")
if '' in ls_text:
ls_text.remove('')
text = random.choice(ls_text)
else:
text = content
label = torch.tensor(self.dataframe.iloc[idx]['BIRADS_numeric'], dtype=torch.long)
results = {"image": img, "text": text, "label": label}
return results
def collate_fn(self, instances: List):
images = torch.stack([ins["image"] for ins in instances], dim=0)
texts = list([ins["text"] for ins in instances])
text_tokens = self.tokenizer(texts, padding="max_length", truncation=True, return_tensors="pt",
max_length=self.text_max_length)
labels = torch.stack([ins["label"] for ins in instances], dim=0)
batch = {
"images": images,
"texts": texts,
"labels": labels,
"text_tokens": text_tokens,
}
return batch |
This is the model config yaml file: name: "clip_custom"
temperature: 0.07
image_encoder:
source: "cnn" # one of { "huggingface"}
name: 'tf_efficientnet_b5_ns-detect'
pretrained: true
model_type: 'cnn'
text_encoder:
source: "huggingface" # one of { "huggingface"}
name: emilyalsentzer/Bio_ClinicalBERT
pretrained: true
gradient_checkpointing: false
pooling: "eos" # one of { "eos" | "bos" | "mean" }
cache_dir: "/restricted/projectnb/batmanlab/shawn24/PhD/Breast-CLIP/src/codebase/outputs/huggingface/"
trust_remote_code: true
mlm_head: true
projection_head: # optional
name: "linear" # one of { "linear" | "mlp" }
dropout: 0.1
proj_dim: 512 |
And this is the overall training config file: defaults:
- _self_
- data_train:
- upmc_convirt
- data_zs:
- upmc_zs
- dataloader: dataloader_b5
- tokenizer: clinical_bert
- transform: clahe
- model: clip_b5_det_clinical
- optimizer: adamw
- scheduler: cosine_epoch15_warmup1
- loss: breast_clip_contrastive
base:
period: "n"
resume_training: False
epoch_to_start: 0
checkpoint_to_start: ""
train_fast: False
fold: 0
seed: 10
amp: True
mean: 0.3089279
std: 0.25053555408335154
image_size_h: 1520
image_size_w: 912
text_max_length: 256
loss_best: contrastive
data_frac: 1.0
zs_prompts:
upmc:
- "birads category 0"
- "birads category 1"
- "birads category 2"
rsna:
- "no malignancy"
- "malignancy"
output:
args_path: ${hydra:run.dir}
checkpoint: ${hydra:run.dir}/checkpoints/
tensorboard: ${hydra:run.dir}/tensorboard/
hydra:
run:
dir: /restricted/projectnb/batmanlab/shawn24/PhD/Breast-CLIP/src/codebase/outputs/upmc_clip/b5_baseline_period_${base.period} |
Let me know if you have further issues. If not, let me know if i can close the issue? |
Thank you so much for your quick and detailed response! I will try your code to build the baseline! |
Hello, thank you for the great work and for publishing the code!
I wonder if there is a fast solution to using your released codebase to reproduce the result of (CLIP w/ RN-50) and (CLIP w/ EN-B5). For example, for image classification, can I directly take an RN-50 or EN-B5 checkpoint (e.g., pre-trained on ImageNet) and fine-tune it with VinDr/RSNA training sets to get the performance? Thanks!
The text was updated successfully, but these errors were encountered: