Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] RNN - Implement fast Gated Recurrent Unit (GRU) #231

Merged
merged 5 commits into from
May 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/nn_primitives/nn_primitives.nim
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ import ./nnp_activation,
./nnp_softmax_cross_entropy,
./nnp_maxpooling,
./nnp_softmax,
./nnp_numerical_gradient
./nnp_numerical_gradient,
./recurrent/nnp_gru

export nnp_activation,
nnp_convolution,
Expand All @@ -28,12 +29,12 @@ export nnp_activation,
nnp_softmax_cross_entropy,
nnp_maxpooling,
nnp_softmax,
nnp_numerical_gradient
nnp_numerical_gradient,
nnp_gru

import private/p_nnp_types
export Size2D


when defined(cudnn) or defined(nimdoc) or defined(nimsuggest):
import ./backend/cudnn,
./nnp_conv2d_cudnn
Expand Down
2 changes: 1 addition & 1 deletion src/nn_primitives/nnp_activation.nim
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ proc relu_backward*[T](gradient: Tensor[T], cached_tensor: Tensor[T]): Tensor[T]

proc tanh_backward*[T](gradient: Tensor[T], cached_tensor: Tensor[T]): Tensor[T]{.noInit.}=
result = map2_inline(cached_tensor, gradient):
y - y * (x * x)
y * (1 - x * x)

