Skip to content

Commit

Permalink
Refactor ml extension support (#1521)
Browse files Browse the repository at this point in the history
* Refactor ml extension support

This makes a few changes, specifically to fasttext, but it can apply to other
DJL ml wrappers as well. First, it creates several passthrough utility classes
with the goal that ml models can be loaded through the model zoo and run through
the predictor.

Next, it modifies the base of the ml construct to better support engines that
have multiple applications. Each application now manifests as a type of
SymbolBlock rather than a model. Then, the single model class can run any of
them. The model is still created for each engine because it contains general
loading functionality that can determine which block should be used to load the
given target.

It also has to update the fasttext JNI to support this. First, it fixes the
modelType to actually return different results. Then, it modifies the
predictProba method to use an ArrayList instead of an Array. The main difference
is because it is possible to pass a topk of -1 in order to load all elements.
However, this doesn't work with the previous array setup.

* Add TrainFastText utility to aggregate links to training functions
  • Loading branch information
zachgk authored Mar 10, 2022
1 parent 09192df commit d9fb99f
Show file tree
Hide file tree
Showing 17 changed files with 762 additions and 107 deletions.
64 changes: 64 additions & 0 deletions api/src/main/java/ai/djl/util/passthrough/PassthroughNDArray.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.util.passthrough;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrayAdapter;
import java.nio.ByteBuffer;

/**
* An {@link NDArray} that stores an arbitrary Java object.
*
* <p>This class is mainly for use in extensions and hybrid engines. Despite it's name, it will
* often not contain actual {@link NDArray}s but just any object necessary to conform to the DJL
* predictor API.
*/
public class PassthroughNDArray extends NDArrayAdapter {

private Object object;

/**
* Constructs a {@link PassthroughNDArray} storing an object.
*
* @param object the object to store
*/
public PassthroughNDArray(Object object) {
super(null, null, null, null, null);
this.object = object;
}

/**
* Returns the object stored.
*
* @return the object stored
*/
public Object getObject() {
return object;
}

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
throw new UnsupportedOperationException("Operation not supported for FastText");
}

/** {@inheritDoc} */
@Override
public void intern(NDArray replaced) {
throw new UnsupportedOperationException("Operation not supported for FastText");
}

/** {@inheritDoc} */
@Override
public void detach() {}
}
209 changes: 209 additions & 0 deletions api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.util.passthrough;

import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDResource;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.PairList;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.nio.file.Path;

/** An {@link NDManager} that does nothing, for use in extensions and hybrid engines. */
public final class PassthroughNDManager implements NDManager {

private static final String UNSUPPORTED = "Not supported by PassthroughNDManager";
public static final PassthroughNDManager INSTANCE = new PassthroughNDManager();

private PassthroughNDManager() {}

@Override
public Device defaultDevice() {
return Device.cpu();
}

@Override
public ByteBuffer allocateDirect(int capacity) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray from(NDArray array) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray create(String[] data, Charset charset, Shape shape) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray create(Shape shape, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray createRowSparse(Buffer data, Shape dataShape, long[] indices, Shape shape) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray createCoo(Buffer data, long[][] indices, Shape shape) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDList load(Path path) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public void setName(String name) {}

@Override
public String getName() {
return "PassthroughNDManager";
}

@Override
public NDArray zeros(Shape shape, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray ones(Shape shape, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray full(Shape shape, float value, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray arange(float start, float stop, float step, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray eye(int rows, int cols, int k, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray linspace(float start, float stop, int num, boolean endpoint) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray randomInteger(long low, long high, Shape shape, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray randomMultinomial(int n, NDArray pValues) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDArray randomMultinomial(int n, NDArray pValues, Shape shape) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public boolean isOpen() {
return true;
}

@Override
public NDManager getParentManager() {
return this;
}

@Override
public NDManager newSubManager() {
return this;
}

@Override
public NDManager newSubManager(Device device) {
return this;
}

@Override
public Device getDevice() {
return Device.cpu();
}

@Override
public void attachInternal(String resourceId, AutoCloseable resource) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public void tempAttachInternal(
NDManager originalManager, String resourceId, NDResource resource) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public void detachInternal(String resourceId) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public void invoke(
String operation, NDArray[] src, NDArray[] dest, PairList<String, ?> params) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public NDList invoke(String operation, NDList src, PairList<String, ?> params) {
throw new UnsupportedOperationException(UNSUPPORTED);
}

@Override
public Engine getEngine() {
return null;
}

@Override
public void close() {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.util.passthrough;

import ai.djl.ndarray.NDList;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslatorContext;

/**
* A translator that stores and removes data from a {@link PassthroughNDArray}.
*
* @param <I> translator input type
* @param <O> translator output type
*/
public class PassthroughTranslator<I, O> implements NoBatchifyTranslator<I, O> {

@Override
public NDList processInput(TranslatorContext ctx, I input) throws Exception {
return new NDList(new PassthroughNDArray(input));
}

@Override
@SuppressWarnings("unchecked")
public O processOutput(TranslatorContext ctx, NDList list) {
PassthroughNDArray wrapper = (PassthroughNDArray) list.singletonOrThrow();
return (O) wrapper.getObject();
}
}
15 changes: 15 additions & 0 deletions api/src/main/java/ai/djl/util/passthrough/package-info.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

/** Contains passthrough DJL classes for use in extensions and hybrid engines. */
package ai.djl.util.passthrough;
6 changes: 4 additions & 2 deletions extensions/fasttext/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
This module contains the NLP support with fastText implementation.

fastText module's implementation in DJL is not considered as an Engine, it doesn't support Trainer and Predictor.
The training and inference functionality is directly provided through [FtModel](https://javadoc.io/doc/ai.djl.fasttext/fasttext-engine/latest/ai/djl/fasttext/FtModel.html)
class. You can find examples [here](https://github.com/deepjavalibrary/djl/blob/master/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java).
Training is only supported by using [TrainFastText](https://javadoc.io/doc/ai.djl.fasttext/fasttext-engine/latest/ai/djl/fasttext/TrainFastText.html).
This produces a special block which can perform inference on its own or by using a model and predictor.
Pre-trained FastText models can also be loaded by using the standard DJL criteria.
You can find examples [here](https://github.com/deepjavalibrary/djl/blob/master/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java).

Current implementation has the following limitations:

Expand Down
Loading

0 comments on commit d9fb99f

Please sign in to comment.