Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add two more test and README
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jul 9, 2018
1 parent ef4d352 commit 776ebe8
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 211 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,104 +155,117 @@ object NeuralStyle {
Math.sqrt(array.map(x => x * x).sum.toDouble).toFloat
}

def main(args: Array[String]): Unit = {
val alle = new NeuralStyle
val parser: CmdLineParser = new CmdLineParser(alle)
try {
parser.parseArgument(args.toList.asJava)
assert(alle.contentImage != null && alle.styleImage != null
&& alle.modelPath != null && alle.outputDir != null)
//scalastyle:off
def runTraining(model : String, contentImage : String, styleImage: String, dev : Context,
modelPath : String, outputDir : String, styleWeight : Float,
contentWeight : Float, tvWeight : Float, gaussianRadius : Int,
lr: Float, maxNumEpochs: Int, maxLongEdge: Int,
saveEpochs : Int, stopEps: Float) : Unit = {

val contentNp = preprocessContentImage(contentImage, maxLongEdge, dev)
val styleNp = preprocessStyleImage(styleImage, contentNp.shape, dev)
val size = (contentNp.shape(2), contentNp.shape(3))

val (style, content) = ModelVgg19.getSymbol
val (gram, gScale) = styleGramSymbol(size, style)
var modelExecutor = ModelVgg19.getExecutor(gram, content, modelPath, size, dev)

modelExecutor.data.set(styleNp)
modelExecutor.executor.forward()

val styleArray = modelExecutor.style.map(_.copyTo(Context.cpu()))
modelExecutor.data.set(contentNp)
modelExecutor.executor.forward()
val contentArray = modelExecutor.content.copyTo(Context.cpu())

// delete the executor
modelExecutor = null

val (styleLoss, contentLoss) = getLoss(gram, content)
modelExecutor = ModelVgg19.getExecutor(
styleLoss, contentLoss, modelPath, size, dev)

val gradArray = {
var tmpGA = Array[NDArray]()
for (i <- 0 until styleArray.length) {
modelExecutor.argDict(s"target_gram_$i").set(styleArray(i))
tmpGA = tmpGA :+ NDArray.ones(Shape(1), dev) * (styleWeight / gScale(i))
}
tmpGA :+ NDArray.ones(Shape(1), dev) * contentWeight
}

val dev = if (alle.gpu >= 0) Context.gpu(alle.gpu) else Context.cpu(0)
val contentNp = preprocessContentImage(alle.contentImage, alle.maxLongEdge, dev)
val styleNp = preprocessStyleImage(alle.styleImage, contentNp.shape, dev)
val size = (contentNp.shape(2), contentNp.shape(3))
modelExecutor.argDict("target_content").set(contentArray)

val (style, content) = ModelVgg19.getSymbol
val (gram, gScale) = styleGramSymbol(size, style)
var modelExecutor = ModelVgg19.getExecutor(gram, content, alle.modelPath, size, dev)
// train
val img = Random.uniform(-0.1f, 0.1f, contentNp.shape, dev)
val lrFS = new FactorScheduler(step = 10, factor = 0.9f)

modelExecutor.data.set(styleNp)
modelExecutor.executor.forward()
saveImage(contentNp, s"${outputDir}/input.jpg", gaussianRadius)
saveImage(styleNp, s"${outputDir}/style.jpg", gaussianRadius)

val styleArray = modelExecutor.style.map(_.copyTo(Context.cpu()))
modelExecutor.data.set(contentNp)
modelExecutor.executor.forward()
val contentArray = modelExecutor.content.copyTo(Context.cpu())
val optimizer = new Adam(
learningRate = lr,
wd = 0.005f,
lrScheduler = lrFS)
val optimState = optimizer.createState(0, img)

// delete the executor
modelExecutor = null
logger.info(s"start training arguments")

val (styleLoss, contentLoss) = getLoss(gram, content)
modelExecutor = ModelVgg19.getExecutor(
styleLoss, contentLoss, alle.modelPath, size, dev)
var oldImg = img.copyTo(dev)
val clipNorm = img.shape.toVector.reduce(_ * _)
val tvGradExecutor = getTvGradExecutor(img, dev, tvWeight)
var eps = 0f
var trainingDone = false
var e = 0
while (e < maxNumEpochs && !trainingDone) {
modelExecutor.data.set(img)
modelExecutor.executor.forward()
modelExecutor.executor.backward(gradArray)

val gradArray = {
var tmpGA = Array[NDArray]()
for (i <- 0 until styleArray.length) {
modelExecutor.argDict(s"target_gram_$i").set(styleArray(i))
tmpGA = tmpGA :+ NDArray.ones(Shape(1), dev) * (alle.styleWeight / gScale(i))
}
tmpGA :+ NDArray.ones(Shape(1), dev) * alle.contentWeight
val gNorm = NDArray.norm(modelExecutor.dataGrad).toScalar
if (gNorm > clipNorm) {
modelExecutor.dataGrad.set(modelExecutor.dataGrad * (clipNorm / gNorm))
}

modelExecutor.argDict("target_content").set(contentArray)

// train
val img = Random.uniform(-0.1f, 0.1f, contentNp.shape, dev)
val lr = new FactorScheduler(step = 10, factor = 0.9f)

saveImage(contentNp, s"${alle.outputDir}/input.jpg", alle.guassianRadius)
saveImage(styleNp, s"${alle.outputDir}/style.jpg", alle.guassianRadius)

val optimizer = new Adam(
learningRate = alle.lr,
wd = 0.005f,
lrScheduler = lr)
val optimState = optimizer.createState(0, img)

logger.info(s"start training arguments $alle")

var oldImg = img.copyTo(dev)
val clipNorm = img.shape.toVector.reduce(_ * _)
val tvGradExecutor = getTvGradExecutor(img, dev, alle.tvWeight)
var eps = 0f
var trainingDone = false
var e = 0
while (e < alle.maxNumEpochs && !trainingDone) {
modelExecutor.data.set(img)
modelExecutor.executor.forward()
modelExecutor.executor.backward(gradArray)

val gNorm = NDArray.norm(modelExecutor.dataGrad).toScalar
if (gNorm > clipNorm) {
modelExecutor.dataGrad.set(modelExecutor.dataGrad * (clipNorm / gNorm))
}
tvGradExecutor match {
case Some(executor) => {
executor.forward()
optimizer.update(0, img,
modelExecutor.dataGrad + executor.outputs(0),
optimState)
}
case None =>
optimizer.update(0, img, modelExecutor.dataGrad, optimState)
tvGradExecutor match {
case Some(executor) => {
executor.forward()
optimizer.update(0, img,
modelExecutor.dataGrad + executor.outputs(0),
optimState)
}
eps = (NDArray.norm(oldImg - img) / NDArray.norm(img)).toScalar
oldImg.set(img)
logger.info(s"epoch $e, relative change $eps")
case None =>
optimizer.update(0, img, modelExecutor.dataGrad, optimState)
}
eps = (NDArray.norm(oldImg - img) / NDArray.norm(img)).toScalar
oldImg.set(img)
logger.info(s"epoch $e, relative change $eps")

if (eps < alle.stopEps) {
logger.info("eps < args.stop_eps, training finished")
trainingDone = true
}
if ((e + 1) % alle.saveEpochs == 0) {
saveImage(img, s"${alle.outputDir}/tmp_${e + 1}.jpg", alle.guassianRadius)
}
e = e + 1
if (eps < stopEps) {
logger.info("eps < args.stop_eps, training finished")
trainingDone = true
}
if ((e + 1) % saveEpochs == 0) {
saveImage(img, s"${outputDir}/tmp_${e + 1}.jpg", gaussianRadius)
}
saveImage(img, s"${alle.outputDir}/out.jpg", alle.guassianRadius)
logger.info("Finish fit ...")
e = e + 1
}
saveImage(img, s"${outputDir}/out.jpg", gaussianRadius)
logger.info("Finish fit ...")
}

def main(args: Array[String]): Unit = {
val alle = new NeuralStyle
val parser: CmdLineParser = new CmdLineParser(alle)
try {
parser.parseArgument(args.toList.asJava)
assert(alle.contentImage != null && alle.styleImage != null
&& alle.modelPath != null && alle.outputDir != null)

val dev = if (alle.gpu >= 0) Context.gpu(alle.gpu) else Context.cpu(0)
runTraining(alle.model, alle.contentImage, alle.styleImage, dev, alle.modelPath,
alle.outputDir, alle.styleWeight, alle.contentWeight, alle.tvWeight,
alle.gaussianRadius, alle.lr, alle.maxNumEpochs, alle.maxLongEdge,
alle.saveEpochs, alle.stopEps)
} catch {
case ex: Exception => {
logger.error(ex.getMessage, ex)
Expand Down Expand Up @@ -292,6 +305,6 @@ class NeuralStyle {
private val outputDir: String = null
@Option(name = "--save-epochs", usage = "save the output every n epochs")
private val saveEpochs: Int = 50
@Option(name = "--guassian-radius", usage = "the gaussian blur filter radius")
private val guassianRadius: Int = 1
@Option(name = "--gaussian-radius", usage = "the gaussian blur filter radius")
private val gaussianRadius: Int = 1
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Neural Style Example for Scala

## Introduction
This model contains three important components:
- Boost Inference
- Boost Training
- Neural Style conversion

You can use the prebuilt VGG model to do the conversion.
By adding a style image, you can create several interesting images.

Original Image | Style Image
:-------------------------:|:-------------------------:
![](https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/NeuralStyle/IMG_4343.jpg) | ![](https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/NeuralStyle/starry_night.jpg)

Boost Inference Image (pretrained) | Epoch 150 Image
:-------------------------:|:-------------------------:
![](https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/NeuralStyle/out_3.jpg) | ![](https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/NeuralStyle/tmp_150.jpg)

## Setup
Please download the input image and style image following the links below:

Input image
```bash
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/NeuralStyle/IMG_4343.jpg
```
Style image
```bash
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/NeuralStyle/starry_night.jpg
```

VGG model --Boost inference
```bash
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/NeuralStyle/model.zip
```

VGG model --Boost Training
```bash
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/NeuralStyle/vgg19.params
```

Please unzip the model before you use it.

## Boost Inference Example

Please provide the corresponding arguments before you execute the program
```bash
--input-image
<path>/IMG_4343.jpg
--model-path
<path>/model
--output-path
<outputPath>
```

## Boost Training Example
Please download your own training data for boost training.
You can use 26k images sampled from [MIT Place dataset](http://places.csail.mit.edu/).
```bash
--style-image
<path>/starry_night.jpg
--data-path
<path>/images
--vgg-model-path
<path>/vgg19.params
--save-model-path
<path>
```

## NeuralStyle Example
Please provide the corresponding arguments before you execute the program
```bash
--model-path
<path>/vgg19.params
--content-image
<path>/IMG_4343.jpg
--style-image
<path>/starry_night.jpg
--gpu
<num_of_gpus>
--output-dir
<path>
```
Loading

0 comments on commit 776ebe8

Please sign in to comment.