Skip to content

Commit

Permalink
Avoid load default engine if not being used (#1136)
Browse files Browse the repository at this point in the history
Change-Id: I9c14b470306d9c1c7226c50952e439e2842ad01c
  • Loading branch information
frankfliu authored Aug 2, 2021
1 parent 14bf240 commit 8183faf
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 7 deletions.
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/engine/Engine.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ private static synchronized String initEngine() {
*/
public abstract int getRank();

/**
* Returns the default Engine name.
*
* @return the default Engine name
*/
public static String getDefaultEngineName() {
return DEFAULT_ENGINE;
}

/**
* Returns the default Engine.
*
Expand Down
3 changes: 1 addition & 2 deletions api/src/main/java/ai/djl/modality/cv/util/NDImageUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
*/
package ai.djl.modality.cv.util;

import ai.djl.engine.Engine;
import ai.djl.modality.cv.Image;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.Shape;
Expand Down Expand Up @@ -109,7 +108,7 @@ public static NDArray normalize(NDArray input, float mean, float std) {
*/
public static NDArray normalize(NDArray input, float[] mean, float[] std) {
boolean chw = isCHW(input.getShape());
boolean tf = "TensorFlow".equals(Engine.getInstance().getEngineName());
boolean tf = "TensorFlow".equals(input.getManager().getEngine().getEngineName());
if ((chw && tf) || (!chw && !tf)) {
throw new IllegalArgumentException(
"normalize requires CHW format. TensorFlow requires HWC");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public <I, O> ZooModel<I, O> loadModel(Criteria<I, O> criteria)
// Otherwise check the modelzoo supported engine and grab a random engine in the list.
// Otherwise if none of them is specified or model zoo is null, go to default engine.
if (engine == null && modelZoo != null) {
String defaultEngine = Engine.getInstance().getEngineName();
String defaultEngine = Engine.getDefaultEngineName();
for (String supportedEngine : modelZoo.getSupportedEngines()) {
if (supportedEngine.equals(defaultEngine)) {
engine = supportedEngine;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,9 @@ public void onTrainingBegin(Trainer trainer) {
logger.info("Training on: {}.", devicesMsg);

long init = System.nanoTime();
String engineName = Engine.getInstance().getEngineName();
String version = Engine.getInstance().getVersion();
Engine engine = trainer.getManager().getEngine();
String engineName = engine.getEngineName();
String version = engine.getVersion();
long loaded = System.nanoTime();
logger.info(
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public class Arguments {
if (cmd.hasOption("engine")) {
engine = cmd.getOptionValue("engine");
} else {
engine = Engine.getInstance().getEngineName();
engine = Engine.getDefaultEngineName();
}

if (cmd.hasOption("duration")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,6 @@ public String getGroupId() {
/** {@inheritDoc} */
@Override
public Set<String> getSupportedEngines() {
return Collections.singleton(Engine.getInstance().getEngineName());
return Collections.singleton(Engine.getDefaultEngineName());
}
}

0 comments on commit 8183faf

Please sign in to comment.