diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 13f85a731dc4..60efd2ba62bd 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -830,6 +830,8 @@ object Symbol {
private val functions: Map[String, SymbolFunction] = initSymbolModule()
private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3)
+ val api = SymbolAPI
+
def pow(sym1: Symbol, sym2: Symbol): Symbol = {
Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2))
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
new file mode 100644
index 000000000000..49de9ae73218
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet
+
+
+@AddSymbolAPIs(false)
+/**
+ * typesafe Symbol API: Symbol.api._
+ * Main code will be generated during compile time through Macros
+ */
+object SymbolAPI {
+}
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
index d1ec88d67c6b..e9171bd47c28 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
@@ -30,40 +30,40 @@ object TrainMnist {
// multi-layer perceptron
def getMlp: Symbol = {
val data = Symbol.Variable("data")
- val fc1 = Symbol.FullyConnected(name = "fc1")()(Map("data" -> data, "num_hidden" -> 128))
- val act1 = Symbol.Activation(name = "relu1")()(Map("data" -> fc1, "act_type" -> "relu"))
- val fc2 = Symbol.FullyConnected(name = "fc2")()(Map("data" -> act1, "num_hidden" -> 64))
- val act2 = Symbol.Activation(name = "relu2")()(Map("data" -> fc2, "act_type" -> "relu"))
- val fc3 = Symbol.FullyConnected(name = "fc3")()(Map("data" -> act2, "num_hidden" -> 10))
- val mlp = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc3))
+
+ val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128, name = "fc1")
+ val act1 = Symbol.api.Activation (data = Some(fc1), "relu", name = "relu")
+ val fc2 = Symbol.api.FullyConnected(Some(act1), None, None, 64, name = "fc2")
+ val act2 = Symbol.api.Activation(data = Some(fc2), "relu", name = "relu2")
+ val fc3 = Symbol.api.FullyConnected(Some(act2), None, None, 10, name = "fc3")
+ val mlp = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc3))
mlp
}
// LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick
// Haffner. "Gradient-based learning applied to document recognition."
// Proceedings of the IEEE (1998)
+
def getLenet: Symbol = {
val data = Symbol.Variable("data")
// first conv
- val conv1 = Symbol.Convolution()()(
- Map("data" -> data, "kernel" -> "(5, 5)", "num_filter" -> 20))
- val tanh1 = Symbol.Activation()()(Map("data" -> conv1, "act_type" -> "tanh"))
- val pool1 = Symbol.Pooling()()(Map("data" -> tanh1, "pool_type" -> "max",
- "kernel" -> "(2, 2)", "stride" -> "(2, 2)"))
+ val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 5), num_filter = 20)
+ val tanh1 = Symbol.api.tanh(data = Some(conv1))
+ val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"),
+ kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
// second conv
- val conv2 = Symbol.Convolution()()(
- Map("data" -> pool1, "kernel" -> "(5, 5)", "num_filter" -> 50))
- val tanh2 = Symbol.Activation()()(Map("data" -> conv2, "act_type" -> "tanh"))
- val pool2 = Symbol.Pooling()()(Map("data" -> tanh2, "pool_type" -> "max",
- "kernel" -> "(2, 2)", "stride" -> "(2, 2)"))
+ val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5, 5), num_filter = 50)
+ val tanh2 = Symbol.api.tanh(data = Some(conv2))
+ val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"),
+ kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
// first fullc
- val flatten = Symbol.Flatten()()(Map("data" -> pool2))
- val fc1 = Symbol.FullyConnected()()(Map("data" -> flatten, "num_hidden" -> 500))
- val tanh3 = Symbol.Activation()()(Map("data" -> fc1, "act_type" -> "tanh"))
+ val flatten = Symbol.api.Flatten(data = Some(pool2))
+ val fc1 = Symbol.api.FullyConnected(data = Some(flatten), num_hidden = 500)
+ val tanh3 = Symbol.api.tanh(data = Some(fc1))
// second fullc
- val fc2 = Symbol.FullyConnected()()(Map("data" -> tanh3, "num_hidden" -> 10))
+ val fc2 = Symbol.api.FullyConnected(data = Some(tanh3), num_hidden = 10)
// loss
- val lenet = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc2))
+ val lenet = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc2))
lenet
}
diff --git a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala
index 7af2e052255c..7402dbd3bc1d 100644
--- a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala
+++ b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala
@@ -37,7 +37,12 @@ object Base {
@throws(classOf[UnsatisfiedLinkError])
private def tryLoadInitLibrary(): Unit = {
- val baseDir = System.getProperty("user.dir") + "/init-native"
+ var baseDir = System.getProperty("user.dir") + "/init-native"
+ // TODO(lanKing520) Update this to use relative path to the MXNet director.
+ // TODO(lanking520) baseDir = sys.env("MXNET_BASEDIR") + "/scala-package/init-native"
+ if (System.getenv().containsKey("MXNET_BASEDIR")) {
+ baseDir = sys.env("MXNET_BASEDIR")
+ }
val os = System.getProperty("os.name")
// ref: http://lopica.sourceforge.net/os.html
if (os.startsWith("Linux")) {
diff --git a/scala-package/macros/pom.xml b/scala-package/macros/pom.xml
index 0aa3030e7ce3..59cc181bd360 100644
--- a/scala-package/macros/pom.xml
+++ b/scala-package/macros/pom.xml
@@ -52,4 +52,42 @@
${libtype}
+
+
+
+
+ org.apache.maven.plugins
+ maven-jar-plugin
+
+
+ META-INF/*.SF
+ META-INF/*.DSA
+ META-INF/*.RSA
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+
+ org.scalatest
+ scalatest-maven-plugin
+
+
+ ${project.parent.basedir}/init-native
+
+
+ -Djava.library.path=${project.parent.basedir}/native/${platform}/target \
+ -Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties
+
+
+
+
+ org.scalastyle
+ scalastyle-maven-plugin
+
+
+
+
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
index b6ddaafc7ad7..234a8604cb91 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
@@ -21,7 +21,6 @@ import scala.annotation.StaticAnnotation
import scala.collection.mutable.ListBuffer
import scala.language.experimental.macros
import scala.reflect.macros.blackbox
-
import org.apache.mxnet.init.Base._
import org.apache.mxnet.utils.OperatorBuildUtils
@@ -29,18 +28,29 @@ private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends StaticAnnota
private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addDefs
}
+private[mxnet] class AddSymbolAPIs(isContrib: Boolean) extends StaticAnnotation {
+ private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.typeSafeAPIDefs
+}
+
private[mxnet] object SymbolImplMacros {
- case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String)
+ case class SymbolArg(argName: String, argType: String, isOptional : Boolean)
+ case class SymbolFunction(name: String, listOfArgs: List[SymbolArg])
// scalastyle:off havetype
def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
- impl(c)(false, annottees: _*)
+ impl(c)(annottees: _*)
}
- // scalastyle:off havetype
+ def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
+ newAPIImpl(c)(annottees: _*)
+ }
+ // scalastyle:on havetype
- private val symbolFunctions: Map[String, SymbolFunction] = initSymbolModule()
+ private val symbolFunctions: List[SymbolFunction] = initSymbolModule()
- private def impl(c: blackbox.Context)(addSuper: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = {
+ /**
+ * Implementation for fixed input API structure
+ */
+ private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._
val isContrib: Boolean = c.prefix.tree match {
@@ -48,74 +58,106 @@ private[mxnet] object SymbolImplMacros {
}
val newSymbolFunctions = {
- if (isContrib) symbolFunctions.filter(_._1.startsWith("_contrib_"))
- else symbolFunctions.filter(!_._1.startsWith("_contrib_"))
+ if (isContrib) symbolFunctions.filter(
+ func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
+ else symbolFunctions.filter(!_.name.startsWith("_"))
}
- val AST_TYPE_MAP_STRING_ANY = AppliedTypeTree(Ident(TypeName("Map")),
- List(Ident(TypeName("String")), Ident(TypeName("Any"))))
- val AST_TYPE_MAP_STRING_STRING = AppliedTypeTree(Ident(TypeName("Map")),
- List(Ident(TypeName("String")), Ident(TypeName("String"))))
- val AST_TYPE_SYMBOL_VARARG = AppliedTypeTree(
- Select(
- Select(Ident(termNames.ROOTPKG), TermName("scala")),
- TypeName("")
- ),
- List(Select(Select(Select(
- Ident(TermName("org")), TermName("apache")), TermName("mxnet")), TypeName("Symbol")))
- )
-
- val functionDefs = newSymbolFunctions map { case (funcName, funcProp) =>
- val functionScope = {
- if (isContrib) Modifiers()
- else {
- if (funcName.startsWith("_")) Modifiers(Flag.PRIVATE) else Modifiers()
- }
- }
- val newName = {
- if (isContrib) funcName.substring(funcName.indexOf("_contrib_") + "_contrib_".length())
- else funcName
+
+ val functionDefs = newSymbolFunctions map { symbolfunction =>
+ val funcName = symbolfunction.name
+ val tName = TermName(funcName)
+ q"""
+ def $tName(name : String = null, attr : Map[String, String] = null)
+ (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null)
+ : org.apache.mxnet.Symbol = {
+ createSymbolGeneral($funcName,name,attr,args,kwargs)
+ }
+ """.asInstanceOf[DefDef]
}
- // It will generate definition something like,
- // def Concat(name: String = null, attr: Map[String, String] = null)
- // (args: Symbol*)(kwargs: Map[String, Any] = null)
- DefDef(functionScope, TermName(newName), List(),
- List(
- List(
- ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("name"),
- Ident(TypeName("String")), Literal(Constant(null))),
- ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("attr"),
- AST_TYPE_MAP_STRING_STRING, Literal(Constant(null)))
- ),
- List(
- ValDef(Modifiers(), TermName("args"), AST_TYPE_SYMBOL_VARARG, EmptyTree)
- ),
- List(
- ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("kwargs"),
- AST_TYPE_MAP_STRING_ANY, Literal(Constant(null)))
- )
- ), TypeTree(),
- Apply(
- Ident(TermName("createSymbolGeneral")),
- List(
- Literal(Constant(funcName)),
- Ident(TermName("name")),
- Ident(TermName("attr")),
- Ident(TermName("args")),
- Ident(TermName("kwargs"))
- )
- )
- )
+ structGeneration(c)(functionDefs, annottees : _*)
+ }
+
+ /**
+ * Implementation for Dynamic typed API Symbol.api.
+ */
+ private def newAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = {
+ import c.universe._
+
+ val isContrib: Boolean = c.prefix.tree match {
+ case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b))
+ }
+
+ // TODO: Put Symbol.api.foo --> Stable APIs
+ // Symbol.contrib.bar--> Contrib APIs
+ val newSymbolFunctions = {
+ if (isContrib) symbolFunctions.filter(
+ func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
+ else symbolFunctions.filter(!_.name.startsWith("_"))
+ }
+
+ val functionDefs = newSymbolFunctions map { symbolfunction =>
+
+ // Construct argument field
+ var argDef = ListBuffer[String]()
+ // Construct Implementation field
+ var impl = ListBuffer[String]()
+ impl += "val map = scala.collection.mutable.Map[String, Any]()"
+ symbolfunction.listOfArgs.foreach({ symbolarg =>
+ // var is a special word used to define variable in Scala,
+ // need to changed to something else in order to make it work
+ val currArgName = symbolarg.argName match {
+ case "var" => "vari"
+ case "type" => "typeOf"
+ case default => symbolarg.argName
+ }
+ if (symbolarg.isOptional) {
+ argDef += s"${currArgName} : Option[${symbolarg.argType}] = None"
+ }
+ else {
+ argDef += s"${currArgName} : ${symbolarg.argType}"
+ }
+ var base = "map(\"" + symbolarg.argName + "\") = " + currArgName
+ if (symbolarg.isOptional) {
+ base = "if (!" + currArgName + ".isEmpty)" + base + ".get"
+ }
+ impl += base
+ })
+ argDef += "name : String = null"
+ argDef += "attr : Map[String, String] = null"
+ // scalastyle:off
+ // TODO: Seq() here allows user to place Symbols rather than normal arguments to run, need to fix if old API deprecated
+ impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)"
+ // scalastyle:on
+ // Combine and build the function string
+ val returnType = "org.apache.mxnet.Symbol"
+ var finalStr = s"def ${symbolfunction.name}"
+ finalStr += s" (${argDef.mkString(",")}) : $returnType"
+ finalStr += s" = {${impl.mkString("\n")}}"
+ c.parse(finalStr).asInstanceOf[DefDef]
}
+ structGeneration(c)(functionDefs, annottees : _*)
+ }
+ /**
+ * Generate class structure for all function APIs
+ * @param c
+ * @param funcDef DefDef type of function definitions
+ * @param annottees
+ * @return
+ */
+ private def structGeneration(c: blackbox.Context)
+ (funcDef : List[c.universe.DefDef], annottees: c.Expr[Any]*)
+ : c.Expr[Any] = {
+ import c.universe._
val inputs = annottees.map(_.tree).toList
// pattern match on the inputs
val modDefs = inputs map {
case ClassDef(mods, name, something, template) =>
val q = template match {
case Template(superMaybe, emptyValDef, defs) =>
- Template(superMaybe, emptyValDef, defs ++ functionDefs)
+ Template(superMaybe, emptyValDef, defs ++ funcDef)
case ex =>
throw new IllegalArgumentException(s"Invalid template: $ex")
}
@@ -123,7 +165,7 @@ private[mxnet] object SymbolImplMacros {
case ModuleDef(mods, name, template) =>
val q = template match {
case Template(superMaybe, emptyValDef, defs) =>
- Template(superMaybe, emptyValDef, defs ++ functionDefs)
+ Template(superMaybe, emptyValDef, defs ++ funcDef)
case ex =>
throw new IllegalArgumentException(s"Invalid template: $ex")
}
@@ -136,20 +178,80 @@ private[mxnet] object SymbolImplMacros {
result
}
+ // Convert C++ Types to Scala Types
+ def typeConversion(in : String, argType : String = "") : String = {
+ in match {
+ case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape"
+ case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "org.apache.mxnet.Symbol"
+ case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]"
+ => "Array[org.apache.mxnet.Symbol]"
+ case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat"
+ case "int" | "intorNone" | "int(non-negative)" => "Int"
+ case "long" | "long(non-negative)" => "Long"
+ case "double" | "doubleorNone" => "Double"
+ case "string" => "String"
+ case "boolean" => "Boolean"
+ case "tupleof" | "tupleof" | "ptr" | "" => "Any"
+ case default => throw new IllegalArgumentException(
+ s"Invalid type for args: $default, $argType")
+ }
+ }
+
+
+ /**
+ * By default, the argType come from the C++ API is a description more than a single word
+ * For Example:
+ * , ,
+ * The three field shown above do not usually come at the same time
+ * This function used the above format to determine if the argument is
+ * optional, what is it Scala type and possibly pass in a default value
+ * @param argType Raw arguement Type description
+ * @return (Scala_Type, isOptional)
+ */
+ def argumentCleaner(argType : String) : (String, Boolean) = {
+ val spaceRemoved = argType.replaceAll("\\s+", "")
+ var commaRemoved : Array[String] = new Array[String](0)
+ // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
+ if (spaceRemoved.charAt(0)== '{') {
+ val endIdx = spaceRemoved.indexOf('}')
+ commaRemoved = spaceRemoved.substring(endIdx + 1).split(",")
+ commaRemoved(0) = "string"
+ } else {
+ commaRemoved = spaceRemoved.split(",")
+ }
+ // Optional Field
+ if (commaRemoved.length >= 3) {
+ // arg: Type, optional, default = Null
+ require(commaRemoved(1).equals("optional"))
+ require(commaRemoved(2).startsWith("default="))
+ (typeConversion(commaRemoved(0), argType), true)
+ } else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
+ val tempType = typeConversion(commaRemoved(0), argType)
+ val tempOptional = tempType.equals("org.apache.mxnet.Symbol")
+ (tempType, tempOptional)
+ } else {
+ throw new IllegalArgumentException(
+ s"Unrecognized arg field: $argType, ${commaRemoved.length}")
+ }
+
+ }
+
+
// List and add all the atomic symbol functions to current module.
- private def initSymbolModule(): Map[String, SymbolFunction] = {
+ private def initSymbolModule(): List[SymbolFunction] = {
val opNames = ListBuffer.empty[String]
_LIB.mxListAllOpNames(opNames)
+ // TODO: Add '_linalg_', '_sparse_', '_image_' support
opNames.map(opName => {
val opHandle = new RefLong
_LIB.nnGetOpHandle(opName, opHandle)
makeAtomicSymbolFunction(opHandle.value, opName)
- }).toMap
+ }).toList
}
// Create an atomic symbol function by handle and function name.
private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String)
- : (String, SymbolFunction) = {
+ : SymbolFunction = {
val name = new RefString
val desc = new RefString
val keyVarNumArgs = new RefString
@@ -174,6 +276,10 @@ private[mxnet] object SymbolImplMacros {
println("Symbol function definition:\n" + docStr)
}
// scalastyle:on println
- (aliasName, new SymbolFunction(handle, keyVarNumArgs.value))
+ val argList = argNames zip argTypes map { case (argName, argType) =>
+ val typeAndOption = argumentCleaner(argType)
+ new SymbolArg(argName, typeAndOption._1, typeAndOption._2)
+ }
+ new SymbolFunction(aliasName, argList.toList)
}
}
diff --git a/scala-package/macros/src/test/resources/log4j.properties b/scala-package/macros/src/test/resources/log4j.properties
new file mode 100644
index 000000000000..d82fd7ea4f3d
--- /dev/null
+++ b/scala-package/macros/src/test/resources/log4j.properties
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# for development debugging
+log4j.rootLogger = debug, stdout
+
+log4j.appender.stdout = org.apache.log4j.ConsoleAppender
+log4j.appender.stdout.Target = System.out
+log4j.appender.stdout.layout = org.apache.log4j.PatternLayout
+log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SSS} [%t] [%c] [%p] - %m%n
diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
new file mode 100644
index 000000000000..bc8be7df5fb1
--- /dev/null
+++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.slf4j.LoggerFactory
+
+class MacrosSuite extends FunSuite with BeforeAndAfterAll {
+
+ private val logger = LoggerFactory.getLogger(classOf[MacrosSuite])
+
+
+ test("MacrosSuite-testArgumentCleaner") {
+ val input = List(
+ "Symbol, optional, default = Null",
+ "int, required",
+ "Shape(tuple), optional, default = []",
+ "{'csr', 'default', 'row_sparse'}, optional, default = 'csr'",
+ ", required"
+ )
+ val output = List(
+ ("org.apache.mxnet.Symbol", true),
+ ("Int", false),
+ ("org.apache.mxnet.Shape", true),
+ ("String", true),
+ ("Any", false)
+ )
+
+ for (idx <- input.indices) {
+ val result = SymbolImplMacros.argumentCleaner(input(idx))
+ assert(result._1 === output(idx)._1 && result._2 === output(idx)._2)
+ }
+ }
+
+}