Skip to content

Commit

Permalink
Separate AbstractSymbolBlock from AbstractBlock
Browse files Browse the repository at this point in the history
Right now, the AbstractSymbolBlock inherits from AbstractBlock. However, the
AbstractBlock parameters system and getParameters implies that it is always
possible to get the parameters for a block. While this should be true for a
correctly written DJL block, it is not always true to symbol blocks.

So, this change separates them to ensure that trying to get parameters when it
is not possible returns the exception reflecting that the operation is
unsupported. The shared functionality between AbstractSymbolBlock and
AbstractBlock was moved to a common base class, AbstractBaseBlock.
  • Loading branch information
zachgk committed Apr 7, 2022
1 parent 1d62d6d commit f3ddb85
Show file tree
Hide file tree
Showing 13 changed files with 536 additions and 375 deletions.
417 changes: 417 additions & 0 deletions api/src/main/java/ai/djl/nn/AbstractBaseBlock.java

Large diffs are not rendered by default.

378 changes: 5 additions & 373 deletions api/src/main/java/ai/djl/nn/AbstractBlock.java

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import ai.djl.ndarray.types.Shape;

/** {@code AbstractSymbolBlock} is an abstract implementation of {@link SymbolBlock}. */
public abstract class AbstractSymbolBlock extends AbstractBlock implements SymbolBlock {
public abstract class AbstractSymbolBlock extends AbstractBaseBlock implements SymbolBlock {

/** Constructs a new {@code AbstractSymbolBlock} instance. */
public AbstractSymbolBlock() {}
Expand All @@ -34,4 +34,10 @@ public AbstractSymbolBlock(byte version) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
throw new UnsupportedOperationException("not implement!");
}

/** {@inheritDoc} */
@Override
public BlockList getChildren() {
return new BlockList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.dlr.jni.JniUtils;
import ai.djl.ndarray.NDList;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
Expand Down Expand Up @@ -75,4 +76,10 @@ public void close() {
JniUtils.deleteDlrModel(pointer);
}
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("Not yet supported");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
Expand Down Expand Up @@ -105,6 +106,12 @@ void setTreeLimit(int treeLimit) {
this.treeLimit = treeLimit;
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("Not yet supported");
}

/** The mode of inference for OptionMask. */
public enum Mode {
DEFAULT(0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
Expand All @@ -31,6 +32,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -53,6 +55,7 @@ public class MxSymbolBlock extends AbstractSymbolBlock {
private CachedOp op;
private Symbol symbol;
private List<Parameter> mxNetParams; // includes input data
private Map<String, Parameter> parameters;
private Map<String, Shape> paramShapes;
private Shape[] outputShapes;
private PairList<String, Shape> inputDescriptions;
Expand Down Expand Up @@ -94,9 +97,10 @@ public void setInputNames(List<String> inputNames) {
// now that we know which of the parameters are just input placeholders and which
// are trainable, add them properly so they are correctly handled
Set<String> nameLookup = new HashSet<>(inputNames);
parameters = new LinkedHashMap<>(mxNetParams.size());
for (Parameter mxNetParameter : mxNetParams) {
if (!nameLookup.contains(mxNetParameter.getName())) {
addParameter(mxNetParameter);
parameters.put(mxNetParameter.getName(), mxNetParameter);
}
}
}
Expand Down Expand Up @@ -156,6 +160,11 @@ public PairList<String, Shape> describeInput() {
return inputDescriptions;
}

@Override
public ParameterList getDirectParameters() {
return new ParameterList(parameters);
}

/** {@inheritDoc} */
@Override
public PairList<String, Shape> describeOutput() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
Expand Down Expand Up @@ -174,4 +175,10 @@ public void close() {
}
}
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("Not yet supported");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.paddlepaddle.jni.JniUtils;
import ai.djl.training.ParameterStore;
Expand Down Expand Up @@ -73,6 +74,12 @@ private PpNDArray[] getInputs(PpNDManager sub, NDList inputs) {
return inputArray;
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("Not yet supported");
}

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.pytorch.jni.IValue;
import ai.djl.pytorch.jni.IValueUtils;
Expand All @@ -27,6 +29,8 @@
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -49,6 +53,7 @@ public class PtSymbolBlock extends AbstractSymbolBlock implements AutoCloseable
private PairList<String, Shape> inputDescriptions;
private PairList<String, Shape> outputDescriptions;
private boolean first;
private Map<String, Parameter> parameters;

/**
* Constructs a {@code PtSymbolBlock}.
Expand Down Expand Up @@ -146,6 +151,42 @@ public PairList<String, Shape> describeInput() {
return inputDescriptions;
}

@Override
public ParameterList getDirectParameters() {
if (parameters == null) {
NDList params = JniUtils.moduleGetParams(this, manager);
parameters = new LinkedHashMap<>(params.size());
for (NDArray param : params) {
parameters.put(
param.getName(),
Parameter.builder()
.setName(param.getName())
.setType(inferType(param.getName()))
.optArray(param)
.build());
}
}
// Defensive copy
return new ParameterList(parameters);
}

private static Parameter.Type inferType(String name) {
if (name.contains("bias")) {
return Parameter.Type.BIAS;
} else if (name.contains("gamma")) {
return Parameter.Type.GAMMA;
} else if (name.contains("beta")) {
return Parameter.Type.BETA;
} else if (name.contains("moving_mean") || name.contains("running_mean")) {
return Parameter.Type.RUNNING_MEAN;
} else if (name.contains("moving_var") || name.contains("running_var")) {
return Parameter.Type.RUNNING_VAR;
} else if (name.contains("weight")) {
return Parameter.Type.WEIGHT;
}
return Parameter.Type.OTHER;
}

/** {@inheritDoc} */
@Override
public PairList<String, Shape> describeOutput() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.tensorflow.engine.javacpp.JavacppUtils;
import ai.djl.training.ParameterStore;
Expand Down Expand Up @@ -192,6 +193,12 @@ public final PairList<String, Shape> describeInput() {
return inputDescriptions;
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("Not yet supported");
}

/** {@inheritDoc} */
@Override
public final PairList<String, Shape> describeOutput() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import ai.djl.ndarray.NDList;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.tensorrt.jni.JniUtils;
import ai.djl.training.ParameterStore;
Expand Down Expand Up @@ -66,4 +67,10 @@ TrtSession createSession(TrtNDManager manager) {
long session = JniUtils.createSession(handle.get());
return new TrtSession(manager, handle.get(), session);
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("Not yet supported");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
Expand Down Expand Up @@ -65,4 +66,10 @@ protected NDList forwardInternal(
public void close() {
interpreter.close();
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("Not yet supported");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.fasttext.jni.FtWrapper;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import java.nio.file.Path;

/**
Expand Down Expand Up @@ -55,6 +56,12 @@ public float[] embedWord(String word) {
return fta.getWordVector(word);
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("Not yet supported");
}

@Override
public void close() {
fta.unloadModel();
Expand Down

0 comments on commit f3ddb85

Please sign in to comment.