From e6d0c833123e54130ac4ff022034ef916b19b5bc Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 12 Jan 2016 18:35:46 +0800 Subject: [PATCH 1/3] NDArray save & load --- .../scala/ml/dmlc/mxnet/Initializer.scala | 7 +- .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 9 + .../main/scala/ml/dmlc/mxnet/NDArray.scala | 208 +++++++++++++----- .../src/main/scala/ml/dmlc/mxnet/Random.scala | 6 +- .../scala/ml/dmlc/mxnet/NDArraySuite.scala | 48 ++++ .../main/native/ml_dmlc_mxnet_native_c_api.cc | 76 +++++++ 6 files changed, 287 insertions(+), 67 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala index e1952236db94..399c35e42749 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala @@ -1,8 +1,5 @@ package ml.dmlc.mxnet -import ml.dmlc.mxnet.NDArray.array - - /** * * Base class for Initializer. @@ -46,13 +43,13 @@ abstract class Initializer { val f = shape(3) / 2.0f val c = (2 * f - 1 - f % 2) / (2.0f * f) - (0 to (arr.size)).foreach { i => + (0 to arr.size).foreach { i => val x = i % shape(3) val y = (i / shape(3)) % shape(2) weight(i) = (1 - math.abs(x / f - c)) * (1 - math.abs(y / f - c)) } - arr.set(array(weight)) + arr.set(NDArray.array(weight, shape)) } def _initZero(name: String, arr: NDArray): Unit = { diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 8ca87dc2873f..3495f92b3a75 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -49,6 +49,15 @@ class LibInfo { @native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle, source: Array[MXFloat], size: Int): Int + @native def mxNDArrayLoad(fname: String, + outSize: MXUintRef, + handles: ArrayBuffer[NDArrayHandle], + outNameSize: MXUintRef, + names: ArrayBuffer[String]): Int + @native def mxNDArraySave(fname: String, + handles: Array[NDArrayHandle], + keys: Array[String]): Int + // KVStore @native def mxKVStoreCreate(name: String, handle: KVStoreHandleRef): Int @native def mxKVStoreInit(handle: KVStoreHandle, len: MXUint, diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index 83a45dded683..c12069d0df0f 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -11,14 +11,13 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} */ object NDArray { private val logger = LoggerFactory.getLogger(classOf[NDArray]) - private val functions: Map[String, NDArrayFunction] = _initNdarrayModule() + private val functions: Map[String, NDArrayFunction] = initNDarrayModule() // Definition of internal functions. // Internal binary function - private[mxnet] def _binaryNDArrayFunction(funcName: String, - lhs: NDArray, - rhs: NDArray, - out: NDArray = null): NDArray = { + def invokeBinaryFunc(funcName: String, + lhs: NDArray, rhs: NDArray, + out: NDArray = null): NDArray = { var output = out val function = functions(funcName) require(function != null, s"invalid function name $funcName") @@ -27,21 +26,18 @@ object NDArray { case BinaryNDArrayFunction(handle: FunctionHandle, acceptEmptyMutate: Boolean) => if (output == null) { require(acceptEmptyMutate, s"argument out is required to call $funcName") - output = new NDArray(_newEmptyHandle()) + output = new NDArray(newEmptyHandle()) } checkCall(_LIB.mxFuncInvoke(handle, Array(lhs.handle, rhs.handle), Array[MXFloat](), Array(output.handle))) - case _ => throw new RuntimeException(s"call $funcName as binary function") + case _ => throw new IllegalArgumentException(s"call $funcName as binary function") } output } - // internal NDArray function - private[mxnet] def _unaryNDArrayFunction(funcName: String, - src: NDArray, - out: NDArray = null): NDArray = { + def invokeUnaryFunc(funcName: String, src: NDArray, out: NDArray = null): NDArray = { var output = out val function = functions(funcName) require(function != null, s"invalid function name $funcName") @@ -50,13 +46,13 @@ object NDArray { case UnaryNDArrayFunction(handle: NDArrayHandle, acceptEmptyMutate: Boolean) => if (output == null) { require(acceptEmptyMutate, s"argument out is required to call $funcName") - output = new NDArray(_newEmptyHandle()) + output = new NDArray(newEmptyHandle()) } checkCall(_LIB.mxFuncInvoke(handle, Array(src.handle), Array[MXFloat](), Array(output.handle))) - case _ => throw new RuntimeException(s"call $funcName as unary function") + case _ => throw new IllegalArgumentException(s"call $funcName as unary function") } output } @@ -69,9 +65,9 @@ object NDArray { * Output NDArray, used to hold the output result. * @return The result NDArray(tuple) of result of computation. */ - private[mxnet] def _genericNDArrayFunction(funcName: String, - args: Array[Any], - out: Array[NDArray] = null): Array[NDArray] = { + def invokeGenericFunc(funcName: String, + args: Array[Any], + out: Array[NDArray] = null): Array[NDArray] = { var mutateVars = out val function = functions(funcName) require(function != null, s"invalid function name $funcName") @@ -85,13 +81,13 @@ object NDArray { s"expect $nMutateVars in $funcName") if (mutateVars == null) { require(acceptEmptyMutate, s"argument out is required to call $funcName") - mutateVars = Array.fill[NDArray](nMutateVars)(new NDArray(_newEmptyHandle())) + mutateVars = Array.fill[NDArray](nMutateVars)(new NDArray(newEmptyHandle())) } checkCall(_LIB.mxFuncInvoke(handle, useVarsRange.map(args(_).asInstanceOf[NDArray].handle).toArray, scalarRange.map(args(_).asInstanceOf[MXFloat]).toArray, mutateVars.map(_.handle).array)) - case _ => throw new RuntimeException(s"call $funcName as generic function") + case _ => throw new IllegalArgumentException(s"call $funcName as generic function") } mutateVars } @@ -102,7 +98,7 @@ object NDArray { * * @return a new empty ndarray handle */ - private def _newEmptyHandle(): NDArrayHandle = { + private def newEmptyHandle(): NDArrayHandle = { val hdl = new NDArrayHandleRef checkCall(_LIB.mxNDArrayCreateNone(hdl)) hdl.value @@ -114,9 +110,9 @@ object NDArray { * * @return a new empty ndarray handle */ - private def _newAllocHandle(shape: Array[Int], - ctx: Context, - delayAlloc: Boolean): NDArrayHandle = { + private def newAllocHandle(shape: Array[Int], + ctx: Context, + delayAlloc: Boolean): NDArrayHandle = { val hdl = new NDArrayHandleRef checkCall(_LIB.mxNDArrayCreate( shape, @@ -137,7 +133,7 @@ object NDArray { } // Create a NDArray function from the FunctionHandle. - private def _makeNdarrayFunction(handle: FunctionHandle): (String, NDArrayFunction) = { + private def makeNdarrayFunction(handle: FunctionHandle): (String, NDArrayFunction) = { val NDARRAY_ARG_BEFORE_SCALAR = 1 val ACCEPT_EMPTY_MUTATE_TARGET = 1 << 2 // Get the property of NDArray @@ -182,10 +178,10 @@ object NDArray { } // List and add all the ndarray functions to current module. - private def _initNdarrayModule(): Map[String, NDArrayFunction] = { + private def initNDarrayModule(): Map[String, NDArrayFunction] = { val functions = ListBuffer[FunctionHandle]() checkCall(_LIB.mxListFunctions(functions)) - functions.map(_makeNdarrayFunction).toMap + functions.map(makeNdarrayFunction).toMap } /** @@ -194,7 +190,9 @@ object NDArray { * @param out The result holder of the encoding. * @return Same as out. */ - def onehotEncode(indices: NDArray, out: NDArray): NDArray = ??? + def onehotEncode(indices: NDArray, out: NDArray): NDArray = { + NDArray.invokeBinaryFunc("_onehot_encode", indices, out, out) + } /** * Create an empty uninitialized new NDArray, with specified shape. @@ -206,7 +204,7 @@ object NDArray { */ def empty(shape: Array[Int], ctx: Context = null): NDArray = { val context = if (ctx == null) Context.defaultCtx else ctx - new NDArray(handle = NDArray._newAllocHandle(shape, context, delayAlloc = false)) + new NDArray(handle = NDArray.newAllocHandle(shape, context, delayAlloc = false)) } def empty(shape: Int *): NDArray = empty(shape.toArray) @@ -255,7 +253,7 @@ object NDArray { * @return a new clipped [[NDArray]] */ def clip(array: NDArray, min: Float, max: Float): NDArray = { - NDArray._genericNDArrayFunction("clip", Array(array, min, max))(0) + NDArray.invokeGenericFunc("clip", Array(array, min, max))(0) } /** @@ -264,31 +262,97 @@ object NDArray { * @return new [[NDArray]] */ def sqrt(src: NDArray): NDArray = { - NDArray._unaryNDArrayFunction("sqrt", src) + NDArray.invokeUnaryFunc("sqrt", src) + } + + def rsqrt(src: NDArray): NDArray = ??? + + /** + * Calculate 2D matrix multiplication + * @param lhs left ndarray + * @param rhs right ndarray + * @return a new [[NDArray]] + */ + def dot(lhs: NDArray, rhs: NDArray): NDArray = { + NDArray.invokeBinaryFunc("dot", lhs, rhs) } /** * Take L2 norm of the src. * @param src Source input to the function - * @return new [[NDArray]] of shape (1,) on the same device + * @return a new [[NDArray]] of shape (1,) on the same device */ def norm(src: NDArray): NDArray = { - NDArray._unaryNDArrayFunction("norm", src) + NDArray.invokeUnaryFunc("norm", src) + } + + /** + * Take absolute value of the src + * @param src Source nd-array + * @return a new [[NDArray]] + */ + def abs(src: NDArray): NDArray = { + NDArray.invokeUnaryFunc("abs", src) } - // TODO - def _randomUniform(low: Float, high: Float, out: NDArray): NDArray = ??? + def sign(src: NDArray): NDArray = ??? + + def round(src: NDArray): NDArray = ??? + + def ceil(src: NDArray): NDArray = ??? + + def floor(src: NDArray): NDArray = ??? + + def square(src: NDArray): NDArray = ??? + + def exp(src: NDArray): NDArray = ??? + + def log(src: NDArray): NDArray = ??? + + def cos(src: NDArray): NDArray = ??? + + def sin(src: NDArray): NDArray = ??? + + def max(src: NDArray): NDArray = ??? - def _randomGaussian(mean: Float, stdvar: Float, out: NDArray): NDArray = ??? + def min(src: NDArray): NDArray = ??? + def sum(src: NDArray): NDArray = ??? + + def argmaxChannel(src: NDArray): NDArray = ??? + + /** + * TODO: + * Choose one element from each line(row for python, column for R/Julia) + * in array according to the index. This function assume index uses 0-based index. + * @param array source array + * @param index index array + * @return a new [[NDArray]] + */ + def chooseElement0Index(array: NDArray, index: NDArray): NDArray = { + NDArray.invokeBinaryFunc("choose_element_0index", array, index) + } + + def randomUniform(low: Float, high: Float, out: NDArray): NDArray = { + NDArray.invokeGenericFunc("_random_uniform", Array(low, high), Array(out))(0) + } + + def randomGaussian(mean: Float, stdvar: Float, out: NDArray): NDArray = { + NDArray.invokeGenericFunc("_random_gaussian", Array(mean, stdvar), Array(out))(0) + } /** * Create a new NDArray that copies content from source_array. * @param sourceArr Source data to create NDArray from. + * @param shape shape of the NDArray * @param ctx The context of the NDArray, default to current default context. * @return The created NDArray. */ - def array(sourceArr: Array[Float], ctx: Context = null): NDArray = ??? + def array(sourceArr: Array[Float], shape: Array[Int], ctx: Context = null): NDArray = { + val arr = empty(shape, ctx) + arr.set(sourceArr) + arr + } /** * Load ndarray from binary file. @@ -306,7 +370,15 @@ object NDArray { * - `/path-to/my-local-ndarray` * @return dict of str->NDArray to be saved */ - def load(fname: String): Map[String, NDArray] = ??? + def load(fname: String): (Array[String], Array[NDArray]) = { + val outSize = new MXUintRef + val outNameSize = new MXUintRef + val handles = ArrayBuffer.empty[NDArrayHandle] + val names = ArrayBuffer.empty[String] + checkCall(_LIB.mxNDArrayLoad(fname, outSize, handles, outNameSize, names)) + require(outNameSize.value == 0 || outNameSize.value == outSize.value) + (names.toArray, handles.map(new NDArray(_)).toArray) + } /** * Save list of NDArray or dict of str->NDArray to binary file. @@ -324,7 +396,15 @@ object NDArray { * - `/path-to/my-local-ndarray` * @param data dict of str->NDArray */ - def save(fname: String, data: Map[String, NDArray]): Unit = ??? + def save(fname: String, data: Map[String, NDArray]): Unit = { + val keys = data.keys.toArray + val handles = data.values.map(_.handle).toArray + save(fname, keys, handles) + } + + private def save(fname: String, keys: Array[String], handles: Array[NDArrayHandle]): Unit = { + checkCall(_LIB.mxNDArraySave(fname, handles, keys)) + } } /** @@ -386,7 +466,7 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = */ def set(value: Float): NDArray = { require(writable, "trying to assign to a readonly NDArray") - NDArray._genericNDArrayFunction("_set_value", Array[Any](value), out = Array(this)) + NDArray.invokeGenericFunc("_set_value", Array[Any](value), out = Array(this)) this } @@ -402,98 +482,98 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = } def +(other: NDArray): NDArray = { - NDArray._binaryNDArrayFunction("_plus", this, other) + NDArray.invokeBinaryFunc("_plus", this, other) } def +(other: Float): NDArray = { - NDArray._genericNDArrayFunction("_plus_scalar", Array[Any](this, other))(0) + NDArray.invokeGenericFunc("_plus_scalar", Array[Any](this, other))(0) } def +=(other: NDArray): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to add to a readonly NDArray") } - NDArray._binaryNDArrayFunction("_plus", this, other, out = this) + NDArray.invokeBinaryFunc("_plus", this, other, out = this) } def +=(other: Float): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to add to a readonly NDArray") } - NDArray._genericNDArrayFunction("_plus_scalar", Array[Any](this, other), out = Array(this)) + NDArray.invokeGenericFunc("_plus_scalar", Array[Any](this, other), out = Array(this)) this } def -(other: NDArray): NDArray = { - NDArray._binaryNDArrayFunction("_minus", this, other) + NDArray.invokeBinaryFunc("_minus", this, other) } def -(other: Float): NDArray = { - NDArray._genericNDArrayFunction("_minus_scalar", Array[Any](this, other))(0) + NDArray.invokeGenericFunc("_minus_scalar", Array[Any](this, other))(0) } def -=(other: NDArray): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to subtract from a readonly NDArray") } - NDArray._binaryNDArrayFunction("_minus", this, other, out = this) + NDArray.invokeBinaryFunc("_minus", this, other, out = this) } def -=(other: Float): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to subtract from a readonly NDArray") } - NDArray._genericNDArrayFunction("_minus_scalar", Array[Any](this, other), out = Array(this)) + NDArray.invokeGenericFunc("_minus_scalar", Array[Any](this, other), out = Array(this)) this } def *(other: NDArray): NDArray = { - NDArray._binaryNDArrayFunction("_mul", this, other) + NDArray.invokeBinaryFunc("_mul", this, other) } def *(other: Float): NDArray = { - NDArray._genericNDArrayFunction("_mul_scalar", Array[Any](this, other))(0) + NDArray.invokeGenericFunc("_mul_scalar", Array[Any](this, other))(0) } def unary_-(): NDArray = { - NDArray._genericNDArrayFunction("_mul_scalar", Array[Any](this, -1f))(0) + NDArray.invokeGenericFunc("_mul_scalar", Array[Any](this, -1f))(0) } def *=(other: NDArray): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to multiply to a readonly NDArray") } - NDArray._binaryNDArrayFunction("_mul", this, other, out = this) + NDArray.invokeBinaryFunc("_mul", this, other, out = this) } def *=(other: Float): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to multiply to a readonly NDArray") } - NDArray._genericNDArrayFunction("_mul_scalar", Array[Any](this, other), out = Array(this)) + NDArray.invokeGenericFunc("_mul_scalar", Array[Any](this, other), out = Array(this)) this } def /(other: NDArray): NDArray = { - NDArray._binaryNDArrayFunction("_div", this, other) + NDArray.invokeBinaryFunc("_div", this, other) } def /(other: Float): NDArray = { - NDArray._genericNDArrayFunction("_div_scalar", Array[Any](this, other))(0) + NDArray.invokeGenericFunc("_div_scalar", Array[Any](this, other))(0) } def /=(other: NDArray): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to divide from a readonly NDArray") } - NDArray._binaryNDArrayFunction("_div", this, other, out = this) + NDArray.invokeBinaryFunc("_div", this, other, out = this) } def /=(other: Float): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to divide from a readonly NDArray") } - NDArray._genericNDArrayFunction("_div_scalar", Array[Any](this, other), out = Array(this)) + NDArray.invokeGenericFunc("_div_scalar", Array[Any](this, other), out = Array(this)) this } @@ -525,7 +605,14 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = * @param other Target NDArray or context we want to copy data to. * @return The copy target NDArray */ - def copyTo(other: NDArray): NDArray = ??? + def copyTo(other: NDArray): NDArray = { + if (other.handle == this.handle) { + NDArray.logger.warn("copy an array to itself, is it intended ?") + other + } else { + NDArray.invokeUnaryFunc("_copyto", this, out = other) + } + } /** * Copy the content of current array to a new NDArray in the context. @@ -533,7 +620,10 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = * @param ctx Target context we want to copy data to. * @return The copy target NDArray */ - def copyTo(ctx: Context): NDArray = ??? + def copyTo(ctx: Context): NDArray = { + val ret = new NDArray(NDArray.newAllocHandle(shape, ctx, delayAlloc = true)) + copyTo(ret) + } /** * Get shape of current NDArray. @@ -564,7 +654,7 @@ class NDArrayConversions(val value: Float) { } def -(other: NDArray): NDArray = { - NDArray._genericNDArrayFunction("_rminus_scalar", Array[Any](other, value))(0) + NDArray.invokeGenericFunc("_rminus_scalar", Array[Any](other, value))(0) } def *(other: NDArray): NDArray = { @@ -572,7 +662,7 @@ class NDArrayConversions(val value: Float) { } def /(other: NDArray): NDArray = { - NDArray._genericNDArrayFunction("_rdiv_scalar", Array[Any](other, value))(0) + NDArray.invokeGenericFunc("_rdiv_scalar", Array[Any](other, value))(0) } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala index b7abf9627429..d2935590f055 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala @@ -1,6 +1,6 @@ package ml.dmlc.mxnet -import ml.dmlc.mxnet.NDArray.{_randomGaussian, _randomUniform, empty} +import ml.dmlc.mxnet.NDArray.{randomGaussian, randomUniform, empty} /** * Random Number interface of mxnet. @@ -29,7 +29,7 @@ object Random { require(shape != null, "shape is required when out is not specified") outCopy = empty(shape, ctx) } - _randomUniform(low, high, outCopy) + randomUniform(low, high, outCopy) } @@ -55,7 +55,7 @@ object Random { require(shape != null, "shape is required when out is not specified") outCopy = empty(shape, ctx) } - _randomGaussian(mean, stdvar, outCopy) + randomGaussian(mean, stdvar, outCopy) } diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala index 5fa02ece10bf..d8664488a5b4 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala @@ -1,5 +1,7 @@ package ml.dmlc.mxnet +import java.io.File + import ml.dmlc.mxnet.NDArrayConversions._ import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalactic.Tolerance._ @@ -105,4 +107,50 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { assert(normed.shape === Array(1)) assert(normed.toScalar === math.sqrt(14.0).toFloat +- 1e-3f) } + + test("one hot encode") { + // TODO + } + + test("dot") { + // TODO + } + + test("choose_element_0index") { + // TODO + } + + test("copy to") { + // TODO + } + + test("random uniform") { + // TODO + } + + test("random gaussian") { + // TODO + } + + test("abs") { + // TODO + } + + test("save and load") { + val filename = "ndarray.bin" + try { + val ndarray = NDArray.array(Array(1f, 2f, 3f), shape = Array(3, 1)) + NDArray.save(filename, Map("local" -> ndarray)) + val (keys, arrays) = NDArray.load(filename) + assert(keys.length === 1) + assert(keys(0) === "local") + assert(arrays.length === 1) + val loadedArray = arrays(0) + assert(loadedArray.shape === Array(3, 1)) + assert(loadedArray.toArray === Array(1f, 2f, 3f)) + } finally { + val file = new File(filename) + file.delete() + } + } } diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index c844ef36b950..0b316ebcdf3b 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -188,6 +188,82 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree return MXNDArrayFree((NDArrayHandle) ndArrayHandle); } +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayLoad + (JNIEnv * env, jobject obj, jstring jfname, jobject joutSize, + jobject jhandles, jobject joutNameSize, jobject jnames) { + mx_uint outSize; + NDArrayHandle *outArr; + mx_uint outNameSize; + const char **outNames; + + const char *fname = env->GetStringUTFChars(jfname, 0); + int ret = MXNDArrayLoad(fname, &outSize, &outArr, &outNameSize, &outNames); + env->ReleaseStringUTFChars(jfname, fname); + + if (ret) { + return ret; + } + + // fill sizes + jclass refIntClass = env->FindClass("ml/dmlc/mxnet/Base$RefInt"); + jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I"); + env->SetIntField(joutSize, valueInt, outSize); + env->SetIntField(joutNameSize, valueInt, outNameSize); + + jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer"); + jmethodID arrayAppend = env->GetMethodID(arrayClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;"); + + // fill handles + jclass longCls = env->FindClass("java/lang/Long"); + jmethodID longConst = env->GetMethodID(longCls, "", "(J)V"); + for (int i = 0; i < outSize; ++i) { + jobject handle = env->NewObject(longCls, longConst, outArr[i]); + env->CallObjectMethod(jhandles, arrayAppend, handle); + } + + // fill names + for (int i = 0; i < outNameSize; ++i) { + jstring jname = env->NewStringUTF(outNames[i]); + env->CallObjectMethod(jnames, arrayAppend, jname); + } + + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySave + (JNIEnv * env, jobject obj, jstring jfname, jlongArray jhandles, jobjectArray jkeys) { + int numArgs = env->GetArrayLength(jhandles); + const char **keys = NULL; + if (jkeys != NULL) { + keys = new const char *[numArgs]; + for (int i = 0; i < numArgs; i++) { + jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); + const char *key = env->GetStringUTFChars(jkey, 0); + keys[i] = key; + } + } + + const char *fname = env->GetStringUTFChars(jfname, 0); + jlong *handles = env->GetLongArrayElements(jhandles, NULL); + + int ret = MXNDArraySave(fname, (mx_uint) numArgs, (NDArrayHandle *) handles, keys); + + env->ReleaseLongArrayElements(jhandles, handles, 0); + env->ReleaseStringUTFChars(jfname, fname); + + // release allocated memory + if (jkeys != NULL) { + for (int i = 0; i < numArgs; i++) { + jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); + env->ReleaseStringUTFChars(jkey, keys[i]); + } + delete[] keys; + } + + return ret; +} + // The related c api MXKVStoreSetUpdater function takes a c function pointer as its parameter, // while we write java functions here in scala-package. // Thus we have to wrap the function in a java object, and run env->CallVoidMethod(obj) once updater is invoked, From 28db218ea5b76a7c2ffded5024ea183bb8a49d94 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 13 Jan 2016 13:49:33 +0800 Subject: [PATCH 2/3] NDArray all functions finished --- .../main/scala/ml/dmlc/mxnet/Context.scala | 9 +- .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 3 + .../main/scala/ml/dmlc/mxnet/NDArray.scala | 165 +++++++++++++++--- .../scala/ml/dmlc/mxnet/NDArraySuite.scala | 155 +++++++++++++++- .../main/native/ml_dmlc_mxnet_native_c_api.cc | 17 ++ 5 files changed, 311 insertions(+), 38 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala index 4bb37a8dbce5..1ca12df2a975 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala @@ -1,6 +1,8 @@ package ml.dmlc.mxnet object Context { + val devtype2str = Map(1 -> "cpu", 2 -> "gpu", 3 -> "cpu_pinned") + val devstr2type = Map("cpu" -> 1, "gpu" -> 2, "cpu_pinned" -> 3) val defaultCtx = new Context("cpu", 0) } @@ -10,10 +12,7 @@ object Context { * @param deviceId (default=0) The device id of the device, needed for GPU */ class Context(deviceTypeName: String, val deviceId: Int = 0) { - private val devtype2str = Map(1 -> "cpu", 2 -> "gpu", 3 -> "cpu_pinned") - private val devstr2type = Map("cpu" -> 1, "gpu" -> 2, "cpu_pinned" -> 3) - - val deviceTypeid: Int = devstr2type(deviceTypeName) + val deviceTypeid: Int = Context.devstr2type(deviceTypeName) def this(context: Context) = { this(context.deviceType, context.deviceId) @@ -23,5 +22,5 @@ class Context(deviceTypeName: String, val deviceId: Int = 0) { * Return device type of current context. * @return device_type */ - def deviceType: String = devtype2str(deviceTypeid) + def deviceType: String = Context.devtype2str(deviceTypeid) } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 3495f92b3a75..87ee3a729fb3 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -9,6 +9,7 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} * @author Yizhi Liu */ class LibInfo { + // NDArray @native def mxNDArrayFree(handle: NDArrayHandle): Int @native def mxGetLastError(): String @native def mxNDArrayCreateNone(out: NDArrayHandleRef): Int @@ -19,6 +20,7 @@ class LibInfo { delayAlloc: Int, out: NDArrayHandleRef): Int @native def mxNDArrayWaitAll(): Int + @native def mxNDArrayWaitToRead(handle: NDArrayHandle): Int @native def mxListFunctions(functions: ListBuffer[FunctionHandle]): Int @native def mxFuncDescribe(handle: FunctionHandle, nUsedVars: MXUintRef, @@ -57,6 +59,7 @@ class LibInfo { @native def mxNDArraySave(fname: String, handles: Array[NDArrayHandle], keys: Array[String]): Int + @native def mxNDArrayGetContext(handle: NDArrayHandle, devTypeId: RefInt, devId: RefInt): Int // KVStore @native def mxKVStoreCreate(name: String, handle: KVStoreHandleRef): Int @native def mxKVStoreInit(handle: KVStoreHandle, diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index c12069d0df0f..9a2c7912f104 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -11,7 +11,7 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} */ object NDArray { private val logger = LoggerFactory.getLogger(classOf[NDArray]) - private val functions: Map[String, NDArrayFunction] = initNDarrayModule() + private val functions: Map[String, NDArrayFunction] = initNDArrayModule() // Definition of internal functions. // Internal binary function @@ -178,7 +178,7 @@ object NDArray { } // List and add all the ndarray functions to current module. - private def initNDarrayModule(): Map[String, NDArrayFunction] = { + private def initNDArrayModule(): Map[String, NDArrayFunction] = { val functions = ListBuffer[FunctionHandle]() checkCall(_LIB.mxListFunctions(functions)) functions.map(makeNdarrayFunction).toMap @@ -265,7 +265,14 @@ object NDArray { NDArray.invokeUnaryFunc("sqrt", src) } - def rsqrt(src: NDArray): NDArray = ??? + /** + * Take rsqrt of the src + * @param src Source input to the function + * @return new [[NDArray]] + */ + def rsqrt(src: NDArray): NDArray = { + NDArray.invokeUnaryFunc("rsqrt", src) + } /** * Calculate 2D matrix multiplication @@ -288,43 +295,133 @@ object NDArray { /** * Take absolute value of the src - * @param src Source nd-array + * @param src Source ndarray * @return a new [[NDArray]] */ def abs(src: NDArray): NDArray = { NDArray.invokeUnaryFunc("abs", src) } - def sign(src: NDArray): NDArray = ??? + /** + * Take sign value of the src + * @param src Source ndarray + * @return a new [[NDArray]] + */ + def sign(src: NDArray): NDArray = { + NDArray.invokeUnaryFunc("sign", src) + } - def round(src: NDArray): NDArray = ??? + /** + * Take round value of the src + * @param src Source ndarray + * @return a new [[NDArray]] + */ + def round(src: NDArray): NDArray = { + NDArray.invokeUnaryFunc("round", src) + } - def ceil(src: NDArray): NDArray = ??? + /** + * Take ceil value of the src + * @param src Source ndarray + * @return a new [[NDArray]] + */ + def ceil(src: NDArray): NDArray = { + NDArray.invokeUnaryFunc("ceil", src) + } - def floor(src: NDArray): NDArray = ??? + /** + * Take floor value of the src + * @param src Source ndarray + * @return a new [[NDArray]] + */ + def floor(src: NDArray): NDArray = { + NDArray.invokeUnaryFunc("floor", src) + } - def square(src: NDArray): NDArray = ??? + /** + * Take square of the src + * @param src Source ndarray + * @return a new [[NDArray]] + */ + def square(src: NDArray): NDArray = { + NDArray.invokeUnaryFunc("square", src) + } - def exp(src: NDArray): NDArray = ??? + /** + * Take exp of the src + * @param src Source ndarray + * @return a new [[NDArray]] + */ + def exp(src: NDArray): NDArray = { + NDArray.invokeUnaryFunc("exp", src) + } - def log(src: NDArray): NDArray = ??? + /** + * Take log of the src + * @param src Source ndarray + * @return a new [[NDArray]] + */ + def log(src: NDArray): NDArray = { + NDArray.invokeUnaryFunc("log", src) + } - def cos(src: NDArray): NDArray = ??? + /** + * Take cos of the src + * @param src Source ndarray + * @return a new [[NDArray]] + */ + def cos(src: NDArray): NDArray = { + NDArray.invokeUnaryFunc("cos", src) + } - def sin(src: NDArray): NDArray = ??? + /** + * Take sin of the src + * @param src Source ndarray + * @return a new [[NDArray]] + */ + def sin(src: NDArray): NDArray = { + NDArray.invokeUnaryFunc("sin", src) + } - def max(src: NDArray): NDArray = ??? + /** + * Take max of the src. The result will be ndarray of shape (1,) on the same device. + * @param src Source ndarray + * @return a new [[NDArray]] + */ + def max(src: NDArray): NDArray = { + NDArray.invokeUnaryFunc("max", src) + } - def min(src: NDArray): NDArray = ??? + /** + * Take max of the src.The result will be ndarray of shape (1,) on the same device. + * @param src Source ndarray + * @return a new [[NDArray]] + */ + def min(src: NDArray): NDArray = { + NDArray.invokeUnaryFunc("min", src) + } - def sum(src: NDArray): NDArray = ??? + /** + * Take sum of the src. The result will be ndarray of shape (1,) on the same device. + * @param src Source ndarray + * @return a new [[NDArray]] + */ + def sum(src: NDArray): NDArray = { + NDArray.invokeUnaryFunc("sum", src) + } - def argmaxChannel(src: NDArray): NDArray = ??? + /** + * Take the argmax index of each channel (row) in src. + * @param src Source ndarray + * @return a new [[NDArray]] + */ + def argmaxChannel(src: NDArray): NDArray = { + NDArray.invokeUnaryFunc("argmax_channel", src) + } /** - * TODO: - * Choose one element from each line(row for python, column for R/Julia) - * in array according to the index. This function assume index uses 0-based index. + * Choose one element from each row in array according to the index. + * This function assume index uses 0-based index. * @param array source array * @param index index array * @return a new [[NDArray]] @@ -380,6 +477,16 @@ object NDArray { (names.toArray, handles.map(new NDArray(_)).toArray) } + def load2Map(fname: String): Map[String, NDArray] = { + val (keys, vals) = load(fname) + require(keys.length == vals.length, "Loaded NDArrays have no name") + (keys zip vals).toMap + } + + def load2Array(fname: String): Array[NDArray] = { + load(fname)._2 + } + /** * Save list of NDArray or dict of str->NDArray to binary file. * @@ -402,6 +509,10 @@ object NDArray { save(fname, keys, handles) } + def save(fname: String, data: Traversable[NDArray]): Unit = { + save(fname, null, data.map(_.handle).toArray) + } + private def save(fname: String, keys: Array[String], handles: Array[NDArrayHandle]): Unit = { checkCall(_LIB.mxNDArraySave(fname, handles, keys)) } @@ -451,13 +562,20 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = * NDArray finishes. There can still be pending read going on when the * function returns. */ - def waitToRead(): Unit = ??? + def waitToRead(): Unit = { + checkCall(_LIB.mxNDArrayWaitToRead(handle)) + } /** * Get context of current NDArray. * @return The context of current NDArray. */ - def context: Context = ??? + def context: Context = { + val devTypeId = new RefInt + val devId = new RefInt + checkCall(_LIB.mxNDArrayGetContext(handle, devTypeId, devId)) + new Context(Context.devtype2str(devTypeId.value), devId.value) + } /** * Set the values of the NDArray @@ -578,10 +696,9 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = } /** - * Return a copied flat java array of current array. + * Return a copied flat java array of current array (row-major). * @return A copy of array content. */ - // TODO: Shall we use column-major or row-major ? def toArray: Array[Float] = { val data = Array.ofDim[Float](size) checkCall(_LIB.mxNDArraySyncCopyToCPU(handle, data, size)) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala index d8664488a5b4..74f7a021de92 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala @@ -1,12 +1,15 @@ package ml.dmlc.mxnet import java.io.File +import java.util.concurrent.atomic.AtomicInteger import ml.dmlc.mxnet.NDArrayConversions._ import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalactic.Tolerance._ class NDArraySuite extends FunSuite with BeforeAndAfterAll { + private val sequence: AtomicInteger = new AtomicInteger(0) + test("to java array") { val ndarray = NDArray.zeros(2, 2) assert(ndarray.toArray === Array(0f, 0f, 0f, 0f)) @@ -100,6 +103,11 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { assert(NDArray.sqrt(ndarray).toArray === Array(0f, 1f, 2f, 3f)) } + test("rsqrt") { + val ndarray = NDArray.array(Array(1f, 4f), shape = Array(2, 1)) + assert(NDArray.rsqrt(ndarray).toArray === Array(1f, 0.5f)) + } + test("norm") { val ndarray = NDArray.empty(3, 1) ndarray.set(Array(1f, 2f, 3f)) @@ -109,35 +117,139 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { } test("one hot encode") { - // TODO + val indices = NDArray.array(Array(1f, 0f, 2f), shape = Array(3)) + val array = NDArray.empty(3, 3) + NDArray.onehotEncode(indices, array) + assert(array.shape === Array(3, 3)) + assert(array.toArray === Array(0f, 1f, 0f, + 1f, 0f, 0f, + 0f, 0f, 1f)) } test("dot") { - // TODO + val arr1 = NDArray.array(Array(1f, 2f), shape = Array(1, 2)) + val arr2 = NDArray.array(Array(3f, 4f), shape = Array(2, 1)) + val res = NDArray.dot(arr1, arr2) + assert(res.shape === Array(1, 1)) + assert(res.toArray === Array(11f)) } test("choose_element_0index") { - // TODO + val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 6f, 5f), shape = Array(2, 3)) + val indices = NDArray.array(Array(0f, 1f), shape = Array(2)) + val res = NDArray.chooseElement0Index(arr, indices) + assert(res.toArray === Array(1f, 6f)) } test("copy to") { - // TODO + val source = NDArray.array(Array(1f, 2f, 3f), shape = Array(1, 3)) + val dest = NDArray.empty(1, 3) + source.copyTo(dest) + assert(dest.shape === Array(1, 3)) + assert(dest.toArray === Array(1f, 2f, 3f)) } test("random uniform") { - // TODO + val matrix = NDArray.empty(3, 2) + NDArray.randomUniform(0f, 1f, matrix) + assert(matrix.shape === Array(3, 2)) + val arr = matrix.toArray + // scalastyle:off println + println(s"Random Uniform: [${arr.mkString(",")}]") + // scalastyle:on println + arr.foreach { elem => + assert(elem > 0f && elem < 1f) + } } test("random gaussian") { - // TODO + val matrix = NDArray.empty(3, 2) + NDArray.randomGaussian(0f, 1f, matrix) + assert(matrix.shape === Array(3, 2)) + val arr = matrix.toArray + // scalastyle:off println + println(s"Random Gaussian: [${arr.mkString(",")}]") + // scalastyle:on println } test("abs") { - // TODO + val arr = NDArray.array(Array(-1f, -2f, 3f), shape = Array(3, 1)) + assert(NDArray.abs(arr).toArray === Array(1f, 2f, 3f)) + } + + test("sign") { + val arr = NDArray.array(Array(-1f, -2f, 3f), shape = Array(3, 1)) + assert(NDArray.sign(arr).toArray === Array(-1f, -1f, 1f)) + } + + test("round") { + val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Array(3, 1)) + assert(NDArray.round(arr).toArray === Array(2f, 2f, 4f)) + } + + test("ceil") { + val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Array(3, 1)) + assert(NDArray.ceil(arr).toArray === Array(2f, 3f, 4f)) + } + + test("floor") { + val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Array(3, 1)) + assert(NDArray.floor(arr).toArray === Array(1f, 2f, 3f)) } - test("save and load") { - val filename = "ndarray.bin" + test("square") { + val arr = NDArray.array(Array(1f, 2f, 3f), shape = Array(3, 1)) + assert(NDArray.square(arr).toArray === Array(1f, 4f, 9f)) + } + + test("exp") { + val arr = NDArray.ones(1) + assert(NDArray.exp(arr).toScalar === 2.71828f +- 1e-3f) + } + + test("log") { + val arr = NDArray.empty(1) + arr.set(10f) + assert(NDArray.log(arr).toScalar === 2.302585f +- 1e-5f) + } + + test("cos") { + val arr = NDArray.empty(1) + arr.set(12f) + assert(NDArray.cos(arr).toScalar === 0.8438539f +- 1e-5f) + } + + test("sin") { + val arr = NDArray.empty(1) + arr.set(12f) + assert(NDArray.sin(arr).toScalar === -0.536572918f +- 1e-5f) + } + + test("max") { + val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Array(3, 1)) + assert(NDArray.max(arr).toScalar === 3.7f +- 1e-3f) + } + + test("min") { + val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Array(3, 1)) + assert(NDArray.min(arr).toScalar === 1.5f +- 1e-3f) + } + + test("sum") { + val arr = NDArray.array(Array(1f, 2f, 3f, 4f), shape = Array(2, 2)) + assert(NDArray.sum(arr).toScalar === 10f +- 1e-3f) + } + + test("argmaxChannel") { + val arr = NDArray.array(Array(1f, 2f, 4f, 3f), shape = Array(2, 2)) + val argmax = NDArray.argmaxChannel(arr) + assert(argmax.shape === Array(2)) + assert(argmax.toArray === Array(1f, 0f)) + } + + test("save and load with names") { + val filename + = s"${System.getProperty("java.io.tmpdir")}/ndarray-${sequence.getAndIncrement}.bin" try { val ndarray = NDArray.array(Array(1f, 2f, 3f), shape = Array(3, 1)) NDArray.save(filename, Map("local" -> ndarray)) @@ -153,4 +265,29 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { file.delete() } } + + test("save and load without names") { + val filename + = s"${System.getProperty("java.io.tmpdir")}/ndarray-${sequence.getAndIncrement}.bin" + try { + val ndarray = NDArray.array(Array(1f, 2f, 3f), shape = Array(3, 1)) + NDArray.save(filename, Array(ndarray)) + val (keys, arrays) = NDArray.load(filename) + assert(keys.length === 0) + assert(arrays.length === 1) + val loadedArray = arrays(0) + assert(loadedArray.shape === Array(3, 1)) + assert(loadedArray.toArray === Array(1f, 2f, 3f)) + } finally { + val file = new File(filename) + file.delete() + } + } + + test("get context") { + val ndarray = NDArray.ones(3, 2) + val ctx = ndarray.context + assert(ctx.deviceType === "cpu") + assert(ctx.deviceId === 0) + } } diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index 0b316ebcdf3b..7a830a073d54 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -40,6 +40,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayWaitAll(JNIEnv *env, return MXNDArrayWaitAll(); } +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayWaitToRead + (JNIEnv *env, jobject obj, jlong arrayPtr) { + return MXNDArrayWaitToRead((NDArrayHandle) arrayPtr); +} + JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListFunctions (JNIEnv *env, jobject obj, jobject functions) { jclass longCls = env->FindClass("java/lang/Long"); @@ -183,6 +188,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySyncCopyFromCPU return ret; } +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayGetContext + (JNIEnv *env, jobject obj, jlong arrayPtr, jobject devTypeId, jobject devId) { + int outDevType; + int outDevId; + int ret = MXNDArrayGetContext((NDArrayHandle) arrayPtr, &outDevType, &outDevId); + jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefInt"); + jfieldID refFid = env->GetFieldID(refClass, "value", "I"); + env->SetIntField(devTypeId, refFid, outDevType); + env->SetIntField(devId, refFid, outDevId); + return ret; +} + JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree (JNIEnv * env, jobject obj, jlong ndArrayHandle) { return MXNDArrayFree((NDArrayHandle) ndArrayHandle); From 07f726dfede9100fab88052e6933f5a5ca101b41 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 13 Jan 2016 22:22:49 +0800 Subject: [PATCH 3/3] Use additional ptr to implement KVStore Updater callback --- .../main/scala/ml/dmlc/mxnet/KVStore.scala | 2 +- .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 4 +- .../main/scala/ml/dmlc/mxnet/Optimizer.scala | 5 +- .../scala/ml/dmlc/mxnet/KVStoreSuite.scala | 2 +- .../native/src/main/native/jni_helper_func.h | 7 ++ .../main/native/ml_dmlc_mxnet_native_c_api.cc | 71 ++++++++----------- 6 files changed, 43 insertions(+), 48 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala index 40b62b652990..11f3376874ac 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala @@ -172,7 +172,7 @@ class KVStore(private val handle: KVStoreHandle) { */ def setUpdater(updater: MXKVStoreUpdater): Unit = { this.updaterFunc = updater - checkCall(_LIB.mxKVStoreSetUpdater(handle, updaterFunc, null)) + checkCall(_LIB.mxKVStoreSetUpdater(handle, updaterFunc)) } /** diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 87ee3a729fb3..8bfdf2f9cc2e 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -76,9 +76,7 @@ class LibInfo { keys: Array[Int], outs: Array[NDArrayHandle], priority: Int): Int - @native def mxKVStoreSetUpdater(handle: KVStoreHandle, - updaterFunc: MXKVStoreUpdater, - updaterHandle: AnyRef): Int + @native def mxKVStoreSetUpdater(handle: KVStoreHandle, updaterFunc: MXKVStoreUpdater): Int @native def mxKVStoreIsWorkerNode(isWorker: RefInt): Int @native def mxKVStoreGetType(handle: KVStoreHandle, kvType: RefString): Int @native def mxKVStoreSendCommmandToServers(handle: KVStoreHandle, diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala index f9a58f5ca4db..0ebcd9123290 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala @@ -6,7 +6,7 @@ object Optimizer { def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = { new MXKVStoreUpdater { val states = new scala.collection.mutable.HashMap[Int, AnyRef] - override def update(index: Int, grad: NDArray, weight: NDArray, handle: AnyRef): Unit = { + override def update(index: Int, grad: NDArray, weight: NDArray): Unit = { val state = states.getOrElseUpdate(index, optimizer.createState(index, weight)) optimizer.update(index, weight, grad, state) } @@ -57,7 +57,6 @@ trait MXKVStoreUpdater { * @param key the key * @param recv the pushed value on this key * @param local the value stored on local on this key - * @param handle The additional handle to the updater */ - def update(key: Int, recv: NDArray, local: NDArray, handle: AnyRef = null): Unit + def update(key: Int, recv: NDArray, local: NDArray): Unit } diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala index 145ba4d0b71b..2812623b2d02 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala @@ -27,7 +27,7 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll { test("updater runs when push") { val kv = KVStore.create() val updater = new MXKVStoreUpdater { - override def update(key: Int, input: NDArray, stored: NDArray, handle: AnyRef): Unit = { + override def update(key: Int, input: NDArray, stored: NDArray): Unit = { // scalastyle:off println println(s"update on key $key") // scalastyle:on println diff --git a/scala-package/native/src/main/native/jni_helper_func.h b/scala-package/native/src/main/native/jni_helper_func.h index cce9cb0efe22..c86a451bbdce 100644 --- a/scala-package/native/src/main/native/jni_helper_func.h +++ b/scala-package/native/src/main/native/jni_helper_func.h @@ -3,6 +3,13 @@ #ifndef MXNET_SCALA_JNI_HELPER_FUNC_H #define MXNET_SCALA_JNI_HELPER_FUNC_H +// Define an env closure +// e.g. it can be used to implement java callback +typedef struct { + JNIEnv *env; + jobject obj; +} JNIClosure; + jlong getLongField(JNIEnv *env, jobject obj) { jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); jfieldID refFid = env->GetFieldID(refClass, "value", "J"); diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index 7a830a073d54..cfc419f29812 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -281,48 +281,39 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySave return ret; } -// The related c api MXKVStoreSetUpdater function takes a c function pointer as its parameter, -// while we write java functions here in scala-package. -// Thus we have to wrap the function in a java object, and run env->CallVoidMethod(obj) once updater is invoked, -// which implies the function registered to KVStore must be stateful. -// This is why we re-implement MXKVStoreSetUpdater as follows. +extern "C" void KVStoreUpdaterCallbackFunc + (int key, NDArrayHandle recv, NDArrayHandle local, void *handle) { + JNIClosure *closure = (JNIClosure *) handle; + JNIEnv *env = closure->env; + jobject updaterFuncObjGlb = closure->obj; + + // find java updater method + jclass updtClass = env->GetObjectClass(updaterFuncObjGlb); + jmethodID updtFunc = env->GetMethodID(updtClass, + "update", "(ILml/dmlc/mxnet/NDArray;Lml/dmlc/mxnet/NDArray;)V"); + + // find java NDArray constructor + jclass ndObjClass = env->FindClass("ml/dmlc/mxnet/NDArray"); + jmethodID ndObjConstructor = env->GetMethodID(ndObjClass, "", "(JZ)V"); + + jobject ndRecv = env->NewObject(ndObjClass, ndObjConstructor, (jlong)recv, true); + jobject ndLocal = env->NewObject(ndObjClass, ndObjConstructor, (jlong)local, true); + + env->CallVoidMethod(updaterFuncObjGlb, updtFunc, key, ndRecv, ndLocal); + // FIXME: This function can be called multiple times, + // can we find a way to safely destroy these two objects ? + // env->DeleteGlobalRef(updaterFuncObjGlb); + // delete closure; +} + JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSetUpdater - (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject updaterFuncObj, jobject updaterHandle) { + (JNIEnv *env, jobject obj, jlong kvStorePtr, jobject updaterFuncObj) { jobject updaterFuncObjGlb = env->NewGlobalRef(updaterFuncObj); - jobject updaterHandleGlb = env->NewGlobalRef(updaterHandle); - std::function updt - = [env, updaterFuncObjGlb, updaterHandleGlb](int key, const mxnet::NDArray& recv, mxnet::NDArray* local) { - // find java updater method - jclass updtClass = env->GetObjectClass(updaterFuncObjGlb); - jmethodID updtFunc = env->GetMethodID(updtClass, - "update", "(ILml/dmlc/mxnet/NDArray;Lml/dmlc/mxnet/NDArray;Ljava/lang/Object;)V"); - - // find java NDArray constructor - jclass ndObjClass = env->FindClass("ml/dmlc/mxnet/NDArray"); - jmethodID ndObjConstructor = env->GetMethodID(ndObjClass, "", "(JZ)V"); - - mxnet::NDArray *recvCopy = new mxnet::NDArray(); - *recvCopy = recv; - jobject jNdRecvCopy = env->NewObject(ndObjClass, ndObjConstructor, (jlong)recvCopy, true); - - mxnet::NDArray *localCopy = new mxnet::NDArray(); - *localCopy = *local; - jobject jNdLocalCopy = env->NewObject(ndObjClass, ndObjConstructor, (jlong)localCopy, true); - - env->CallVoidMethod(updaterFuncObjGlb, updtFunc, key, jNdRecvCopy, jNdLocalCopy, updaterHandleGlb); - env->DeleteGlobalRef(updaterFuncObjGlb); - env->DeleteGlobalRef(updaterHandleGlb); - }; - try { - static_cast((KVStoreHandle)kvStorePtr)->set_updater(updt); - } catch(dmlc::Error &except) { - // It'll be too complicated to set & get mx error in jni code. - // thus simply return -1 to indicate a failure. - // Notice that we'll NOT be able to run MXGetLastError - // to get the error message after this function fails. - return -1; - } - return 0; + JNIClosure *closure = new JNIClosure(); + closure->env = env; + closure->obj = updaterFuncObjGlb; + return MXKVStoreSetUpdater((KVStoreHandle) kvStorePtr, + KVStoreUpdaterCallbackFunc, (void *) closure); } JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreCreate