-
Notifications
You must be signed in to change notification settings - Fork 654
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor ml extension support (#1521)
* Refactor ml extension support This makes a few changes, specifically to fasttext, but it can apply to other DJL ml wrappers as well. First, it creates several passthrough utility classes with the goal that ml models can be loaded through the model zoo and run through the predictor. Next, it modifies the base of the ml construct to better support engines that have multiple applications. Each application now manifests as a type of SymbolBlock rather than a model. Then, the single model class can run any of them. The model is still created for each engine because it contains general loading functionality that can determine which block should be used to load the given target. It also has to update the fasttext JNI to support this. First, it fixes the modelType to actually return different results. Then, it modifies the predictProba method to use an ArrayList instead of an Array. The main difference is because it is possible to pass a topk of -1 in order to load all elements. However, this doesn't work with the previous array setup. * Add TrainFastText utility to aggregate links to training functions
- Loading branch information
Showing
17 changed files
with
762 additions
and
107 deletions.
There are no files selected for viewing
64 changes: 64 additions & 0 deletions
64
api/src/main/java/ai/djl/util/passthrough/PassthroughNDArray.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
/* | ||
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.util.passthrough; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDArrayAdapter; | ||
import java.nio.ByteBuffer; | ||
|
||
/** | ||
* An {@link NDArray} that stores an arbitrary Java object. | ||
* | ||
* <p>This class is mainly for use in extensions and hybrid engines. Despite it's name, it will | ||
* often not contain actual {@link NDArray}s but just any object necessary to conform to the DJL | ||
* predictor API. | ||
*/ | ||
public class PassthroughNDArray extends NDArrayAdapter { | ||
|
||
private Object object; | ||
|
||
/** | ||
* Constructs a {@link PassthroughNDArray} storing an object. | ||
* | ||
* @param object the object to store | ||
*/ | ||
public PassthroughNDArray(Object object) { | ||
super(null, null, null, null, null); | ||
this.object = object; | ||
} | ||
|
||
/** | ||
* Returns the object stored. | ||
* | ||
* @return the object stored | ||
*/ | ||
public Object getObject() { | ||
return object; | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public ByteBuffer toByteBuffer() { | ||
throw new UnsupportedOperationException("Operation not supported for FastText"); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public void intern(NDArray replaced) { | ||
throw new UnsupportedOperationException("Operation not supported for FastText"); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public void detach() {} | ||
} |
209 changes: 209 additions & 0 deletions
209
api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
/* | ||
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.util.passthrough; | ||
|
||
import ai.djl.Device; | ||
import ai.djl.engine.Engine; | ||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDList; | ||
import ai.djl.ndarray.NDManager; | ||
import ai.djl.ndarray.NDResource; | ||
import ai.djl.ndarray.types.DataType; | ||
import ai.djl.ndarray.types.Shape; | ||
import ai.djl.util.PairList; | ||
import java.nio.Buffer; | ||
import java.nio.ByteBuffer; | ||
import java.nio.charset.Charset; | ||
import java.nio.file.Path; | ||
|
||
/** An {@link NDManager} that does nothing, for use in extensions and hybrid engines. */ | ||
public final class PassthroughNDManager implements NDManager { | ||
|
||
private static final String UNSUPPORTED = "Not supported by PassthroughNDManager"; | ||
public static final PassthroughNDManager INSTANCE = new PassthroughNDManager(); | ||
|
||
private PassthroughNDManager() {} | ||
|
||
@Override | ||
public Device defaultDevice() { | ||
return Device.cpu(); | ||
} | ||
|
||
@Override | ||
public ByteBuffer allocateDirect(int capacity) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray from(NDArray array) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray create(String[] data, Charset charset, Shape shape) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray create(Shape shape, DataType dataType) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray createRowSparse(Buffer data, Shape dataShape, long[] indices, Shape shape) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray createCoo(Buffer data, long[][] indices, Shape shape) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDList load(Path path) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public void setName(String name) {} | ||
|
||
@Override | ||
public String getName() { | ||
return "PassthroughNDManager"; | ||
} | ||
|
||
@Override | ||
public NDArray zeros(Shape shape, DataType dataType) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray ones(Shape shape, DataType dataType) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray full(Shape shape, float value, DataType dataType) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray arange(float start, float stop, float step, DataType dataType) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray eye(int rows, int cols, int k, DataType dataType) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray linspace(float start, float stop, int num, boolean endpoint) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray randomInteger(long low, long high, Shape shape, DataType dataType) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataType) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray randomMultinomial(int n, NDArray pValues) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDArray randomMultinomial(int n, NDArray pValues, Shape shape) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public boolean isOpen() { | ||
return true; | ||
} | ||
|
||
@Override | ||
public NDManager getParentManager() { | ||
return this; | ||
} | ||
|
||
@Override | ||
public NDManager newSubManager() { | ||
return this; | ||
} | ||
|
||
@Override | ||
public NDManager newSubManager(Device device) { | ||
return this; | ||
} | ||
|
||
@Override | ||
public Device getDevice() { | ||
return Device.cpu(); | ||
} | ||
|
||
@Override | ||
public void attachInternal(String resourceId, AutoCloseable resource) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public void tempAttachInternal( | ||
NDManager originalManager, String resourceId, NDResource resource) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public void detachInternal(String resourceId) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public void invoke( | ||
String operation, NDArray[] src, NDArray[] dest, PairList<String, ?> params) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public NDList invoke(String operation, NDList src, PairList<String, ?> params) { | ||
throw new UnsupportedOperationException(UNSUPPORTED); | ||
} | ||
|
||
@Override | ||
public Engine getEngine() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public void close() {} | ||
} |
38 changes: 38 additions & 0 deletions
38
api/src/main/java/ai/djl/util/passthrough/PassthroughTranslator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
/* | ||
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.util.passthrough; | ||
|
||
import ai.djl.ndarray.NDList; | ||
import ai.djl.translate.NoBatchifyTranslator; | ||
import ai.djl.translate.TranslatorContext; | ||
|
||
/** | ||
* A translator that stores and removes data from a {@link PassthroughNDArray}. | ||
* | ||
* @param <I> translator input type | ||
* @param <O> translator output type | ||
*/ | ||
public class PassthroughTranslator<I, O> implements NoBatchifyTranslator<I, O> { | ||
|
||
@Override | ||
public NDList processInput(TranslatorContext ctx, I input) throws Exception { | ||
return new NDList(new PassthroughNDArray(input)); | ||
} | ||
|
||
@Override | ||
@SuppressWarnings("unchecked") | ||
public O processOutput(TranslatorContext ctx, NDList list) { | ||
PassthroughNDArray wrapper = (PassthroughNDArray) list.singletonOrThrow(); | ||
return (O) wrapper.getObject(); | ||
} | ||
} |
15 changes: 15 additions & 0 deletions
15
api/src/main/java/ai/djl/util/passthrough/package-info.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
/* | ||
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
|
||
/** Contains passthrough DJL classes for use in extensions and hybrid engines. */ | ||
package ai.djl.util.passthrough; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.