-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrun_offline.py
85 lines (69 loc) · 2.99 KB
/
run_offline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
import logging
import matplotlib.pyplot as plt
import numpy as np
import torch.utils.data
from PIL import Image
from hardware.device import get_device
from inference.post_process import post_process_output
from utils.data.camera_data import CameraData
from utils.visualisation.plot import plot_results, save_results
logging.basicConfig(level=logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(description='Evaluate network')
parser.add_argument('--network', type=str,
help='Path to saved network to evaluate')
parser.add_argument('--rgb_path', type=str, default='cornell/08/pcd0845r.png',
help='RGB Image path')
parser.add_argument('--depth_path', type=str, default='cornell/08/pcd0845d.tiff',
help='Depth Image path')
parser.add_argument('--use-depth', type=int, default=1,
help='Use Depth image for evaluation (1/0)')
parser.add_argument('--use-rgb', type=int, default=1,
help='Use RGB image for evaluation (1/0)')
parser.add_argument('--n-grasps', type=int, default=1,
help='Number of grasps to consider per image')
parser.add_argument('--save', type=int, default=0,
help='Save the results')
parser.add_argument('--cpu', dest='force_cpu', action='store_true', default=False,
help='Force code to run in CPU mode')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
# Load image
logging.info('Loading image...')
pic = Image.open(args.rgb_path, 'r')
rgb = np.array(pic)
pic = Image.open(args.depth_path, 'r')
depth = np.expand_dims(np.array(pic), axis=2)
# Load Network
logging.info('Loading model...')
net = torch.load(args.network)
logging.info('Done')
# Get the compute device
device = get_device(args.force_cpu)
img_data = CameraData(include_depth=args.use_depth, include_rgb=args.use_rgb)
x, depth_img, rgb_img = img_data.get_data(rgb=rgb, depth=depth)
with torch.no_grad():
xc = x.to(device)
pred = net.predict(xc)
q_img, ang_img, width_img = post_process_output(pred['pos'], pred['cos'], pred['sin'], pred['width'])
if args.save:
save_results(
rgb_img=img_data.get_rgb(rgb, False),
depth_img=np.squeeze(img_data.get_depth(depth)),
grasp_q_img=q_img,
grasp_angle_img=ang_img,
no_grasps=args.n_grasps,
grasp_width_img=width_img
)
else:
fig = plt.figure(figsize=(10, 10))
plot_results(fig=fig,
rgb_img=img_data.get_rgb(rgb, False),
grasp_q_img=q_img,
grasp_angle_img=ang_img,
no_grasps=args.n_grasps,
grasp_width_img=width_img)
fig.savefig('img_result.pdf')