From 3f40cd4d633ef49810923a61759ef4af8f3960a2 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 6 Mar 2016 14:55:59 +0800 Subject: [PATCH 1/2] Fix MXDataIter NPE bug. Throw NoSuchElementException when calling next() at the end of the dataItet, follow the same way as java --- scala-package/README.md | 6 +- .../src/main/scala/ml/dmlc/mxnet/IO.scala | 5 +- .../src/main/scala/ml/dmlc/mxnet/Model.scala | 24 ++--- .../scala/ml/dmlc/mxnet/io/MXDataIter.scala | 36 +++---- .../scala/ml/dmlc/mxnet/io/NDArrayIter.scala | 8 +- .../ml/dmlc/mxnet/io/PrefetchingIter.scala | 2 + .../dmlc/mxnet/util/NativeLibraryLoader.scala | 2 +- .../test/scala/ml/dmlc/mxnet/IOSuite.scala | 99 +++++++++---------- .../scala/ml/dmlc/mxnet/train/ConvSuite.scala | 5 +- 9 files changed, 94 insertions(+), 93 deletions(-) diff --git a/scala-package/README.md b/scala-package/README.md index f4fc412d8bdb..8f95bdcf8c81 100644 --- a/scala-package/README.md +++ b/scala-package/README.md @@ -45,7 +45,6 @@ Here is a Scala example of how training a simple 3-layer MLP on MNIST looks like ```scala import ml.dmlc.mxnet._ import ml.dmlc.mxnet.optimizer.SGD -import org.slf4j.LoggerFactory // model definition val data = Symbol.Variable("data") @@ -95,10 +94,9 @@ val prob = probArrays(0) import scala.collection.mutable.ListBuffer valDataIter.reset() val labels = ListBuffer.empty[NDArray] -var evalData = valDataIter.next() -while (evalData != null) { +while (valDataIter.hasNext) { + val evalData = valDataIter.next() labels += evalData.label(0).copy() - evalData = valDataIter.next() } val y = NDArray.concatenate(labels) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index b8a0eeb4548d..183967266c0e 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -118,16 +118,19 @@ case class DataBatch(data: IndexedSeq[NDArray], /** * DataIter object in mxnet. */ -abstract class DataIter(val batchSize: Int = 0) extends Iterator[DataBatch] { +abstract class DataIter extends Iterator[DataBatch] { /** * reset the iterator */ def reset(): Unit + def batchSize: Int + /** * get next data batch from iterator * @return */ + @throws(classOf[NoSuchElementException]) def next(): DataBatch = { new DataBatch(getData(), getLabel(), getIndex(), getPad()) } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala index 83294e57f59b..3aa573afb4f6 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala @@ -242,9 +242,8 @@ object Model { trainData.reset() while (!epochDone) { var doReset = true - // TODO: make DataIter implement Iterator - var dataBatch = trainData.next() - while (doReset && dataBatch != null) { + while (doReset && trainData.hasNext) { + val dataBatch = trainData.next() executorManager.loadDataBatch(dataBatch) monitor.foreach(_.tic()) executorManager.forward(isTrain = true) @@ -271,7 +270,6 @@ object Model { doReset = false } dataBatch.dispose() - dataBatch = trainData.next() } if (doReset) { trainData.reset() @@ -290,13 +288,12 @@ object Model { evalMetric.reset() evalDataIter.reset() // TODO: make DataIter implement Iterator - var evalBatch = evalDataIter.next() - while (evalBatch != null) { + while (evalDataIter.hasNext) { + val evalBatch = evalDataIter.next() executorManager.loadDataBatch(evalBatch) executorManager.forward(isTrain = false) evalMetric.update(evalBatch.label, executorManager.cpuOutputArrays) evalBatch.dispose() - evalBatch = evalDataIter.next() } val (name, value) = evalMetric.get @@ -477,8 +474,8 @@ class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(Context.cp val outputs = Array.fill(predExec.outputs.length)(ListBuffer.empty[NDArray]) var i = 0 - var batch = data.next() - while (batch != null && i != numBatch) { + while (data.hasNext && i != numBatch) { + val batch = data.next() i += 1 Executor.loadData(batch, dataArrays) predExec.forward(isTrain = false) @@ -487,10 +484,13 @@ class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(Context.cp for ((list, nd) <- outputs zip predExec.outputs) { list += nd.slice(0, realSize).copy() } - batch = data.next() } - // TODO: we can use Symbol.concat to do the same thing. Can it be more efficient? - outputs.map(NDArray.concatenate(_)) + // TODO(Yizhi): we can use Symbol.concat to do the same thing. Can it be more efficient? + val results = outputs.map(NDArray.concatenate(_)) + for (output <- outputs) { + output.foreach(_.dispose()) + } + results } /** diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala index c71f9dfa3afd..41e9ef1cf9b4 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala @@ -21,17 +21,22 @@ class MXDataIter private[mxnet](private[mxnet] val handle: DataIterHandle, // (may be this is not the best way to do this work, // fix me if any better way found) private var currentBatch: DataBatch = null - iterNext() - private val data = currentBatch.data(0) - private val label = currentBatch.label(0) - reset() - - // properties - val _provideData: Map[String, Shape] = Map(dataName -> data.shape) - val _provideLabel: Map[String, Shape] = Map(labelName -> label.shape) - override val batchSize = data.shape(0) - private var disposed = false + private val (_provideData: Map[String, Shape], + _provideLabel: Map[String, Shape], + _batchSize: Int) = + if (hasNext) { + iterNext() + val data = currentBatch.data(0) + val label = currentBatch.label(0) + reset() + // properties + (Map(dataName -> data.shape), Map(labelName -> label.shape), data.shape(0)) + } else { + (null, null, 0) + } + + private var disposed = false override protected def finalize(): Unit = { dispose() } @@ -51,16 +56,12 @@ class MXDataIter private[mxnet](private[mxnet] val handle: DataIterHandle, * reset the iterator */ override def reset(): Unit = { - // TODO: self._debug_at_begin = True currentBatch = null checkCall(_LIB.mxDataIterBeforeFirst(handle)) } + @throws(classOf[NoSuchElementException]) override def next(): DataBatch = { - // TODO - // if self._debug_skip_load and not self._debug_at_begin: - // return DataBatch(data =[self.getdata()], label =[self.getlabel()], - // pad = self.getpad(), index = self.getindex()) if (currentBatch == null) { iterNext() } @@ -70,8 +71,7 @@ class MXDataIter private[mxnet](private[mxnet] val handle: DataIterHandle, currentBatch = null batch } else { - // TODO raise StopIteration - null + throw new NoSuchElementException } } @@ -145,6 +145,8 @@ class MXDataIter private[mxnet](private[mxnet] val handle: DataIterHandle, iterNext() } } + + override def batchSize: Int = _batchSize } // scalastyle:on finalize diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/NDArrayIter.scala index 32997dadf33d..6d8a1283f856 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/NDArrayIter.scala @@ -8,7 +8,7 @@ import ml.dmlc.mxnet.{DataIter, NDArray, Shape} * NDArrayIter object in mxnet. Taking NDArray or numpy array to get dataiter. * @param data NDArrayIter supports single or multiple data and label. * @param label Same as data, but is not fed to the model during testing. - * @param batchSize Batch Size + * @param dataBatchSize Batch Size * @param shuffle Whether to shuffle the data * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch * @note @@ -17,8 +17,8 @@ import ml.dmlc.mxnet.{DataIter, NDArray, Shape} * for training and can cause problems if used for prediction. */ class NDArrayIter(data: NDArray, label: NDArray = null, - batchSize: Int = 1, shuffle: Boolean = false, - lastBatchHandle: String = "pad") extends DataIter(batchSize) { + private val dataBatchSize: Int = 1, shuffle: Boolean = false, + lastBatchHandle: String = "pad") extends DataIter { /** * reset the iterator */ @@ -56,4 +56,6 @@ class NDArrayIter(data: NDArray, label: NDArray = null, override def provideLabel: Map[String, Shape] = ??? override def hasNext: Boolean = ??? + + override def batchSize: Int = dataBatchSize } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/PrefetchingIter.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/PrefetchingIter.scala index 1419e76ff98a..e14764f1785a 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/PrefetchingIter.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/io/PrefetchingIter.scala @@ -51,4 +51,6 @@ class PrefetchingIter(val iters: List[DataIter], override def provideData: Map[String, Shape] = ??? override def hasNext: Boolean = ??? + + override def batchSize: Int = ??? } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/util/NativeLibraryLoader.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/util/NativeLibraryLoader.scala index 744697de3344..2070240a09f4 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/util/NativeLibraryLoader.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/util/NativeLibraryLoader.scala @@ -80,7 +80,7 @@ object NativeLibraryLoader { */ private def unifyOSName(osname: String): String = { if (osname.startsWith("Windows")) { - return "Windows" + "Windows" } osname } diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala index 326248114fcb..78fd57a62981 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala @@ -39,10 +39,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { // test_loop mnistIter.reset() batchCount = 0 - var batch = mnistIter.next() - while (batch != null) { + while (mnistIter.hasNext) { + mnistIter.next() batchCount += 1 - batch = mnistIter.next() } // test loop assert(nBatch === batchCount) @@ -66,54 +65,50 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { /** * default skip this test for saving time */ -// test("test ImageRecordIter") { -// // get data -// "./scripts/get_cifar_data.sh" ! -// -// val params = Map( -// "path_imgrec" -> "data/cifar/train.rec", -// "mean_img" -> "data/cifar/cifar10_mean.bin", -// "rand_crop" -> "False", -// "and_mirror" -> "False", -// "shuffle" -> "False", -// "data_shape" -> "(3,28,28)", -// "batch_size" -> "100", -// "preprocess_threads" -> "4", -// "prefetch_buffer" -> "1" -// ) -// val imgRecIter = IO.ImageRecordIter(params) -// val nBatch = 500 -// var batchCount = 0 -// // test provideData -// val provideData = imgRecIter.provideData -// val provideLabel = imgRecIter.provideLabel -// assert(provideData("data") === Array(100, 3, 28, 28)) -// assert(provideLabel("label") === Array(100)) -// -// imgRecIter.reset() -// while(imgRecIter.hasNext()) { -// val batch = imgRecIter.next() -// batchCount += 1 -// } -// // test loop -// assert(batchCount === nBatch) -// // test reset -// imgRecIter.reset() -// imgRecIter.next() -// val label0 = imgRecIter.getLabel().head.toArray -// val data0 = imgRecIter.getData().head.toArray -// imgRecIter.reset() -// imgRecIter.reset() -// imgRecIter.reset() -// imgRecIter.reset() -// imgRecIter.reset() -// val label1 = imgRecIter.getLabel().head.toArray -// val data1 = imgRecIter.getData().head.toArray -// assert(label0 === label1) -// assert(data0 === data1) -// } + test("test ImageRecordIter") { + // get data + "./scripts/get_cifar_data.sh" ! + + val params = Map( + "path_imgrec" -> "data/cifar/train.rec", + "mean_img" -> "data/cifar/cifar10_mean.bin", + "rand_crop" -> "False", + "and_mirror" -> "False", + "shuffle" -> "False", + "data_shape" -> "(3,28,28)", + "batch_size" -> "100", + "preprocess_threads" -> "4", + "prefetch_buffer" -> "1" + ) + val imgRecIter = IO.ImageRecordIter(params) + val nBatch = 500 + var batchCount = 0 + // test provideData + val provideData = imgRecIter.provideData + val provideLabel = imgRecIter.provideLabel + assert(provideData("data").toArray === Array(100, 3, 28, 28)) + assert(provideLabel("label").toArray === Array(100)) -// test("test NDarryIter") { -// -// } + imgRecIter.reset() + while (imgRecIter.hasNext) { + imgRecIter.next() + batchCount += 1 + } + // test loop + assert(batchCount === nBatch) + // test reset + imgRecIter.reset() + imgRecIter.next() + val label0 = imgRecIter.getLabel().head.toArray + val data0 = imgRecIter.getData().head.toArray + imgRecIter.reset() + imgRecIter.reset() + imgRecIter.reset() + imgRecIter.reset() + imgRecIter.reset() + val label1 = imgRecIter.getLabel().head.toArray + val data1 = imgRecIter.getData().head.toArray + assert(label0 === label1) + assert(data0 === data1) + } } diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/train/ConvSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/train/ConvSuite.scala index 4ef62a5f23f5..45f1726642e1 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/train/ConvSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/train/ConvSuite.scala @@ -70,10 +70,9 @@ class ConvSuite extends FunSuite with BeforeAndAfterAll { valDataIter.reset() val labels = ListBuffer.empty[NDArray] - var evalData = valDataIter.next() - while (evalData != null) { + while (valDataIter.hasNext) { + val evalData = valDataIter.next() labels += evalData.label(0).copy() - evalData = valDataIter.next() } val y = NDArray.concatenate(labels) From 0f50e29750611b68f387bf221c9789dc57bf27cc Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 6 Mar 2016 17:05:14 +0800 Subject: [PATCH 2/2] Monitor callback --- .../src/main/scala/ml/dmlc/mxnet/Model.scala | 13 ++++- .../imclassification/ModelTrain.scala | 5 +- .../imclassification/TrainMnist.scala | 31 +++++----- .../main/native/ml_dmlc_mxnet_native_c_api.cc | 58 +++++++++---------- 4 files changed, 59 insertions(+), 48 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala index 3aa573afb4f6..21b8fd3d11c3 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala @@ -385,6 +385,16 @@ class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(Context.cp // internal helper state var predExec: Executor = null + private var monitor: Option[Monitor] = None + + def setMonitor(m: Monitor): Unit = { + monitor = Option(m) + } + + def unsetMonitor(): Unit = { + setMonitor(null) + } + // Initialize weight parameters and auxiliary states private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean = false) : (Seq[String], Seq[String], Seq[String]) = { @@ -601,7 +611,8 @@ class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(Context.cp evalMetric = evalMetric, epochEndCallback = Option(epochEndCallback), batchEndCallback = Option(batchEndCallback), - logger = logger, workLoadList = workLoadList) + logger = logger, workLoadList = workLoadList, + monitor = monitor) } /** diff --git a/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/ModelTrain.scala b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/ModelTrain.scala index 2311d1e2a15b..d2605a152b4a 100644 --- a/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/ModelTrain.scala +++ b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/ModelTrain.scala @@ -13,7 +13,7 @@ object ModelTrain { network: Symbol, dataLoader: (String, Int, KVStore) => (DataIter, DataIter), kvStore: String, numEpochs: Int, modelPrefix: String = null, loadEpoch: Int = -1, lr: Float = 0.1f, lrFactor: Float = 1f, lrFactorEpoch: Float = 1f, - clipGradient: Float = 0f): Unit = { + clipGradient: Float = 0f, monitorSize: Int = -1): Unit = { // kvstore // TODO: if local mode and no gpu is used, set kv = null val kv = KVStore.create(kvStore) @@ -71,6 +71,9 @@ object ModelTrain { auxParams = auxParams, beginEpoch = beginEpoch, epochSize = epochSize) + if (monitorSize > 0) { + model.setMonitor(new Monitor(monitorSize)) + } model.fit(trainData = train, evalData = validation, evalMetric = new Accuracy(), diff --git a/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/TrainMnist.scala b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/TrainMnist.scala index 44facfab5820..88e0ce1c2ecd 100644 --- a/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/TrainMnist.scala +++ b/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/imclassification/TrainMnist.scala @@ -96,7 +96,8 @@ object TrainMnist { network = net, dataLoader = getIterator(dataShape), kvStore = inst.kvStore, numEpochs = inst.numEpochs, modelPrefix = inst.modelPrefix, loadEpoch = inst.loadEpoch, - lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch = inst.lrFactorEpoch) + lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch = inst.lrFactorEpoch, + monitorSize = inst.monitor) logger.info("Finish fit ...") } catch { case ex: Exception => { @@ -110,30 +111,32 @@ object TrainMnist { class TrainMnist { @Option(name = "--network", usage = "the cnn to use: ['mlp', 'lenet']") - private var network: String = "mlp" + private val network: String = "mlp" @Option(name = "--data-dir", usage = "the input data directory") - private var dataDir: String = "mnist/" + private val dataDir: String = "mnist/" @Option(name = "--gpus", usage = "the gpus will be used, e.g. '0,1,2,3'") - private var gpus: String = _ + private val gpus: String = null @Option(name = "--cpus", usage = "the cpus will be used, e.g. '0,1,2,3'") - private var cpus: String = _ + private val cpus: String = null @Option(name = "--num-examples", usage = "the number of training examples") - private var numExamples: Int = 60000 + private val numExamples: Int = 60000 @Option(name = "--batch-size", usage = "the batch size") - private var batchSize: Int = 128 + private val batchSize: Int = 128 @Option(name = "--lr", usage = "the initial learning rate") - private var lr: Float = 0.1f + private val lr: Float = 0.1f @Option(name = "--model-prefix", usage = "the prefix of the model to load/save") - private var modelPrefix: String = _ + private val modelPrefix: String = null @Option(name = "--num-epochs", usage = "the number of training epochs") - private var numEpochs = 10 + private val numEpochs = 10 @Option(name = "--load-epoch", usage = "load the model on an epoch using the model-prefix") - private var loadEpoch: Int = -1 + private val loadEpoch: Int = -1 @Option(name = "--kv-store", usage = "the kvstore type") - private var kvStore = "local" + private val kvStore = "local" @Option(name = "--lr-factor", usage = "times the lr with a factor for every lr-factor-epoch epoch") - private var lrFactor: Float = 1f + private val lrFactor: Float = 1f @Option(name = "--lr-factor-epoch", usage = "the number of epoch to factor the lr, could be .5") - private var lrFactorEpoch: Float = 1f + private val lrFactorEpoch: Float = 1f + @Option(name = "--monitor", usage = "monitor the training process every N batch") + private val monitor: Int = -1 } diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index 7fe18f24f444..e3c0173a0d50 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -317,8 +317,8 @@ extern "C" void KVStoreUpdaterCallbackFunc env->DeleteLocalRef(ndRecv); env->DeleteLocalRef(ndObjClass); env->DeleteLocalRef(updtClass); - // FIXME: This function can be called multiple times, - // can we find a way to safely destroy these two objects ? + // FIXME(Yizhi): This function can be called multiple times, + // can we find a way to safely destroy this object ? // env->DeleteGlobalRef(updaterFuncObjGlb); } @@ -490,39 +490,33 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorPrint return ret; } +extern "C" void ExecutorMonitorCallbackFunc + (const char *name, NDArrayHandle arr, void *handle) { + jobject callbackFuncObjGlb = static_cast(handle); + + JNIEnv *env; + _jvm->AttachCurrentThread(reinterpret_cast(&env), NULL); + + // find java callback method + jclass callbackClass = env->GetObjectClass(callbackFuncObjGlb); + jmethodID callbackFunc = env->GetMethodID(callbackClass, "invoke", "(Ljava/lang/String;J)V"); + + // invoke java callback method + jstring jname = env->NewStringUTF(name); + env->CallVoidMethod(callbackFuncObjGlb, callbackFunc, jname, reinterpret_cast(arr)); + env->DeleteLocalRef(jname); + + env->DeleteLocalRef(callbackClass); + // FIXME(Yizhi): This function can be called multiple times, + // can we find a way to safely destroy this global ref ? + // env->DeleteGlobalRef(callbackFuncObjGlb); +} JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorSetMonitorCallback (JNIEnv *env, jobject obj, jlong executorPtr, jobject callbackFuncObj) { jobject callbackFuncObjGlb = env->NewGlobalRef(callbackFuncObj); - std::function callback - = [env, callbackFuncObjGlb](const char *name, NDArrayHandle array) { - // find java callback method - jclass callbackClass = env->GetObjectClass(callbackFuncObjGlb); - jmethodID invokeFunc = env->GetMethodID(callbackClass, - "invoke", "(Ljava/lang/String;Lml/dmlc/mxnet/Base$RefLong;)V"); - - jstring jname = env->NewStringUTF(name); - // ndArray handle - jclass ndHandleClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jmethodID ndHandleCont = env->GetMethodID(ndHandleClass, "", "(J)V"); - jobject jNDArrayHandle - = env->NewObject(ndHandleClass, ndHandleCont, reinterpret_cast(array)); - - env->CallVoidMethod(callbackFuncObjGlb, invokeFunc, jname, jNDArrayHandle); - env->DeleteGlobalRef(callbackFuncObjGlb); - }; - /* TODO: we need to modify Executor::SetMonitorCallback, make it take std::function as param - try { - mxnet::Executor *exec = static_cast((ExecutorHandle)executorPtr); - exec->SetMonitorCallback(callback); - } catch(dmlc::Error &except) { - // It'll be too complicated to set & get mx error in jni code. - // thus simply return -1 to indicate a failure. - // Notice that we'll NOT be able to run MXGetLastError - // to get the error message after this function fails. - return -1; - } - */ - return 0; + return MXExecutorSetMonitorCallback(reinterpret_cast(executorPtr), + ExecutorMonitorCallbackFunc, + reinterpret_cast(callbackFuncObjGlb)); } JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env, jobject obj) {