Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make RNN op generic #554

Merged
merged 1 commit into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions api/src/main/java/ai/djl/modality/nlp/Decoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,6 @@ public Decoder(byte version, Block block) {
this.block = addChildBlock("Block", block);
}

/**
* Sets the state of the encoder as the initial state of the decoder.
*
* @param encoderStates the states of the encoder
*/
public abstract void initState(NDList encoderStates);

/** {@inheritDoc} */
@Override
protected NDList forwardInternal(
Expand Down
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ public NDList forward(
NDList labels,
PairList<String, Object> params) {
NDList encoderOutputs = encoder.forward(parameterStore, data, true, params);
decoder.initState(encoder.getStates(encoderOutputs));
// add hidden states & cell states to decoder inputs
labels.addAll(encoder.getStates(encoderOutputs));
return decoder.forward(parameterStore, labels, true, params);
}

Expand Down
130 changes: 77 additions & 53 deletions api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.recurrent.RNN;
import ai.djl.util.PairList;
import java.util.List;

Expand Down Expand Up @@ -313,67 +314,90 @@ NDList batchNorm(
boolean training);

/**
* Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are implemented,
* with both multi-layer and bidirectional support.
*
* @param inputs the inputs to the recurrent operation. Must include input data, parameter
* vector of all trainable parameters concatenated, initial hidden state of the RNN. For
* LSTM, it must include initial cell state. If useSequenceLength is true, it must also
* include vector of valid sequence lengths for each element in the batch
* @param mode the type of RNN to compute
* @param stateSize the sizes of the state for each layer
* @param dropRate the drop rate of the dropout on the outputs of each RNN layer, except the
* last layer
* @param numStackedLayers the number of stacked layers
* @param useSequenceLength if set to true, this layer takes in an extra input parameter
* sequence_length to specify variable length sequence.
* @param useBidirectional whether to use bidirectional recurrent layers
* @param stateOutputs whether to include the state in the output
* @param additional additional parameters
* Applies RNN operation to input data.
*
* @param input the inputs to the recurrent operation.
* @param state the hidden state to the recurrent operation.
* @param params all params (weights and biases) for the recurrent operation
* @param hasBiases If false, then the recurrent operation does not use bias weights b_ih and
* b_hh
* @param numLayers the number of recurrent layers.
* @param activation the activation function to use
* @param dropRate If non-zero, introduces a Dropout layer on the outputs of each RNN layer
* except the last layer, with dropout probability equal to dropout
* @param training apply dropout if is true
* @param bidirectional If true, becomes a bidirectional RNN
* @param batchFirst If true, then the input and output NDArray are provided as (batch, seq,
* feature)
* @return the output of the operation
*/
NDList rnn(
NDList inputs,
String mode,
long stateSize,
float dropRate,
int numStackedLayers,
boolean useSequenceLength,
boolean useBidirectional,
boolean stateOutputs,
PairList<String, Object> additional);
NDArray input,
NDArray state,
NDList params,
boolean hasBiases,
int numLayers,
RNN.Activation activation,
double dropRate,
boolean training,
boolean bidirectional,
boolean batchFirst);
stu1130 marked this conversation as resolved.
Show resolved Hide resolved

/**
* Applies LSTM recurrent layers to input data.
*
* @param inputs the inputs to the recurrent operation. Must include input data, parameter
* vector of all trainable parameters concatenated, initial hidden state of the RNN and
* initial cell state. If useSequenceLength is true, it must also include vector of valid
* sequence lengths for each element in the batch
* @param stateSize the sizes of the state for each layer
* @param dropRate the drop rate of the dropout on the outputs of each RNN layer, except the
* last layer
* @param numStackedLayers the number of stacked layers
* @param useSequenceLength if set to true, this layer takes in an extra input parameter
* sequence_length to specify variable length sequence.
* @param useBidirectional whether to use bidirectional recurrent layers
* @param stateOutputs whether to include the state in the output
* @param lstmStateClipMin the minimum clip value of LSTM states
* @param lstmStateClipMax the maximum clip value of LSTM states
* @param additional additional parameters
* Applies GRU operation to input data.
*
* @param input the inputs to the GRU operation.
* @param state the hidden state to the GRU operation.
* @param params all params (weights and biases) for the GRU operation
* @param hasBiases If false, then the recurrent operation does not use bias weights b_ih and
* b_hh
* @param numLayers the number of recurrent layers.
* @param dropRate If non-zero, introduces a Dropout layer on the outputs of each GRU layer
* except the last layer, with dropout probability equal to dropout
* @param training apply dropout if is true
* @param bidirectional If true, becomes a bidirectional GRU
* @param batchFirst If true, then the input and output NDArray are provided as (batch, seq,
* feature)
* @return the output of the operation
*/
NDList gru(
NDArray input,
NDArray state,
NDList params,
boolean hasBiases,
int numLayers,
double dropRate,
boolean training,
boolean bidirectional,
boolean batchFirst);

