Skip to content

Commit 0f91708

Browse files
andrewfayresanirudh2290
authored andcommitted
Fix JNI custom op code from deregistering the operator (apache#11885)
1 parent 8194f88 commit 0f91708

File tree

1 file changed

+11
-39
lines changed

1 file changed

+11
-39
lines changed

scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc

+11-39
Original file line numberDiff line numberDiff line change
@@ -1898,31 +1898,28 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRtcFree
18981898

18991899
// store the user defined CustomOpProp object reference with its name
19001900
std::unordered_map<std::string, jobject> globalOpPropMap;
1901-
// store how many time of the delete function was called
1902-
// for a specific CustomOpProp object
1903-
std::unordered_map<std::string, int> globalOpPropCountMap;
19041901
// store the user defined CustomOp object reference with its name
19051902
std::unordered_map<std::string, jobject> globalOpMap;
19061903
// used for thread safty when insert elements into
19071904
// or erase elements from the std::unordered_map
19081905
std::mutex mutex_opprop;
19091906
std::mutex mutex_op;
19101907

1908+
// Registers a custom operator when called
19111909
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxCustomOpRegister
19121910
(JNIEnv *env, jobject obj, jstring jregName, jobject jopProp) {
19131911
const char *regName = env->GetStringUTFChars(jregName, 0);
19141912
std::string key(regName);
19151913

19161914
std::unique_lock<std::mutex> lock(mutex_opprop);
19171915
globalOpPropMap.insert({ key, env->NewGlobalRef(jopProp) });
1918-
globalOpPropCountMap.insert({ key, 0 });
19191916
lock.unlock();
19201917

1918+
// lambda function to initialize the operator and create all callbacks
19211919
auto creatorLambda = [](const char *opType, const int numKwargs,
19221920
const char **keys, const char **values, MXCallbackList *ret) {
19231921
int success = true;
19241922

1925-
// set CustomOpProp.kwargs
19261923
std::string opPropKey(opType);
19271924
if (globalOpPropMap.find(opPropKey) == globalOpPropMap.end()) {
19281925
LOG(WARNING) << "CustomOpProp: " << opPropKey << " not found";
@@ -1937,7 +1934,7 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxCustomOpRegister
19371934
LOG(WARNING) << "could not find CustomOpProp method init.";
19381935
success = false;
19391936
} else {
1940-
// call init
1937+
// call init and set CustomOpProp.kwargs
19411938
jclass strCls = env->FindClass("Ljava/lang/String;");
19421939
jobjectArray keysArr = env->NewObjectArray(numKwargs, strCls, NULL);
19431940
jobjectArray valuesArr = env->NewObjectArray(numKwargs, strCls, NULL);
@@ -2419,39 +2416,14 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxCustomOpRegister
24192416

24202417
// del callback
24212418
auto opPropDel = [](void *state) {
2422-
std::string key(reinterpret_cast<char *>(state));
2423-
std::unique_lock<std::mutex> lock(mutex_opprop);
2424-
int count_prop = globalOpPropCountMap.at(key);
2425-
if (count_prop < 2) {
2426-
globalOpPropCountMap[key] = ++count_prop;
2427-
return 1;
2428-
}
2429-
int success = true;
2430-
if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
2431-
LOG(WARNING) << "opProp: " << key << " not found";
2432-
success = false;
2433-
} else {
2434-
JNIEnv *env;
2435-
_jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
2436-
env->DeleteGlobalRef(globalOpPropMap.at(key));
2437-
_jvm->DetachCurrentThread();
2438-
for (auto it = globalOpPropMap.begin(); it != globalOpPropMap.end(); ) {
2439-
if (it->first == key) {
2440-
it = globalOpPropMap.erase(it);
2441-
} else {
2442-
++it;
2443-
}
2444-
}
2445-
for (auto it = globalOpPropCountMap.begin(); it != globalOpPropCountMap.end(); ) {
2446-
if (it->first == key) {
2447-
it = globalOpPropCountMap.erase(it);
2448-
} else {
2449-
++it;
2450-
}
2451-
}
2452-
}
2453-
lock.unlock();
2454-
return success;
2419+
/*
2420+
* This method seems to be called by the engine to clean up after multiple calls were made
2421+
* to the creator lambda. The current creator function isn't allocating a new object but is
2422+
* instead reinitializing the object which was created when register was called. This means
2423+
* that there doesn't seem to be anything to clean up here (previous efforts were actually
2424+
* deregistering the operator).
2425+
*/
2426+
return 1;
24552427
};
24562428

24572429
// TODO(eric): Memory leak. Missing infertype.

0 commit comments

Comments
 (0)