Skip to content
This repository has been archived by the owner on Jul 1, 2023. It is now read-only.

Commit

Permalink
Improve RNN cell abstraction. (#86)
Browse files Browse the repository at this point in the history
The existing `RNNCell` protocol and `RNNInput` type are not flexible in that each time step has the take both the previous output and the hidden state. This PR lifts that restriction.

* Rename `RNNInput` to `RNNCellInput` so that it's more accurate.
* Add a new `RNNCellOutput` generic structure type that stores an output and a state.
* Add associated type `State` in `RNNCell`.
* Make the `Output` type of `RNNCell` be `RNNOutput<TimeStepOutput, State>`.

Thanks @superbobry for the suggestions.
  • Loading branch information
rxwei authored Apr 13, 2019
1 parent 1944e47 commit 861d1f5
Showing 1 changed file with 32 additions and 11 deletions.
43 changes: 32 additions & 11 deletions Sources/DeepLearning/Layer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1318,26 +1318,44 @@ public struct Reshape<Scalar: TensorFlowFloatingPoint>: Layer {
}

/// An input to a recurrent neural network.
public struct RNNInput<TimeStepInput: Differentiable, State: Differentiable>: Differentiable {
public struct RNNCellInput<Input: Differentiable, State: Differentiable>: Differentiable {
/// The input at the current time step.
public var timeStepInput: TimeStepInput
public var input: Input
/// The previous state.
public var previousState: State
public var state: State

@differentiable
public init(timeStepInput: TimeStepInput, previousState: State) {
self.timeStepInput = timeStepInput
self.previousState = previousState
public init(input: Input, state: State) {
self.input = input
self.state = state
}
}

/// An output to a recurrent neural network.
public struct RNNCellOutput<Output: Differentiable, State: Differentiable>: Differentiable {
/// The output at the current time step.
public var output: Output
/// The current state.
public var state: State

@differentiable
public init(output: Output, state: State) {
self.output = output
self.state = state
}
}

/// A recurrent neural network cell.
public protocol RNNCell: Layer where Input == RNNInput<TimeStepInput, State> {
public protocol RNNCell: Layer where Input == RNNCellInput<TimeStepInput, State>,
Output == RNNCellOutput<TimeStepOutput, State> {
/// The input at a time step.
associatedtype TimeStepInput: Differentiable
/// The output at a time step.
associatedtype TimeStepOutput: Differentiable
/// The state that may be preserved across time steps.
typealias State = Output
associatedtype State: Differentiable
/// The zero state.
@differentiable
var zeroState: State { get }
}

Expand All @@ -1352,8 +1370,11 @@ public extension RNNCell {
/// phase.
/// - Returns: The output.
@differentiable
func applied(to timeStepInput: TimeStepInput, previous: State, in context: Context) -> State {
return applied(to: Input(timeStepInput: timeStepInput, previousState: previous),
in: context)
func applied(
to input: TimeStepInput,
state: State,
in context: Context
) -> RNNCellOutput<TimeStepOutput, State> {
return applied(to: RNNCellInput(input: input, state: state), in: context)
}
}

0 comments on commit 861d1f5

Please sign in to comment.