Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate AbstractSymbolBlock from AbstractBlock #1555

Merged
merged 2 commits into from
Apr 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
417 changes: 417 additions & 0 deletions api/src/main/java/ai/djl/nn/AbstractBaseBlock.java

Large diffs are not rendered by default.

378 changes: 5 additions & 373 deletions api/src/main/java/ai/djl/nn/AbstractBlock.java

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import ai.djl.ndarray.types.Shape;

/** {@code AbstractSymbolBlock} is an abstract implementation of {@link SymbolBlock}. */
public abstract class AbstractSymbolBlock extends AbstractBlock implements SymbolBlock {
public abstract class AbstractSymbolBlock extends AbstractBaseBlock implements SymbolBlock {

/** Constructs a new {@code AbstractSymbolBlock} instance. */
public AbstractSymbolBlock() {}
Expand All @@ -34,4 +34,10 @@ public AbstractSymbolBlock(byte version) {
public Shape[] getOutputShapes(Shape[] inputShapes) {
throw new UnsupportedOperationException("not implement!");
}

/** {@inheritDoc} */
@Override
public BlockList getChildren() {
return new BlockList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.dlr.jni.JniUtils;
import ai.djl.ndarray.NDList;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
Expand Down Expand Up @@ -75,4 +76,10 @@ public void close() {
JniUtils.deleteDlrModel(pointer);
}
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("Not yet supported");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
Expand Down Expand Up @@ -105,6 +106,12 @@ void setTreeLimit(int treeLimit) {
this.treeLimit = treeLimit;
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("Not yet supported");
}

/** The mode of inference for OptionMask. */
public enum Mode {
DEFAULT(0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
Expand All @@ -31,6 +32,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -53,6 +55,7 @@ public class MxSymbolBlock extends AbstractSymbolBlock {
private CachedOp op;
private Symbol symbol;
private List<Parameter> mxNetParams; // includes input data
private Map<String, Parameter> parameters;
private Map<String, Shape> paramShapes;
private Shape[] outputShapes;
private PairList<String, Shape> inputDescriptions;
Expand Down Expand Up @@ -94,9 +97,10 @@ public void setInputNames(List<String> inputNames) {
// now that we know which of the parameters are just input placeholders and which
// are trainable, add them properly so they are correctly handled
Set<String> nameLookup = new HashSet<>(inputNames);
parameters = new LinkedHashMap<>(mxNetParams.size());
for (Parameter mxNetParameter : mxNetParams) {
if (!nameLookup.contains(mxNetParameter.getName())) {
addParameter(mxNetParameter);
parameters.put(mxNetParameter.getName(), mxNetParameter);
}
}
}
Expand Down Expand Up @@ -156,6 +160,11 @@ public PairList<String, Shape> describeInput() {
return inputDescriptions;
}

@Override
zachgk marked this conversation as resolved.
Show resolved Hide resolved
public ParameterList getDirectParameters() {
return new ParameterList(parameters);
}

/** {@inheritDoc} */
@Override
public PairList<String, Shape> describeOutput() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
Expand Down Expand Up @@ -174,4 +175,10 @@ public void close() {
}
}
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("Not yet supported");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.paddlepaddle.jni.JniUtils;
import ai.djl.training.ParameterStore;
Expand Down Expand Up @@ -73,6 +74,12 @@ private PpNDArray[] getInputs(PpNDManager sub, NDList inputs) {
return inputArray;
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("Not yet supported");
}

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.pytorch.jni.IValue;
import ai.djl.pytorch.jni.IValueUtils;
Expand All @@ -27,6 +29,8 @@
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -49,6 +53,7 @@ public class PtSymbolBlock extends AbstractSymbolBlock implements AutoCloseable
private PairList<String, Shape> inputDescriptions;
private PairList<String, Shape> outputDescriptions;
private boolean first;
private Map<String, Parameter> parameters;

/**
* Constructs a {@code PtSymbolBlock}.
Expand Down Expand Up @@ -146,6 +151,42 @@ public PairList<String, Shape> describeInput() {
return inputDescriptions;
}

@Override
zachgk marked this conversation as resolved.
Show resolved Hide resolved
public ParameterList getDirectParameters() {
if (parameters == null) {
NDList params = JniUtils.moduleGetParams(this, manager);
parameters = new LinkedHashMap<>(params.size());
for (NDArray param : params) {
parameters.put(
param.getName(),
Parameter.builder()
.setName(param.getName())
.setType(inferType(param.getName()))
.optArray(param)
.build());
}
}
// Defensive copy
return new ParameterList(parameters);
}

private static Parameter.Type inferType(String name) {
if (name.contains("bias")) {
return Parameter.Type.BIAS;
} else if (name.contains("gamma")) {
return Parameter.Type.GAMMA;
} else if (name.contains("beta")) {
return Parameter.Type.BETA;
} else if (name.contains("moving_mean") || name.contains("running_mean")) {
return Parameter.Type.RUNNING_MEAN;
} else if (name.contains("moving_var") || name.contains("running_var")) {
return Parameter.Type.RUNNING_VAR;
} else if (name.contains("weight")) {
return Parameter.Type.WEIGHT;
}
return Parameter.Type.OTHER;
}

/** {@inheritDoc} */
@Override
public PairList<String, Shape> describeOutput() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.tensorflow.engine.javacpp.JavacppUtils;
import ai.djl.training.ParameterStore;
Expand Down Expand Up @@ -192,6 +193,12 @@ public final PairList<String, Shape> describeInput() {
return inputDescriptions;
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("Not yet supported");
}

/** {@inheritDoc} */
@Override
public final PairList<String, Shape> describeOutput() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import ai.djl.ndarray.NDList;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.tensorrt.jni.JniUtils;
import ai.djl.training.ParameterStore;
Expand Down Expand Up @@ -66,4 +67,10 @@ TrtSession createSession(TrtNDManager manager) {
long session = JniUtils.createSession(handle.get());
return new TrtSession(manager, handle.get(), session);
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("Not yet supported");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
Expand Down Expand Up @@ -65,4 +66,10 @@ protected NDList forwardInternal(
public void close() {
interpreter.close();
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("Not yet supported");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.fasttext.jni.FtWrapper;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import java.nio.file.Path;

/**
Expand Down Expand Up @@ -55,6 +56,12 @@ public float[] embedWord(String word) {
return fta.getWordVector(word);
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("Not yet supported");
}

@Override
public void close() {
fta.unloadModel();
Expand Down