diff --git a/api/src/main/java/ai/djl/modality/nlp/Decoder.java b/api/src/main/java/ai/djl/modality/nlp/Decoder.java index df887a5d316..72a6ead095f 100644 --- a/api/src/main/java/ai/djl/modality/nlp/Decoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/Decoder.java @@ -46,13 +46,6 @@ public Decoder(byte version, Block block) { this.block = addChildBlock("Block", block); } - /** - * Sets the state of the encoder as the initial state of the decoder. - * - * @param encoderStates the states of the encoder - */ - public abstract void initState(NDList encoderStates); - /** {@inheritDoc} */ @Override protected NDList forwardInternal( diff --git a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java index f5b47a23b41..8866b9de4cd 100644 --- a/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java +++ b/api/src/main/java/ai/djl/modality/nlp/EncoderDecoder.java @@ -89,7 +89,8 @@ public NDList forward( NDList labels, PairList params) { NDList encoderOutputs = encoder.forward(parameterStore, data, true, params); - decoder.initState(encoder.getStates(encoderOutputs)); + // add hidden states & cell states to decoder inputs + labels.addAll(encoder.getStates(encoderOutputs)); return decoder.forward(parameterStore, labels, true, params); } diff --git a/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java b/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java index ba8cd6eaa80..4d7823adc47 100644 --- a/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java +++ b/api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Activation; +import ai.djl.nn.recurrent.RNN; import ai.djl.util.PairList; import java.util.List; @@ -313,67 +314,90 @@ NDList batchNorm( boolean training); /** - * Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are implemented, - * with both multi-layer and bidirectional support. - * - * @param inputs the inputs to the recurrent operation. Must include input data, parameter - * vector of all trainable parameters concatenated, initial hidden state of the RNN. For - * LSTM, it must include initial cell state. If useSequenceLength is true, it must also - * include vector of valid sequence lengths for each element in the batch - * @param mode the type of RNN to compute - * @param stateSize the sizes of the state for each layer - * @param dropRate the drop rate of the dropout on the outputs of each RNN layer, except the - * last layer - * @param numStackedLayers the number of stacked layers - * @param useSequenceLength if set to true, this layer takes in an extra input parameter - * sequence_length to specify variable length sequence. - * @param useBidirectional whether to use bidirectional recurrent layers - * @param stateOutputs whether to include the state in the output - * @param additional additional parameters + * Applies RNN operation to input data. + * + * @param input the inputs to the recurrent operation. + * @param state the hidden state to the recurrent operation. + * @param params all params (weights and biases) for the recurrent operation + * @param hasBiases If false, then the recurrent operation does not use bias weights b_ih and + * b_hh + * @param numLayers the number of recurrent layers. + * @param activation the activation function to use + * @param dropRate If non-zero, introduces a Dropout layer on the outputs of each RNN layer + * except the last layer, with dropout probability equal to dropout + * @param training apply dropout if is true + * @param bidirectional If true, becomes a bidirectional RNN + * @param batchFirst If true, then the input and output NDArray are provided as (batch, seq, + * feature) * @return the output of the operation */ NDList rnn( - NDList inputs, - String mode, - long stateSize, - float dropRate, - int numStackedLayers, - boolean useSequenceLength, - boolean useBidirectional, - boolean stateOutputs, - PairList additional); + NDArray input, + NDArray state, + NDList params, + boolean hasBiases, + int numLayers, + RNN.Activation activation, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst); /** - * Applies LSTM recurrent layers to input data. - * - * @param inputs the inputs to the recurrent operation. Must include input data, parameter - * vector of all trainable parameters concatenated, initial hidden state of the RNN and - * initial cell state. If useSequenceLength is true, it must also include vector of valid - * sequence lengths for each element in the batch - * @param stateSize the sizes of the state for each layer - * @param dropRate the drop rate of the dropout on the outputs of each RNN layer, except the - * last layer - * @param numStackedLayers the number of stacked layers - * @param useSequenceLength if set to true, this layer takes in an extra input parameter - * sequence_length to specify variable length sequence. - * @param useBidirectional whether to use bidirectional recurrent layers - * @param stateOutputs whether to include the state in the output - * @param lstmStateClipMin the minimum clip value of LSTM states - * @param lstmStateClipMax the maximum clip value of LSTM states - * @param additional additional parameters + * Applies GRU operation to input data. + * + * @param input the inputs to the GRU operation. + * @param state the hidden state to the GRU operation. + * @param params all params (weights and biases) for the GRU operation + * @param hasBiases If false, then the recurrent operation does not use bias weights b_ih and + * b_hh + * @param numLayers the number of recurrent layers. + * @param dropRate If non-zero, introduces a Dropout layer on the outputs of each GRU layer + * except the last layer, with dropout probability equal to dropout + * @param training apply dropout if is true + * @param bidirectional If true, becomes a bidirectional GRU + * @param batchFirst If true, then the input and output NDArray are provided as (batch, seq, + * feature) + * @return the output of the operation + */ + NDList gru( + NDArray input, + NDArray state, + NDList params, + boolean hasBiases, + int numLayers, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst); + + /** + * Applies LSTM operation to input data. + * + * @param input the inputs to the LSTM operation. + * @param states the hidden state and cell state to the LSTM operation. + * @param params all params (weights and biases) for the LSTM operation + * @param hasBiases If false, then the recurrent operation does not use bias weights b_ih and + * b_hh + * @param numLayers the number of recurrent layers. + * @param dropRate If non-zero, introduces a Dropout layer on the outputs of each LSTM layer + * except the last layer, with dropout probability equal to dropout + * @param training apply dropout if is true + * @param bidirectional If true, becomes a bidirectional LSTM + * @param batchFirst If true, then the input and output NDArray are provided as (batch, seq, + * feature) * @return the output of the operation */ NDList lstm( - NDList inputs, - long stateSize, - float dropRate, - int numStackedLayers, - boolean useSequenceLength, - boolean useBidirectional, - boolean stateOutputs, - double lstmStateClipMin, - double lstmStateClipMax, - PairList additional); + NDArray input, + NDList states, + NDList params, + boolean hasBiases, + int numLayers, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst); //////////////////////////////////////// // Image and CV diff --git a/api/src/main/java/ai/djl/ndarray/types/Shape.java b/api/src/main/java/ai/djl/ndarray/types/Shape.java index d75a534dedd..502517f1384 100644 --- a/api/src/main/java/ai/djl/ndarray/types/Shape.java +++ b/api/src/main/java/ai/djl/ndarray/types/Shape.java @@ -303,6 +303,20 @@ public long head() { return shape[0]; } + /** + * Returns the tail index of the shape. + * + * @return the tail index of the shape + * @throws IndexOutOfBoundsException Thrown if the shape is empty + */ + public long tail() { + // scalar case + if (shape.length == 0) { + throw new IndexOutOfBoundsException("can't get value from scalar shape."); + } + return shape[shape.length - 1]; + } + /** * Returns the number of trailing ones in the array shape. * diff --git a/api/src/main/java/ai/djl/nn/recurrent/GRU.java b/api/src/main/java/ai/djl/nn/recurrent/GRU.java index 6403a31e505..28291362539 100644 --- a/api/src/main/java/ai/djl/nn/recurrent/GRU.java +++ b/api/src/main/java/ai/djl/nn/recurrent/GRU.java @@ -12,7 +12,15 @@ */ package ai.djl.nn.recurrent; +import ai.djl.Device; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.internal.NDArrayEx; +import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; +import ai.djl.nn.Parameter; +import ai.djl.training.ParameterStore; +import ai.djl.util.PairList; import ai.djl.util.Preconditions; /** @@ -33,10 +41,52 @@ public class GRU extends RecurrentBlock { GRU(Builder builder) { super(builder); - mode = "gru"; gates = 3; } + /** {@inheritDoc} */ + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + NDArrayEx ex = inputs.head().getNDArrayInternal(); + Device device = inputs.head().getDevice(); + NDList gruParams = new NDList(); + for (Parameter parameter : parameters.values()) { + gruParams.add(parameterStore.getValue(parameter, device, training)); + } + + NDArray input = inputs.head(); + if (inputs.size() == 1) { + int batchIndex = batchFirst ? 0 : 1; + inputs.add( + input.getManager() + .zeros( + new Shape( + (long) numLayers * getNumDirections(), + input.size(batchIndex), + stateSize))); + } + NDList outputs = + ex.gru( + input, + inputs.get(1), + gruParams, + hasBiases, + numLayers, + dropRate, + training, + bidirectional, + batchFirst); + if (returnState) { + return outputs; + } + outputs.stream().skip(1).forEach(NDArray::close); + return new NDList(outputs.get(0)); + } + /** * Creates a builder to build a {@link GRU}. * @@ -62,8 +112,7 @@ protected Builder self() { */ public GRU build() { Preconditions.checkArgument( - stateSize > 0 && numStackedLayers > 0, - "Must set stateSize and numStackedLayers"); + stateSize > 0 && numLayers > 0, "Must set stateSize and numStackedLayers"); return new GRU(this); } } diff --git a/api/src/main/java/ai/djl/nn/recurrent/LSTM.java b/api/src/main/java/ai/djl/nn/recurrent/LSTM.java index 900f22dbfce..71fa793cacc 100644 --- a/api/src/main/java/ai/djl/nn/recurrent/LSTM.java +++ b/api/src/main/java/ai/djl/nn/recurrent/LSTM.java @@ -14,11 +14,8 @@ import ai.djl.Device; import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDArrays; import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; import ai.djl.ndarray.internal.NDArrayEx; -import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; import ai.djl.nn.Parameter; @@ -43,11 +40,6 @@ */ public class LSTM extends RecurrentBlock { - private boolean clipLstmState; - private double lstmStateClipMin; - private double lstmStateClipMax; - private NDArray beginStateCell; - /** * Creates an LSTM block. * @@ -55,11 +47,7 @@ public class LSTM extends RecurrentBlock { */ LSTM(Builder builder) { super(builder); - mode = "lstm"; gates = 4; - clipLstmState = builder.clipLstmState; - lstmStateClipMin = builder.lstmStateClipMin; - lstmStateClipMax = builder.lstmStateClipMax; } /** {@inheritDoc} */ @@ -69,93 +57,42 @@ protected NDList forwardInternal( NDList inputs, boolean training, PairList params) { - inputs = opInputs(parameterStore, inputs, training); NDArrayEx ex = inputs.head().getNDArrayInternal(); - - NDList output; - if (clipLstmState) { - output = - ex.lstm( - inputs, - stateSize, - dropRate, - numStackedLayers, - useSequenceLength, - isBidirectional(), - true, - lstmStateClipMin, - lstmStateClipMax, - params); - } else { - output = - ex.rnn( - inputs, - mode, - stateSize, - dropRate, - numStackedLayers, - useSequenceLength, - isBidirectional(), - true, - params); - } - - NDList result = new NDList(output.head().transpose(1, 0, 2)); - if (stateOutputs) { - result.add(output.get(1)); - result.add(output.get(2)); + Device device = inputs.head().getDevice(); + NDList rnnParams = new NDList(); + for (Parameter parameter : parameters.values()) { + rnnParams.add(parameterStore.getValue(parameter, device, training)); } - resetBeginStates(); - return result; - } - - /** {@inheritDoc} */ - @Override - public void setBeginStates(NDList beginStates) { - this.beginState = beginStates.get(0); - this.beginStateCell = beginStates.get(1); - } - - /** {@inheritDoc} */ - @Override - protected void resetBeginStates() { - beginState = null; - beginStateCell = null; - } - - /** {@inheritDoc} */ - @Override - protected NDList opInputs(ParameterStore parameterStore, NDList inputs, boolean training) { - validateInputSize(inputs); - long batchSize = inputs.head().getShape().get(0); - inputs = updateInputLayoutToTNC(inputs); - NDArray head = inputs.head(); - NDManager manager = head.getManager(); - Device device = head.getDevice(); - NDList result = new NDList(head); - try (NDList parameterList = new NDList()) { - for (Parameter parameter : parameters.values()) { - NDArray array = parameterStore.getValue(parameter, device, training).flatten(); - array.attach(manager); - parameterList.add(array); - } - NDArray array = NDArrays.concat(parameterList); - result.add(array); - } - Shape stateShape = new Shape((long) numStackedLayers * numDirections, batchSize, stateSize); - if (beginState != null) { - result.add(beginState); - result.add(beginStateCell); - } else { - // TODO manager creates the NDArray with the wrong device - result.add(manager.zeros(stateShape, DataType.FLOAT32, device)); - result.add(manager.zeros(stateShape, DataType.FLOAT32, device)); + NDArray input = inputs.head(); + if (inputs.size() == 1) { + int batchIndex = batchFirst ? 0 : 1; + Shape stateShape = + new Shape( + (long) numLayers * getNumDirections(), + input.size(batchIndex), + stateSize); + // hidden state + inputs.add(input.getManager().zeros(stateShape)); + // cell + inputs.add(input.getManager().zeros(stateShape)); } - if (useSequenceLength) { - result.add(inputs.get(1)); + NDList outputs = + ex.lstm( + input, + new NDList(inputs.get(1), inputs.get(2)), + rnnParams, + hasBiases, + numLayers, + dropRate, + training, + bidirectional, + batchFirst); + if (returnState) { + return outputs; } - return result; + outputs.stream().skip(1).forEach(NDArray::close); + return new NDList(outputs.get(0)); } /** @@ -176,20 +113,6 @@ protected Builder self() { return this; } - /** - * Sets the minimum and maximum clip value of LSTM states. - * - * @param lstmStateClipMin the minimum clip value of LSTM states - * @param lstmStateClipMax the maximum clip value of LSTM states - * @return this Builder - */ - public Builder optLstmStateClipMin(float lstmStateClipMin, float lstmStateClipMax) { - this.lstmStateClipMin = lstmStateClipMin; - this.lstmStateClipMax = lstmStateClipMax; - this.clipLstmState = true; - return self(); - } - /** * Builds a {@link LSTM} block. * @@ -197,8 +120,7 @@ public Builder optLstmStateClipMin(float lstmStateClipMin, float lstmStateClipMa */ public LSTM build() { Preconditions.checkArgument( - stateSize > 0 && numStackedLayers > 0, - "Must set stateSize and numStackedLayers"); + stateSize > 0 && numLayers > 0, "Must set stateSize and numStackedLayers"); return new LSTM(this); } } diff --git a/api/src/main/java/ai/djl/nn/recurrent/RNN.java b/api/src/main/java/ai/djl/nn/recurrent/RNN.java index c0fd684fb20..8f88aa7911b 100644 --- a/api/src/main/java/ai/djl/nn/recurrent/RNN.java +++ b/api/src/main/java/ai/djl/nn/recurrent/RNN.java @@ -12,7 +12,15 @@ */ package ai.djl.nn.recurrent; +import ai.djl.Device; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.internal.NDArrayEx; +import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; +import ai.djl.nn.Parameter; +import ai.djl.training.ParameterStore; +import ai.djl.util.PairList; import ai.djl.util.Preconditions; /** @@ -32,6 +40,7 @@ */ public class RNN extends RecurrentBlock { + private Activation activation; /** * Creates a vanilla RNN block. * @@ -39,10 +48,54 @@ public class RNN extends RecurrentBlock { */ RNN(Builder builder) { super(builder); - mode = builder.activation == Activation.RELU ? "rnn_relu" : "rnn_tanh"; + activation = builder.activation; gates = 1; } + /** {@inheritDoc} */ + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + NDArrayEx ex = inputs.head().getNDArrayInternal(); + Device device = inputs.head().getDevice(); + NDList rnnParams = new NDList(); + for (Parameter parameter : parameters.values()) { + rnnParams.add(parameterStore.getValue(parameter, device, training)); + } + + NDArray input = inputs.head(); + if (inputs.size() == 1) { + int batchIndex = batchFirst ? 0 : 1; + inputs.add( + input.getManager() + .zeros( + new Shape( + (long) numLayers * getNumDirections(), + input.size(batchIndex), + stateSize))); + } + NDList outputs = + ex.rnn( + input, + inputs.get(1), + rnnParams, + hasBiases, + numLayers, + activation, + dropRate, + training, + bidirectional, + batchFirst); + if (returnState) { + return outputs; + } + outputs.stream().skip(1).forEach(NDArray::close); + return new NDList(outputs.get(0)); + } + /** * Creates a builder to build a {@link RNN}. * @@ -79,8 +132,7 @@ public Builder setActivation(RNN.Activation activation) { */ public RNN build() { Preconditions.checkArgument( - stateSize > 0 && numStackedLayers > 0, - "Must set stateSize and numStackedLayers"); + stateSize > 0 && numLayers > 0, "Must set stateSize and numLayers"); return new RNN(this); } } diff --git a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java index e102faedb06..ab99442f180 100644 --- a/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java +++ b/api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java @@ -12,21 +12,14 @@ */ package ai.djl.nn.recurrent; -import ai.djl.Device; import ai.djl.MalformedModelException; -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDArrays; -import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; -import ai.djl.ndarray.internal.NDArrayEx; import ai.djl.ndarray.types.LayoutType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractBlock; import ai.djl.nn.Block; import ai.djl.nn.Parameter; import ai.djl.nn.ParameterType; -import ai.djl.training.ParameterStore; -import ai.djl.util.PairList; import java.io.DataInputStream; import java.io.IOException; @@ -52,13 +45,12 @@ public abstract class RecurrentBlock extends AbstractBlock { protected long stateSize; protected float dropRate; - protected int numStackedLayers; - protected String mode; - protected boolean useSequenceLength; - protected int numDirections = 1; + protected int numLayers; protected int gates; - protected boolean stateOutputs; - protected NDArray beginState; + protected boolean batchFirst; + protected boolean hasBiases; + protected boolean bidirectional; + protected boolean returnState; /** * Creates a {@code RecurrentBlock} object. @@ -69,22 +61,21 @@ public RecurrentBlock(BaseBuilder builder) { super(VERSION); stateSize = builder.stateSize; dropRate = builder.dropRate; - numStackedLayers = builder.numStackedLayers; - useSequenceLength = builder.useSequenceLength; - stateOutputs = builder.stateOutputs; - if (builder.useBidirectional) { - numDirections = 2; - } + numLayers = builder.numLayers; + batchFirst = builder.batchFirst; + hasBiases = builder.hasBiases; + bidirectional = builder.bidirectional; + returnState = builder.returnState; ParameterType[] parameterTypes = {ParameterType.WEIGHT, ParameterType.BIAS}; String[] directions = {"l"}; - if (builder.useBidirectional) { + if (builder.bidirectional) { directions = new String[] {"l", "r"}; } String[] gateStrings = {"i2h", "h2h"}; for (ParameterType parameterType : parameterTypes) { - for (int i = 0; i < numStackedLayers; i++) { + for (int i = 0; i < numLayers; i++) { for (String direction : directions) { for (String gateString : gateStrings) { String name = @@ -96,86 +87,24 @@ public RecurrentBlock(BaseBuilder builder) { } } - protected void validateInputSize(NDList inputs) { - int numberofInputsRequired = 1; - if (useSequenceLength) { - numberofInputsRequired = 2; - } - if (inputs.size() != numberofInputsRequired) { - throw new IllegalArgumentException( - "Invalid number of inputs for RNN. Size of input NDList must be " - + numberofInputsRequired - + " when useSequenceLength is " - + useSequenceLength); - } - } - - /** - * Sets the parameter that indicates whether the output must include the hidden states. - * - * @param stateOutputs whether the output must include the hidden states. - */ - public final void setStateOutputs(boolean stateOutputs) { - this.stateOutputs = stateOutputs; - } - - /** {@inheritDoc} */ - @Override - protected NDList forwardInternal( - ParameterStore parameterStore, - NDList inputs, - boolean training, - PairList params) { - inputs = opInputs(parameterStore, inputs, training); - NDArrayEx ex = inputs.head().getNDArrayInternal(); - NDList output = - ex.rnn( - inputs, - mode, - stateSize, - dropRate, - numStackedLayers, - useSequenceLength, - isBidirectional(), - true, - params); - - NDList result = new NDList(output.head().transpose(1, 0, 2)); - if (stateOutputs) { - result.add(output.get(1)); - } - resetBeginStates(); - return result; - } - - /** - * Sets the initial {@link NDArray} value for the hidden states. - * - * @param beginStates the {@link NDArray} value for the hidden states - */ - public void setBeginStates(NDList beginStates) { - this.beginState = beginStates.get(0); - } - - protected void resetBeginStates() { - beginState = null; - } - /** {@inheritDoc} */ @Override public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) { - // Input shape at this point is NTC. Output Shape should be NTS Shape inputShape = inputs[0]; - long nShape = inputShape.get(0); - long tShape = inputShape.get(1); - Shape nonStateOutputShape = new Shape(nShape, tShape, stateSize * numDirections); - if (stateOutputs) { + Shape outputShape = + new Shape(inputShape.get(0), inputShape.get(1), stateSize * getNumDirections()); + if (!returnState) { return new Shape[] { - nonStateOutputShape, - new Shape((long) numStackedLayers * numDirections, nShape, stateSize) + outputShape, }; } - return new Shape[] {nonStateOutputShape}; + return new Shape[] { + outputShape, + new Shape( + (long) numLayers * getNumDirections(), + inputShape.get((batchFirst) ? 0 : 1), + stateSize) + }; } /** {@inheritDoc} */ @@ -193,7 +122,7 @@ public Shape getParameterShape(String name, Shape[] inputShapes) { Shape shape = inputShapes[0]; long inputs = shape.get(2); if (layer > 0) { - inputs = stateSize * numDirections; + inputs = stateSize * getNumDirections(); } if (name.contains("BIAS")) { return new Shape(gates * stateSize); @@ -218,42 +147,8 @@ public void loadMetadata(byte version, DataInputStream is) } } - protected boolean isBidirectional() { - return numDirections == 2; - } - - protected NDList opInputs(ParameterStore parameterStore, NDList inputs, boolean training) { - validateInputSize(inputs); - long batchSize = inputs.head().getShape().get(0); - inputs = updateInputLayoutToTNC(inputs); - NDArray head = inputs.head(); - NDManager manager = head.getManager(); - Device device = head.getDevice(); - - NDList result = new NDList(head); - try (NDList parameterList = new NDList()) { - for (Parameter parameter : parameters.values()) { - NDArray array = parameterStore.getValue(parameter, device, training).flatten(); - array.attach(manager); - parameterList.add(array); - } - NDArray array = NDArrays.concat(parameterList); - result.add(array); - } - Shape stateShape = new Shape((long) numStackedLayers * numDirections, batchSize, stateSize); - if (beginState != null) { - result.add(beginState); - } else { - result.add(manager.zeros(stateShape)); - } - if (useSequenceLength) { - result.add(inputs.get(1)); - } - return result; - } - - protected NDList updateInputLayoutToTNC(NDList inputs) { - return new NDList(inputs.singletonOrThrow().transpose(1, 0, 2)); + protected int getNumDirections() { + return bidirectional ? 2 : 1; } /** The Builder to construct a {@link RecurrentBlock} type of {@link ai.djl.nn.Block}. */ @@ -262,13 +157,12 @@ public abstract static class BaseBuilder { protected float dropRate; protected long stateSize; - protected int numStackedLayers; - protected double lstmStateClipMin; - protected double lstmStateClipMax; - protected boolean clipLstmState; - protected boolean useSequenceLength; - protected boolean useBidirectional; - protected boolean stateOutputs; + protected int numLayers; + // set it true by default for usability + protected boolean batchFirst = true; + protected boolean hasBiases = true; + protected boolean bidirectional; + protected boolean returnState; protected RNN.Activation activation; /** @@ -297,45 +191,57 @@ public T setStateSize(int stateSize) { /** * Sets the Required number of stacked layers. * - * @param numStackedLayers the number of stacked layers + * @param numLayers the number of stacked layers * @return this Builder */ - public T setNumStackedLayers(int numStackedLayers) { - this.numStackedLayers = numStackedLayers; + public T setNumLayers(int numLayers) { + this.numLayers = numLayers; return self(); } /** - * Sets the optional parameter that indicates whether to include an extra input parameter - * sequence_length to specify variable length sequence. + * Sets the optional parameter that indicates whether to use bidirectional recurrent layers. * - * @param useSequenceLength whether to use sequence length + * @param useBidirectional whether to use bidirectional recurrent layers * @return this Builder */ - public T setSequenceLength(boolean useSequenceLength) { - this.useSequenceLength = useSequenceLength; + public T optBidirectional(boolean useBidirectional) { + this.bidirectional = useBidirectional; return self(); } /** - * Sets the optional parameter that indicates whether to use bidirectional recurrent layers. + * Sets the optional batchFirst flag that indicates whether the input is batch major or not. + * The default value is true. * - * @param useBidirectional whether to use bidirectional recurrent layers + * @param batchFirst whether the input is batch major or not + * @return this Builder + */ + public T optBatchFirst(boolean batchFirst) { + this.batchFirst = batchFirst; + return self(); + } + + /** + * Sets the optional biases flag that indicates whether to use biases or not. + * + * @param hasBiases whether to use biases or not * @return this Builder */ - public T optBidrectional(boolean useBidirectional) { - this.useBidirectional = useBidirectional; + public T optHasBiases(boolean hasBiases) { + this.hasBiases = hasBiases; return self(); } /** - * Sets the optional parameter that indicates whether to have the states as symbol outputs. + * Sets the optional flag that indicates whether to return state or not. This is typically + * useful when you use RecurrentBlock in Sequential block. The default value is false. * - * @param stateOutputs whether to have the states as symbol output + * @param returnState whether to return state or not * @return this Builder */ - public T optStateOutput(boolean stateOutputs) { - this.stateOutputs = stateOutputs; + public T optReturnState(boolean returnState) { + this.returnState = returnState; return self(); } diff --git a/api/src/main/native/djl/utils.h b/api/src/main/native/djl/utils.h index ff412dad73b..edce9e54a9c 100644 --- a/api/src/main/native/djl/utils.h +++ b/api/src/main/native/djl/utils.h @@ -62,7 +62,7 @@ inline jlongArray GetPtrArrayFromContainer(JNIEnv* env, T1 list) { return jarray; } -inline std::vector GetVecFromJLongArray(JNIEnv* env, jlongArray jarray) { +inline std::vector GetVecFromJLongArray(JNIEnv* env, jlongArray jarray) { jlong* jarr = env->GetLongArrayElements(jarray, JNI_FALSE); jsize length = env->GetArrayLength(jarray); std::vector vec(jarr, jarr + length); @@ -86,7 +86,7 @@ inline std::vector GetVecFromJFloatArray(JNIEnv* env, jfloatArray jarray) return std::move(vec); } -inline std::vector GetVecFromJStringArray(JNIEnv* env, jobjectArray array) { +inline std::vector GetVecFromJStringArray(JNIEnv* env, jobjectArray array) { std::vector vec; jsize len = env->GetArrayLength(array); vec.reserve(len); diff --git a/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java b/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java index 0bd481ed75b..894fc818a0d 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainMnistWithLSTM.java @@ -93,7 +93,12 @@ private static Block getLSTMModel() { return input.reshape(new Shape(batchSize, time, channel)); }); block.add( - new LSTM.Builder().setStateSize(64).setNumStackedLayers(1).optDropRate(0).build()); + new LSTM.Builder() + .setStateSize(64) + .setNumLayers(1) + .optDropRate(0) + .optReturnState(false) + .build()); block.add(BatchNorm.builder().optEpsilon(1e-5f).optMomentum(0.9f).build()); block.add(Blocks.batchFlattenBlock()); block.add(Linear.builder().setUnits(10).build()); diff --git a/examples/src/main/java/ai/djl/examples/training/TrainSentimentAnalysis.java b/examples/src/main/java/ai/djl/examples/training/TrainSentimentAnalysis.java index df9b693baab..60926df8364 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainSentimentAnalysis.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainSentimentAnalysis.java @@ -147,10 +147,9 @@ private static Block getModel() { return new SequentialBlock() .add( LSTM.builder() - .setNumStackedLayers(2) + .setNumLayers(2) .setStateSize(100) - .setSequenceLength(false) - .optBidrectional(true) + .optBidirectional(true) .build()) .add( x -> { diff --git a/examples/src/main/java/ai/djl/examples/training/TrainSeq2Seq.java b/examples/src/main/java/ai/djl/examples/training/TrainSeq2Seq.java index 414b08eca16..ddece45b821 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainSeq2Seq.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainSeq2Seq.java @@ -122,16 +122,20 @@ private static Block getSeq2SeqModel( sourceEmbedding, new LSTM.Builder() .setStateSize(32) - .setNumStackedLayers(2) + .setNumLayers(2) .optDropRate(0) + .optBatchFirst(true) + .optReturnState(true) .build()); SimpleTextDecoder simpleTextDecoder = new SimpleTextDecoder( targetEmbedding, new LSTM.Builder() .setStateSize(32) - .setNumStackedLayers(2) + .setNumLayers(2) .optDropRate(0) + .optBatchFirst(true) + .optReturnState(false) .build(), vocabSize); return new EncoderDecoder(simpleTextEncoder, simpleTextDecoder); diff --git a/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java b/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java index 2baf42d405b..2c64703212a 100644 --- a/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/modality/nlp/SimpleTextEncoderTest.java @@ -44,9 +44,10 @@ public void testEncoder() { new SimpleTextEncoder( trainableTextEmbedding, LSTM.builder() - .setNumStackedLayers(2) - .setSequenceLength(false) + .setNumLayers(2) .setStateSize(16) + .optBatchFirst(true) + .optReturnState(true) .build()); try (NDManager manager = NDManager.newBaseManager(TestUtils.getDevices()[0])) { encoder.setInitializer(new XavierInitializer()); @@ -55,7 +56,7 @@ public void testEncoder() { encoder.forward( new ParameterStore(manager, false), new NDList(manager.zeros(new Shape(4, 7))), - true); + false); Assert.assertEquals(output.head().getShape(), new Shape(4, 7, 16)); Assert.assertEquals(output.size(), 3); Assert.assertEquals(output.get(1).getShape(), new Shape(2, 4, 16)); diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java index 77f6d8ea117..df34fb7d777 100644 --- a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java @@ -433,6 +433,7 @@ public void testConv3d() throws IOException, MalformedModelException { } } + @SuppressWarnings("try") @Test public void testRNNTanh() throws IOException, MalformedModelException { Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true); @@ -443,34 +444,43 @@ public void testRNNTanh() throws IOException, MalformedModelException { Block block = RNN.builder() .setStateSize(4) - .setNumStackedLayers(1) + .setNumLayers(1) .setActivation(RNN.Activation.TANH) - .optStateOutput(true) + .optBatchFirst(true) + .optReturnState(true) .build(); try (Model model = Model.newInstance("model", config.getDevices()[0])) { model.setBlock(block); try (Trainer trainer = model.newTrainer(config)) { - Shape inputShape = new Shape(1, 2, 4); - Engine.getInstance().setRandomSeed(1234); - trainer.initialize(inputShape); - NDManager manager = trainer.getManager(); - NDArray data = - manager.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8}).reshape(inputShape); - NDArray labels = - manager.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8}).reshape(inputShape); - NDList result = trainer.forward(new NDList(data)); - NDArray expected = - manager.create(new float[] {1, 1, 1, 1, 1, 1, 1, 1}, new Shape(1, 2, 4)); - Assertions.assertAlmostEquals(result.head(), expected); - Assertions.assertAlmostEquals(result.size(), 2); - NDArray lossValue = loss.evaluate(new NDList(labels), new NDList(result.head())); - Assertions.assertAlmostEquals(lossValue.getFloat(), -18); - testEncode(manager, block); + // the unused GradientCollector is for BatchNorm to know it is on training mode + try (GradientCollector collector = trainer.newGradientCollector()) { + Shape inputShape = new Shape(1, 2, 4); + Engine.getInstance().setRandomSeed(1234); + trainer.initialize(inputShape); + NDManager manager = trainer.getManager(); + NDArray data = + manager.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8}) + .reshape(inputShape); + NDArray labels = + manager.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8}) + .reshape(inputShape); + NDList result = trainer.forward(new NDList(data)); + NDArray expected = + manager.create( + new float[] {1, 1, 1, 1, 1, 1, 1, 1}, new Shape(1, 2, 4)); + Assertions.assertAlmostEquals(result.head(), expected); + Assertions.assertAlmostEquals(result.size(), 2); + NDArray lossValue = + loss.evaluate(new NDList(labels), new NDList(result.head())); + Assertions.assertAlmostEquals(lossValue.getFloat(), -18); + testEncode(manager, block); + } } } } + @SuppressWarnings("try") @Test public void testRNNRelu() throws IOException, MalformedModelException { Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true); @@ -481,35 +491,44 @@ public void testRNNRelu() throws IOException, MalformedModelException { Block block = RNN.builder() .setStateSize(4) - .setNumStackedLayers(1) + .setNumLayers(1) .setActivation(RNN.Activation.RELU) - .optStateOutput(true) + .optBatchFirst(true) + .optReturnState(true) .build(); try (Model model = Model.newInstance("model", config.getDevices()[0])) { model.setBlock(block); try (Trainer trainer = model.newTrainer(config)) { - Shape inputShape = new Shape(1, 2, 4); - Engine.getInstance().setRandomSeed(1234); - trainer.initialize(inputShape); - NDManager manager = trainer.getManager(); - NDArray data = - manager.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8}).reshape(inputShape); - NDArray labels = - manager.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8}).reshape(inputShape); - NDList result = trainer.forward(new NDList(data)); - NDArray expected = - manager.create( - new float[] {10, 10, 10, 10, 66, 66, 66, 66}, new Shape(1, 2, 4)); - Assertions.assertAlmostEquals(result.head(), expected); - Assertions.assertAlmostEquals(result.size(), 2); - NDArray lossValue = loss.evaluate(new NDList(labels), new NDList(result.head())); - Assertions.assertAlmostEquals(lossValue.getFloat(), -908); - testEncode(manager, block); + // the unused GradientCollector is for BatchNorm to know it is on training mode + try (GradientCollector collector = trainer.newGradientCollector()) { + Shape inputShape = new Shape(1, 2, 4); + Engine.getInstance().setRandomSeed(1234); + trainer.initialize(inputShape); + NDManager manager = trainer.getManager(); + NDArray data = + manager.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8}) + .reshape(inputShape); + NDArray labels = + manager.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8}) + .reshape(inputShape); + NDList result = trainer.forward(new NDList(data)); + NDArray expected = + manager.create( + new float[] {10, 10, 10, 10, 66, 66, 66, 66}, + new Shape(1, 2, 4)); + Assertions.assertAlmostEquals(result.head(), expected); + Assertions.assertAlmostEquals(result.size(), 2); + NDArray lossValue = + loss.evaluate(new NDList(labels), new NDList(result.head())); + Assertions.assertAlmostEquals(lossValue.getFloat(), -908); + testEncode(manager, block); + } } } } + @SuppressWarnings("try") @Test public void testLstm() throws IOException, MalformedModelException { Loss loss = new SoftmaxCrossEntropyLoss("SmCeLoss", 1, -1, false, true); @@ -518,36 +537,48 @@ public void testLstm() throws IOException, MalformedModelException { .optInitializer(Initializer.ONES) .optDevices(TestUtils.getDevices()); Block block = - LSTM.builder().setStateSize(4).setNumStackedLayers(1).optStateOutput(true).build(); + LSTM.builder() + .setStateSize(4) + .setNumLayers(1) + .optBatchFirst(true) + .optReturnState(true) + .build(); try (Model model = Model.newInstance("model", config.getDevices()[0])) { model.setBlock(block); try (Trainer trainer = model.newTrainer(config)) { - Shape inputShape = new Shape(1, 2, 4); - Engine.getInstance().setRandomSeed(1234); - trainer.initialize(inputShape); - NDManager manager = trainer.getManager(); - NDArray data = - manager.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8}).reshape(inputShape); - NDArray labels = - manager.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8}).reshape(inputShape); - NDList result = trainer.forward(new NDList(data)); - NDArray expected = - manager.create( - new float[] { - 00.7615f, 0.7615f, 0.7615f, 0.7615f, 0.964f, 0.964f, 0.964f, - 0.964f - }, - new Shape(1, 2, 4)); - Assertions.assertAlmostEquals(result.head(), expected); - Assertions.assertAlmostEquals(result.size(), 3); - NDArray lossValue = loss.evaluate(new NDList(labels), new NDList(result.head())); - Assertions.assertAlmostEquals(lossValue.getFloat(), -16.340019); - testEncode(manager, block); + // the unused GradientCollector is for BatchNorm to know it is on training mode + try (GradientCollector collector = trainer.newGradientCollector()) { + Shape inputShape = new Shape(1, 2, 4); + Engine.getInstance().setRandomSeed(1234); + trainer.initialize(inputShape); + NDManager manager = trainer.getManager(); + NDArray data = + manager.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8}) + .reshape(inputShape); + NDArray labels = + manager.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8}) + .reshape(inputShape); + NDList result = trainer.forward(new NDList(data)); + NDArray expected = + manager.create( + new float[] { + 00.7615f, 0.7615f, 0.7615f, 0.7615f, 0.964f, 0.964f, 0.964f, + 0.964f + }, + new Shape(1, 2, 4)); + Assertions.assertAlmostEquals(result.head(), expected); + Assertions.assertAlmostEquals(result.size(), 3); + NDArray lossValue = + loss.evaluate(new NDList(labels), new NDList(result.head())); + Assertions.assertAlmostEquals(lossValue.getFloat(), -16.340019); + testEncode(manager, block); + } } } } + @SuppressWarnings("try") @Test public void testGRU() throws IOException, MalformedModelException { @@ -556,38 +587,50 @@ public void testGRU() throws IOException, MalformedModelException { new DefaultTrainingConfig(loss) .optInitializer(Initializer.ONES) .optDevices(TestUtils.getDevices()); - GRU block = GRU.builder().setStateSize(4).setNumStackedLayers(1).build(); + GRU block = + GRU.builder() + .setStateSize(4) + .setNumLayers(1) + .optBatchFirst(true) + .optReturnState(false) + .build(); try (Model model = Model.newInstance("model", config.getDevices()[0])) { model.setBlock(block); try (Trainer trainer = model.newTrainer(config)) { - Shape inputShape = new Shape(1, 2, 4); - Engine.getInstance().setRandomSeed(1234); - trainer.initialize(inputShape); - NDManager manager = trainer.getManager(); - NDArray data = - manager.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8}).reshape(inputShape); - NDArray labels = - manager.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8}).reshape(inputShape); - NDList result = trainer.forward(new NDList(data)); - NDArray expected = - manager.create( - new float[] { - 4.54187393e-05f, - 4.54187393e-05f, - 4.54187393e-05f, - 4.54187393e-05f, - 4.54187393e-05f, - 4.54187393e-05f, - 4.54187393e-05f, - 4.54187393e-05f - }, - new Shape(1, 2, 4)); - Assertions.assertAlmostEquals(result.head(), expected); - Assertions.assertAlmostEquals(result.size(), 1); - NDArray lossValue = loss.evaluate(new NDList(labels), new NDList(result.head())); - Assertions.assertAlmostEquals(lossValue.getFloat(), -8.17537307E-4); - testEncode(manager, block); + // the unused GradientCollector is for BatchNorm to know it is on training mode + try (GradientCollector collector = trainer.newGradientCollector()) { + Shape inputShape = new Shape(1, 2, 4); + Engine.getInstance().setRandomSeed(1234); + trainer.initialize(inputShape); + NDManager manager = trainer.getManager(); + NDArray data = + manager.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8}) + .reshape(inputShape); + NDArray labels = + manager.create(new float[] {1, 2, 3, 4, 5, 6, 7, 8}) + .reshape(inputShape); + NDList result = trainer.forward(new NDList(data)); + NDArray expected = + manager.create( + new float[] { + 4.54187393e-05f, + 4.54187393e-05f, + 4.54187393e-05f, + 4.54187393e-05f, + 4.54187393e-05f, + 4.54187393e-05f, + 4.54187393e-05f, + 4.54187393e-05f + }, + new Shape(1, 2, 4)); + Assertions.assertAlmostEquals(result.head(), expected); + Assertions.assertAlmostEquals(result.size(), 1); + NDArray lossValue = + loss.evaluate(new NDList(labels), new NDList(result.head())); + Assertions.assertAlmostEquals(lossValue.getFloat(), -8.17537307E-4); + testEncode(manager, block); + } } } } diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/nlp/SimpleTextDecoder.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/nlp/SimpleTextDecoder.java index 09f989edd91..33a6a8d0d3c 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/nlp/SimpleTextDecoder.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/nlp/SimpleTextDecoder.java @@ -32,8 +32,6 @@ public class SimpleTextDecoder extends Decoder { private static final byte VERSION = 1; - private RecurrentBlock recurrentBlock; - /** * Contructs a new instance of {@code SimpleTextDecoder} with the given {@link RecurrentBlock}. * Use this constructor if you are planning to use pre-trained embeddings that don't need @@ -60,7 +58,6 @@ public SimpleTextDecoder( RecurrentBlock recurrentBlock, long vocabSize) { super(VERSION, getBlock(trainableTextEmbedding, recurrentBlock, vocabSize)); - this.recurrentBlock = recurrentBlock; } private static Block getBlock( @@ -75,12 +72,6 @@ private static Block getBlock( return sequentialBlock; } - /** {@inheritDoc} */ - @Override - public void initState(NDList encoderStates) { - recurrentBlock.setBeginStates(encoderStates); - } - /** {@inheritDoc} */ @Override protected NDList forwardInternal( diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/nlp/SimpleTextEncoder.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/nlp/SimpleTextEncoder.java index d7a05d6801a..25faa6cb795 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/nlp/SimpleTextEncoder.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/nlp/SimpleTextEncoder.java @@ -35,7 +35,6 @@ public class SimpleTextEncoder extends Encoder { */ public SimpleTextEncoder(RecurrentBlock recurrentBlock) { super(VERSION, recurrentBlock); - recurrentBlock.setStateOutputs(true); } /** @@ -49,7 +48,6 @@ public SimpleTextEncoder(RecurrentBlock recurrentBlock) { public SimpleTextEncoder( TrainableTextEmbedding trainableTextEmbedding, RecurrentBlock recurrentBlock) { super(VERSION, new SequentialBlock().add(trainableTextEmbedding).add(recurrentBlock)); - recurrentBlock.setStateOutputs(true); } /** {@inheritDoc} */ diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java index bd908e09f31..96e1d5b6519 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java @@ -15,6 +15,7 @@ import ai.djl.Device; import ai.djl.mxnet.jna.JnaUtils; import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.NDUtils; @@ -22,7 +23,9 @@ import ai.djl.ndarray.internal.NDArrayEx; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; +import ai.djl.nn.recurrent.RNN; import ai.djl.util.PairList; +import ai.djl.util.Preconditions; import java.util.Arrays; import java.util.List; @@ -675,54 +678,184 @@ public NDList batchNorm( /** {@inheritDoc} */ @Override public NDList rnn( - NDList inputs, - String mode, - long stateSize, - float dropRate, - int numStackedLayers, - boolean useSequenceLength, - boolean useBidirectional, - boolean stateOutputs, - PairList additional) { - MxOpParams params = new MxOpParams(); - params.addParam("p", dropRate); - params.addParam("state_size", stateSize); - params.addParam("num_layers", numStackedLayers); - params.addParam("use_sequence_length", useSequenceLength); - params.addParam("bidirectional", useBidirectional); - params.addParam("state_outputs", stateOutputs); - params.addParam("mode", mode); - params.addAll(additional); - return getManager().invoke("_npx_rnn", inputs, params); + NDArray input, + NDArray state, + NDList params, + boolean hasBiases, + int numLayers, + RNN.Activation activation, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst) { + int numParams = numLayers * ((hasBiases) ? 4 : 2); + Preconditions.checkArgument( + params.size() == numParams, + "The size of Params is incorrect expect " + + numParams + + " parameters but got " + + params.size()); + + if (training != JnaUtils.autogradIsTraining()) { + throw new IllegalArgumentException( + "the mode of rnn in MXNet should align with the mode of GradientCollector"); + } + + if (batchFirst) { + input = input.swapAxes(0, 1); + } + + MxOpParams opParams = new MxOpParams(); + opParams.addParam("p", dropRate); + opParams.addParam("state_size", state.getShape().tail()); + opParams.addParam("num_layers", numLayers); + opParams.addParam("bidirectional", bidirectional); + opParams.addParam("state_outputs", true); + opParams.addParam("mode", activation == RNN.Activation.TANH ? "rnn_tanh" : "rnn_relu"); + + NDList inputs = new NDList(); + inputs.add(input); + + try (NDList temp = new NDList()) { + for (NDArray param : params) { + temp.add(param.flatten()); + } + NDArray tempParam = NDArrays.concat(temp); + tempParam.attach(input.getManager()); + inputs.add(tempParam); + } + + inputs.add(state); + + if (!batchFirst) { + return getManager().invoke("_npx_rnn", inputs, opParams); + } + + NDList result = getManager().invoke("_npx_rnn", inputs, opParams); + try (NDArray temp = result.head()) { + return new NDList(temp.swapAxes(0, 1), result.get(1)); + } + } + + /** {@inheritDoc} */ + @Override + public NDList gru( + NDArray input, + NDArray state, + NDList params, + boolean hasBiases, + int numLayers, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst) { + int numParams = numLayers * ((hasBiases) ? 4 : 2); + Preconditions.checkArgument( + params.size() == numParams, + "The size of Params is incorrect expect " + + numParams + + " parameters but got " + + params.size()); + + if (training != JnaUtils.autogradIsTraining()) { + throw new IllegalArgumentException( + "the mode of gru in MXNet should align with the mode of GradientCollector"); + } + + if (batchFirst) { + input = input.swapAxes(0, 1); + } + + MxOpParams opParams = new MxOpParams(); + opParams.addParam("p", dropRate); + opParams.addParam("state_size", state.getShape().tail()); + opParams.addParam("num_layers", numLayers); + opParams.addParam("bidirectional", bidirectional); + opParams.addParam("state_outputs", true); + opParams.addParam("mode", "gru"); + + NDList inputs = new NDList(); + inputs.add(input); + + try (NDList temp = new NDList()) { + for (NDArray param : params) { + temp.add(param.flatten()); + } + NDArray tempParam = NDArrays.concat(temp); + tempParam.attach(input.getManager()); + inputs.add(tempParam); + } + + inputs.add(state); + + if (!batchFirst) { + return getManager().invoke("_npx_rnn", inputs, opParams); + } + + NDList result = getManager().invoke("_npx_rnn", inputs, opParams); + try (NDArray temp = result.head()) { + return new NDList(temp.swapAxes(0, 1), result.get(1)); + } } /** {@inheritDoc} */ @Override public NDList lstm( - NDList inputs, - long stateSize, - float dropRate, - int numStackedLayers, - boolean useSequenceLength, - boolean useBidirectional, - boolean stateOutputs, - double lstmStateClipMin, - double lstmStateClipMax, - PairList additional) { - MxOpParams params = new MxOpParams(); - params.addParam("mode", "lstm"); - params.addParam("p", dropRate); - params.addParam("state_size", stateSize); - params.addParam("num_layers", numStackedLayers); - params.addParam("use_sequence_length", useSequenceLength); - params.addParam("bidirectional", useBidirectional); - params.addParam("state_outputs", stateOutputs); - params.addParam("lstm_state_clip_nan", true); - params.addParam("lstm_state_clip_min", lstmStateClipMin); - params.addParam("lstm_state_clip_max", lstmStateClipMax); - params.addAll(additional); + NDArray input, + NDList states, + NDList params, + boolean hasBiases, + int numLayers, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst) { + int numParams = numLayers * ((hasBiases) ? 4 : 2); + Preconditions.checkArgument( + params.size() == numParams, + "The size of Params is incorrect expect " + + numParams + + " parameters but got " + + params.size()); + + if (training != JnaUtils.autogradIsTraining()) { + throw new IllegalArgumentException( + "the mode of lstm in MXNet should align with the mode of GradientCollector"); + } + + if (batchFirst) { + input = input.swapAxes(0, 1); + } + + MxOpParams opParams = new MxOpParams(); + opParams.addParam("mode", "lstm"); + opParams.addParam("p", dropRate); + opParams.addParam("state_size", states.head().getShape().tail()); + opParams.addParam("state_outputs", true); + opParams.addParam("num_layers", numLayers); + opParams.addParam("bidirectional", bidirectional); + opParams.addParam("lstm_state_clip_nan", true); + + NDList inputs = new NDList(); + inputs.add(input); + try (NDList temp = new NDList()) { + for (NDArray param : params) { + temp.add(param.flatten()); + } + NDArray tempParam = NDArrays.concat(temp); + tempParam.attach(input.getManager()); + inputs.add(tempParam); + } + inputs.addAll(states); - return getManager().invoke("_npx_rnn", inputs, params); + if (!batchFirst) { + return getManager().invoke("_npx_rnn", inputs, opParams); + } + + NDList result = getManager().invoke("_npx_rnn", inputs, opParams); + try (NDArray temp = result.head()) { + return new NDList(temp.swapAxes(0, 1), result.get(1), result.get(2)); + } } //////////////////////////////////////// diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java index a1c52a8c8bc..eba7b9db16d 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java @@ -20,6 +20,7 @@ import ai.djl.ndarray.internal.NDArrayEx; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; +import ai.djl.nn.recurrent.RNN; import ai.djl.pytorch.jni.JniUtils; import ai.djl.util.PairList; import java.util.List; @@ -462,32 +463,75 @@ public NDList batchNorm( /** {@inheritDoc} */ @Override public NDList rnn( - NDList inputs, - String mode, - long stateSize, - float dropRate, - int numStackedLayers, - boolean useSequenceLength, - boolean useBidirectional, - boolean stateOutputs, - PairList additional) { - throw new UnsupportedOperationException("Not implemented"); + NDArray input, + NDArray state, + NDList params, + boolean hasBiases, + int numLayers, + RNN.Activation activation, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst) { + return JniUtils.rnn( + (PtNDArray) input, + (PtNDArray) state, + params, + hasBiases, + numLayers, + activation, + dropRate, + training, + bidirectional, + batchFirst); + } + + /** {@inheritDoc} */ + @Override + public NDList gru( + NDArray input, + NDArray state, + NDList params, + boolean hasBiases, + int numLayers, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst) { + return JniUtils.gru( + (PtNDArray) input, + (PtNDArray) state, + params, + hasBiases, + numLayers, + dropRate, + training, + bidirectional, + batchFirst); } /** {@inheritDoc} */ @Override public NDList lstm( - NDList inputs, - long stateSize, - float dropRate, - int numStackedLayers, - boolean useSequenceLength, - boolean useBidirectional, - boolean stateOutputs, - double lstmStateClipMin, - double lstmStateClipMax, - PairList additional) { - throw new UnsupportedOperationException("Not implemented"); + NDArray input, + NDList states, + NDList params, + boolean hasBiases, + int numLayers, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst) { + return JniUtils.lstm( + (PtNDArray) input, + states, + params, + hasBiases, + numLayers, + dropRate, + training, + bidirectional, + batchFirst); } /** {@inheritDoc} */ diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 19e1e4060f1..1c3dff103c5 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; +import ai.djl.nn.recurrent.RNN; import ai.djl.pytorch.engine.PtDeviceType; import ai.djl.pytorch.engine.PtNDArray; import ai.djl.pytorch.engine.PtNDManager; @@ -1108,6 +1109,103 @@ public static PtNDArray dropout(PtNDArray ndArray, double prob, boolean training PyTorchLibrary.LIB.torchNNDropout(ndArray.getHandle(), prob, training)); } + public static NDList rnn( + PtNDArray input, + PtNDArray hx, + NDList params, + boolean hasBiases, + int numLayers, + RNN.Activation activation, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst) { + PtNDManager manager = input.getManager(); + long[] paramHandles = + params.stream().mapToLong(array -> ((PtNDArray) array).getHandle()).toArray(); + long[] outputs = + PyTorchLibrary.LIB.torchNNRnn( + input.getHandle(), + hx.getHandle(), + paramHandles, + hasBiases, + numLayers, + activation.ordinal(), + dropRate, + training, + bidirectional, + batchFirst); + NDList res = new NDList(); + for (long output : outputs) { + res.add(new PtNDArray(manager, output)); + } + return res; + } + + public static NDList gru( + PtNDArray input, + PtNDArray hx, + NDList params, + boolean hasBiases, + int numLayers, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst) { + PtNDManager manager = input.getManager(); + long[] paramHandles = + params.stream().mapToLong(array -> ((PtNDArray) array).getHandle()).toArray(); + long[] outputs = + PyTorchLibrary.LIB.torchNNGru( + input.getHandle(), + hx.getHandle(), + paramHandles, + hasBiases, + numLayers, + dropRate, + training, + bidirectional, + batchFirst); + NDList res = new NDList(); + for (long output : outputs) { + res.add(new PtNDArray(manager, output)); + } + return res; + } + + public static NDList lstm( + PtNDArray input, + NDList hx, + NDList params, + boolean hasBiases, + int numLayers, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst) { + PtNDManager manager = input.getManager(); + long[] hxHandles = + hx.stream().mapToLong(array -> ((PtNDArray) array).getHandle()).toArray(); + long[] paramHandles = + params.stream().mapToLong(array -> ((PtNDArray) array).getHandle()).toArray(); + long[] outputs = + PyTorchLibrary.LIB.torchNNLstm( + input.getHandle(), + hxHandles, + paramHandles, + hasBiases, + numLayers, + dropRate, + training, + bidirectional, + batchFirst); + NDList res = new NDList(); + for (long output : outputs) { + res.add(new PtNDArray(manager, output)); + } + return res; + } + public static PtNDArray avgPool( PtNDArray ndArray, Shape kernelSize, diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index dd7adf09acc..d3a89d006b0 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -388,6 +388,40 @@ native long torchNNBatchNorm( double momentum, double eps); + native long[] torchNNRnn( + long inputHandle, + long hxHandle, + long[] paramHandles, + boolean hasBiases, + int numLayers, + int activation, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst); + + native long[] torchNNGru( + long inputHandle, + long hxHandle, + long[] paramHandles, + boolean hasBiases, + int numLayers, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst); + + native long[] torchNNLstm( + long inputHandle, + long[] hxHandles, + long[] paramHandles, + boolean hasBiases, + int numLayers, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst); + native long torchNNAvgPool( long inputHandle, long[] kernel, diff --git a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc index a1dd102e0eb..59ca127a8cb 100644 --- a/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc +++ b/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_nn_functional.cc @@ -143,6 +143,83 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNDropout( API_END_RETURN() } +JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNRnn(JNIEnv* env, jobject jthis, + jlong jinput, jlong jhx, jlongArray jparams, jboolean jhas_biases, jint jnum_layers, jint jactivation, + jdouble jdrop_rate, jboolean jtraining, jboolean jbidirectional, jboolean jbatch_first) { + API_BEGIN() + const auto* input_ptr = reinterpret_cast(jinput); + const auto* hx_ptr = reinterpret_cast(jhx); + const std::vector params = djl::utils::jni::GetObjectVecFromJHandles(env, jparams); + + std::tuple outputs; + if (jactivation == 0) { + outputs = torch::rnn_relu(*input_ptr, *hx_ptr, torch::TensorList(params), jhas_biases, jnum_layers, + jdrop_rate, jtraining, jbidirectional, jbatch_first); + } else if (jactivation == 1) { + outputs = torch::rnn_tanh(*input_ptr, *hx_ptr, torch::TensorList(params), jhas_biases, jnum_layers, + jdrop_rate, jtraining, jbidirectional, jbatch_first); + } else { + env->ThrowNew(ENGINE_EXCEPTION_CLASS, "can't find activation"); + } + + // process output + jlongArray jarray = env->NewLongArray(2); + std::vector jptrs; + jptrs.reserve(2); + jptrs[0] = reinterpret_cast(new torch::Tensor(std::get<0>(outputs))); + jptrs[1] = reinterpret_cast(new torch::Tensor(std::get<1>(outputs))); + env->SetLongArrayRegion(jarray, 0, 2, jptrs.data()); + return jarray; + API_END_RETURN() +} + +JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNGru(JNIEnv* env, jobject jthis, + jlong jinput, jlong jhx, jlongArray jparams, jboolean jhas_biases, jint jnum_layers, jdouble jdrop_rate, + jboolean jtraining, jboolean jbidirectional, jboolean jbatch_first) { + API_BEGIN() + const auto* input_ptr = reinterpret_cast(jinput); + const auto* hx_ptr = reinterpret_cast(jhx); + const std::vector params = djl::utils::jni::GetObjectVecFromJHandles(env, jparams); + + std::tuple outputs = + torch::gru(*input_ptr, *hx_ptr, torch::TensorList(params), jhas_biases, jnum_layers, + jdrop_rate, jtraining, jbidirectional, jbatch_first); + + // process output + jlongArray jarray = env->NewLongArray(2); + std::vector jptrs; + jptrs.reserve(2); + jptrs[0] = reinterpret_cast(new torch::Tensor(std::get<0>(outputs))); + jptrs[1] = reinterpret_cast(new torch::Tensor(std::get<1>(outputs))); + env->SetLongArrayRegion(jarray, 0, 2, jptrs.data()); + return jarray; + API_END_RETURN() +} + +JNIEXPORT jlongArray JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNLstm(JNIEnv* env, jobject jthis, + jlong jinput, jlongArray jhx, jlongArray jparams, jboolean jhas_biases, jint jnum_layers, jdouble jdrop_rate, + jboolean jtraining, jboolean jbidirectional, jboolean jbatch_first) { + API_BEGIN() + const auto* input_ptr = reinterpret_cast(jinput); + const std::vector hx = djl::utils::jni::GetObjectVecFromJHandles(env, jhx); + const std::vector params = djl::utils::jni::GetObjectVecFromJHandles(env, jparams); + + std::tuple outputs = + torch::lstm(*input_ptr, torch::TensorList(hx), torch::TensorList(params), jhas_biases, jnum_layers, + jdrop_rate, jtraining, jbidirectional, jbatch_first); + + // process output + jlongArray jarray = env->NewLongArray(3); + std::vector jptrs; + jptrs.reserve(3); + jptrs[0] = reinterpret_cast(new torch::Tensor(std::get<0>(outputs))); + jptrs[1] = reinterpret_cast(new torch::Tensor(std::get<1>(outputs))); + jptrs[2] = reinterpret_cast(new torch::Tensor(std::get<2>(outputs))); + env->SetLongArrayRegion(jarray, 0, 3, jptrs.data()); + return jarray; + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNRelu(JNIEnv* env, jobject jthis, jlong jinput) { API_BEGIN() const auto* tensor_ptr = reinterpret_cast(jinput); diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java index df7279e873f..a44ad803e5d 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayEx.java @@ -20,6 +20,7 @@ import ai.djl.ndarray.internal.NDArrayEx; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; +import ai.djl.nn.recurrent.RNN; import ai.djl.util.PairList; import java.util.ArrayList; import java.util.List; @@ -402,31 +403,46 @@ public NDList batchNorm( /** {@inheritDoc} */ @Override public NDList rnn( - NDList inputs, - String mode, - long stateSize, - float dropRate, - int numStackedLayers, - boolean useSequenceLength, - boolean useBidirectional, - boolean stateOutputs, - PairList additional) { + NDArray input, + NDArray state, + NDList params, + boolean hasBiases, + int numLayers, + RNN.Activation activation, + double dropRate, + boolean train, + boolean bidirectional, + boolean batchFirst) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** {@inheritDoc} */ + @Override + public NDList gru( + NDArray input, + NDArray state, + NDList params, + boolean hasBiases, + int numLayers, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst) { throw new UnsupportedOperationException("Not implemented"); } /** {@inheritDoc} */ @Override public NDList lstm( - NDList inputs, - long stateSize, - float dropRate, - int numStackedLayers, - boolean useSequenceLength, - boolean useBidirectional, - boolean stateOutputs, - double lstmStateClipMin, - double lstmStateClipMax, - PairList additional) { + NDArray input, + NDList states, + NDList params, + boolean hasBiases, + int numLayers, + double dropRate, + boolean training, + boolean bidirectional, + boolean batchFirst) { throw new UnsupportedOperationException("Not implemented"); }