Skip to content

Commit

Permalink
Merge pull request apache#38 from javelinjs/scala-package-l
Browse files Browse the repository at this point in the history
DataIter bug fix & monitor callback setup
  • Loading branch information
yanqingmen committed Mar 6, 2016
2 parents 9e5137f + 0f50e29 commit 15864ff
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 141 deletions.
6 changes: 2 additions & 4 deletions scala-package/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ Here is a Scala example of how training a simple 3-layer MLP on MNIST looks like
```scala
import ml.dmlc.mxnet._
import ml.dmlc.mxnet.optimizer.SGD
import org.slf4j.LoggerFactory

// model definition
val data = Symbol.Variable("data")
Expand Down Expand Up @@ -95,10 +94,9 @@ val prob = probArrays(0)
import scala.collection.mutable.ListBuffer
valDataIter.reset()
val labels = ListBuffer.empty[NDArray]
var evalData = valDataIter.next()
while (evalData != null) {
while (valDataIter.hasNext) {
val evalData = valDataIter.next()
labels += evalData.label(0).copy()
evalData = valDataIter.next()
}
val y = NDArray.concatenate(labels)

Expand Down
5 changes: 4 additions & 1 deletion scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,19 @@ case class DataBatch(data: IndexedSeq[NDArray],
/**
* DataIter object in mxnet.
*/
abstract class DataIter(val batchSize: Int = 0) extends Iterator[DataBatch] {
abstract class DataIter extends Iterator[DataBatch] {
/**
* reset the iterator
*/
def reset(): Unit

def batchSize: Int

/**
* get next data batch from iterator
* @return
*/
@throws(classOf[NoSuchElementException])
def next(): DataBatch = {
new DataBatch(getData(), getLabel(), getIndex(), getPad())
}
Expand Down
37 changes: 24 additions & 13 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,8 @@ object Model {
trainData.reset()
while (!epochDone) {
var doReset = true
// TODO: make DataIter implement Iterator
var dataBatch = trainData.next()
while (doReset && dataBatch != null) {
while (doReset && trainData.hasNext) {
val dataBatch = trainData.next()
executorManager.loadDataBatch(dataBatch)
monitor.foreach(_.tic())
executorManager.forward(isTrain = true)
Expand All @@ -271,7 +270,6 @@ object Model {
doReset = false
}
dataBatch.dispose()
dataBatch = trainData.next()
}
if (doReset) {
trainData.reset()
Expand All @@ -290,13 +288,12 @@ object Model {
evalMetric.reset()
evalDataIter.reset()
// TODO: make DataIter implement Iterator
var evalBatch = evalDataIter.next()
while (evalBatch != null) {
while (evalDataIter.hasNext) {
val evalBatch = evalDataIter.next()
executorManager.loadDataBatch(evalBatch)
executorManager.forward(isTrain = false)
evalMetric.update(evalBatch.label, executorManager.cpuOutputArrays)
evalBatch.dispose()
evalBatch = evalDataIter.next()
}

val (name, value) = evalMetric.get
Expand Down Expand Up @@ -388,6 +385,16 @@ class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(Context.cp
// internal helper state
var predExec: Executor = null

private var monitor: Option[Monitor] = None

def setMonitor(m: Monitor): Unit = {
monitor = Option(m)
}

def unsetMonitor(): Unit = {
setMonitor(null)
}

// Initialize weight parameters and auxiliary states
private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean = false)
: (Seq[String], Seq[String], Seq[String]) = {
Expand Down Expand Up @@ -477,8 +484,8 @@ class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(Context.cp
val outputs = Array.fill(predExec.outputs.length)(ListBuffer.empty[NDArray])

var i = 0
var batch = data.next()
while (batch != null && i != numBatch) {
while (data.hasNext && i != numBatch) {
val batch = data.next()
i += 1
Executor.loadData(batch, dataArrays)
predExec.forward(isTrain = false)
Expand All @@ -487,10 +494,13 @@ class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(Context.cp
for ((list, nd) <- outputs zip predExec.outputs) {
list += nd.slice(0, realSize).copy()
}
batch = data.next()
}
// TODO: we can use Symbol.concat to do the same thing. Can it be more efficient?
outputs.map(NDArray.concatenate(_))
// TODO(Yizhi): we can use Symbol.concat to do the same thing. Can it be more efficient?
val results = outputs.map(NDArray.concatenate(_))
for (output <- outputs) {
output.foreach(_.dispose())
}
results
}

/**
Expand Down Expand Up @@ -601,7 +611,8 @@ class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(Context.cp
evalMetric = evalMetric,
epochEndCallback = Option(epochEndCallback),
batchEndCallback = Option(batchEndCallback),
logger = logger, workLoadList = workLoadList)
logger = logger, workLoadList = workLoadList,
monitor = monitor)
}

/**
Expand Down
36 changes: 19 additions & 17 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/io/MXDataIter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,22 @@ class MXDataIter private[mxnet](private[mxnet] val handle: DataIterHandle,
// (may be this is not the best way to do this work,
// fix me if any better way found)
private var currentBatch: DataBatch = null
iterNext()
private val data = currentBatch.data(0)
private val label = currentBatch.label(0)
reset()

// properties
val _provideData: Map[String, Shape] = Map(dataName -> data.shape)
val _provideLabel: Map[String, Shape] = Map(labelName -> label.shape)
override val batchSize = data.shape(0)
private var disposed = false

private val (_provideData: Map[String, Shape],
_provideLabel: Map[String, Shape],
_batchSize: Int) =
if (hasNext) {
iterNext()
val data = currentBatch.data(0)
val label = currentBatch.label(0)
reset()
// properties
(Map(dataName -> data.shape), Map(labelName -> label.shape), data.shape(0))
} else {
(null, null, 0)
}

private var disposed = false
override protected def finalize(): Unit = {
dispose()
}
Expand All @@ -51,16 +56,12 @@ class MXDataIter private[mxnet](private[mxnet] val handle: DataIterHandle,
* reset the iterator
*/
override def reset(): Unit = {
// TODO: self._debug_at_begin = True
currentBatch = null
checkCall(_LIB.mxDataIterBeforeFirst(handle))
}

@throws(classOf[NoSuchElementException])
override def next(): DataBatch = {
// TODO
// if self._debug_skip_load and not self._debug_at_begin:
// return DataBatch(data =[self.getdata()], label =[self.getlabel()],
// pad = self.getpad(), index = self.getindex())
if (currentBatch == null) {
iterNext()
}
Expand All @@ -70,8 +71,7 @@ class MXDataIter private[mxnet](private[mxnet] val handle: DataIterHandle,
currentBatch = null
batch
} else {
// TODO raise StopIteration
null
throw new NoSuchElementException
}
}

Expand Down Expand Up @@ -145,6 +145,8 @@ class MXDataIter private[mxnet](private[mxnet] val handle: DataIterHandle,
iterNext()
}
}

override def batchSize: Int = _batchSize
}

// scalastyle:on finalize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import ml.dmlc.mxnet.{DataIter, NDArray, Shape}
* NDArrayIter object in mxnet. Taking NDArray or numpy array to get dataiter.
* @param data NDArrayIter supports single or multiple data and label.
* @param label Same as data, but is not fed to the model during testing.
* @param batchSize Batch Size
* @param dataBatchSize Batch Size
* @param shuffle Whether to shuffle the data
* @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch
* @note
Expand All @@ -17,8 +17,8 @@ import ml.dmlc.mxnet.{DataIter, NDArray, Shape}
* for training and can cause problems if used for prediction.
*/
class NDArrayIter(data: NDArray, label: NDArray = null,
batchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad") extends DataIter(batchSize) {
private val dataBatchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad") extends DataIter {
/**
* reset the iterator
*/
Expand Down Expand Up @@ -56,4 +56,6 @@ class NDArrayIter(data: NDArray, label: NDArray = null,
override def provideLabel: Map[String, Shape] = ???

override def hasNext: Boolean = ???

override def batchSize: Int = dataBatchSize
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,6 @@ class PrefetchingIter(val iters: List[DataIter],
override def provideData: Map[String, Shape] = ???

override def hasNext: Boolean = ???

override def batchSize: Int = ???
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ object NativeLibraryLoader {
*/
private def unifyOSName(osname: String): String = {
if (osname.startsWith("Windows")) {
return "Windows"
"Windows"
}
osname
}
Expand Down
99 changes: 47 additions & 52 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
// test_loop
mnistIter.reset()
batchCount = 0
var batch = mnistIter.next()
while (batch != null) {
while (mnistIter.hasNext) {
mnistIter.next()
batchCount += 1
batch = mnistIter.next()
}
// test loop
assert(nBatch === batchCount)
Expand All @@ -66,54 +65,50 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
/**
* default skip this test for saving time
*/
// test("test ImageRecordIter") {
// // get data
// "./scripts/get_cifar_data.sh" !
//
// val params = Map(
// "path_imgrec" -> "data/cifar/train.rec",
// "mean_img" -> "data/cifar/cifar10_mean.bin",
// "rand_crop" -> "False",
// "and_mirror" -> "False",
// "shuffle" -> "False",
// "data_shape" -> "(3,28,28)",
// "batch_size" -> "100",
// "preprocess_threads" -> "4",
// "prefetch_buffer" -> "1"
// )
// val imgRecIter = IO.ImageRecordIter(params)
// val nBatch = 500
// var batchCount = 0
// // test provideData
// val provideData = imgRecIter.provideData
// val provideLabel = imgRecIter.provideLabel
// assert(provideData("data") === Array(100, 3, 28, 28))
// assert(provideLabel("label") === Array(100))
//
// imgRecIter.reset()
// while(imgRecIter.hasNext()) {
// val batch = imgRecIter.next()
// batchCount += 1
// }
// // test loop
// assert(batchCount === nBatch)
// // test reset
// imgRecIter.reset()
// imgRecIter.next()
// val label0 = imgRecIter.getLabel().head.toArray
// val data0 = imgRecIter.getData().head.toArray
// imgRecIter.reset()
// imgRecIter.reset()
// imgRecIter.reset()
// imgRecIter.reset()
// imgRecIter.reset()
// val label1 = imgRecIter.getLabel().head.toArray
// val data1 = imgRecIter.getData().head.toArray
// assert(label0 === label1)
// assert(data0 === data1)
// }
test("test ImageRecordIter") {
// get data
"./scripts/get_cifar_data.sh" !

val params = Map(
"path_imgrec" -> "data/cifar/train.rec",
"mean_img" -> "data/cifar/cifar10_mean.bin",
"rand_crop" -> "False",
"and_mirror" -> "False",
"shuffle" -> "False",
"data_shape" -> "(3,28,28)",
"batch_size" -> "100",
"preprocess_threads" -> "4",
"prefetch_buffer" -> "1"
)
val imgRecIter = IO.ImageRecordIter(params)
val nBatch = 500
var batchCount = 0
// test provideData
val provideData = imgRecIter.provideData
val provideLabel = imgRecIter.provideLabel
assert(provideData("data").toArray === Array(100, 3, 28, 28))
assert(provideLabel("label").toArray === Array(100))

// test("test NDarryIter") {
//
// }
imgRecIter.reset()
while (imgRecIter.hasNext) {
imgRecIter.next()
batchCount += 1
}
// test loop
assert(batchCount === nBatch)
// test reset
imgRecIter.reset()
imgRecIter.next()
val label0 = imgRecIter.getLabel().head.toArray
val data0 = imgRecIter.getData().head.toArray
imgRecIter.reset()
imgRecIter.reset()
imgRecIter.reset()
imgRecIter.reset()
imgRecIter.reset()
val label1 = imgRecIter.getLabel().head.toArray
val data1 = imgRecIter.getData().head.toArray
assert(label0 === label1)
assert(data0 === data1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,9 @@ class ConvSuite extends FunSuite with BeforeAndAfterAll {

valDataIter.reset()
val labels = ListBuffer.empty[NDArray]
var evalData = valDataIter.next()
while (evalData != null) {
while (valDataIter.hasNext) {
val evalData = valDataIter.next()
labels += evalData.label(0).copy()
evalData = valDataIter.next()
}
val y = NDArray.concatenate(labels)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ object ModelTrain {
network: Symbol, dataLoader: (String, Int, KVStore) => (DataIter, DataIter),
kvStore: String, numEpochs: Int, modelPrefix: String = null, loadEpoch: Int = -1,
lr: Float = 0.1f, lrFactor: Float = 1f, lrFactorEpoch: Float = 1f,
clipGradient: Float = 0f): Unit = {
clipGradient: Float = 0f, monitorSize: Int = -1): Unit = {
// kvstore
// TODO: if local mode and no gpu is used, set kv = null
val kv = KVStore.create(kvStore)
Expand Down Expand Up @@ -71,6 +71,9 @@ object ModelTrain {
auxParams = auxParams,
beginEpoch = beginEpoch,
epochSize = epochSize)
if (monitorSize > 0) {
model.setMonitor(new Monitor(monitorSize))
}
model.fit(trainData = train,
evalData = validation,
evalMetric = new Accuracy(),
Expand Down
Loading

0 comments on commit 15864ff

Please sign in to comment.