Skip to content
This repository has been archived by the owner on Jul 1, 2023. It is now read-only.

Improve RNN cell abstraction. #86

Merged
merged 1 commit into from
Apr 13, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions Sources/DeepLearning/Layer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1318,26 +1318,44 @@ public struct Reshape<Scalar: TensorFlowFloatingPoint>: Layer {
}

/// An input to a recurrent neural network.
public struct RNNInput<TimeStepInput: Differentiable, State: Differentiable>: Differentiable {
public struct RNNCellInput<Input: Differentiable, State: Differentiable>: Differentiable {
/// The input at the current time step.
public var timeStepInput: TimeStepInput
public var input: Input
/// The previous state.
public var previousState: State
public var state: State

@differentiable
public init(timeStepInput: TimeStepInput, previousState: State) {
self.timeStepInput = timeStepInput
self.previousState = previousState
public init(input: Input, state: State) {
self.input = input
self.state = state
}
}

/// An output to a recurrent neural network.
public struct RNNCellOutput<Output: Differentiable, State: Differentiable>: Differentiable {
/// The output at the current time step.
public var output: Output
/// The current state.
public var state: State

@differentiable
public init(output: Output, state: State) {
self.output = output
self.state = state
}
}

/// A recurrent neural network cell.
public protocol RNNCell: Layer where Input == RNNInput<TimeStepInput, State> {
public protocol RNNCell: Layer where Input == RNNCellInput<TimeStepInput, State>,
Output == RNNCellOutput<TimeStepOutput, State> {
/// The input at a time step.
associatedtype TimeStepInput: Differentiable
/// The output at a time step.
associatedtype TimeStepOutput: Differentiable
/// The state that may be preserved across time steps.
typealias State = Output
associatedtype State: Differentiable
/// The zero state.
@differentiable
var zeroState: State { get }
}

Expand All @@ -1352,8 +1370,11 @@ public extension RNNCell {
/// phase.
/// - Returns: The output.
@differentiable
func applied(to timeStepInput: TimeStepInput, previous: State, in context: Context) -> State {
return applied(to: Input(timeStepInput: timeStepInput, previousState: previous),
in: context)
func applied(
to input: TimeStepInput,
state: State,
in context: Context
) -> RNNCellOutput<TimeStepOutput, State> {
return applied(to: RNNCellInput(input: input, state: state), in: context)
}
}