Skip to content

Commit

Permalink
Spark extension POC
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Nov 1, 2022
1 parent d92d33e commit 3509c8b
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 95 deletions.
107 changes: 84 additions & 23 deletions apache-spark/notebook/Image_Classification_Spark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"import java.util\n",
"import ai.djl.Model\n",
"import ai.djl.modality.Classifications\n",
"import ai.djl.repository.zoo.ZooModel\n",
"import ai.djl.modality.cv.transform.{ Resize, ToTensor}\n",
"import ai.djl.ndarray.types.{DataType, Shape}\n",
"import ai.djl.ndarray.{NDList, NDManager}\n",
Expand Down Expand Up @@ -83,7 +84,7 @@
"outputs": [],
"source": [
" // Translator: a class used to do preprocessing and post processing\n",
" class MyTranslator extends Translator[Row, Classifications] {\n",
" class SparkImageClassificationTranslator extends Translator[Row, Classifications] {\n",
"\n",
" private var classes: java.util.List[String] = new util.ArrayList[String]()\n",
" private val pipeline: Pipeline = new Pipeline()\n",
Expand All @@ -95,7 +96,6 @@
" }\n",
"\n",
" override def processInput(ctx: TranslatorContext, row: Row): NDList = {\n",
"\n",
" val height = ImageSchema.getHeight(row)\n",
" val width = ImageSchema.getWidth(row)\n",
" val channel = ImageSchema.getNChannels(row)\n",
Expand Down Expand Up @@ -129,30 +129,93 @@
"If you are using MXNet as the backend engine, plase uncomment the mxnet model url."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Start Spark application\n",
"\n",
"We can create a `NotebookSparkSession` through the Almond Spark plugin. It will internally apply all necessary jars to each of the worker node."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"val modelUrl = \"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18\"\n",
"// val modelUrl = \"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/mxnet_resnet18.zip?model_name=resnet18_v1\"\n",
"lazy val criteria = Criteria.builder\n",
" .setTypes(classOf[Row], classOf[Classifications])\n",
" .optModelUrls(modelUrl)\n",
" .optTranslator(new MyTranslator())\n",
" .optProgress(new ProgressBar)\n",
" .build()\n",
"lazy val model = ModelZoo.loadModel(criteria)"
"import ai.djl.inference.Predictor\n",
"import ai.djl.modality.Classifications\n",
"import ai.djl.repository.zoo.{Criteria, ModelZoo}\n",
"import ai.djl.training.util.ProgressBar\n",
"import org.apache.spark.sql.Row\n",
"\n",
"@SerialVersionUID(123456789L)\n",
"class SparkModel(val url : String) extends Serializable {\n",
"\n",
" private lazy val criteria = Criteria.builder\n",
" .setTypes(classOf[Row], classOf[Classifications])\n",
" .optModelUrls(url)\n",
" .optTranslator(new SparkImageClassificationTranslator())\n",
" .optProgress(new ProgressBar)\n",
" .build()\n",
" private lazy val model = ModelZoo.loadModel(criteria)\n",
"\n",
" // def load(url: String) = new SparkModel(url)\n",
"\n",
" def newPredictor(): Predictor[Row, Classifications] = {\n",
" model.newPredictor()\n",
" }\n",
"}"
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## Start Spark application\n",
"import org.apache.spark.ml.Transformer\n",
"import org.apache.spark.ml.param.{Param, ParamMap}\n",
"import org.apache.spark.ml.util.Identifiable\n",
"import org.apache.spark.sql.types.StructType\n",
"import org.apache.spark.sql.{DataFrame, Dataset, Encoders}\n",
"\n",
"We can create a `NotebookSparkSession` through the Almond Spark plugin. It will internally apply all necessary jars to each of the worker node."
"class SparkPredictor(override val uid: String) extends Transformer {\n",
"\n",
" def this() = this(Identifiable.randomUID(\"SparkPredictor\"))\n",
"\n",
" final val inputCol = new Param[String](this, \"inputCol\", \"The input column\")\n",
" final val outputCol = new Param[String](this, \"outputCol\", \"The output column\")\n",
" final val modelUrl = new Param[String](this, \"modelUrl\", \"The model URL\")\n",
"\n",
" def setInputCol(value: String): this.type = set(inputCol, value)\n",
"\n",
" def setOutputCol(value: String): this.type = set(outputCol, value)\n",
"\n",
" def setModelUrl(value: String): this.type = set(modelUrl, value)\n",
"\n",
" def predict(dataset: Dataset[_]): DataFrame = {\n",
" transform(dataset)\n",
" }\n",
"\n",
" override def transform(dataset: Dataset[_]): DataFrame = {\n",
" val outputSchema = transformSchema(dataset.schema)\n",
" val model = new SparkModel($(modelUrl))\n",
" val outputDf = dataset.select($(inputCol)).mapPartitions(partition => {\n",
" val predictor = model.newPredictor()\n",
" partition.map(row => {\n",
" // image data stored as HWC format\n",
" predictor.predict(row).toString\n",
" })\n",
" })(Encoders.STRING)\n",
" outputDf.select($(outputCol))\n",
" }\n",
"\n",
" override def transformSchema(schema: StructType) = schema\n",
"\n",
" override def copy(paramMap: ParamMap) = this\n",
"}"
]
},
{
Expand Down Expand Up @@ -199,14 +262,12 @@
"metadata": {},
"outputs": [],
"source": [
"val result = df.select(col(\"image.*\")).mapPartitions(partition => {\n",
" val predictor = model.newPredictor()\n",
" partition.map(row => {\n",
" // image data stored as HWC format\n",
" predictor.predict(row).toString\n",
" })\n",
"})(Encoders.STRING)\n",
"println(result.collect().mkString(\"\\n\"))"
"val predictor = new SparkPredictor()\n",
" .setInputCol(\"image.*\")\n",
" .setOutputCol(\"value\")\n",
" .setModelUrl(\"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18\")\n",
"val outputDf = predictor.predict(df)\n",
"println(outputDf.collect().mkString(\"\\n\"))"
]
}
],
Expand All @@ -222,7 +283,7 @@
"mimetype": "text/x-scala",
"name": "scala",
"nbconvert_exporter": "script",
"version": "2.12.12"
"version": "2.12.11"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,91 +12,28 @@
*/
package com.examples

