From 949a611ce7a6dd14c0e16d9305d997a1e2552a2d Mon Sep 17 00:00:00 2001 From: yanqingmen Date: Wed, 23 Dec 2015 22:28:33 +0800 Subject: [PATCH] add IO jni code --- .../src/main/scala/ml/dmlc/mxnet/IO.scala | 11 +- .../test/scala/ml/dmlc/mxnet/IOSuite.scala | 2 +- .../native/src/main/native/jni_helper_func.h | 5 + .../main/native/ml_dmlc_mxnet_native_c_api.cc | 136 +++++++++++++++++- 4 files changed, 148 insertions(+), 6 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index 014c6e7d4731..cb175c39801a 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -6,7 +6,8 @@ import org.slf4j.LoggerFactory import scala.collection.mutable.ListBuffer object IO { - type IterCreateFunc = (Array[String], Array[String])=>DataIter + private val logger = LoggerFactory.getLogger(classOf[DataIter]) + type IterCreateFunc = (Map[String, String])=>DataIter private val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule() def _initIOModule(): Map[String, IterCreateFunc] = { @@ -24,14 +25,16 @@ object IO { checkCall(_LIB.mxDataIterGetIterInfo(handle, name, desc, argNames, argTypes, argDescs)) val paramStr = Base.ctypes2docstring(argNames, argTypes, argDescs) val docStr = s"${name.value}\n${desc.value}\n\n$paramStr\n" + logger.debug(docStr) return (name.value, creator(handle)) } def creator(handle:DataIterCreator)( - keys: Array[String], - values: Array[String]): DataIter = { + params: Map[String, String]): DataIter = { val out = new DataIterHandle - checkCall(_LIB.mxDateIterCreateIter(handle, keys, values, out)) + val keys = params.keys.toArray + val vals = params.values.toArray + checkCall(_LIB.mxDateIterCreateIter(handle, keys, vals, out)) return new MXDataIter(out) } } diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala index 35afb27ec315..34d5ff063a10 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala @@ -6,6 +6,6 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite} class IOSuite extends FunSuite with BeforeAndAfterAll { test("create iter funcs") { val iterCreateFuncs: Map[String, IO.IterCreateFunc] = IO._initIOModule() - println(iterCreateFuncs) + println(iterCreateFuncs.keys.toList) } } diff --git a/scala-package/native/src/main/native/jni_helper_func.h b/scala-package/native/src/main/native/jni_helper_func.h index 0d651adf9a94..8354f417ed35 100644 --- a/scala-package/native/src/main/native/jni_helper_func.h +++ b/scala-package/native/src/main/native/jni_helper_func.h @@ -21,4 +21,9 @@ void setIntField(JNIEnv *env, jobject obj, jint value) { env->SetIntField(obj, refFid, value); } +void setLongField(JNIEnv *env, jobject obj, jlong value) { + jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); + jfieldID refFid = env->GetFieldID(refClass, "value", "J"); + env->SetLongField(obj, refFid, value); +} #endif diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index 0e1a704ef45c..20b76f1238be 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -391,7 +391,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree(JNIEnv * env, jo //IO funcs JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListDataIters - (JNIEnv * env, jobject obj, jobjectArray creators) { + (JNIEnv * env, jobject obj, jobject creators) { // Base.FunctionHandle.constructor jclass chClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); jmethodID chConstructor = env->GetMethodID(chClass,"","(J)V"); @@ -413,3 +413,137 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListDataIters return ret; } +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDateIterCreateIter + (JNIEnv * env, jobject obj, jobject creator, + jobjectArray jkeys, jobjectArray jvals, jobject dataIterHandle) { + //keys and values + int paramSize = env->GetArrayLength(jkeys); + char** keys = new char*[paramSize]; + char** vals = new char*[paramSize]; + jstring jkey, jval; + for(int i=0; iGetObjectArrayElement(jkeys, i); + keys[i] = (char*)env->GetStringUTFChars(jkey, 0); + jval = (jstring) env->GetObjectArrayElement(jvals, i); + vals[i] = (char*)env->GetStringUTFChars(jval, 0); + } + + //create iter + jlong creatorPtr = getLongField(env, creator); + DataIterHandle out; + int ret = MXDataIterCreateIter((DataIterCreator)creator, + (mx_uint) paramSize, + (const char**) keys, + (const char**) vals, + &out); + jclass hClass = env->GetObjectClass(dataIterHandle); + jfieldID ptr = env->GetFieldID(hClass, "value", "J"); + env->SetLongField(dataIterHandle, ptr, (long)out); + + //release const char* + for(int i=0; iGetObjectArrayElement(jkeys, i); + env->ReleaseStringUTFChars(jkey,(const char*)keys[i]); + jval = (jstring) env->GetObjectArrayElement(jvals, i); + env->ReleaseStringUTFChars(jval,(const char*)vals[i]); + } + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetIterInfo + (JNIEnv * env, jobject obj, jobject creator, jobject jname, + jobject jdesc, jobject jargNames, jobject jargTypeInfos, jobject jargDescs) { + jlong creatorPtr = getLongField(env, creator); + const char* name; + const char* description; + mx_uint numArgs; + const char** argNames; + const char** argTypeInfos; + const char** argDescs; + int ret = MXDataIterGetIterInfo((DataIterCreator)creatorPtr, + &name, + &description, + &numArgs, + &argNames, + &argTypeInfos, + &argDescs); + + jclass refStringClass = env->FindClass("ml/dmlc/mxnet/Base$RefString"); + jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;"); + //set params + env->SetObjectField(jname, valueStr, env->NewStringUTF(name)); + env->SetObjectField(jdesc, valueStr, env->NewStringUTF(description)); + jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer"); + jmethodID listAppend = env->GetMethodID(listClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); + for(int i=0; iCallObjectMethod(jargNames, listAppend, env->NewStringUTF(argNames[i])); + env->CallObjectMethod(jargTypeInfos, listAppend, env->NewStringUTF(argTypeInfos[i])); + env->CallObjectMethod(jargDescs, listAppend, env->NewStringUTF(argDescs[i])); + } + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterFree + (JNIEnv *env, jobject obj, jobject handle) { + jlong handlePtr = getLongField(env, handle); + int ret = MXDataIterFree((DataIterHandle) handlePtr); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterBeforeFirst + (JNIEnv *env, jobject obj, jobject handle) { + jlong handlePtr = getLongField(env, handle); + int ret = MXDataIterBeforeFirst((DataIterHandle) handlePtr); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterNext + (JNIEnv *env, jobject obj, jobject handle, jobject out) { + jlong handlePtr = getLongField(env, handle); + int cout; + int ret = MXDataIterNext((DataIterHandle)handlePtr, &cout); + setIntField(env, out, cout); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetLabel + (JNIEnv *env, jobject obj, jobject handle, jobject ndArrayHandle) { + jlong handlePtr = getLongField(env, handle); + NDArrayHandle out; + int ret = MXDataIterGetLabel((DataIterHandle)handlePtr, &out); + jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); + jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); + env->SetLongField(ndArrayHandle, refLongFid, (jlong)out); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetData + (JNIEnv *env, jobject obj, jobject handle, jobject ndArrayHandle) { + jlong handlePtr = getLongField(env, handle); + NDArrayHandle out; + int ret = MXDataIterGetData((DataIterHandle)handlePtr, &out); + jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); + jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); + env->SetLongField(ndArrayHandle, refLongFid, (jlong)out); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetIndex + (JNIEnv *env, jobject obj, jobject handle, jobject outIndex, jobject outSize) { + jlong handlePtr = getLongField(env, handle); + uint64_t* coutIndex; + uint64_t coutSize; + int ret = MXDataIterGetIndex((DataIterHandle)handlePtr, &coutIndex, &coutSize); + //to do + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetPadNum + (JNIEnv *env, jobject obj, jobject handle, jobject pad) { + jlong handlePtr = getLongField(env, handle); + int cpad; + int ret = MXDataIterGetPadNum((DataIterHandle)handlePtr, &cpad); + setIntField(env, pad, cpad); + return ret; +} \ No newline at end of file