-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathConvGRUCell.py
41 lines (30 loc) · 1.53 KB
/
ConvGRUCell.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
###################################################
# Nicolo Savioli, 2017 -- Conv-GRU pytorch v 1.0 #
###################################################
import torch
from torch import nn
import torch.nn.functional as f
from torch.autograd import Variable
class ConvGRUCell(nn.Module):
def __init__(self,input_size,hidden_size,kernel_size):
super(ConvGRUCell,self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.kernel_size = kernel_size
self.dropout = nn.Dropout(p=0.5)
self.ConvGates = nn.Conv2d(self.input_size + self.hidden_size,2 * self.hidden_size,self.kernel_size,padding=self.kernel_size//2)
self.Conv_ct = nn.Conv2d(self.input_size + self.hidden_size,self.hidden_size,self.kernel_size,padding=self.kernel_size//2)
dtype = torch.FloatTensor
def forward(self,input,hidden):
if hidden is None:
size_h = [input.data.size()[0],self.hidden_size] + list(input.data.size()[2:])
hidden = Variable(torch.zeros(size_h)).cuda()
c1 = self.ConvGates(torch.cat((input,hidden),1))
(rt,ut) = c1.chunk(2, 1)
reset_gate = self.dropout(f.sigmoid(rt))
update_gate = self.dropout(f.sigmoid(ut))
gated_hidden = torch.mul(reset_gate,hidden)
p1 = self.Conv_ct(torch.cat((input,gated_hidden),1))
ct = f.tanh(p1)
next_h = torch.mul(update_gate,hidden) + (1-update_gate)*ct
return next_h