From 660d6841ecdc9bdaed0a95003948022562dbce16 Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Wed, 4 Aug 2021 09:38:50 -0700 Subject: [PATCH] [Paddle] add remove pass option --- .../src/main/java/ai/djl/paddlepaddle/engine/PpModel.java | 6 ++++++ .../src/main/java/ai/djl/paddlepaddle/jni/JniUtils.java | 4 ++++ .../main/java/ai/djl/paddlepaddle/jni/PaddleLibrary.java | 2 ++ .../ai_djl_paddlepaddle_jni_PaddleLibrary_inference.cc | 6 ++++++ 4 files changed, 18 insertions(+) diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpModel.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpModel.java index feda8038bf8..7db41c4b9c9 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpModel.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpModel.java @@ -77,6 +77,12 @@ public void load(Path modelPath, String prefix, Map options) throws I if (System.getenv().containsKey("PADDLE_ENABLE_MKLDNN")) { JniUtils.enableMKLDNN(config); } + if (options != null && options.containsKey("removePass")) { + String[] values = ((String) options.get("removePass")).split(","); + for (String value : values) { + JniUtils.removePass(config, value); + } + } paddlePredictor = new PaddlePredictor(JniUtils.createPredictor(config)); JniUtils.deleteConfig(config); setBlock(new PpSymbolBlock(paddlePredictor)); diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/JniUtils.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/JniUtils.java index b1d1bd9a617..a899dfc5e96 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/JniUtils.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/JniUtils.java @@ -84,6 +84,10 @@ public static void enableMKLDNN(long config) { PaddleLibrary.LIB.analysisConfigEnableMKLDNN(config); } + public static void removePass(long config, String pass) { + PaddleLibrary.LIB.analysisConfigRemovePass(config, pass); + } + public static void useFeedFetchOp(long config) { PaddleLibrary.LIB.useFeedFetchOp(config); } diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/PaddleLibrary.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/PaddleLibrary.java index 33af0079ca1..0aef846b574 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/PaddleLibrary.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/PaddleLibrary.java @@ -46,6 +46,8 @@ private PaddleLibrary() {} native void analysisConfigEnableMKLDNN(long handle); + native void analysisConfigRemovePass(long handle, String pass); + native void useFeedFetchOp(long handle); native void deleteAnalysisConfig(long handle); diff --git a/paddlepaddle/paddlepaddle-native/src/main/native/ai_djl_paddlepaddle_jni_PaddleLibrary_inference.cc b/paddlepaddle/paddlepaddle-native/src/main/native/ai_djl_paddlepaddle_jni_PaddleLibrary_inference.cc index 477ec775ca7..c6931b216bd 100644 --- a/paddlepaddle/paddlepaddle-native/src/main/native/ai_djl_paddlepaddle_jni_PaddleLibrary_inference.cc +++ b/paddlepaddle/paddlepaddle-native/src/main/native/ai_djl_paddlepaddle_jni_PaddleLibrary_inference.cc @@ -57,6 +57,12 @@ JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_analysisConfig config_ptr->EnableMKLDNN(); } +JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_analysisConfigRemovePass( + JNIEnv* env, jobject jthis, jlong jhandle, jstring jpass) { + auto* config_ptr = reinterpret_cast(jhandle); + config_ptr->pass_builder()->DeletePass(djl::utils::jni::GetStringFromJString(env, jpass)); +} + JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_deleteAnalysisConfig( JNIEnv* env, jobject jthis, jlong jhandle) { const auto* config_ptr = reinterpret_cast(jhandle);