Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Fix JNI custom op code from deregistering the operator fixes #10438 #11885

Merged
merged 1 commit into from
Aug 21, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1898,31 +1898,28 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxRtcFree

// store the user defined CustomOpProp object reference with its name
std::unordered_map<std::string, jobject> globalOpPropMap;
// store how many time of the delete function was called
// for a specific CustomOpProp object
std::unordered_map<std::string, int> globalOpPropCountMap;
// store the user defined CustomOp object reference with its name
std::unordered_map<std::string, jobject> globalOpMap;
// used for thread safty when insert elements into
// or erase elements from the std::unordered_map
std::mutex mutex_opprop;
std::mutex mutex_op;

// Registers a custom operator when called
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxCustomOpRegister
(JNIEnv *env, jobject obj, jstring jregName, jobject jopProp) {
const char *regName = env->GetStringUTFChars(jregName, 0);
std::string key(regName);

std::unique_lock<std::mutex> lock(mutex_opprop);
globalOpPropMap.insert({ key, env->NewGlobalRef(jopProp) });
globalOpPropCountMap.insert({ key, 0 });
lock.unlock();

// lambda function to initialize the operator and create all callbacks
auto creatorLambda = [](const char *opType, const int numKwargs,
const char **keys, const char **values, MXCallbackList *ret) {
int success = true;

// set CustomOpProp.kwargs
std::string opPropKey(opType);
if (globalOpPropMap.find(opPropKey) == globalOpPropMap.end()) {
LOG(WARNING) << "CustomOpProp: " << opPropKey << " not found";
Expand All @@ -1937,7 +1934,7 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxCustomOpRegister
LOG(WARNING) << "could not find CustomOpProp method init.";
success = false;
} else {
// call init
// call init and set CustomOpProp.kwargs
jclass strCls = env->FindClass("Ljava/lang/String;");
jobjectArray keysArr = env->NewObjectArray(numKwargs, strCls, NULL);
jobjectArray valuesArr = env->NewObjectArray(numKwargs, strCls, NULL);
Expand Down Expand Up @@ -2419,39 +2416,14 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxCustomOpRegister

// del callback
auto opPropDel = [](void *state) {
std::string key(reinterpret_cast<char *>(state));
std::unique_lock<std::mutex> lock(mutex_opprop);
int count_prop = globalOpPropCountMap.at(key);
if (count_prop < 2) {
globalOpPropCountMap[key] = ++count_prop;
return 1;
}
int success = true;
if (globalOpPropMap.find(key) == globalOpPropMap.end()) {
LOG(WARNING) << "opProp: " << key << " not found";
success = false;
} else {
JNIEnv *env;
_jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);
env->DeleteGlobalRef(globalOpPropMap.at(key));
_jvm->DetachCurrentThread();
for (auto it = globalOpPropMap.begin(); it != globalOpPropMap.end(); ) {
if (it->first == key) {
it = globalOpPropMap.erase(it);
} else {
++it;
}
}
for (auto it = globalOpPropCountMap.begin(); it != globalOpPropCountMap.end(); ) {
if (it->first == key) {
it = globalOpPropCountMap.erase(it);
} else {
++it;
}
}
}
lock.unlock();
return success;
/*
* This method seems to be called by the engine to clean up after multiple calls were made
* to the creator lambda. The current creator function isn't allocating a new object but is
* instead reinitializing the object which was created when register was called. This means
* that there doesn't seem to be anything to clean up here (previous efforts were actually
* deregistering the operator).
*/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you mentioned in here, have you observe any memory leaks in here if we got the counter part removed?

return 1;
};

// TODO(eric): Memory leak. Missing infertype.
Expand Down