diff --git a/api/src/main/java/ai/djl/nn/AbstractBaseBlock.java b/api/src/main/java/ai/djl/nn/AbstractBaseBlock.java new file mode 100644 index 00000000000..5c6a5503091 --- /dev/null +++ b/api/src/main/java/ai/djl/nn/AbstractBaseBlock.java @@ -0,0 +1,417 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.nn; + +import ai.djl.MalformedModelException; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.training.ParameterStore; +import ai.djl.training.initializer.Initializer; +import ai.djl.util.Pair; +import ai.djl.util.PairList; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.Predicate; + +/** + * This provides shared functionality for both the DJL-based {@link AbstractBlock}s and the imported + * {@link AbstractSymbolBlock}s. + */ +public abstract class AbstractBaseBlock implements Block { + + /** + * The model version of this block, used for checking if parameters are still valid during + * parameter loading. + */ + protected byte version; + + /** The shape of the input for this block, set by the initialization process. */ + protected Shape[] inputShapes; + + /** List of names for the input, named inputs should be manually set in sub class. */ + protected List inputNames = Collections.emptyList(); + + /** Constructs a new {@link AbstractBaseBlock} instance. */ + public AbstractBaseBlock() { + this((byte) 1); + } + + /** + * Builds an empty block with the given version for parameter serialization. + * + * @param version the version to use for parameter serialization. + */ + public AbstractBaseBlock(byte version) { + this.version = version; + } + + /** {@inheritDoc} */ + @Override + public final NDList forward( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + NDManager paramsManager = parameterStore.getManager(); + if (training && !isInitialized()) { + initialize(paramsManager, DataType.FLOAT32, inputs.getShapes()); + } + return forwardInternal(parameterStore, inputs, training, params); + } + + /** {@inheritDoc} */ + @Override + public NDList forward( + ParameterStore parameterStore, + NDList data, + NDList labels, + PairList params) { + NDManager paramsManager = parameterStore.getManager(); + if (!isInitialized()) { + initialize(paramsManager, DataType.FLOAT32, data.getShapes()); + } + return forwardInternal(parameterStore, data, labels, params); + } + + /** + * A helper for {@link Block#forward(ParameterStore, NDList, boolean, PairList)} after + * initialization. + * + * @param parameterStore the parameter store + * @param inputs the input NDList + * @param training true for a training forward pass + * @param params optional parameters + * @return the output of the forward pass + */ + protected abstract NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params); + + /** + * A helper for {@link Block#forward(ParameterStore, NDList, NDList, PairList)} after + * initialization. + * + * @param parameterStore the parameter store + * @param data the input data NDList + * @param labels the input labels NDList + * @param params optional parameters + * @return the output of the forward pass + * @see #forward(ParameterStore, NDList, boolean, PairList) + */ + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList data, + NDList labels, + PairList params) { + return forwardInternal(parameterStore, data, true, params); + } + + /** {@inheritDoc} */ + @Override + public PairList describeInput() { + if (!isInitialized()) { + throw new IllegalStateException( + "Parameter of this block are not initialised," + + "please call model.newTrainer and trainer.initialize"); + } + return new PairList<>(inputNames, Arrays.asList(inputShapes)); + } + + /** {@inheritDoc} */ + @Override + public void setInitializer(Initializer initializer, Parameter.Type params) { + Predicate predicate = parameter -> parameter.getType().equals(params); + setInitializer(initializer, predicate); + } + + /** {@inheritDoc} */ + @Override + public void setInitializer(Initializer initializer, String paramName) { + Parameter parameter = + getDirectParameters() + .values() + .stream() + .filter(p -> p.getName().equals(paramName)) + .findFirst() + .orElseThrow( + () -> + new IllegalArgumentException( + "Could not find parameter " + paramName)); + parameter.setInitializer(initializer); + } + + /** {@inheritDoc} */ + @Override + public void setInitializer(Initializer initializer, Predicate predicate) { + List params = getParameters().values(); + for (Parameter param : params) { + if (predicate.test(param)) { + param.setInitializer(initializer); + } + } + } + + /** {@inheritDoc} */ + @Override + public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) { + beforeInitialize(inputShapes); + // if parameters are initialized, skip it + if (!isInitialized()) { + // setShape for all params + prepare(inputShapes); + } + for (Parameter parameter : getDirectParameters().values()) { + parameter.initialize(manager, dataType); + } + initializeChildBlocks(manager, dataType, inputShapes); + } + + /** + * Performs any action necessary before initialization. For example, keep the input information + * or verify the layout. + * + * @param inputShapes the expected shapes of the input + */ + protected void beforeInitialize(Shape... inputShapes) { + if (inputNames.isEmpty()) { + // automatically assign input names + inputNames = new ArrayList<>(); + for (int i = 0; i < inputShapes.length; ++i) { + inputNames.add("data" + i); + } + } + this.inputShapes = inputShapes; + } + + /** + * Initializes the Child blocks of this block. You need to override this method if your subclass + * has child blocks. Used to determine the correct input shapes for child blocks based on the + * requested input shape for this block. + * + * @param manager the manager to use for initialization + * @param dataType the requested data type + * @param inputShapes the expected input shapes for this block + */ + protected void initializeChildBlocks( + NDManager manager, DataType dataType, Shape... inputShapes) { + if (!getChildren().isEmpty()) { + throw new IllegalStateException( + getClass().getSimpleName() + + " has child blocks but initializeChildBlocks is not overwritten."); + } + } + + /** + * Sets the shape of {@link Parameter}s. + * + * @param inputShapes the shapes of inputs + */ + protected void prepare(Shape[] inputShapes) {} + + /** {@inheritDoc} */ + @Override + public ParameterList getParameters() { + // we accumulate a list of all parameters by starting with a list of the direct parameters + ParameterList allParams = getDirectParameters(); + // then we add the parameters of child blocks + for (Pair childPair : getChildren()) { + for (Pair paramPair : childPair.getValue().getParameters()) { + // we prepend the name of the child block to the parameter name + allParams.add(childPair.getKey() + "_" + paramPair.getKey(), paramPair.getValue()); + } + } + return allParams; + } + + /** {@inheritDoc} */ + @Override + public boolean isInitialized() { + if (inputShapes == null) { + return false; + } + for (Parameter param : getParameters().values()) { + if (!param.isInitialized()) { + return false; + } + } + return true; + } + + /** {@inheritDoc} */ + @Override + public void clear() { + getParameters().forEach(param -> param.getValue().close()); + } + + /** {@inheritDoc} */ + @Override + public void cast(DataType dataType) { + throw new UnsupportedOperationException("Not implemented yet."); + } + + /** {@inheritDoc} */ + @Override + public void saveParameters(DataOutputStream os) throws IOException { + os.write(version); + saveMetadata(os); + for (Parameter parameter : getDirectParameters().values()) { + parameter.save(os); + } + for (Block child : getChildren().values()) { + child.saveParameters(os); + } + } + + /** {@inheritDoc} */ + @Override + public void loadParameters(NDManager manager, DataInputStream is) + throws IOException, MalformedModelException { + byte loadVersion = is.readByte(); + loadMetadata(loadVersion, is); + for (Parameter parameter : getDirectParameters().values()) { + parameter.load(manager, is); + } + for (Block child : getChildren().values()) { + child.loadParameters(manager, is); + } + } + + /** + * Override this method to save additional data apart from parameter values. + * + *

This default implementation saves the currently set input shapes. + * + * @param os the non-null output stream the parameter values and metadata are written to + * @throws IOException saving failed + */ + protected void saveMetadata(DataOutputStream os) throws IOException { + saveInputShapes(os); + } + + /** + * Overwrite this to load additional metadata with the parameter values. + * + *

If you overwrite {@link AbstractBlock#saveMetadata(DataOutputStream)} or need to provide + * backward compatibility to older binary formats, you prabably need to overwrite this. This + * default implementation checks if the version number fits, if not it throws an {@link + * MalformedModelException}. After that it restores the input shapes. + * + * @param loadVersion the version used for loading this metadata. + * @param is the input stream we are loading from + * @throws IOException loading failed + * @throws MalformedModelException data can be loaded but has wrong format + */ + protected void loadMetadata(byte loadVersion, DataInputStream is) + throws IOException, MalformedModelException { + if (loadVersion != version) { + throw new MalformedModelException( + "Cannot load parameters for " + + this.getClass().getCanonicalName() + + ", expected version " + + version + + ", got " + + loadVersion + + "."); + } + readInputShapes(is); + } + + protected void saveInputShapes(DataOutputStream os) throws IOException { + os.writeInt(inputShapes.length); + for (Shape shape : inputShapes) { + os.write(shape.getEncoded()); + } + } + + protected void readInputShapes(DataInputStream is) throws IOException { + int len = is.readInt(); + Shape[] shapes = new Shape[len]; + for (int i = 0; i < len; ++i) { + shapes[i] = Shape.decode(is); + } + if (inputShapes == null) { + // load inputShapes from parameter file if Block has not been initialized + inputShapes = shapes; + } + } + + /** {@inheritDoc} */ + @Override + public String toString() { + // FIXME: This is a quick hack for display in jupyter notebook. + StringBuilder sb = new StringBuilder(200); + String className = getClass().getSimpleName(); + if (className.endsWith("Block")) { + className = className.substring(0, className.length() - 5); + } + sb.append(className).append('('); + if (isInitialized()) { + PairList inputShapeDescription = describeInput(); + appendShape(sb, inputShapeDescription.values().toArray(new Shape[0])); + sb.append(" -> "); + Shape[] outputShapes = + getOutputShapes(inputShapeDescription.values().toArray(new Shape[0])); + appendShape(sb, outputShapes); + } else { + sb.append("Uninitialized"); + } + sb.append(')'); + return sb.toString(); + } + + private void appendShape(StringBuilder sb, Shape[] shapes) { + boolean first = true; + for (Shape shape : shapes) { + if (first) { + first = false; + } else { + sb.append(", "); + } + long[] sh = shape.getShape(); + int length = sh.length; + if (length == 0) { + sb.append("()"); + } else { + int index = 0; + if (sh[0] == -1) { + --length; + index = 1; + } + + if (length == 0) { + sb.append("()"); + } else if (length == 1) { + sb.append(sh[index]); + } else { + sb.append('('); + for (int i = index; i < sh.length; ++i) { + if (i > index) { + sb.append(", "); + } + sb.append(sh[i]); + } + sb.append(')'); + } + } + } + } +} diff --git a/api/src/main/java/ai/djl/nn/AbstractBlock.java b/api/src/main/java/ai/djl/nn/AbstractBlock.java index 4bf3e7ae010..e3cc6cd6adc 100644 --- a/api/src/main/java/ai/djl/nn/AbstractBlock.java +++ b/api/src/main/java/ai/djl/nn/AbstractBlock.java @@ -12,25 +12,16 @@ */ package ai.djl.nn; -import ai.djl.MalformedModelException; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.training.ParameterStore; -import ai.djl.training.initializer.Initializer; import ai.djl.util.Pair; import ai.djl.util.PairList; -import java.io.DataInputStream; import java.io.DataOutputStream; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; import java.util.LinkedHashMap; -import java.util.List; import java.util.Locale; -import java.util.function.Predicate; /** * {@code AbstractBlock} is an abstract implementation of {@link Block}. @@ -55,8 +46,8 @@ * implement the computation of your block *

  • IFF you need to save data apart from the parameter values of your block, you need to * override {@link AbstractBlock#saveMetadata(DataOutputStream)} and {@link - * AbstractBlock#loadMetadata(byte, DataInputStream)}. If you do not need to save or load any - * state other than parameters in your block, you can skip this. + * AbstractBlock#loadMetadata(byte, java.io.DataInputStream)}. If you do not need to save or + * load any state other than parameters in your block, you can skip this. * * *

    If you use {@link AbstractBlock#addParameter(Parameter)} to add parameters, you have to take @@ -68,19 +59,7 @@ // of this API know the children and parameters are always iterated over in insertion order. // LinkedHashMap provides this guarantee, Map does not. @SuppressWarnings("PMD.LooseCoupling") -public abstract class AbstractBlock implements Block { - - /** The shape of the input for this block, set by the initialization process. */ - protected Shape[] inputShapes; - - /** List of names for the input, named inputs should be manually set in sub class. */ - protected List inputNames = Collections.emptyList(); - - /** - * The model version of this block, used for checking if parameters are still valid during - * parameter loading. - */ - protected byte version; +public abstract class AbstractBlock extends AbstractBaseBlock { /** * All direct children of this Block. Keys are names of the blocks. @@ -99,9 +78,7 @@ public abstract class AbstractBlock implements Block { protected LinkedHashMap parameters = new LinkedHashMap<>(); /** Constructs a new {@code AbstractBlock} instance. */ - public AbstractBlock() { - this((byte) 1); - } + public AbstractBlock() {} /** * Builds an empty block with the given version for parameter serialization. @@ -109,70 +86,7 @@ public AbstractBlock() { * @param version the version to use for parameter serialization. */ public AbstractBlock(byte version) { - this.version = version; - } - - /** {@inheritDoc} */ - @Override - public final NDList forward( - ParameterStore parameterStore, - NDList inputs, - boolean training, - PairList params) { - NDManager paramsManager = parameterStore.getManager(); - if (training && !isInitialized()) { - initialize(paramsManager, DataType.FLOAT32, inputs.getShapes()); - } - return forwardInternal(parameterStore, inputs, training, params); - } - - /** {@inheritDoc} */ - @Override - public NDList forward( - ParameterStore parameterStore, - NDList data, - NDList labels, - PairList params) { - NDManager paramsManager = parameterStore.getManager(); - if (!isInitialized()) { - initialize(paramsManager, DataType.FLOAT32, data.getShapes()); - } - return forwardInternal(parameterStore, data, labels, params); - } - - /** - * A helper for {@link Block#forward(ParameterStore, NDList, boolean, PairList)} after - * initialization. - * - * @param parameterStore the parameter store - * @param inputs the input NDList - * @param training true for a training forward pass - * @param params optional parameters - * @return the output of the forward pass - */ - protected abstract NDList forwardInternal( - ParameterStore parameterStore, - NDList inputs, - boolean training, - PairList params); - - /** - * A helper for {@link Block#forward(ParameterStore, NDList, NDList, PairList)} after - * initialization. - * - * @param parameterStore the parameter store - * @param data the input data NDList - * @param labels the input labels NDList - * @param params optional parameters - * @return the output of the forward pass - * @see #forward(ParameterStore, NDList, boolean, PairList) - */ - protected NDList forwardInternal( - ParameterStore parameterStore, - NDList data, - NDList labels, - PairList params) { - return forwardInternal(parameterStore, data, true, params); + super(version); } /** @@ -215,293 +129,9 @@ public BlockList getChildren() { return defensiveCopy; } - /** {@inheritDoc} */ - @Override - public PairList describeInput() { - if (!isInitialized()) { - throw new IllegalStateException( - "Parameter of this block are not initialised," - + "please call model.newTrainer and trainer.initialize"); - } - return new PairList<>(inputNames, Arrays.asList(inputShapes)); - } - - /** {@inheritDoc} */ - @Override - public void setInitializer(Initializer initializer, Parameter.Type params) { - Predicate predicate = parameter -> parameter.getType().equals(params); - setInitializer(initializer, predicate); - } - - /** {@inheritDoc} */ - @Override - public void setInitializer(Initializer initializer, String paramName) { - Parameter parameter = parameters.get(paramName); - if (parameter == null) { - throw new IllegalArgumentException("Could not find parameter " + paramName); - } - parameter.setInitializer(initializer); - } - - /** {@inheritDoc} */ - @Override - public void setInitializer(Initializer initializer, Predicate predicate) { - List params = getParameters().values(); - for (Parameter param : params) { - if (predicate.test(param)) { - param.setInitializer(initializer); - } - } - } - - /** {@inheritDoc} */ - @Override - public void initialize(NDManager manager, DataType dataType, Shape... inputShapes) { - beforeInitialize(inputShapes); - // if parameters are initialized, skip it - if (!isInitialized()) { - // setShape for all params - prepare(inputShapes); - } - for (Parameter parameter : parameters.values()) { - parameter.initialize(manager, dataType); - } - initializeChildBlocks(manager, dataType, inputShapes); - } - - /** - * Performs any action necessary before initialization. For example, keep the input information - * or verify the layout. - * - * @param inputShapes the expected shapes of the input - */ - protected void beforeInitialize(Shape... inputShapes) { - if (inputNames.isEmpty()) { - // automatically assign input names - inputNames = new ArrayList<>(); - for (int i = 0; i < inputShapes.length; ++i) { - inputNames.add("data" + i); - } - } - this.inputShapes = inputShapes; - } - - /** - * Initializes the Child blocks of this block. You need to override this method if your subclass - * has child blocks. Used to determine the correct input shapes for child blocks based on the - * requested input shape for this block. - * - * @param manager the manager to use for initialization - * @param dataType the requested data type - * @param inputShapes the expected input shapes for this block - */ - protected void initializeChildBlocks( - NDManager manager, DataType dataType, Shape... inputShapes) { - if (!children.isEmpty()) { - throw new IllegalStateException( - getClass().getSimpleName() - + " has child blocks but initializeChildBlocks is not overwritten."); - } - } - - /** {@inheritDoc} */ - @Override - public ParameterList getParameters() { - // we accumulate a list of all parameters by starting with a list of the direct parameters - ParameterList allParams = getDirectParameters(); - // then we add the parameters of child blocks - for (Pair childPair : getChildren()) { - for (Pair paramPair : childPair.getValue().getParameters()) { - // we prepend the name of the child block to the parameter name - allParams.add(childPair.getKey() + "_" + paramPair.getKey(), paramPair.getValue()); - } - } - return allParams; - } - /** {@inheritDoc} */ @Override public ParameterList getDirectParameters() { return new ParameterList(parameters); } - - /** - * Sets the shape of {@link Parameter}s. - * - * @param inputShapes the shapes of inputs - */ - protected void prepare(Shape[] inputShapes) {} - - /** {@inheritDoc} */ - @Override - public boolean isInitialized() { - if (inputShapes == null) { - return false; - } - for (Parameter param : getParameters().values()) { - if (!param.isInitialized()) { - return false; - } - } - return true; - } - - /** {@inheritDoc} */ - @Override - public void clear() { - getParameters().forEach(param -> param.getValue().close()); - } - - /** {@inheritDoc} */ - @Override - public void cast(DataType dataType) { - throw new UnsupportedOperationException("Not implemented yet."); - } - - /** {@inheritDoc} */ - @Override - public void saveParameters(DataOutputStream os) throws IOException { - os.write(version); - saveMetadata(os); - for (Parameter parameter : parameters.values()) { - parameter.save(os); - } - for (Block child : children.values()) { - child.saveParameters(os); - } - } - - /** {@inheritDoc} */ - @Override - public void loadParameters(NDManager manager, DataInputStream is) - throws IOException, MalformedModelException { - byte loadVersion = is.readByte(); - loadMetadata(loadVersion, is); - for (Parameter parameter : parameters.values()) { - parameter.load(manager, is); - } - for (Block child : children.values()) { - child.loadParameters(manager, is); - } - } - - /** - * Override this method to save additional data apart from parameter values. - * - *

    This default implementation saves the currently set input shapes. - * - * @param os the non-null output stream the parameter values and metadata are written to - * @throws IOException saving failed - */ - protected void saveMetadata(DataOutputStream os) throws IOException { - saveInputShapes(os); - } - - /** - * Overwrite this to load additional metadata with the parameter values. - * - *

    If you overwrite {@link AbstractBlock#saveMetadata(DataOutputStream)} or need to provide - * backward compatibility to older binary formats, you prabably need to overwrite this. This - * default implementation checks if the version number fits, if not it throws an {@link - * MalformedModelException}. After that it restores the input shapes. - * - * @param loadVersion the version used for loading this metadata. - * @param is the input stream we are loading from - * @throws IOException loading failed - * @throws MalformedModelException data can be loaded but has wrong format - */ - protected void loadMetadata(byte loadVersion, DataInputStream is) - throws IOException, MalformedModelException { - if (loadVersion != version) { - throw new MalformedModelException( - "Cannot load parameters for " - + this.getClass().getCanonicalName() - + ", expected version " - + version - + ", got " - + loadVersion - + "."); - } - readInputShapes(is); - } - - protected void saveInputShapes(DataOutputStream os) throws IOException { - os.writeInt(inputShapes.length); - for (Shape shape : inputShapes) { - os.write(shape.getEncoded()); - } - } - - protected void readInputShapes(DataInputStream is) throws IOException { - int len = is.readInt(); - Shape[] shapes = new Shape[len]; - for (int i = 0; i < len; ++i) { - shapes[i] = Shape.decode(is); - } - if (inputShapes == null) { - // load inputShapes from parameter file if Block has not been initialized - inputShapes = shapes; - } - } - - /** {@inheritDoc} */ - @Override - public String toString() { - // FIXME: This is a quick hack for display in jupyter notebook. - StringBuilder sb = new StringBuilder(200); - String className = getClass().getSimpleName(); - if (className.endsWith("Block")) { - className = className.substring(0, className.length() - 5); - } - sb.append(className).append('('); - if (isInitialized()) { - PairList inputShapeDescription = describeInput(); - appendShape(sb, inputShapeDescription.values().toArray(new Shape[0])); - sb.append(" -> "); - Shape[] outputShapes = - getOutputShapes(inputShapeDescription.values().toArray(new Shape[0])); - appendShape(sb, outputShapes); - } else { - sb.append("Uninitialized"); - } - sb.append(')'); - return sb.toString(); - } - - private void appendShape(StringBuilder sb, Shape[] shapes) { - boolean first = true; - for (Shape shape : shapes) { - if (first) { - first = false; - } else { - sb.append(", "); - } - long[] sh = shape.getShape(); - int length = sh.length; - if (length == 0) { - sb.append("()"); - } else { - int index = 0; - if (sh[0] == -1) { - --length; - index = 1; - } - - if (length == 0) { - sb.append("()"); - } else if (length == 1) { - sb.append(sh[index]); - } else { - sb.append('('); - for (int i = index; i < sh.length; ++i) { - if (i > index) { - sb.append(", "); - } - sb.append(sh[i]); - } - sb.append(')'); - } - } - } - } } diff --git a/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java b/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java index efcb4622276..9ffbd3f0083 100644 --- a/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java +++ b/api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java @@ -15,7 +15,7 @@ import ai.djl.ndarray.types.Shape; /** {@code AbstractSymbolBlock} is an abstract implementation of {@link SymbolBlock}. */ -public abstract class AbstractSymbolBlock extends AbstractBlock implements SymbolBlock { +public abstract class AbstractSymbolBlock extends AbstractBaseBlock implements SymbolBlock { /** Constructs a new {@code AbstractSymbolBlock} instance. */ public AbstractSymbolBlock() {} @@ -34,4 +34,10 @@ public AbstractSymbolBlock(byte version) { public Shape[] getOutputShapes(Shape[] inputShapes) { throw new UnsupportedOperationException("not implement!"); } + + /** {@inheritDoc} */ + @Override + public BlockList getChildren() { + return new BlockList(); + } } diff --git a/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrSymbolBlock.java b/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrSymbolBlock.java index d46d87282a3..0e7dd93f59f 100644 --- a/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrSymbolBlock.java +++ b/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrSymbolBlock.java @@ -16,6 +16,7 @@ import ai.djl.dlr.jni.JniUtils; import ai.djl.ndarray.NDList; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; @@ -75,4 +76,10 @@ public void close() { JniUtils.deleteDlrModel(pointer); } } + + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } } diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java index 860526ef0d3..8f0d1b51c39 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbSymbolBlock.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; @@ -105,6 +106,12 @@ void setTreeLimit(int treeLimit) { this.treeLimit = treeLimit; } + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } + /** The mode of inference for OptionMask. */ public enum Mode { DEFAULT(0), diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java index 4d6c5be9be0..27959169c68 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxSymbolBlock.java @@ -21,6 +21,7 @@ import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractSymbolBlock; import ai.djl.nn.Parameter; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; @@ -31,6 +32,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -53,6 +55,7 @@ public class MxSymbolBlock extends AbstractSymbolBlock { private CachedOp op; private Symbol symbol; private List mxNetParams; // includes input data + private Map parameters; private Map paramShapes; private Shape[] outputShapes; private PairList inputDescriptions; @@ -94,9 +97,10 @@ public void setInputNames(List inputNames) { // now that we know which of the parameters are just input placeholders and which // are trainable, add them properly so they are correctly handled Set nameLookup = new HashSet<>(inputNames); + parameters = new LinkedHashMap<>(mxNetParams.size()); for (Parameter mxNetParameter : mxNetParams) { if (!nameLookup.contains(mxNetParameter.getName())) { - addParameter(mxNetParameter); + parameters.put(mxNetParameter.getName(), mxNetParameter); } } } @@ -156,6 +160,12 @@ public PairList describeInput() { return inputDescriptions; } + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + return new ParameterList(parameters); + } + /** {@inheritDoc} */ @Override public PairList describeOutput() { diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java index 3e2b9d76fc4..b9dad479b86 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtSymbolBlock.java @@ -20,6 +20,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; @@ -174,4 +175,10 @@ public void close() { } } } + + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } } diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java index 9fb7b5264af..cb75ce6d7cf 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpSymbolBlock.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.paddlepaddle.jni.JniUtils; import ai.djl.training.ParameterStore; @@ -73,6 +74,12 @@ private PpNDArray[] getInputs(PpNDManager sub, NDList inputs) { return inputArray; } + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } + /** {@inheritDoc} */ @Override public Shape[] getOutputShapes(Shape[] inputShapes) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java index 0bde372bdf1..89463bef0ea 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtSymbolBlock.java @@ -18,6 +18,8 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.Parameter; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.pytorch.jni.IValue; import ai.djl.pytorch.jni.IValueUtils; @@ -27,6 +29,8 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.Map; import java.util.concurrent.atomic.AtomicReference; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,6 +53,7 @@ public class PtSymbolBlock extends AbstractSymbolBlock implements AutoCloseable private PairList inputDescriptions; private PairList outputDescriptions; private boolean first; + private Map parameters; /** * Constructs a {@code PtSymbolBlock}. @@ -146,6 +151,43 @@ public PairList describeInput() { return inputDescriptions; } + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + if (parameters == null) { + NDList params = JniUtils.moduleGetParams(this, manager); + parameters = new LinkedHashMap<>(params.size()); + for (NDArray param : params) { + parameters.put( + param.getName(), + Parameter.builder() + .setName(param.getName()) + .setType(inferType(param.getName())) + .optArray(param) + .build()); + } + } + // Defensive copy + return new ParameterList(parameters); + } + + private static Parameter.Type inferType(String name) { + if (name.contains("bias")) { + return Parameter.Type.BIAS; + } else if (name.contains("gamma")) { + return Parameter.Type.GAMMA; + } else if (name.contains("beta")) { + return Parameter.Type.BETA; + } else if (name.contains("moving_mean") || name.contains("running_mean")) { + return Parameter.Type.RUNNING_MEAN; + } else if (name.contains("moving_var") || name.contains("running_var")) { + return Parameter.Type.RUNNING_VAR; + } else if (name.contains("weight")) { + return Parameter.Type.WEIGHT; + } + return Parameter.Type.OTHER; + } + /** {@inheritDoc} */ @Override public PairList describeOutput() { diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java index 5213d2fed3a..f0819808a9c 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java @@ -19,6 +19,7 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.tensorflow.engine.javacpp.JavacppUtils; import ai.djl.training.ParameterStore; @@ -192,6 +193,12 @@ public final PairList describeInput() { return inputDescriptions; } + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } + /** {@inheritDoc} */ @Override public final PairList describeOutput() { diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtSymbolBlock.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtSymbolBlock.java index c032515d61d..232a7fb8dbb 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtSymbolBlock.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtSymbolBlock.java @@ -15,6 +15,7 @@ import ai.djl.ndarray.NDList; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.tensorrt.jni.JniUtils; import ai.djl.training.ParameterStore; @@ -66,4 +67,10 @@ TrtSession createSession(TrtNDManager manager) { long session = JniUtils.createSession(handle.get()); return new TrtSession(manager, handle.get(), session); } + + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } } diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteSymbolBlock.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteSymbolBlock.java index 942bdb414c7..360801d72f8 100644 --- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteSymbolBlock.java +++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteSymbolBlock.java @@ -16,6 +16,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; import ai.djl.nn.SymbolBlock; import ai.djl.training.ParameterStore; import ai.djl.util.PairList; @@ -65,4 +66,10 @@ protected NDList forwardInternal( public void close() { interpreter.close(); } + + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } } diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java index c7ffec04a95..e450307a0b2 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java @@ -14,6 +14,7 @@ import ai.djl.fasttext.jni.FtWrapper; import ai.djl.nn.AbstractSymbolBlock; +import ai.djl.nn.ParameterList; import java.nio.file.Path; /** @@ -55,6 +56,12 @@ public float[] embedWord(String word) { return fta.getWordVector(word); } + /** {@inheritDoc} */ + @Override + public ParameterList getDirectParameters() { + throw new UnsupportedOperationException("Not yet supported"); + } + @Override public void close() { fta.unloadModel();