-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathget_feature.py
50 lines (41 loc) · 1.35 KB
/
get_feature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import pandas as pd
import torch, torchvision
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import ResNet as ResNet
from torch.utils.data import Dataset
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
trnsfrms_val = transforms.Compose(
[
transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize(mean = mean, std = std)
]
)
class roi_dataset(Dataset):
def __init__(self, img_csv,
):
super().__init__()
self.transform = trnsfrms_val
self.images_lst = img_csv
def __len__(self):
return len(self.images_lst)
def __getitem__(self, idx):
path = self.images_lst.filename[idx]
image = Image.open(path).convert('RGB')
image = self.transform(image)
return image
img_csv=pd.read_csv(r'./test_list.csv')
test_datat=roi_dataset(img_csv)
database_loader = torch.utils.data.DataLoader(test_datat, batch_size=1, shuffle=False)
model = ResNet.resnet50(num_classes=128,mlp=False, two_branch=False, normlinear=True)
pretext_model = torch.load(r'./best_ckpt.pth')
model.fc = nn.Identity()
model.load_state_dict(pretext_model, strict=True)
model.eval()
with torch.no_grad():
for batch in database_loader:
features = model(batch)
features = features.cpu().numpy()