Skip to content

Commit

Permalink
Merge pull request apache#28 from yanqingmen/scala
Browse files Browse the repository at this point in the history
small change for DataIter's Iterator implementation and add DataPack
  • Loading branch information
yzhliu committed Feb 6, 2016
2 parents 0b597c8 + b8c3ec1 commit 77e79c9
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 251 deletions.
262 changes: 26 additions & 236 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ml.dmlc.mxnet

import ml.dmlc.mxnet.Base._
import ml.dmlc.mxnet.io.{MXDataPack, MXDataIter}
import org.slf4j.LoggerFactory

import scala.collection.mutable.ListBuffer
Expand All @@ -11,6 +12,7 @@ import scala.collection.mutable.ListBuffer
*/
object IO {
type IterCreateFunc = (Map[String, String]) => DataIter
type PackCreateFunc = (Map[String, String]) => DataPack

private val logger = LoggerFactory.getLogger(classOf[DataIter])
private val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule()
Expand All @@ -19,6 +21,11 @@ object IO {
def ImageRecordIter: IterCreateFunc = iterCreateFuncs("ImageRecordIter")
def CSVIter: IterCreateFunc = iterCreateFuncs("CSVIter")

def MNISTPack: PackCreateFunc = createMXDataPack("MNISTIter")
def ImageRecodePack: PackCreateFunc = createMXDataPack("ImageRecordIter")
def CSVPack: PackCreateFunc = createMXDataPack("CSVIter")


/**
* create iterator via iterName and params
* @param iterName name of iterator; "MNISTIter" or "ImageRecordIter"
Expand All @@ -29,6 +36,16 @@ object IO {
iterCreateFuncs(iterName)(params)
}

/**
* create dataPack for iterator via itername and params
* @param iterName name of iterator: "MNISTIter" or "ImageRecordIter"
* @param params parameters for create iterator
* @return
*/
def createMXDataPack(iterName: String)(params: Map[String, String]): DataPack = {
new MXDataPack(iterName, params)
}

/**
* initi all IO creator Functions
* @return
Expand All @@ -53,7 +70,7 @@ object IO {
}

/**
*
* DataIter creator
* @param handle
* @param params
* @return
Expand Down Expand Up @@ -89,21 +106,16 @@ case class DataBatch(data: IndexedSeq[NDArray],
index: IndexedSeq[Long],
pad: Int)


/**
* DataIter object in mxnet.
*/
abstract class DataIter(val batchSize: Int = 0) {
abstract class DataIter(val batchSize: Int = 0) extends Iterator[DataBatch] {
/**
* reset the iterator
*/
def reset(): Unit

/**
* Iterate to next batch
* @return whether the move is successful
*/
def iterNext(): Boolean

/**
* get next data batch from iterator
* @return
Expand Down Expand Up @@ -145,236 +157,14 @@ abstract class DataIter(val batchSize: Int = 0) {
}

/**
* DataIter built in MXNet.
* @param handle the handle to the underlying C++ Data Iterator
* pack of DataIter, use as Iterable class
*/
// scalastyle:off finalize
class MXDataIter(private[mxnet] val handle: DataIterHandle,
private val dataName: String = "data",
private val labelName: String = "label") extends DataIter {
private val logger = LoggerFactory.getLogger(classOf[MXDataIter])

// load the first batch to get shape information
private var firstBatch: DataBatch = next()
private val data = firstBatch.data(0)
private val label = firstBatch.label(0)

// properties
val _provideData: Map[String, Shape] = Map(dataName -> data.shape)
val _provideLabel: Map[String, Shape] = Map(labelName -> label.shape)
override val batchSize = data.shape(0)

override def finalize(): Unit = {
checkCall(_LIB.mxDataIterFree(handle))
}

abstract class DataPack() extends Iterable[DataBatch] {
/**
* reset the iterator
*/
override def reset(): Unit = {
// TODO: self._debug_at_begin = True
firstBatch = null
checkCall(_LIB.mxDataIterBeforeFirst(handle))
}

override def next(): DataBatch = {
// TODO
// if self._debug_skip_load and not self._debug_at_begin:
// return DataBatch(data =[self.getdata()], label =[self.getlabel()],
// pad = self.getpad(), index = self.getindex())
if (firstBatch != null) {
val batch = firstBatch
firstBatch = null
batch
} else {
// self._debug_at_begin = False
val nextRes = new RefInt
checkCall(_LIB.mxDataIterNext(handle, nextRes))
if (nextRes.value > 0) {
new DataBatch(data = getData(), label = getLabel(), index = getIndex(), pad = getPad())
} else {
// TODO raise StopIteration
null
}
}
}

/**
* Iterate to next batch
* @return whether the move is successful
*/
override def iterNext(): Boolean = {
if (firstBatch != null) {
// FIXME: this implementation is confusing,
// if we call iterNext() continuously from the very beginning,
// it always returns true but never moves forward
true
} else {
val next = new RefInt
checkCall(_LIB.mxDataIterNext(handle, next))
next.value > 0
}
}

/**
* get data of current batch
* @return the data of current batch
*/
override def getData(): IndexedSeq[NDArray] = {
val out = new NDArrayHandleRef
checkCall(_LIB.mxDataIterGetData(handle, out))
IndexedSeq(new NDArray(out.value, writable = false))
}

/**
* Get label of current batch
* @return the label of current batch
*/
override def getLabel(): IndexedSeq[NDArray] = {
val out = new NDArrayHandleRef
checkCall(_LIB.mxDataIterGetLabel(handle, out))
IndexedSeq(new NDArray(out.value, writable = false))
}

/**
* the index of current batch
* @return
*/
override def getIndex(): IndexedSeq[Long] = {
val outIndex = new ListBuffer[Long]
val outSize = new RefLong
checkCall(_LIB.mxDataIterGetIndex(handle, outIndex, outSize))
outIndex.toIndexedSeq
}

/**
* get the number of padding examples
* in current batch
* @return number of padding examples in current batch
*/
override def getPad(): MXUint = {
val out = new MXUintRef
checkCall(_LIB.mxDataIterGetPadNum(handle, out))
out.value
}

// The name and shape of data provided by this iterator
override def provideData: Map[String, Shape] = _provideData

// The name and shape of label provided by this iterator
override def provideLabel: Map[String, Shape] = _provideLabel
* get data iterator
* @return DataIter
*/
def iterator: DataIter
}
// scalastyle:on finalize

/**
* Base class for prefetching iterators. Takes one or more DataIters
* (or any class with "reset" and "read" methods) and combine them with
* prefetching.
* @param iters list of DataIters
* @param dataNames
* @param labelNames
*/
class PrefetchingIter(val iters: List[DataIter],
val dataNames: Map[String, String] = null,
val labelNames: Map[String, String] = null) extends DataIter {
/**
* reset the iterator
*/
override def reset(): Unit = ???

/**
* get data of current batch
* @return the data of current batch
*/
override def getData(): IndexedSeq[NDArray] = ???

/**
* Get label of current batch
* @return the label of current batch
*/
override def getLabel(): IndexedSeq[NDArray] = ???

/**
* the index of current batch
* @return
*/
override def getIndex(): IndexedSeq[Long] = ???

/**
* Iterate to next batch
* @return whether the move is successful
*/
override def iterNext(): Boolean = ???

/**
* get the number of padding examples
* in current batch
* @return number of padding examples in current batch
*/
override def getPad(): Int = ???

// The name and shape of data provided by this iterator
override def provideData: Map[String, Shape] = ???

// The name and shape of label provided by this iterator
override def provideLabel: Map[String, Shape] = ???
}

/**
* TODO
* NDArrayIter object in mxnet. Taking NDArray or numpy array to get dataiter.
* @param data NDArrayIter supports single or multiple data and label.
* @param label Same as data, but is not fed to the model during testing.
* @param batchSize Batch Size
* @param shuffle Whether to shuffle the data
* @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the last batch
* @note
* This iterator will pad, discard or roll over the last batch if
* the size of data does not match batch_size. Roll over is intended
* for training and can cause problems if used for prediction.
*/
class NDArrayIter(data: NDArray, label: NDArray = null,
batchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad") extends DataIter(batchSize) {
/**
* reset the iterator
*/
override def reset(): Unit = ???

/**
* get data of current batch
* @return the data of current batch
*/
override def getData(): IndexedSeq[NDArray] = ???

/**
* Get label of current batch
* @return the label of current batch
*/
override def getLabel(): IndexedSeq[NDArray] = ???

/**
* the index of current batch
* @return
*/
override def getIndex(): IndexedSeq[Long] = ???

/**
* Iterate to next batch
* @return whether the move is successful
*/
override def iterNext(): Boolean = ???

/**
* get the number of padding examples
* in current batch
* @return number of padding examples in current batch
*/
override def getPad(): MXUint = ???

// The name and shape of data provided by this iterator
override def provideData: Map[String, Shape] = ???

// The name and shape of label provided by this iterator
override def provideLabel: Map[String, Shape] = ???
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ml.dmlc.mxnet

import ml.dmlc.mxnet.Base.Shape
import ml.dmlc.mxnet.io.NDArrayIter
import ml.dmlc.mxnet.optimizer.SGD
import org.slf4j.{Logger, LoggerFactory}

Expand Down
Loading

0 comments on commit 77e79c9

Please sign in to comment.