Skip to content

Commit

Permalink
[Paddle] add remove pass option (#1141)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lanking authored Aug 4, 2021
1 parent 27cce9b commit ecc076b
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
if (System.getenv().containsKey("PADDLE_ENABLE_MKLDNN")) {
JniUtils.enableMKLDNN(config);
}
if (options != null && options.containsKey("removePass")) {
String[] values = ((String) options.get("removePass")).split(",");
for (String value : values) {
JniUtils.removePass(config, value);
}
}
paddlePredictor = new PaddlePredictor(JniUtils.createPredictor(config));
JniUtils.deleteConfig(config);
setBlock(new PpSymbolBlock(paddlePredictor));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ public static void enableMKLDNN(long config) {
PaddleLibrary.LIB.analysisConfigEnableMKLDNN(config);
}

public static void removePass(long config, String pass) {
PaddleLibrary.LIB.analysisConfigRemovePass(config, pass);
}

public static void useFeedFetchOp(long config) {
PaddleLibrary.LIB.useFeedFetchOp(config);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ private PaddleLibrary() {}

native void analysisConfigEnableMKLDNN(long handle);

native void analysisConfigRemovePass(long handle, String pass);

native void useFeedFetchOp(long handle);

native void deleteAnalysisConfig(long handle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_analysisConfig
config_ptr->EnableMKLDNN();
}

JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_analysisConfigRemovePass(
JNIEnv* env, jobject jthis, jlong jhandle, jstring jpass) {
auto* config_ptr = reinterpret_cast<paddle::AnalysisConfig*>(jhandle);
config_ptr->pass_builder()->DeletePass(djl::utils::jni::GetStringFromJString(env, jpass));
}

JNIEXPORT void JNICALL Java_ai_djl_paddlepaddle_jni_PaddleLibrary_deleteAnalysisConfig(
JNIEnv* env, jobject jthis, jlong jhandle) {
const auto* config_ptr = reinterpret_cast<paddle::AnalysisConfig*>(jhandle);
Expand Down

0 comments on commit ecc076b

Please sign in to comment.