Skip to content

Commit

Permalink
python gru cell
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Nov 29, 2023
1 parent 8f60daf commit d144f14
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 182 deletions.
161 changes: 0 additions & 161 deletions test.py

This file was deleted.

67 changes: 46 additions & 21 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,20 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import warnings
from typing import Optional, Tuple

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn.modules.rnn import RNNCellBase
from tensordict import TensorDictBase, unravel_key_list

from tensordict.nn import TensorDictModuleBase as ModuleBase

from tensordict.tensordict import NO_DEFAULT
from tensordict.utils import prod

from torch import nn
from torch import nn, Tensor
from torch.nn.modules.rnn import RNNCellBase

from torchrl.data import UnboundedContinuousTensorSpec
from torchrl.objectives.value.functional import (
Expand All @@ -28,8 +26,7 @@


class PythonLSTMCell(RNNCellBase):
r"""A long short-term memory (LSTM) cell that performs the same operation as nn.LSTMCell but is
fully coded in Python.
r"""A long short-term memory (LSTM) cell that performs the same operation as nn.LSTMCell but is fully coded in Python.
.. math::
Expand Down Expand Up @@ -88,23 +85,38 @@ class PythonLSTMCell(RNNCellBase):
>>> output = torch.stack(output, dim=0)
"""

def __init__(self, input_size: int, hidden_size: int, bias: bool = True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)

def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
def forward(
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[Tensor, Tensor]:
if input.dim() not in (1, 2):
raise ValueError(f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead")
raise ValueError(
f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
)
if hx is not None:
for idx, value in enumerate(hx):
if value.dim() not in (1, 2):
raise ValueError(f"LSTMCell: Expected hx[{idx}] to be 1D or 2D, got {value.dim()}D instead")
raise ValueError(
f"LSTMCell: Expected hx[{idx}] to be 1D or 2D, got {value.dim()}D instead"
)
is_batched = input.dim() == 2
if not is_batched:
input = input.unsqueeze(0)

if hx is None:
zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
zeros = torch.zeros(
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
)
hx = (zeros, zeros)
else:
hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx
Expand All @@ -118,7 +130,9 @@ def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) ->
def lstm_cell(self, x, hx, cx):
x = x.view(-1, x.size(1))

gates = F.linear(x, self.weight_ih, self.bias_ih) + F.linear(hx, self.weight_hh, self.bias_hh)
gates = F.linear(x, self.weight_ih, self.bias_ih) + F.linear(
hx, self.weight_hh, self.bias_hh
)

i_gate, f_gate, g_gate, o_gate = gates.chunk(4, 1)

Expand Down Expand Up @@ -525,8 +539,7 @@ def _lstm(


class PythonGRUCell(RNNCellBase):
r"""A gated recurrent unit (GRU) cell that performs the same operation as nn.LSTMCell but is
fully coded in Python.
r"""A gated recurrent unit (GRU) cell that performs the same operation as nn.LSTMCell but is fully coded in Python.
.. math::
Expand Down Expand Up @@ -587,22 +600,34 @@ class PythonGRUCell(RNNCellBase):
... output.append(hx)
"""

def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)

def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
if input.dim() not in (1, 2):
raise ValueError(f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead")
raise ValueError(
f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
)
if hx is not None and hx.dim() not in (1, 2):
raise ValueError(f"GRUCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead")
raise ValueError(
f"GRUCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead"
)
is_batched = input.dim() == 2
if not is_batched:
input = input.unsqueeze(0)

if hx is None:
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
hx = torch.zeros(
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
)
else:
hx = hx.unsqueeze(0) if not is_batched else hx

Expand Down

0 comments on commit d144f14

Please sign in to comment.