/**
* Applies LSTM operation to input data.
*
* @param input the inputs to the LSTM operation.
* @param states the hidden state and cell state to the LSTM operation.
* @param params all params (weights and biases) for the LSTM operation
* @param hasBiases If false, then the recurrent operation does not use bias weights b_ih and
* b_hh
* @param numLayers the number of recurrent layers.
* @param dropRate If non-zero, introduces a Dropout layer on the outputs of each LSTM layer
* except the last layer, with dropout probability equal to dropout
* @param training apply dropout if is true
* @param bidirectional If true, becomes a bidirectional LSTM
* @param batchFirst If true, then the input and output NDArray are provided as (batch, seq,
* feature)
* @return the output of the operation
*/
NDList lstm(
NDList inputs,
long stateSize,
float dropRate,
int numStackedLayers,
boolean useSequenceLength,
boolean useBidirectional,
boolean stateOutputs,
double lstmStateClipMin,
double lstmStateClipMax,
PairList<String, Object> additional);
NDArray input,
NDList states,
NDList params,
boolean hasBiases,
int numLayers,
double dropRate,
boolean training,
boolean bidirectional,
boolean batchFirst);

////////////////////////////////////////
// Image and CV
Expand Down
14 changes: 14 additions & 0 deletions api/src/main/java/ai/djl/ndarray/types/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,20 @@ public long head() {
return shape[0];
}

/**
* Returns the tail index of the shape.
*
* @return the tail index of the shape
* @throws IndexOutOfBoundsException Thrown if the shape is empty
*/
public long tail() {
// scalar case
if (shape.length == 0) {
throw new IndexOutOfBoundsException("can't get value from scalar shape.");
}
return shape[shape.length - 1];
}

/**
* Returns the number of trailing ones in the array shape.
*
Expand Down
55 changes: 52 additions & 3 deletions api/src/main/java/ai/djl/nn/recurrent/GRU.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@
*/
package ai.djl.nn.recurrent;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import ai.djl.util.Preconditions;

/**
Expand All @@ -33,10 +41,52 @@ public class GRU extends RecurrentBlock {

GRU(Builder builder) {
super(builder);
mode = "gru";
gates = 3;
}

/** {@inheritDoc} */
@Override
protected NDList forwardInternal(
ParameterStore parameterStore,
NDList inputs,
boolean training,
PairList<String, Object> params) {
NDArrayEx ex = inputs.head().getNDArrayInternal();
Device device = inputs.head().getDevice();
NDList gruParams = new NDList();
for (Parameter parameter : parameters.values()) {
gruParams.add(parameterStore.getValue(parameter, device, training));
}

NDArray input = inputs.head();
if (inputs.size() == 1) {
int batchIndex = batchFirst ? 0 : 1;
inputs.add(
input.getManager()
.zeros(
new Shape(
(long) numLayers * getNumDirections(),
input.size(batchIndex),
stateSize)));
}
NDList outputs =
ex.gru(
input,
inputs.get(1),
gruParams,
hasBiases,
numLayers,
dropRate,
training,
bidirectional,
batchFirst);
if (returnState) {
return outputs;
}
outputs.stream().skip(1).forEach(NDArray::close);
return new NDList(outputs.get(0));
}

/**
* Creates a builder to build a {@link GRU}.
*
Expand All @@ -62,8 +112,7 @@ protected Builder self() {
*/
public GRU build() {
Preconditions.checkArgument(
stateSize > 0 && numStackedLayers > 0,
"Must set stateSize and numStackedLayers");
stateSize > 0 && numLayers > 0, "Must set stateSize and numStackedLayers");
return new GRU(this);
}
}
Expand Down
Loading