Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reafactor SymbolBlock with AbstractSymbolBlock #491

Merged
merged 1 commit into from
Jan 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.nn;

import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;

/** {@code AbstractSymbolBlock} is an abstract implementation of {@link SymbolBlock}. */
public abstract class AbstractSymbolBlock extends AbstractBlock implements SymbolBlock {

/**
* Builds an empty block with the given version for parameter serialization.
*
* @param version the version to use for parameter serialization.
*/
public AbstractSymbolBlock(byte version) {
super(version);
}

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
throw new UnsupportedOperationException("not implement!");
}
}
8 changes: 6 additions & 2 deletions api/src/main/java/ai/djl/nn/SymbolBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@
public interface SymbolBlock extends Block {

/** Removes the last block in the symbolic graph. */
void removeLastBlock();
default void removeLastBlock() {
throw new UnsupportedOperationException("not supported");
}

/**
* Returns a {@link PairList} of output names and shapes stored in model file.
*
* @return the {@link PairList} of output names, and shapes
*/
PairList<String, Shape> describeOutput();
default PairList<String, Shape> describeOutput() {
throw new UnsupportedOperationException("not supported");
}
}
2 changes: 2 additions & 0 deletions api/src/main/java/ai/djl/util/NativeResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
/**
* {@code NativeResource} is an internal class for {@link AutoCloseable} blocks of memory created in
* the different engines.
*
* @param <T> the resource that could map to a native pointer or java object
*/
public abstract class NativeResource<T> implements AutoCloseable {

Expand Down
115 changes: 8 additions & 107 deletions dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrSymbolBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,22 @@
import ai.djl.dlr.jni.JniUtils;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.BlockList;
import ai.djl.nn.ParameterList;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.NativeResource;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.concurrent.atomic.AtomicReference;

/**
* {@code DlrSymbolBlock} is the DLR implementation of {@link SymbolBlock}.
*
* <p>You can create a {@code DlrSymbolBlock} using {@link ai.djl.Model#load(java.nio.file.Path,
* String)}.
*/
public class DlrSymbolBlock extends NativeResource<Long> implements SymbolBlock {
public class DlrSymbolBlock extends AbstractSymbolBlock implements AutoCloseable {

private static final byte VERSION = 1;
private AtomicReference<Long> handle;
/**
* Constructs a {@code DlrSymbolBlock}.
*
Expand All @@ -45,13 +41,8 @@ public class DlrSymbolBlock extends NativeResource<Long> implements SymbolBlock
* @param handle the handle for native DLR model
*/
public DlrSymbolBlock(long handle) {
super(handle);
}

/** {@inheritDoc} */
@Override
public void removeLastBlock() {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
super(VERSION);
this.handle = new AtomicReference<>(handle);
}

/** {@inheritDoc} */
Expand All @@ -61,7 +52,7 @@ public NDList forward(
NDList inputs,
boolean training,
PairList<String, Object> params) {
long modelHandle = getHandle();
long modelHandle = handle.get();
NDManager manager = inputs.head().getManager();
// TODO maybe verify the number of inputs
// currently we assume the order of the input NDList is the same
Expand All @@ -73,96 +64,6 @@ public NDList forward(
return JniUtils.getDlrOutputs(modelHandle, manager);
}

/** {@inheritDoc} */
@Override
public void setInitializer(Initializer initializer) {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public void setInitializer(Initializer initializer, String paramName) {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public boolean isInitialized() {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public void cast(DataType dataType) {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public void clear() {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public PairList<String, Shape> describeInput() {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public PairList<String, Shape> describeOutput() {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public BlockList getChildren() {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public ParameterList getParameters() {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public Shape getParameterShape(String name, Shape[] inputShapes) {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public void saveParameters(DataOutputStream os) {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public void loadParameters(NDManager manager, DataInputStream is) {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public void close() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.nn.SymbolBlock;
Expand All @@ -43,7 +43,7 @@
* <p>You can create a {@code MxSymbolBlock} using {@link ai.djl.Model#load(java.nio.file.Path,
* String)}.
*/
public class MxSymbolBlock extends AbstractBlock implements SymbolBlock {
public class MxSymbolBlock extends AbstractSymbolBlock {

private static final byte VERSION = 2;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.BlockList;
import ai.djl.nn.ParameterList;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.PairList;
import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxSequence;
Expand All @@ -33,8 +31,6 @@
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.SequenceInfo;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -47,10 +43,10 @@
* <p>You can create a {@code OrtSymbolBlock} using {@link ai.djl.Model#load(java.nio.file.Path,
* String)}.
*/
public class OrtSymbolBlock implements SymbolBlock, AutoCloseable {
public class OrtSymbolBlock extends AbstractSymbolBlock implements AutoCloseable {

private static final byte VERSION = 1;
private OrtSession session;

/**
* Constructs a {@code OrtSymbolBlock}.
*
Expand All @@ -60,6 +56,7 @@ public class OrtSymbolBlock implements SymbolBlock, AutoCloseable {
* @param session the {@link OrtSession} contains the model information
*/
public OrtSymbolBlock(OrtSession session) {
super(VERSION);
this.session = session;
}

Expand Down Expand Up @@ -170,96 +167,6 @@ private NDArray seq2Nd(OnnxSequence seq, NDManager manager) {
}
}

/** {@inheritDoc} */
@Override
public void setInitializer(Initializer initializer) {
throw new UnsupportedOperationException("ONNX Runtime not supported");
}

/** {@inheritDoc} */
@Override
public void setInitializer(Initializer initializer, String paramName) {
throw new UnsupportedOperationException("ONNX Runtime not supported");
}

/** {@inheritDoc} */
@Override
public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
throw new UnsupportedOperationException("ONNX Runtime not supported");
}

/** {@inheritDoc} */
@Override
public boolean isInitialized() {
throw new UnsupportedOperationException("ONNX Runtime not supported");
}

/** {@inheritDoc} */
@Override
public void cast(DataType dataType) {
throw new UnsupportedOperationException("ONNX Runtime not supported");
}

/** {@inheritDoc} */
@Override
public void clear() {
throw new UnsupportedOperationException("ONNX Runtime not supported");
}

/** {@inheritDoc} */
@Override
public PairList<String, Shape> describeInput() {
throw new UnsupportedOperationException("ONNX Runtime not supported");
}

/** {@inheritDoc} */
@Override
public PairList<String, Shape> describeOutput() {
throw new UnsupportedOperationException("ONNX Runtime not supported");
}

/** {@inheritDoc} */
@Override
public BlockList getChildren() {
throw new UnsupportedOperationException("ONNX Runtime not supported");
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("ONNX Runtime not supported");
}

/** {@inheritDoc} */
@Override
public ParameterList getParameters() {
throw new UnsupportedOperationException("ONNX Runtime not supported");
}

/** {@inheritDoc} */
@Override
public Shape getParameterShape(String name, Shape[] inputShapes) {
throw new UnsupportedOperationException("ONNX Runtime not supported");
}

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
throw new UnsupportedOperationException("ONNX Runtime not supported");
}

/** {@inheritDoc} */
@Override
public void saveParameters(DataOutputStream os) {
throw new UnsupportedOperationException("ONNX Runtime not supported");
}

/** {@inheritDoc} */
@Override
public void loadParameters(NDManager manager, DataInputStream is) {
throw new UnsupportedOperationException("ONNX Runtime not supported");
}

/** {@inheritDoc} */
@Override
public void close() {
Expand Down
Loading