Skip to content

Commit

Permalink
Merge pull request apache#22 from javelinjs/scala-package-l
Browse files Browse the repository at this point in the history
NDArray finished
  • Loading branch information
terrytangyuan committed Jan 14, 2016
2 parents 968785b + 07f726d commit 0013f85
Show file tree
Hide file tree
Showing 11 changed files with 612 additions and 124 deletions.
9 changes: 4 additions & 5 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package ml.dmlc.mxnet

object Context {
val devtype2str = Map(1 -> "cpu", 2 -> "gpu", 3 -> "cpu_pinned")
val devstr2type = Map("cpu" -> 1, "gpu" -> 2, "cpu_pinned" -> 3)
val defaultCtx = new Context("cpu", 0)
}

Expand All @@ -10,10 +12,7 @@ object Context {
* @param deviceId (default=0) The device id of the device, needed for GPU
*/
class Context(deviceTypeName: String, val deviceId: Int = 0) {
private val devtype2str = Map(1 -> "cpu", 2 -> "gpu", 3 -> "cpu_pinned")
private val devstr2type = Map("cpu" -> 1, "gpu" -> 2, "cpu_pinned" -> 3)

val deviceTypeid: Int = devstr2type(deviceTypeName)
val deviceTypeid: Int = Context.devstr2type(deviceTypeName)

def this(context: Context) = {
this(context.deviceType, context.deviceId)
Expand All @@ -23,5 +22,5 @@ class Context(deviceTypeName: String, val deviceId: Int = 0) {
* Return device type of current context.
* @return device_type
*/
def deviceType: String = devtype2str(deviceTypeid)
def deviceType: String = Context.devtype2str(deviceTypeid)
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package ml.dmlc.mxnet

import ml.dmlc.mxnet.NDArray.array


/**
*
* Base class for Initializer.
Expand Down Expand Up @@ -46,13 +43,13 @@ abstract class Initializer {
val f = shape(3) / 2.0f
val c = (2 * f - 1 - f % 2) / (2.0f * f)

(0 to (arr.size)).foreach { i =>
(0 to arr.size).foreach { i =>
val x = i % shape(3)
val y = (i / shape(3)) % shape(2)
weight(i) = (1 - math.abs(x / f - c)) * (1 - math.abs(y / f - c))
}

arr.set(array(weight))
arr.set(NDArray.array(weight, shape))
}

def _initZero(name: String, arr: NDArray): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class KVStore(private val handle: KVStoreHandle) {
*/
def setUpdater(updater: MXKVStoreUpdater): Unit = {
this.updaterFunc = updater
checkCall(_LIB.mxKVStoreSetUpdater(handle, updaterFunc, null))
checkCall(_LIB.mxKVStoreSetUpdater(handle, updaterFunc))
}

/**
Expand Down
16 changes: 13 additions & 3 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
* @author Yizhi Liu
*/
class LibInfo {
// NDArray
@native def mxNDArrayFree(handle: NDArrayHandle): Int
@native def mxGetLastError(): String
@native def mxNDArrayCreateNone(out: NDArrayHandleRef): Int
Expand All @@ -19,6 +20,7 @@ class LibInfo {
delayAlloc: Int,
out: NDArrayHandleRef): Int
@native def mxNDArrayWaitAll(): Int
@native def mxNDArrayWaitToRead(handle: NDArrayHandle): Int
@native def mxListFunctions(functions: ListBuffer[FunctionHandle]): Int
@native def mxFuncDescribe(handle: FunctionHandle,
nUsedVars: MXUintRef,
Expand Down Expand Up @@ -49,6 +51,16 @@ class LibInfo {
@native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle,
source: Array[MXFloat],
size: Int): Int
@native def mxNDArrayLoad(fname: String,
outSize: MXUintRef,
handles: ArrayBuffer[NDArrayHandle],
outNameSize: MXUintRef,
names: ArrayBuffer[String]): Int
@native def mxNDArraySave(fname: String,
handles: Array[NDArrayHandle],
keys: Array[String]): Int
@native def mxNDArrayGetContext(handle: NDArrayHandle, devTypeId: RefInt, devId: RefInt): Int
// KVStore
@native def mxKVStoreCreate(name: String, handle: KVStoreHandleRef): Int
@native def mxKVStoreInit(handle: KVStoreHandle,
len: MXUint,
Expand All @@ -64,9 +76,7 @@ class LibInfo {
keys: Array[Int],
outs: Array[NDArrayHandle],
priority: Int): Int
@native def mxKVStoreSetUpdater(handle: KVStoreHandle,
updaterFunc: MXKVStoreUpdater,
updaterHandle: AnyRef): Int
@native def mxKVStoreSetUpdater(handle: KVStoreHandle, updaterFunc: MXKVStoreUpdater): Int
@native def mxKVStoreIsWorkerNode(isWorker: RefInt): Int
@native def mxKVStoreGetType(handle: KVStoreHandle, kvType: RefString): Int
@native def mxKVStoreSendCommmandToServers(handle: KVStoreHandle,
Expand Down
Loading

0 comments on commit 0013f85

Please sign in to comment.