This repository has been archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathotter.py
85 lines (65 loc) · 2.74 KB
/
otter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import urllib.request
from PIL import Image
import torch
from torchvision import transforms
from models.model_factory import build_image_text_model
# ResNet50 + DeCLUTR-Sci-base model weights.
_MODELS = {
"InfoNCE": "https://onedrive.live.com/download?cid=CDD071074C65025E&resid=CDD071074C65025E%21872&authkey=APyQBYUHUiU2voc",
"LS": "https://onedrive.live.com/download?cid=CDD071074C65025E&resid=CDD071074C65025E%21874&authkey=AHo06qdL59Mx39M",
"KD": "https://onedrive.live.com/download?cid=CDD071074C65025E&resid=CDD071074C65025E%21873&authkey=AAETynOUaaHM7jQ",
"OTTER": "https://onedrive.live.com/download?cid=CDD071074C65025E&resid=CDD071074C65025E%21871&authkey=ANIpSxwJ3x9MAao",
}
def load(name, pretrained=True):
assert name in _MODELS.keys(), f"Model name must be in {list(_MODELS.keys())}."
model = build_image_text_model(
"resnet50",
"declutr-sci-base",
embedding_dim=768,
max_token_length=60,
label_smoothing=False,
pretrain=True,
lock_image=False
)
if pretrained:
url = _MODELS[name]
path = f'./pretrained/{name}.pth.tar'
if not os.path.exists('./pretrained/'):
os.makedirs('./pretrained/')
# Download checkpoint
if not os.path.exists(path):
print("Downloading model to ./pretrained")
urllib.request.urlretrieve(url, path)
print("Downloaded")
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["state_dict"])
return model, _transform()
def _convert_image_to_rgb(image):
return image.convert("RGB")
def _transform():
return transforms.Compose([
transforms.Resize(256, interpolation=Image.BICUBIC),
transforms.CenterCrop(224),
_convert_image_to_rgb,
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
temperature = 60
model, preprocess = load("InfoNCE")
model = model.to(device)
image = Image.open("doge.jpg")
image = preprocess(image).unsqueeze(0).to(device)
texts = ['photo of a dog', 'photo of a sofa', 'photo of a flower']
with torch.no_grad():
features = model.forward_features(image, texts)
image_logits, text_logits = model.compute_logits(features)
image_logits *= temperature
probs = image_logits.softmax(dim=-1).cpu().numpy()
print("Probs:", probs) # Probs: [[0.92657197 0.00180788 0.07162025]]