forked from yangjianxin1/ClipCap-Chinese
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprocess_flickr.py
54 lines (42 loc) · 1.8 KB
/
process_flickr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import skimage.io as io
import clip
from PIL import Image
import pickle
import argparse
from tqdm import tqdm, trange
from os.path import join
from loguru import logger
def main(args):
device = torch.device('cuda:0')
clip_model, preprocess = clip.load(args.clip_model_path, device=device, jit=False)
with open(args.caption_path, 'r') as f:
lines = f.readlines()
logger.info('len of captions:{}'.format(len(lines)))
image_id2embed = {} # imageid到image embedding的映射
caption_list = []
for i in trange(len(lines)):
line = lines[i].strip()
image_id, caption = line.split('\t')
if image_id not in image_id2embed.keys():
# 加载caption对应的图片
file = join(args.image_path, '{}.jpg'.format(image_id))
# 将图片编码成embedding
image = io.imread(file)
image = preprocess(Image.fromarray(image)).unsqueeze(0).to(device)
with torch.no_grad():
clip_embed = clip_model.encode_image(image).cpu()
image_id2embed[image_id] = clip_embed
caption_list.append((image_id, caption))
with open(args.output_path, 'wb') as f:
pickle.dump([caption_list, image_id2embed], f)
logger.info('num of image embedding:{}'.format(len(image_id2embed)))
logger.info('num of captions:{}'.format(len(caption_list)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--clip_model_path', default="pretrain_models/ViT-B-32.pt")
parser.add_argument('--caption_path', default="datasets/flickr_caption.txt")
parser.add_argument('--image_path', default="datasets/flickr30k-images")
parser.add_argument('--output_path', default="datasets/clip_caption.pkl")
args = parser.parse_args()
main(args)