-
Notifications
You must be signed in to change notification settings - Fork 54
/
extract_car_size.py
50 lines (44 loc) · 1.44 KB
/
extract_car_size.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import argparse
import os
import pandas as pd
import torch
parser = argparse.ArgumentParser()
parser.add_argument('-map', type=str, default='i80', choices={'ai', 'i80', 'us101', 'lanker', 'peach'})
opt = parser.parse_args()
path = './traffic-data/xy-trajectories/{}/'.format(opt.map)
trajectories_path = './traffic-data/state-action-cost/data_{}_v0'.format(opt.map)
time_slots = [d[0].split("/")[-1] for d in os.walk(trajectories_path) if d[0] != trajectories_path]
df = dict()
for ts in time_slots:
df[ts] = pd.read_table(path + ts + '.txt', sep='\s+', header=None, names=(
'Vehicle ID',
'Frame ID',
'Total Frames',
'Global Time',
'Local X',
'Local Y',
'Global X',
'Global Y',
'Vehicle Length',
'Vehicle Width',
'Vehicle Class',
'Vehicle Velocity',
'Vehicle Acceleration',
'Lane Identification',
'Preceding Vehicle',
'Following Vehicle',
'Spacing',
'Headway'
))
car_sizes = dict()
for ts in time_slots:
d = df[ts]
car = lambda i: d[d['Vehicle ID'] == i]
car_sizes[ts] = dict()
cars = set(d['Vehicle ID'])
for c in cars:
if len(car(c)) > 0:
size = tuple(car(c).loc[car(c).index[0], ['Vehicle Width', 'Vehicle Length']].values)
car_sizes[ts][c] = size
print(c)
torch.save(car_sizes, 'traffic-data/state-action-cost/data_{}_v0/car_sizes.pth'.format(opt.map))