-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
53 lines (41 loc) · 1.51 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Author: Daiwei (David) Lu
# Make predictions on test images
import torch
from skimage import io, transform
from torch.utils.data import Dataset
from torchvision import transforms
from model import Net
from utils import TestFile, Rescale, ToTensor, Normalize, show_dot
import argparse
MODEL_PATH = './model.pth'
def test_model(path):
image = io.imread(path)
image = transform.resize(image, (256, 256))
device = torch.device("cuda")
model = Net()
model = model.to(device)
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()
with torch.no_grad():
set = TestFile(path,
transform=transforms.Compose([
Rescale(256),
ToTensor(),
Normalize()
])
)
loader = torch.utils.data.DataLoader(set)
for i, input in enumerate(loader):
image = input['image'].float().cuda().to(device)
coordinates = model(image).data
coordinates = coordinates.cpu().numpy()
print('{:.4f} {:.4f}'.format(coordinates[0][0], coordinates[0][1]))
def main():
# Training settings
parser = argparse.ArgumentParser(description='Test Prediction')
parser.add_argument('path', metavar='P', type=str,
help='path of file for prediction')
args = parser.parse_args()
test_model(args.path)
if __name__ == '__main__':
main()