Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ml extension support #1521

Merged
merged 3 commits into from
Mar 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.util.passthrough;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrayAdapter;
import java.nio.ByteBuffer;

/**
* An {@link NDArray} that stores an arbitrary Java object.
*
* <p>This class is mainly for use in extensions and hybrid engines. Despite it's name, it will
* often not contain actual {@link NDArray}s but just any object necessary to conform to the DJL
* predictor API.
*/
public class PassthroughNDArray extends NDArrayAdapter {

private Object object;

/**
* Constructs a {@link PassthroughNDArray} storing an object.
*
* @param object the object to store
*/
public PassthroughNDArray(Object object) {
super(null, null, null, null, null);
this.object = object;
}

/**
* Returns the object stored.
*
* @return the object stored
*/
public Object getObject() {
return object;
}

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
throw new UnsupportedOperationException("Operation not supported for FastText");
}

/** {@inheritDoc} */
@Override
public void intern(NDArray replaced) {
throw new UnsupportedOperationException("Operation not supported for FastText");
}

/** {@inheritDoc} */
@Override
public void detach() {}
}
209 changes: 209 additions & 0 deletions api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.util.passthrough;

import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDResource;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.PairList;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.nio.file.Path;

/** An {@link NDManager} that does nothing, for use in extensions and hybrid engines. */
public final class PassthroughNDManager implements NDManager {

private static final String UNSUPPORTED = "Not supported by PassthroughNDManager";
public static final PassthroughNDManager INSTANCE = new PassthroughNDManager();

private PassthroughNDManager() {}

@Override
public Device defaultDevice() {
return Device.cpu();
}

@Override
public ByteBuffer allocateDirect(int capacity) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray from(NDArray array) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray create(String[] data, Charset charset, Shape shape) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray create(Shape shape, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray createRowSparse(Buffer data, Shape dataShape, long[] indices, Shape shape) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray createCoo(Buffer data, long[][] indices, Shape shape) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDList load(Path path) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public void setName(String name) {}

@Override
public String getName() {
return "PassthroughNDManager";
}

@Override
public NDArray zeros(Shape shape, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray ones(Shape shape, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray full(Shape shape, float value, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray arange(float start, float stop, float step, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray eye(int rows, int cols, int k, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray linspace(float start, float stop, int num, boolean endpoint) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray randomInteger(long low, long high, Shape shape, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray randomMultinomial(int n, NDArray pValues) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray randomMultinomial(int n, NDArray pValues, Shape shape) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public boolean isOpen() {
return true;
}

@Override
public NDManager getParentManager() {
return this;
}

@Override
public NDManager newSubManager() {
return this;
}

@Override
public NDManager newSubManager(Device device) {
return this;
}

@Override
public Device getDevice() {
return Device.cpu();
}

@Override
public void attachInternal(String resourceId, AutoCloseable resource) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public void tempAttachInternal(
NDManager originalManager, String resourceId, NDResource resource) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public void detachInternal(String resourceId) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public void invoke(
String operation, NDArray[] src, NDArray[] dest, PairList<String, ?> params) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDList invoke(String operation, NDList src, PairList<String, ?> params) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public Engine getEngine() {
return null;
}

@Override
public void close() {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.util.passthrough;

import ai.djl.ndarray.NDList;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslatorContext;

/**
* A translator that stores and removes data from a {@link PassthroughNDArray}.
*
* @param <I> translator input type
* @param <O> translator output type
*/
public class PassthroughTranslator<I, O> implements NoBatchifyTranslator<I, O> {

@Override
public NDList processInput(TranslatorContext ctx, I input) throws Exception {
return new NDList(new PassthroughNDArray(input));
}

@Override
@SuppressWarnings("unchecked")
public O processOutput(TranslatorContext ctx, NDList list) {
PassthroughNDArray wrapper = (PassthroughNDArray) list.singletonOrThrow();
return (O) wrapper.getObject();
}
}
15 changes: 15 additions & 0 deletions api/src/main/java/ai/djl/util/passthrough/package-info.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

/** Contains passthrough DJL classes for use in extensions and hybrid engines. */
package ai.djl.util.passthrough;
6 changes: 4 additions & 2 deletions extensions/fasttext/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
This module contains the NLP support with fastText implementation.

fastText module's implementation in DJL is not considered as an Engine, it doesn't support Trainer and Predictor.
The training and inference functionality is directly provided through [FtModel](https://javadoc.io/doc/ai.djl.fasttext/fasttext-engine/latest/ai/djl/fasttext/FtModel.html)
class. You can find examples [here](https://github.com/deepjavalibrary/djl/blob/master/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java).
Training is only supported by using [TrainFastText](https://javadoc.io/doc/ai.djl.fasttext/fasttext-engine/latest/ai/djl/fasttext/TrainFastText.html).
This produces a special block which can perform inference on its own or by using a model and predictor.
Pre-trained FastText models can also be loaded by using the standard DJL criteria.
You can find examples [here](https://github.com/deepjavalibrary/djl/blob/master/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java).

Current implementation has the following limitations:

Expand Down
Loading