-
Notifications
You must be signed in to change notification settings - Fork 0
/
GRU.py
42 lines (35 loc) · 1.25 KB
/
GRU.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
class GRUModel(nn.Module):
def __init__(self, input_size=6, hidden_size=64, output_size=10, num_layers=2, dropout=0.0):
super().__init__()
self.rnn = nn.GRU(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
)
self.fc_out = nn.Linear(hidden_size, output_size)
self.norm1 = nn.BatchNorm1d(input_size)
self.output_size = output_size
if output_size != 1:
self.norm2 = nn.BatchNorm1d(output_size)
self.name = 'GRU'
def forward(self, x):
# # x: [N, F*T]
# x = x.reshape(len(x), self.input_size, -1) # [N, F, T]
# x = x.permute(0, 2, 1) # [N, T, F]
x = x.permute(0, 2, 1)
out = self.norm1(x)
out = out.permute(0, 2, 1)
out, _ = self.rnn(out)
out = out[:, -1, :]
out = self.fc_out(out).squeeze()
if self.output_size != 1:
out = self.norm2(out)
return out