Skip to content

Commit

Permalink
Reafactor SymbolBlock with AbstractSymbolBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Jan 7, 2021
1 parent 21d8050 commit 4598dba
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 487 deletions.
135 changes: 135 additions & 0 deletions api/src/main/java/ai/djl/nn/AbstractSymbolBlock.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.nn;

import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.ParameterStore;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;

/** {@code AbstractSymbolBlock} is an abstract implementation of {@link SymbolBlock}. */
public abstract class AbstractSymbolBlock implements SymbolBlock, AutoCloseable {

/** {@inheritDoc} */
@Override
public abstract NDList forward(
ParameterStore parameterStore,
NDList inputs,
boolean training,
PairList<String, Object> params);

/** {@inheritDoc} */
@Override
public void removeLastBlock() {
throw new UnsupportedOperationException("not supported");
}

/** {@inheritDoc} */
@Override
public void setInitializer(Initializer initializer) {
throw new UnsupportedOperationException("not supported");
}

/** {@inheritDoc} */
@Override
public void setInitializer(Initializer initializer, String paramName) {
throw new UnsupportedOperationException("not supported");
}

/** {@inheritDoc} */
@Override
public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
throw new UnsupportedOperationException("not supported");
}

/** {@inheritDoc} */
@Override
public boolean isInitialized() {
throw new UnsupportedOperationException("not supported");
}

/** {@inheritDoc} */
@Override
public void cast(DataType dataType) {
throw new UnsupportedOperationException("not supported");
}

/** {@inheritDoc} */
@Override
public void clear() {
throw new UnsupportedOperationException("not supported");
}

/** {@inheritDoc} */
@Override
public PairList<String, Shape> describeInput() {
throw new UnsupportedOperationException("not supported");
}

/** {@inheritDoc} */
@Override
public PairList<String, Shape> describeOutput() {
throw new UnsupportedOperationException("not supported");
}

/** {@inheritDoc} */
@Override
public BlockList getChildren() {
throw new UnsupportedOperationException("not supported");
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("not supported");
}

/** {@inheritDoc} */
@Override
public ParameterList getParameters() {
throw new UnsupportedOperationException("not supported");
}

/** {@inheritDoc} */
@Override
public Shape getParameterShape(String name, Shape[] inputShapes) {
throw new UnsupportedOperationException("not supported");
}

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
throw new UnsupportedOperationException("not supported");
}

/** {@inheritDoc} */
@Override
public void saveParameters(DataOutputStream os) {
throw new UnsupportedOperationException("not supported");
}

/** {@inheritDoc} */
@Override
public void loadParameters(NDManager manager, DataInputStream is) {
throw new UnsupportedOperationException("not supported");
}

/** {@inheritDoc} */
@Override
public abstract void close();
}
6 changes: 4 additions & 2 deletions api/src/main/java/ai/djl/util/NativeResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
import java.util.concurrent.atomic.AtomicReference;

/**
* {@code NativeResource} is an internal class for {@link AutoCloseable} blocks of memory created in
* the different engines.
* {@code Resource} is an internal class for {@link AutoCloseable} blocks of memory created in the
* different engines.
*
* @param <T> the resource that could map to a native pointer or java object
*/
public abstract class NativeResource<T> implements AutoCloseable {

Expand Down
113 changes: 6 additions & 107 deletions dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrSymbolBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,21 @@
import ai.djl.dlr.jni.JniUtils;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.BlockList;
import ai.djl.nn.ParameterList;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.NativeResource;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.concurrent.atomic.AtomicReference;

/**
* {@code DlrSymbolBlock} is the DLR implementation of {@link SymbolBlock}.
*
* <p>You can create a {@code DlrSymbolBlock} using {@link ai.djl.Model#load(java.nio.file.Path,
* String)}.
*/
public class DlrSymbolBlock extends NativeResource<Long> implements SymbolBlock {
public class DlrSymbolBlock extends AbstractSymbolBlock {

private AtomicReference<Long> handle;
/**
* Constructs a {@code DlrSymbolBlock}.
*
Expand All @@ -45,13 +40,7 @@ public class DlrSymbolBlock extends NativeResource<Long> implements SymbolBlock
* @param handle the handle for native DLR model
*/
public DlrSymbolBlock(long handle) {
super(handle);
}

/** {@inheritDoc} */
@Override
public void removeLastBlock() {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
this.handle = new AtomicReference<>(handle);
}

/** {@inheritDoc} */
Expand All @@ -61,7 +50,7 @@ public NDList forward(
NDList inputs,
boolean training,
PairList<String, Object> params) {
long modelHandle = getHandle();
long modelHandle = handle.get();
NDManager manager = inputs.head().getManager();
// TODO maybe verify the number of inputs
// currently we assume the order of the input NDList is the same
Expand All @@ -73,96 +62,6 @@ public NDList forward(
return JniUtils.getDlrOutputs(modelHandle, manager);
}

/** {@inheritDoc} */
@Override
public void setInitializer(Initializer initializer) {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public void setInitializer(Initializer initializer, String paramName) {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes) {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public boolean isInitialized() {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public void cast(DataType dataType) {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public void clear() {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public PairList<String, Shape> describeInput() {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public PairList<String, Shape> describeOutput() {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public BlockList getChildren() {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public ParameterList getDirectParameters() {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public ParameterList getParameters() {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public Shape getParameterShape(String name, Shape[] inputShapes) {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public void saveParameters(DataOutputStream os) {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public void loadParameters(NDManager manager, DataInputStream is) {
throw new UnsupportedOperationException("not supported for DlrSymbolBlock");
}

/** {@inheritDoc} */
@Override
public void close() {
Expand Down
Loading

0 comments on commit 4598dba

Please sign in to comment.