-
Notifications
You must be signed in to change notification settings - Fork 119
/
Copy pathrun_realtime.py
79 lines (65 loc) · 2.82 KB
/
run_realtime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import logging
import matplotlib.pyplot as plt
import numpy as np
import torch.utils.data
from hardware.camera import RealSenseCamera
from hardware.device import get_device
from inference.post_process import post_process_output
from utils.data.camera_data import CameraData
from utils.visualisation.plot import save_results, plot_results
logging.basicConfig(level=logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(description='Evaluate network')
parser.add_argument('--network', type=str, default='saved_data/cornell_rgbd_iou_0.96',
help='Path to saved network to evaluate')
parser.add_argument('--use-depth', type=int, default=1,
help='Use Depth image for evaluation (1/0)')
parser.add_argument('--use-rgb', type=int, default=1,
help='Use RGB image for evaluation (1/0)')
parser.add_argument('--n-grasps', type=int, default=1,
help='Number of grasps to consider per image')
parser.add_argument('--cpu', dest='force_cpu', action='store_true', default=False,
help='Force code to run in CPU mode')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
# Connect to Camera
logging.info('Connecting to camera...')
cam = RealSenseCamera(device_id=830112070066)
cam.connect()
cam_data = CameraData(include_depth=args.use_depth, include_rgb=args.use_rgb)
# Load Network
logging.info('Loading model...')
net = torch.load(args.network)
logging.info('Done')
# Get the compute device
device = get_device(args.force_cpu)
try:
fig = plt.figure(figsize=(10, 10))
while True:
image_bundle = cam.get_image_bundle()
rgb = image_bundle['rgb']
depth = image_bundle['aligned_depth']
x, depth_img, rgb_img = cam_data.get_data(rgb=rgb, depth=depth)
with torch.no_grad():
xc = x.to(device)
pred = net.predict(xc)
q_img, ang_img, width_img = post_process_output(pred['pos'], pred['cos'], pred['sin'], pred['width'])
plot_results(fig=fig,
rgb_img=cam_data.get_rgb(rgb, False),
depth_img=np.squeeze(cam_data.get_depth(depth)),
grasp_q_img=q_img,
grasp_angle_img=ang_img,
no_grasps=args.n_grasps,
grasp_width_img=width_img)
finally:
save_results(
rgb_img=cam_data.get_rgb(rgb, False),
depth_img=np.squeeze(cam_data.get_depth(depth)),
grasp_q_img=q_img,
grasp_angle_img=ang_img,
no_grasps=args.n_grasps,
grasp_width_img=width_img
)