-
Notifications
You must be signed in to change notification settings - Fork 0
/
ballJson2Arr.py
73 lines (58 loc) · 3.79 KB
/
ballJson2Arr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from glob import glob
import json
import numpy as np
import torch
import traceback
def load_game(file):
with open(file, 'r') as f:
return json.load(f)
def assertSequences(input_sequences, output_sequences, input_sequence_length, output_sequence_length):
for iseq, oseq in zip(input_sequences, output_sequences):
assert iseq.shape[0] == oseq.shape[0] and iseq.shape[0] != 0
assert iseq.shape[1] == input_sequence_length
assert oseq.shape[1] == output_sequence_length
assert iseq.shape[2] == 11
assert iseq.shape[3] == oseq.shape[2] and iseq.shape[3] == 3
def createSequences(game, input_sequence_length, output_sequence_length):
total_sequence_length = input_sequence_length + output_sequence_length
start_end_times = set()
input_sequences = []
output_sequences = []
player_sequences = []
for event in game['events']:
moments = event['moments']
if len(moments) > 0:
start_time, end_time = moments[0][2], moments[-1][2]
if (start_time, end_time) in start_end_times:
# Don't include repeat data
continue
start_end_times.add((start_time, end_time))
positions = np.array([moment[-1] for moment in moments if len(moment[-1]) == 11])[...,-3:]
theInput = [positions[i:i+input_sequence_length] for i in range(0, positions.shape[0]-total_sequence_length+1, input_sequence_length) if positions[i:i+input_sequence_length].shape[0] == input_sequence_length]
theOutput = [positions[i:i+output_sequence_length,0] for i in range(input_sequence_length, positions.shape[0], input_sequence_length) if positions[i:i+output_sequence_length].shape[0] == output_sequence_length]
thePlayer = [positions[i:i+output_sequence_length] for i in range(input_sequence_length, positions.shape[0], input_sequence_length) if positions[i:i+output_sequence_length].shape[0] == output_sequence_length]
if len(theInput) != 0 and len(theOutput) != 0 and len(thePlayer) != 0:
input_sequences.append(np.array(theInput)) # (moments, ballplayer, xyz) input sequence
output_sequences.append(np.array(theOutput)) # (moments, xyz) of ball
player_sequences.append(np.array(thePlayer)) # (moments, ballplayer, xyz) output sequence
assertSequences(input_sequences, output_sequences, input_sequence_length, output_sequence_length)
input_sequences = np.concatenate(input_sequences, axis=0) # (batch, horizon, ballplayer, xyz)
output_sequences = np.concatenate(output_sequences, axis=0) # (batch, horizon, xyz) of ball
player_sequences = np.concatenate(player_sequences, axis=0) # (batch, horizon, ballplayer, xyz)
return input_sequences, output_sequences, player_sequences
def file2Arr(file, input_sequence_length, output_sequence_length):
game = load_game(file)
try:
input_seq, output_seq, player_seq = createSequences(game, input_sequence_length, output_sequence_length)
except Exception:
print(f'Issue with {file}')
traceback.print_exc()
return
with open(file.rstrip('.json')+f'_{input_sequence_length}_{output_sequence_length}_in.pt', 'wb') as fin, open(file.rstrip('.json')+f'_{input_sequence_length}_{output_sequence_length}_out.pt', 'wb') as fout, open(file.rstrip('.json')+f'_{input_sequence_length}_{output_sequence_length}_outPlayer.pt', 'wb') as fplayer:
torch.save(torch.from_numpy(input_seq), fin)
torch.save(torch.from_numpy(output_seq), fout)
torch.save(torch.from_numpy(player_seq), fplayer)
if __name__ == '__main__':
input_sequence_length, output_sequence_length = 25, 15
for file in glob('./test/*.json'):
file2Arr(file, input_sequence_length, output_sequence_length)