diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpModel.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpModel.java index feda8038bf8..7db41c4b9c9 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpModel.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpModel.java @@ -77,6 +77,12 @@ public void load(Path modelPath, String prefix, Map options) throws I if (System.getenv().containsKey("PADDLE_ENABLE_MKLDNN")) { JniUtils.enableMKLDNN(config); } + if (options != null && options.containsKey("removePass")) { + String[] values = ((String) options.get("removePass")).split(","); + for (String value : values) { + JniUtils.removePass(config, value); + } + } paddlePredictor = new PaddlePredictor(JniUtils.createPredictor(config)); JniUtils.deleteConfig(config); setBlock(new PpSymbolBlock(paddlePredictor)); diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/JniUtils.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/JniUtils.java index b1d1bd9a617..a899dfc5e96 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/JniUtils.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/JniUtils.java @@ -84,6 +84,10 @@ public static void enableMKLDNN(long config) { PaddleLibrary.LIB.analysisConfigEnableMKLDNN(config); } + public static void removePass(long config, String pass) { + PaddleLibrary.LIB.analysisConfigRemovePass(config, pass); + } + public static void useFeedFetchOp(long config) { PaddleLibrary.LIB.useFeedFetchOp(config); } diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/PaddleLibrary.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/PaddleLibrary.java index 33af0079ca1..0aef846b574 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/PaddleLibrary.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/PaddleLibrary.java @@ -46,6 +46,8 @@ private PaddleLibrary() {} native void analysisConfigEnableMKLDNN(long handle); + native void analysisConfigRemovePass(long handle, String pass); + native void useFeedFetchOp(long handle); native void deleteAnalysisConfig(long handle); diff --git a/paddlepaddle/paddlepaddle-native/src/main/native/ai_djl_paddlepaddle_jni_PaddleLibrary_inference.cc b/paddlepaddle/paddlepaddle-native/src/main/native/ai_djl_paddlepaddle_jni_PaddleLibrary_inference.cc index 477ec775ca7..c6931b216bd 100644 --- a/paddlepaddle/paddlepaddle-native/src/main/native/ai_djl_paddlepaddle_jni_PaddleLibrary_inference.cc +++ b/paddlepaddle/paddlepaddle-native/src/main/native/ai_djl_paddlepaddle_jni_PaddleLibrary_inference.cc @@ -57,6 +57,12 @@ JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_analysisConfig config_ptr->EnableMKLDNN(); } +JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_analysisConfigRemovePass( + JNIEnv* env, jobject jthis, jlong jhandle, jstring jpass) { + auto* config_ptr = reinterpret_cast(jhandle); + config_ptr->pass_builder()->DeletePass(djl::utils::jni::GetStringFromJString(env, jpass)); +} + JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_deleteAnalysisConfig( JNIEnv* env, jobject jthis, jlong jhandle) { const auto* config_ptr = reinterpret_cast(jhandle);