Skip to content

Commit

Permalink
add IO jni code
Browse files Browse the repository at this point in the history
  • Loading branch information
yanqingmen committed Dec 23, 2015
1 parent f9b68fd commit 949a611
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 6 deletions.
11 changes: 7 additions & 4 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import org.slf4j.LoggerFactory
import scala.collection.mutable.ListBuffer

object IO {
type IterCreateFunc = (Array[String], Array[String])=>DataIter
private val logger = LoggerFactory.getLogger(classOf[DataIter])
type IterCreateFunc = (Map[String, String])=>DataIter
private val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule()

def _initIOModule(): Map[String, IterCreateFunc] = {
Expand All @@ -24,14 +25,16 @@ object IO {
checkCall(_LIB.mxDataIterGetIterInfo(handle, name, desc, argNames, argTypes, argDescs))
val paramStr = Base.ctypes2docstring(argNames, argTypes, argDescs)
val docStr = s"${name.value}\n${desc.value}\n\n$paramStr\n"
logger.debug(docStr)
return (name.value, creator(handle))
}

def creator(handle:DataIterCreator)(
keys: Array[String],
values: Array[String]): DataIter = {
params: Map[String, String]): DataIter = {
val out = new DataIterHandle
checkCall(_LIB.mxDateIterCreateIter(handle, keys, values, out))
val keys = params.keys.toArray
val vals = params.values.toArray
checkCall(_LIB.mxDateIterCreateIter(handle, keys, vals, out))
return new MXDataIter(out)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite}
class IOSuite extends FunSuite with BeforeAndAfterAll {
test("create iter funcs") {
val iterCreateFuncs: Map[String, IO.IterCreateFunc] = IO._initIOModule()
println(iterCreateFuncs)
println(iterCreateFuncs.keys.toList)
}
}
5 changes: 5 additions & 0 deletions scala-package/native/src/main/native/jni_helper_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,9 @@ void setIntField(JNIEnv *env, jobject obj, jint value) {
env->SetIntField(obj, refFid, value);
}

void setLongField(JNIEnv *env, jobject obj, jlong value) {
jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refFid = env->GetFieldID(refClass, "value", "J");
env->SetLongField(obj, refFid, value);
}
#endif
136 changes: 135 additions & 1 deletion scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree(JNIEnv * env, jo

//IO funcs
JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListDataIters
(JNIEnv * env, jobject obj, jobjectArray creators) {
(JNIEnv * env, jobject obj, jobject creators) {
// Base.FunctionHandle.constructor
jclass chClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jmethodID chConstructor = env->GetMethodID(chClass,"<init>","(J)V");
Expand All @@ -413,3 +413,137 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListDataIters
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDateIterCreateIter
(JNIEnv * env, jobject obj, jobject creator,
jobjectArray jkeys, jobjectArray jvals, jobject dataIterHandle) {
//keys and values
int paramSize = env->GetArrayLength(jkeys);
char** keys = new char*[paramSize];
char** vals = new char*[paramSize];
jstring jkey, jval;
for(int i=0; i<paramSize; i++) {
jkey = (jstring) env->GetObjectArrayElement(jkeys, i);
keys[i] = (char*)env->GetStringUTFChars(jkey, 0);
jval = (jstring) env->GetObjectArrayElement(jvals, i);
vals[i] = (char*)env->GetStringUTFChars(jval, 0);
}

//create iter
jlong creatorPtr = getLongField(env, creator);
DataIterHandle out;
int ret = MXDataIterCreateIter((DataIterCreator)creator,
(mx_uint) paramSize,
(const char**) keys,
(const char**) vals,
&out);
jclass hClass = env->GetObjectClass(dataIterHandle);
jfieldID ptr = env->GetFieldID(hClass, "value", "J");
env->SetLongField(dataIterHandle, ptr, (long)out);

//release const char*
for(int i=0; i<paramSize; i++) {
jkey = (jstring) env->GetObjectArrayElement(jkeys, i);
env->ReleaseStringUTFChars(jkey,(const char*)keys[i]);
jval = (jstring) env->GetObjectArrayElement(jvals, i);
env->ReleaseStringUTFChars(jval,(const char*)vals[i]);
}
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetIterInfo
(JNIEnv * env, jobject obj, jobject creator, jobject jname,
jobject jdesc, jobject jargNames, jobject jargTypeInfos, jobject jargDescs) {
jlong creatorPtr = getLongField(env, creator);
const char* name;
const char* description;
mx_uint numArgs;
const char** argNames;
const char** argTypeInfos;
const char** argDescs;
int ret = MXDataIterGetIterInfo((DataIterCreator)creatorPtr,
&name,
&description,
&numArgs,
&argNames,
&argTypeInfos,
&argDescs);

jclass refStringClass = env->FindClass("ml/dmlc/mxnet/Base$RefString");
jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;");
//set params
env->SetObjectField(jname, valueStr, env->NewStringUTF(name));
env->SetObjectField(jdesc, valueStr, env->NewStringUTF(description));
jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer");
jmethodID listAppend = env->GetMethodID(listClass,
"$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");
for(int i=0; i<numArgs; i++) {
env->CallObjectMethod(jargNames, listAppend, env->NewStringUTF(argNames[i]));
env->CallObjectMethod(jargTypeInfos, listAppend, env->NewStringUTF(argTypeInfos[i]));
env->CallObjectMethod(jargDescs, listAppend, env->NewStringUTF(argDescs[i]));
}
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterFree
(JNIEnv *env, jobject obj, jobject handle) {
jlong handlePtr = getLongField(env, handle);
int ret = MXDataIterFree((DataIterHandle) handlePtr);
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterBeforeFirst
(JNIEnv *env, jobject obj, jobject handle) {
jlong handlePtr = getLongField(env, handle);
int ret = MXDataIterBeforeFirst((DataIterHandle) handlePtr);
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterNext
(JNIEnv *env, jobject obj, jobject handle, jobject out) {
jlong handlePtr = getLongField(env, handle);
int cout;
int ret = MXDataIterNext((DataIterHandle)handlePtr, &cout);
setIntField(env, out, cout);
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetLabel
(JNIEnv *env, jobject obj, jobject handle, jobject ndArrayHandle) {
jlong handlePtr = getLongField(env, handle);
NDArrayHandle out;
int ret = MXDataIterGetLabel((DataIterHandle)handlePtr, &out);
jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
env->SetLongField(ndArrayHandle, refLongFid, (jlong)out);
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetData
(JNIEnv *env, jobject obj, jobject handle, jobject ndArrayHandle) {
jlong handlePtr = getLongField(env, handle);
NDArrayHandle out;
int ret = MXDataIterGetData((DataIterHandle)handlePtr, &out);
jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
env->SetLongField(ndArrayHandle, refLongFid, (jlong)out);
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetIndex
(JNIEnv *env, jobject obj, jobject handle, jobject outIndex, jobject outSize) {
jlong handlePtr = getLongField(env, handle);
uint64_t* coutIndex;
uint64_t coutSize;
int ret = MXDataIterGetIndex((DataIterHandle)handlePtr, &coutIndex, &coutSize);
//to do
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetPadNum
(JNIEnv *env, jobject obj, jobject handle, jobject pad) {
jlong handlePtr = getLongField(env, handle);
int cpad;
int ret = MXDataIterGetPadNum((DataIterHandle)handlePtr, &cpad);
setIntField(env, pad, cpad);
return ret;
}

0 comments on commit 949a611

Please sign in to comment.