# ####################################################################################################
# Documentation
Expand Down
1 change: 0 additions & 1 deletion src/nn_primitives/nnp_linear.nim
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ proc linear*[T](input, weight: Tensor[T], output: var Tensor[T]) {.inline.} =
proc linear_backward*[T](
input,
weight,
bias,
gradOutput: Tensor[T],
gradInput,
gradWeight,
Expand Down
174 changes: 174 additions & 0 deletions src/nn_primitives/recurrent/nnp_gru.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (c) 2018 the Arraymancer contributors
# Distributed under the Apache v2 License (license terms are at http://www.apache.org/licenses/LICENSE-2.0).
# This file may not be copied, modified, or distributed except according to those terms.

import
../../tensor/tensor,
../private/p_activation, ../nnp_linear,
../nnp_activation

# For compatibility with CuDNN and allow loading CPU/Cuda weights interchangeably,
# we use the following equations,
#
# - h is hidden state at t-1, h' at t
# - input == x, hidden == h
# - n = h~ (the candidate hidden state)
# - r is the reset gate
# - z is the update gate
# - h', the final output, is a linear interpolation
#
# r = σ(Wr * x + bWr + Ur * h + bUr)
# z = σ(Wz * x + bWz + Uz * h + bUz)
# n = tanh(W * x + bW + r .* (U * h + bU ))
# h' = (1 - z) .* n + z .* h
#
# Those differs from the original paper for n and h'
# - The pointwise multiplication by r is after the matrix multiplication
# - The linear interpolation has the terms switched

# TODO: after the 2 "linear" in forward prop and before the linear
# in backprop, everything is elementwise
# we could use a giant loop-fusion to avoid intermediate tensors
#
# Note that the CPU prefetcher might not work as well, because
# between the use of U3h.data[i] and U3h.data[i+1]
# there will be a lot of intermediate computation.
#
# Also see here for counterarg: https://software.intel.com/en-us/forums/intel-moderncode-for-parallel-architectures/topic/635075
# Intel CPUs prefetcher can maintain 32 streams

proc gru_cell_inference*[T: SomeReal](
input, hidden,
W3, U3,
bW3, bU3: Tensor[T],
next_hidden: var Tensor[T]) =
## Input:
## - input tensor of shape [batch_size, features]
## - hidden state of shape [batch_size, hidden_size]
## - weight of input W3 [3 * hidden_size, features]
## - weight of hidden U3 [3 * hidden_size, hidden_size]
## - biases of input and hidden state [1, 3 * hidden_size]
##
## Output:
## - y == h'(t): The next hidden state of the GRU Cell.
## (GRU output and next hidden state are the same)
##
## This is an optimized function when backpropagation is not needed.

let
H = hidden.shape[1]
# Slices
sr = (0 ..< H)|1
sz = (H ..< 2*H)|1
srz = (0 ..< 2*H)|1
s = (2*H ..< 3*H)|1


# Step 1 - U*h and W*x - Resulting shape [batch_size, 3*H]
var W3x, U3h: Tensor[T] # TODO, pass those as parameter to allow buffer reuse

linear(input, W3, bW3, W3x)
linear(hidden, U3, bU3, U3h)

# Step 2 - Computing reset (r) and update (z) gate
var W2ru = W3x[_, srz] # shape [batch_size, 2*H] - we reuse the previous buffer
apply2_inline(W2ru, U3h[_, srz]):
sigmoid(x + y)

# Step 3 - Computing candidate hidden state ñ
var n = W3x[_, s] # shape [batch_size, H] - we reuse the previous buffer
apply3_inline(n, W2ru[_, sr], U3h[_, s]):
tanh(x + y * z)

# Step 4 - Compute the next hidden state
next_hidden = map3_inline(W3x[_, sz], n, hidden):
(1 - x) * y + x * z

proc gru_cell_forward*[T: SomeReal](
input, hidden,
W3, U3,
bW3, bU3: Tensor[T],
r, z, n, Uh,
next_hidden: var Tensor[T]
) =
## Input:
## - input tensor of shape [batch_size, features]
## - hidden state of shape [batch_size, hidden_size]
## - weight of input W3 [3 * hidden_size, features]
## - weight of hidden U3 [3 * hidden_size, hidden_size]
## - biases of input and hidden state [1, 3 * hidden_size]
##
## Output:
## - r, z, n, Uh: intermediate tensors saved for backpropagation.
## of size [batch_size, hidden_size]
## - y == h'(t): The next hidden state of the GRU Cell.
## (GRU output and next hidden state are the same)
##

let
H = hidden.shape[1]
# Slices
sr = (0 ..< H)|1
sz = (H ..< 2*H)|1
s = (2*H ..< 3*H)|1

# Step 1 - U*h and W*x - Resulting shape [batch_size, 3*H]
var W3x, U3h: Tensor[T] # TODO, pass those as parameter to allow buffer reuse

linear(input, W3, bW3, W3x)
linear(hidden, U3, bU3, U3h)

# # Saving for backprop
Uh = U3h[_, s].clone()

# Step 2 - Computing reset (r) and update (z) gate
apply3_inline(r, W3x[_, sr], U3h[_, sr]):
sigmoid(y + z)

apply3_inline(z, W3x[_, sz], U3h[_, sz]):
sigmoid(y + z)

# Step 3 - Computing candidate hidden state ñ
n = map3_inline(W3x[_, s], r, U3h[_, s]):
tanh(x + y * z)

# Step 4 - Compute the next hidden state
next_hidden = map3_inline(z, n, hidden):
(1 - x) * y + x * z

proc gru_cell_backward*[T: SomeReal](
dx, dh, dW3, dU3, # input and weights gradients
dbW3, dbU3: var Tensor[T], # bias gradient
dnext: Tensor[T], # gradient flowing back from the next hidden state
x, h, W3, U3: Tensor[T], # input parameters saved from forward
r, z, n, Uh: Tensor[T] # Intermediate tensors saved from forward
) =

# Backprop of step 4 - z part
let dz = (h - n) .* dnext
let dn = (1.0 .- z) .* dnext

# Backprop of step 3.
let dWx = tanh_backward(dn, n)
let dr = Uh .* dWx
let dUh = r .* dWx

# Backprop of step 2 - update gate z
let dWzx = sigmoid_backward(dz, z)
let dUzh = dWzx

# Backprop of step 2 - reset gate r
let dWrx = sigmoid_backward(dr, r)
let dUrh = dWrx

# Concat
let dW3x = concat(dWrx, dWzx, dWx, axis = 1)
let dU3h = concat(dUrh, dUzh, dUh, axis = 1)

# Backprop of step 1
linear_backward(x, W3, dW3x, dx, dW3, dbW3)
linear_backward(h, U3, dU3h, dh, dU3, dbU3)

# Backprop of step 4 - h part
apply3_inline(dh, dnext, z):
x + y * z
53 changes: 53 additions & 0 deletions tests/nn_primitives/references/test_nnp_gru.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import torch.nn as nn

batch_size = 3
features = 4
hidden_size = 2 # weights and bias have 3x2 = 6 size

x = torch.tensor([[ 0.1, 0.2, 0.3, 0.4],
[-0.1, -0.2, -0.3, -0.4],
[ 0.5, 0.6, 0.7, 0.8]])

hidden = torch.tensor([
[ -1.0, -1.0],
[ -1.0, -1.0],
[ -1.0, -1.0]])

w_input = torch.tensor([
[0.9, 0.8, 0.7, 0.6],
[0.8, 0.7, 0.6, 0.5],
[0.7, 0.6, 0.5, 0.4],
[0.6, 0.5, 0.4, 0.3],
[0.5, 0.4, 0.3, 0.2],
[0.4, 0.3, 0.2, 0.1]])

w_recur = torch.tensor([
[-0.3, -0.1],
[-0.2, 0.0],
[-0.3, -0.1],
[-0.2, 0.0],
[-0.3, -0.1],
[-0.2, 0.0],
])

b_input = torch.tensor([
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
])

b_recur = torch.tensor([
[-0.1, -0.2, -0.3, -0.4, -0.5, -0.6],
])

test_gru = nn.GRUCell(4, 2)

test_gru.weight_ih.data = w_input
test_gru.weight_hh.data = w_recur
test_gru.bias_ih.data = b_input
test_gru.bias_hh.data = b_recur

print(test_gru(x, hidden))

# tensor([[-0.5317, -0.4753],
# [-0.3930, -0.3210],
# [-0.7325, -0.6430]])
Loading