-
Notifications
You must be signed in to change notification settings - Fork 1
/
demo.py
executable file
·60 lines (48 loc) · 2.29 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from models.vgg16_drnet import vgg16dres
from models.vgg16_drnet1 import vgg16dres1
from models.vgg19 import vgg19
from PIL import Image
from torchvision import transforms
import gradio as gr
import cv2
import numpy as np
sha_path = "pretrained_models/dr_sha.pth"
shb_path = "pretrained_models/dilres_shb.pth"
ucf_path = "pretrained_models/dr_ucf.pth"
dm_shb_path = "pretrained_models/model_sh_B.pth"
dm_sha_path = "pretrained_models/model_sh_A.pth"
# url = "https://drive.google.com/uc?id=1nnIHPaV9RGqK8JHL645zmRvkNrahD9ru"
# gdown.download(url, model_path, quiet=False)
def load_return_model(path, model):
model.load_state_dict(torch.load(path, device))
model.eval()
return model
device = torch.device('cuda') # device can be "cpu" or "gpu"
models = {
'sha': load_return_model(sha_path, vgg16dres(map_location=device).to(device)),
'ucf': load_return_model(ucf_path, vgg16dres1(map_location=device).to(device)),
'shb': load_return_model(shb_path, vgg16dres1(map_location=device).to(device)),
'dm_shb': load_return_model(dm_shb_path, vgg19().to(device)),
'dm_sha': load_return_model(dm_sha_path, vgg19().to(device)),
}
def predict(inp, model):
inp = Image.fromarray(inp.astype('uint8'), 'RGB')
inp = transforms.ToTensor()(inp).unsqueeze(0)
inp = inp.to(device)
with torch.set_grad_enabled(False):
outputs, _ = models[model](inp)
count = torch.sum(outputs).item()
vis_img = outputs[0, 0].cpu().numpy()
# normalize density map values from 0 to 1, then map it to 0-255.
vis_img = (vis_img - vis_img.min()) / (vis_img.max() - vis_img.min() + 1e-5)
vis_img = (vis_img * 255).astype(np.uint8)
vis_img = cv2.applyColorMap(vis_img, cv2.COLORMAP_JET)
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
return vis_img, int(count)
title = "Distribution Matching for Crowd Counting via Dilated Residual Network"
inputs = [gr.inputs.Image(label="Image of Crowd"),
gr.inputs.Dropdown(choices=['sha', 'shb', 'ucf', 'dm_sha', 'dm_shb'], label='Trained Dataset')]
outputs = [gr.outputs.Image(label="Predicted Density Map"), gr.outputs.Label(label="Predicted Count")]
gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title=title, examples=[],
allow_flagging=False, live=False, allow_screenshot=False).launch(share=True)