Skip to content

Commit

Permalink
Monitor callback
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu committed Mar 6, 2016
1 parent 3f40cd4 commit 0f50e29
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 48 deletions.
13 changes: 12 additions & 1 deletion scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,16 @@ class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(Context.cp
// internal helper state
var predExec: Executor = null

private var monitor: Option[Monitor] = None

def setMonitor(m: Monitor): Unit = {
monitor = Option(m)
}

def unsetMonitor(): Unit = {
setMonitor(null)
}

// Initialize weight parameters and auxiliary states
private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean = false)
: (Seq[String], Seq[String], Seq[String]) = {
Expand Down Expand Up @@ -601,7 +611,8 @@ class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(Context.cp
evalMetric = evalMetric,
epochEndCallback = Option(epochEndCallback),
batchEndCallback = Option(batchEndCallback),
logger = logger, workLoadList = workLoadList)
logger = logger, workLoadList = workLoadList,
monitor = monitor)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ object ModelTrain {
network: Symbol, dataLoader: (String, Int, KVStore) => (DataIter, DataIter),
kvStore: String, numEpochs: Int, modelPrefix: String = null, loadEpoch: Int = -1,
lr: Float = 0.1f, lrFactor: Float = 1f, lrFactorEpoch: Float = 1f,
clipGradient: Float = 0f): Unit = {
clipGradient: Float = 0f, monitorSize: Int = -1): Unit = {
// kvstore
// TODO: if local mode and no gpu is used, set kv = null
val kv = KVStore.create(kvStore)
Expand Down Expand Up @@ -71,6 +71,9 @@ object ModelTrain {
auxParams = auxParams,
beginEpoch = beginEpoch,
epochSize = epochSize)
if (monitorSize > 0) {
model.setMonitor(new Monitor(monitorSize))
}
model.fit(trainData = train,
evalData = validation,
evalMetric = new Accuracy(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ object TrainMnist {
network = net, dataLoader = getIterator(dataShape),
kvStore = inst.kvStore, numEpochs = inst.numEpochs,
modelPrefix = inst.modelPrefix, loadEpoch = inst.loadEpoch,
lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch = inst.lrFactorEpoch)
lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch = inst.lrFactorEpoch,
monitorSize = inst.monitor)
logger.info("Finish fit ...")
} catch {
case ex: Exception => {
Expand All @@ -110,30 +111,32 @@ object TrainMnist {

class TrainMnist {
@Option(name = "--network", usage = "the cnn to use: ['mlp', 'lenet']")
private var network: String = "mlp"
private val network: String = "mlp"
@Option(name = "--data-dir", usage = "the input data directory")
private var dataDir: String = "mnist/"
private val dataDir: String = "mnist/"
@Option(name = "--gpus", usage = "the gpus will be used, e.g. '0,1,2,3'")
private var gpus: String = _
private val gpus: String = null
@Option(name = "--cpus", usage = "the cpus will be used, e.g. '0,1,2,3'")
private var cpus: String = _
private val cpus: String = null
@Option(name = "--num-examples", usage = "the number of training examples")
private var numExamples: Int = 60000
private val numExamples: Int = 60000
@Option(name = "--batch-size", usage = "the batch size")
private var batchSize: Int = 128
private val batchSize: Int = 128
@Option(name = "--lr", usage = "the initial learning rate")
private var lr: Float = 0.1f
private val lr: Float = 0.1f
@Option(name = "--model-prefix", usage = "the prefix of the model to load/save")
private var modelPrefix: String = _
private val modelPrefix: String = null
@Option(name = "--num-epochs", usage = "the number of training epochs")
private var numEpochs = 10
private val numEpochs = 10
@Option(name = "--load-epoch", usage = "load the model on an epoch using the model-prefix")
private var loadEpoch: Int = -1
private val loadEpoch: Int = -1
@Option(name = "--kv-store", usage = "the kvstore type")
private var kvStore = "local"
private val kvStore = "local"
@Option(name = "--lr-factor",
usage = "times the lr with a factor for every lr-factor-epoch epoch")
private var lrFactor: Float = 1f
private val lrFactor: Float = 1f
@Option(name = "--lr-factor-epoch", usage = "the number of epoch to factor the lr, could be .5")
private var lrFactorEpoch: Float = 1f
private val lrFactorEpoch: Float = 1f
@Option(name = "--monitor", usage = "monitor the training process every N batch")
private val monitor: Int = -1
}
58 changes: 26 additions & 32 deletions scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ extern "C" void KVStoreUpdaterCallbackFunc
env->DeleteLocalRef(ndRecv);
env->DeleteLocalRef(ndObjClass);
env->DeleteLocalRef(updtClass);
// FIXME: This function can be called multiple times,
// can we find a way to safely destroy these two objects ?
// FIXME(Yizhi): This function can be called multiple times,
// can we find a way to safely destroy this object ?
// env->DeleteGlobalRef(updaterFuncObjGlb);
}

Expand Down Expand Up @@ -490,39 +490,33 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorPrint
return ret;
}

extern "C" void ExecutorMonitorCallbackFunc
(const char *name, NDArrayHandle arr, void *handle) {
jobject callbackFuncObjGlb = static_cast<jobject>(handle);

JNIEnv *env;
_jvm->AttachCurrentThread(reinterpret_cast<void **>(&env), NULL);

// find java callback method
jclass callbackClass = env->GetObjectClass(callbackFuncObjGlb);
jmethodID callbackFunc = env->GetMethodID(callbackClass, "invoke", "(Ljava/lang/String;J)V");

// invoke java callback method
jstring jname = env->NewStringUTF(name);
env->CallVoidMethod(callbackFuncObjGlb, callbackFunc, jname, reinterpret_cast<jlong>(arr));
env->DeleteLocalRef(jname);

env->DeleteLocalRef(callbackClass);
// FIXME(Yizhi): This function can be called multiple times,
// can we find a way to safely destroy this global ref ?
// env->DeleteGlobalRef(callbackFuncObjGlb);
}
JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorSetMonitorCallback
(JNIEnv *env, jobject obj, jlong executorPtr, jobject callbackFuncObj) {
jobject callbackFuncObjGlb = env->NewGlobalRef(callbackFuncObj);
std::function<void(const char *, NDArrayHandle)> callback
= [env, callbackFuncObjGlb](const char *name, NDArrayHandle array) {
// find java callback method
jclass callbackClass = env->GetObjectClass(callbackFuncObjGlb);
jmethodID invokeFunc = env->GetMethodID(callbackClass,
"invoke", "(Ljava/lang/String;Lml/dmlc/mxnet/Base$RefLong;)V");

jstring jname = env->NewStringUTF(name);
// ndArray handle
jclass ndHandleClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jmethodID ndHandleCont = env->GetMethodID(ndHandleClass, "<init>", "(J)V");
jobject jNDArrayHandle
= env->NewObject(ndHandleClass, ndHandleCont, reinterpret_cast<uint64_t>(array));

env->CallVoidMethod(callbackFuncObjGlb, invokeFunc, jname, jNDArrayHandle);
env->DeleteGlobalRef(callbackFuncObjGlb);
};
/* TODO: we need to modify Executor::SetMonitorCallback, make it take std::function as param
try {
mxnet::Executor *exec = static_cast<mxnet::Executor*>((ExecutorHandle)executorPtr);
exec->SetMonitorCallback(callback);
} catch(dmlc::Error &except) {
// It'll be too complicated to set & get mx error in jni code.
// thus simply return -1 to indicate a failure.
// Notice that we'll NOT be able to run MXGetLastError
// to get the error message after this function fails.
return -1;
}
*/
return 0;
return MXExecutorSetMonitorCallback(reinterpret_cast<ExecutorHandle>(executorPtr),
ExecutorMonitorCallbackFunc,
reinterpret_cast<void *>(callbackFuncObjGlb));
}

JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env, jobject obj) {
Expand Down

0 comments on commit 0f50e29

Please sign in to comment.