diff --git a/scala-package/examples/scripts/run_train_mnist.sh b/scala-package/examples/scripts/run_train_mnist.sh index ea53c1ade66f..dbf6ae854192 100755 --- a/scala-package/examples/scripts/run_train_mnist.sh +++ b/scala-package/examples/scripts/run_train_mnist.sh @@ -19,15 +19,31 @@ set -e +hw_type=cpu +if [[ $1 = gpu ]] +then + hw_type=gpu +fi + +platform=linux-x86_64 + +if [[ $OSTYPE = [darwin]* ]] +then + platform=osx-x86_64 + hw_type=cpu +fi + MXNET_ROOT=$(cd "$(dirname $0)/../../.."; pwd) echo $MXNET_ROOT -CLASS_PATH=$MXNET_ROOT/scala-package/assembly/linux-x86_64-cpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/* +CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/* # model dir DATA_PATH=$2 -java -XX:+PrintGC -Xms256M -Xmx512M -Dmxnet.traceLeakedObjects=false -cp $CLASS_PATH \ - org.apache.mxnetexamples.imclassification.TrainMnist \ - --data-dir /home/ubuntu/mxnet_scala/scala-package/examples/mnist/ \ +java -XX:+PrintGC -Dmxnet.traceLeakedObjects=false -cp $CLASS_PATH \ + org.apache.mxnetexamples.imclassification.TrainModel \ + --data-dir $MXNET_ROOT/scala-package/examples/mnist/ \ + --network mlp \ + --num-layers 50 \ --num-epochs 10000000 \ --batch-size 1024 \ No newline at end of file