import java.util

import ai.djl.Model
import ai.djl.modality.Classifications
import ai.djl.modality.cv.transform.{ Resize, ToTensor}
import ai.djl.ndarray.types.{DataType, Shape}
import ai.djl.ndarray.{NDList, NDManager}
import ai.djl.repository.zoo.{Criteria, ZooModel}
import ai.djl.training.util.ProgressBar
import ai.djl.translate.{Batchifier, Pipeline, Translator, TranslatorContext}
import ai.djl.util.Utils
import org.apache.spark.ml.image.ImageSchema
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{Encoders, Row, SparkSession}

import org.apache.spark.sql.SparkSession

/**
* Example to run image classification on Spark.
*/
object ImageClassificationExample {

private lazy val model = loadModel()

def loadModel(): ZooModel[Row, Classifications] = {
val modelUrl = "https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18"
val criteria = Criteria.builder
.setTypes(classOf[Row], classOf[Classifications])
.optModelUrls(modelUrl)
.optTranslator(new MyTranslator())
.optProgress(new ProgressBar)
.build()
criteria.loadModel()
}

// Translator: a class used to do preprocessing and post processing
class MyTranslator extends Translator[Row, Classifications] {

private var classes: java.util.List[String] = new util.ArrayList[String]()
private val pipeline: Pipeline = new Pipeline()
.add(new Resize(224, 224))
.add(new ToTensor())

override def prepare(manager: NDManager, model: Model): Unit = {
classes = Utils.readLines(model.getArtifact("synset.txt").openStream())
}

override def processInput(ctx: TranslatorContext, row: Row): NDList = {

val height = ImageSchema.getHeight(row)
val width = ImageSchema.getWidth(row)
val channel = ImageSchema.getNChannels(row)
var image = ctx.getNDManager.create(ImageSchema.getData(row), new Shape(height, width, channel)).toType(DataType.UINT8, true)
// BGR to RGB
image = image.flip(2)
pipeline.transform(new NDList(image))
}

// Deal with the output.,NDList contains output result, usually one or more NDArray(s).
override def processOutput(ctx: TranslatorContext, list: NDList): Classifications = {
var probabilitiesNd = list.singletonOrThrow
probabilitiesNd = probabilitiesNd.softmax(0)
new Classifications(classes, probabilitiesNd)
}

override def getBatchifier: Batchifier = Batchifier.STACK
}

def main(args: Array[String]) {

// Spark configuration
val spark = SparkSession.builder()
.master("local[*]")
.appName("Image Classification")
.getOrCreate()

val df = spark.read.format("image").option("dropInvalid", true).load("../../image-classification/images")
println(df.select("image.origin", "image.width", "image.height").show(truncate=false))

val result = df.select(col("image.*")).mapPartitions(partition => {
val predictor = model.newPredictor()
partition.map(row => {
// image data stored as HWC format
predictor.predict(row).toString
})
})(Encoders.STRING)
println(result.collect().mkString("\n"))
println(df.select("image.origin", "image.width", "image.height").show(truncate = false))

val predictor = new SparkPredictor()
.setInputCol("image.*")
.setOutputCol("value")
.setModelUrl("https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18")
val outputDf = predictor.predict(df)
println(outputDf.collect().mkString("\n"))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package com.examples

import ai.djl.Model
import ai.djl.modality.Classifications
import ai.djl.modality.cv.transform.{Resize, ToTensor}
import ai.djl.ndarray.{NDList, NDManager}
import ai.djl.ndarray.types.{DataType, Shape}
import ai.djl.translate.{Batchifier, Pipeline, Translator, TranslatorContext}
import ai.djl.util.Utils
import org.apache.spark.ml.image.ImageSchema
import org.apache.spark.sql.Row

import java.util

// Translator: a class used to do preprocessing and post processing
class SparkImageClassificationTranslator extends Translator[Row, Classifications] {

private var classes: java.util.List[String] = new util.ArrayList[String]()
private val pipeline: Pipeline = new Pipeline()
.add(new Resize(224, 224))
.add(new ToTensor())

override def prepare(manager: NDManager, model: Model): Unit = {
classes = Utils.readLines(model.getArtifact("synset.txt").openStream())
}

override def processInput(ctx: TranslatorContext, row: Row): NDList = {
val height = ImageSchema.getHeight(row)
val width = ImageSchema.getWidth(row)
val channel = ImageSchema.getNChannels(row)
var image = ctx.getNDManager.create(ImageSchema.getData(row), new Shape(height, width, channel)).toType(DataType.UINT8, true)
// BGR to RGB
image = image.flip(2)
pipeline.transform(new NDList(image))
}

// Deal with the output.,NDList contains output result, usually one or more NDArray(s).
override def processOutput(ctx: TranslatorContext, list: NDList): Classifications = {
var probabilitiesNd = list.singletonOrThrow
probabilitiesNd = probabilitiesNd.softmax(0)
new Classifications(classes, probabilitiesNd)
}

override def getBatchifier: Batchifier = Batchifier.STACK
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package com.examples;

import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import org.apache.spark.sql.Row;

import java.io.IOException;
import java.io.Serializable;

public class SparkJavaModel implements Serializable {

private static final long serialVersionUID = 123456789L;

private String url;
private ZooModel<Row, Classifications> model;

public SparkJavaModel(String url) {
this.url = url;
}

public String getUrl() {
return url;
}

public void setUrl(String url) {
this.url = url;
}

public static SparkJavaModel load(String url) {
return new SparkJavaModel(url);
}

public Predictor<Row, Classifications> newPredictor()
throws ModelNotFoundException, MalformedModelException, IOException {
if (model == null) {
Criteria<Row, Classifications> criteria = Criteria.builder()
.setTypes(Row.class, Classifications.class)
.optModelUrls(url)
.optTranslator(new SparkImageClassificationTranslator())
.optProgress(new ProgressBar())
.build();
model = criteria.loadModel();
}
return model.newPredictor();
}
}
Loading

0 comments on commit 3509c8b

Please sign in to comment.