diff --git a/apache-spark/notebook/Image_Classification_Spark.ipynb b/apache-spark/notebook/Image_Classification_Spark.ipynb index 37039e00..688d6ac6 100644 --- a/apache-spark/notebook/Image_Classification_Spark.ipynb +++ b/apache-spark/notebook/Image_Classification_Spark.ipynb @@ -49,6 +49,7 @@ "import java.util\n", "import ai.djl.Model\n", "import ai.djl.modality.Classifications\n", + "import ai.djl.repository.zoo.ZooModel\n", "import ai.djl.modality.cv.transform.{ Resize, ToTensor}\n", "import ai.djl.ndarray.types.{DataType, Shape}\n", "import ai.djl.ndarray.{NDList, NDManager}\n", @@ -83,7 +84,7 @@ "outputs": [], "source": [ " // Translator: a class used to do preprocessing and post processing\n", - " class MyTranslator extends Translator[Row, Classifications] {\n", + " class SparkImageClassificationTranslator extends Translator[Row, Classifications] {\n", "\n", " private var classes: java.util.List[String] = new util.ArrayList[String]()\n", " private val pipeline: Pipeline = new Pipeline()\n", @@ -95,7 +96,6 @@ " }\n", "\n", " override def processInput(ctx: TranslatorContext, row: Row): NDList = {\n", - "\n", " val height = ImageSchema.getHeight(row)\n", " val width = ImageSchema.getWidth(row)\n", " val channel = ImageSchema.getNChannels(row)\n", @@ -129,30 +129,93 @@ "If you are using MXNet as the backend engine, plase uncomment the mxnet model url." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Start Spark application\n", + "\n", + "We can create a `NotebookSparkSession` through the Almond Spark plugin. It will internally apply all necessary jars to each of the worker node." + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "val modelUrl = \"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18\"\n", - "// val modelUrl = \"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/mxnet_resnet18.zip?model_name=resnet18_v1\"\n", - "lazy val criteria = Criteria.builder\n", - " .setTypes(classOf[Row], classOf[Classifications])\n", - " .optModelUrls(modelUrl)\n", - " .optTranslator(new MyTranslator())\n", - " .optProgress(new ProgressBar)\n", - " .build()\n", - "lazy val model = ModelZoo.loadModel(criteria)" + "import ai.djl.inference.Predictor\n", + "import ai.djl.modality.Classifications\n", + "import ai.djl.repository.zoo.{Criteria, ModelZoo}\n", + "import ai.djl.training.util.ProgressBar\n", + "import org.apache.spark.sql.Row\n", + "\n", + "@SerialVersionUID(123456789L)\n", + "class SparkModel(val url : String) extends Serializable {\n", + "\n", + " private lazy val criteria = Criteria.builder\n", + " .setTypes(classOf[Row], classOf[Classifications])\n", + " .optModelUrls(url)\n", + " .optTranslator(new SparkImageClassificationTranslator())\n", + " .optProgress(new ProgressBar)\n", + " .build()\n", + " private lazy val model = ModelZoo.loadModel(criteria)\n", + "\n", + " // def load(url: String) = new SparkModel(url)\n", + "\n", + " def newPredictor(): Predictor[Row, Classifications] = {\n", + " model.newPredictor()\n", + " }\n", + "}" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "## Start Spark application\n", + "import org.apache.spark.ml.Transformer\n", + "import org.apache.spark.ml.param.{Param, ParamMap}\n", + "import org.apache.spark.ml.util.Identifiable\n", + "import org.apache.spark.sql.types.StructType\n", + "import org.apache.spark.sql.{DataFrame, Dataset, Encoders}\n", "\n", - "We can create a `NotebookSparkSession` through the Almond Spark plugin. It will internally apply all necessary jars to each of the worker node." + "class SparkPredictor(override val uid: String) extends Transformer {\n", + "\n", + " def this() = this(Identifiable.randomUID(\"SparkPredictor\"))\n", + "\n", + " final val inputCol = new Param[String](this, \"inputCol\", \"The input column\")\n", + " final val outputCol = new Param[String](this, \"outputCol\", \"The output column\")\n", + " final val modelUrl = new Param[String](this, \"modelUrl\", \"The model URL\")\n", + "\n", + " def setInputCol(value: String): this.type = set(inputCol, value)\n", + "\n", + " def setOutputCol(value: String): this.type = set(outputCol, value)\n", + "\n", + " def setModelUrl(value: String): this.type = set(modelUrl, value)\n", + "\n", + " def predict(dataset: Dataset[_]): DataFrame = {\n", + " transform(dataset)\n", + " }\n", + "\n", + " override def transform(dataset: Dataset[_]): DataFrame = {\n", + " val outputSchema = transformSchema(dataset.schema)\n", + " val model = new SparkModel($(modelUrl))\n", + " val outputDf = dataset.select($(inputCol)).mapPartitions(partition => {\n", + " val predictor = model.newPredictor()\n", + " partition.map(row => {\n", + " // image data stored as HWC format\n", + " predictor.predict(row).toString\n", + " })\n", + " })(Encoders.STRING)\n", + " outputDf.select($(outputCol))\n", + " }\n", + "\n", + " override def transformSchema(schema: StructType) = schema\n", + "\n", + " override def copy(paramMap: ParamMap) = this\n", + "}" ] }, { @@ -199,14 +262,12 @@ "metadata": {}, "outputs": [], "source": [ - "val result = df.select(col(\"image.*\")).mapPartitions(partition => {\n", - " val predictor = model.newPredictor()\n", - " partition.map(row => {\n", - " // image data stored as HWC format\n", - " predictor.predict(row).toString\n", - " })\n", - "})(Encoders.STRING)\n", - "println(result.collect().mkString(\"\\n\"))" + "val predictor = new SparkPredictor()\n", + " .setInputCol(\"image.*\")\n", + " .setOutputCol(\"value\")\n", + " .setModelUrl(\"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18\")\n", + "val outputDf = predictor.predict(df)\n", + "println(outputDf.collect().mkString(\"\\n\"))" ] } ], @@ -222,7 +283,7 @@ "mimetype": "text/x-scala", "name": "scala", "nbconvert_exporter": "script", - "version": "2.12.12" + "version": "2.12.11" } }, "nbformat": 4, diff --git a/apache-spark/spark3.0/image-classification/src/main/scala/com/examples/ImageClassificationExample.scala b/apache-spark/spark3.0/image-classification/src/main/scala/com/examples/ImageClassificationExample.scala index e53cf214..38c28f21 100644 --- a/apache-spark/spark3.0/image-classification/src/main/scala/com/examples/ImageClassificationExample.scala +++ b/apache-spark/spark3.0/image-classification/src/main/scala/com/examples/ImageClassificationExample.scala @@ -12,75 +12,14 @@ */ package com.examples -import java.util - -import ai.djl.Model -import ai.djl.modality.Classifications -import ai.djl.modality.cv.transform.{ Resize, ToTensor} -import ai.djl.ndarray.types.{DataType, Shape} -import ai.djl.ndarray.{NDList, NDManager} -import ai.djl.repository.zoo.{Criteria, ZooModel} -import ai.djl.training.util.ProgressBar -import ai.djl.translate.{Batchifier, Pipeline, Translator, TranslatorContext} -import ai.djl.util.Utils -import org.apache.spark.ml.image.ImageSchema -import org.apache.spark.sql.functions.col -import org.apache.spark.sql.{Encoders, Row, SparkSession} - +import org.apache.spark.sql.SparkSession /** * Example to run image classification on Spark. */ object ImageClassificationExample { - private lazy val model = loadModel() - - def loadModel(): ZooModel[Row, Classifications] = { - val modelUrl = "https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18" - val criteria = Criteria.builder - .setTypes(classOf[Row], classOf[Classifications]) - .optModelUrls(modelUrl) - .optTranslator(new MyTranslator()) - .optProgress(new ProgressBar) - .build() - criteria.loadModel() - } - - // Translator: a class used to do preprocessing and post processing - class MyTranslator extends Translator[Row, Classifications] { - - private var classes: java.util.List[String] = new util.ArrayList[String]() - private val pipeline: Pipeline = new Pipeline() - .add(new Resize(224, 224)) - .add(new ToTensor()) - - override def prepare(manager: NDManager, model: Model): Unit = { - classes = Utils.readLines(model.getArtifact("synset.txt").openStream()) - } - - override def processInput(ctx: TranslatorContext, row: Row): NDList = { - - val height = ImageSchema.getHeight(row) - val width = ImageSchema.getWidth(row) - val channel = ImageSchema.getNChannels(row) - var image = ctx.getNDManager.create(ImageSchema.getData(row), new Shape(height, width, channel)).toType(DataType.UINT8, true) - // BGR to RGB - image = image.flip(2) - pipeline.transform(new NDList(image)) - } - - // Deal with the output.,NDList contains output result, usually one or more NDArray(s). - override def processOutput(ctx: TranslatorContext, list: NDList): Classifications = { - var probabilitiesNd = list.singletonOrThrow - probabilitiesNd = probabilitiesNd.softmax(0) - new Classifications(classes, probabilitiesNd) - } - - override def getBatchifier: Batchifier = Batchifier.STACK - } - def main(args: Array[String]) { - // Spark configuration val spark = SparkSession.builder() .master("local[*]") @@ -88,15 +27,13 @@ object ImageClassificationExample { .getOrCreate() val df = spark.read.format("image").option("dropInvalid", true).load("../../image-classification/images") - println(df.select("image.origin", "image.width", "image.height").show(truncate=false)) - - val result = df.select(col("image.*")).mapPartitions(partition => { - val predictor = model.newPredictor() - partition.map(row => { - // image data stored as HWC format - predictor.predict(row).toString - }) - })(Encoders.STRING) - println(result.collect().mkString("\n")) + println(df.select("image.origin", "image.width", "image.height").show(truncate = false)) + + val predictor = new SparkPredictor() + .setInputCol("image.*") + .setOutputCol("value") + .setModelUrl("https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18") + val outputDf = predictor.predict(df) + println(outputDf.collect().mkString("\n")) } } diff --git a/apache-spark/spark3.0/image-classification/src/main/scala/com/examples/SparkImageClassificationTranslator.scala b/apache-spark/spark3.0/image-classification/src/main/scala/com/examples/SparkImageClassificationTranslator.scala new file mode 100644 index 00000000..c9580ed7 --- /dev/null +++ b/apache-spark/spark3.0/image-classification/src/main/scala/com/examples/SparkImageClassificationTranslator.scala @@ -0,0 +1,57 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package com.examples + +import ai.djl.Model +import ai.djl.modality.Classifications +import ai.djl.modality.cv.transform.{Resize, ToTensor} +import ai.djl.ndarray.{NDList, NDManager} +import ai.djl.ndarray.types.{DataType, Shape} +import ai.djl.translate.{Batchifier, Pipeline, Translator, TranslatorContext} +import ai.djl.util.Utils +import org.apache.spark.ml.image.ImageSchema +import org.apache.spark.sql.Row + +import java.util + +// Translator: a class used to do preprocessing and post processing +class SparkImageClassificationTranslator extends Translator[Row, Classifications] { + + private var classes: java.util.List[String] = new util.ArrayList[String]() + private val pipeline: Pipeline = new Pipeline() + .add(new Resize(224, 224)) + .add(new ToTensor()) + + override def prepare(manager: NDManager, model: Model): Unit = { + classes = Utils.readLines(model.getArtifact("synset.txt").openStream()) + } + + override def processInput(ctx: TranslatorContext, row: Row): NDList = { + val height = ImageSchema.getHeight(row) + val width = ImageSchema.getWidth(row) + val channel = ImageSchema.getNChannels(row) + var image = ctx.getNDManager.create(ImageSchema.getData(row), new Shape(height, width, channel)).toType(DataType.UINT8, true) + // BGR to RGB + image = image.flip(2) + pipeline.transform(new NDList(image)) + } + + // Deal with the output.,NDList contains output result, usually one or more NDArray(s). + override def processOutput(ctx: TranslatorContext, list: NDList): Classifications = { + var probabilitiesNd = list.singletonOrThrow + probabilitiesNd = probabilitiesNd.softmax(0) + new Classifications(classes, probabilitiesNd) + } + + override def getBatchifier: Batchifier = Batchifier.STACK +} diff --git a/apache-spark/spark3.0/image-classification/src/main/scala/com/examples/SparkJavaModel.java b/apache-spark/spark3.0/image-classification/src/main/scala/com/examples/SparkJavaModel.java new file mode 100644 index 00000000..b8e823fd --- /dev/null +++ b/apache-spark/spark3.0/image-classification/src/main/scala/com/examples/SparkJavaModel.java @@ -0,0 +1,63 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package com.examples; + +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.Classifications; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import org.apache.spark.sql.Row; + +import java.io.IOException; +import java.io.Serializable; + +public class SparkJavaModel implements Serializable { + + private static final long serialVersionUID = 123456789L; + + private String url; + private ZooModel model; + + public SparkJavaModel(String url) { + this.url = url; + } + + public String getUrl() { + return url; + } + + public void setUrl(String url) { + this.url = url; + } + + public static SparkJavaModel load(String url) { + return new SparkJavaModel(url); + } + + public Predictor newPredictor() + throws ModelNotFoundException, MalformedModelException, IOException { + if (model == null) { + Criteria criteria = Criteria.builder() + .setTypes(Row.class, Classifications.class) + .optModelUrls(url) + .optTranslator(new SparkImageClassificationTranslator()) + .optProgress(new ProgressBar()) + .build(); + model = criteria.loadModel(); + } + return model.newPredictor(); + } +} diff --git a/apache-spark/spark3.0/image-classification/src/main/scala/com/examples/SparkModel.scala b/apache-spark/spark3.0/image-classification/src/main/scala/com/examples/SparkModel.scala new file mode 100644 index 00000000..e9686a21 --- /dev/null +++ b/apache-spark/spark3.0/image-classification/src/main/scala/com/examples/SparkModel.scala @@ -0,0 +1,37 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package com.examples + +import ai.djl.inference.Predictor +import ai.djl.modality.Classifications +import ai.djl.repository.zoo.{Criteria, ModelZoo} +import ai.djl.training.util.ProgressBar +import org.apache.spark.sql.Row + +@SerialVersionUID(123456789L) +class SparkModel(val url : String) extends Serializable { + + private lazy val criteria = Criteria.builder + .setTypes(classOf[Row], classOf[Classifications]) + .optModelUrls(url) + .optTranslator(new SparkImageClassificationTranslator()) + .optProgress(new ProgressBar) + .build() + private lazy val model = ModelZoo.loadModel(criteria) + + // def load(url: String) = new SparkModel(url) + + def newPredictor(): Predictor[Row, Classifications] = { + model.newPredictor() + } +} diff --git a/apache-spark/spark3.0/image-classification/src/main/scala/com/examples/SparkPredictor.scala b/apache-spark/spark3.0/image-classification/src/main/scala/com/examples/SparkPredictor.scala new file mode 100644 index 00000000..8b7d5321 --- /dev/null +++ b/apache-spark/spark3.0/image-classification/src/main/scala/com/examples/SparkPredictor.scala @@ -0,0 +1,55 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package com.examples + +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Dataset, Encoders} + +class SparkPredictor(override val uid: String) extends Transformer { + + def this() = this(Identifiable.randomUID("SparkPredictor")) + + final val inputCol = new Param[String](this, "inputCol", "The input column") + final val outputCol = new Param[String](this, "outputCol", "The output column") + final val modelUrl = new Param[String](this, "modelUrl", "The model URL") + + def setInputCol(value: String): this.type = set(inputCol, value) + + def setOutputCol(value: String): this.type = set(outputCol, value) + + def setModelUrl(value: String): this.type = set(modelUrl, value) + + def predict(dataset: Dataset[_]): DataFrame = { + transform(dataset) + } + + override def transform(dataset: Dataset[_]): DataFrame = { + val outputSchema = transformSchema(dataset.schema) + val model = new SparkModel($(modelUrl)) + val outputDf = dataset.select($(inputCol)).mapPartitions(partition => { + val predictor = model.newPredictor() + partition.map(row => { + // image data stored as HWC format + predictor.predict(row).toString + }) + })(Encoders.STRING) + outputDf.select($(outputCol)) + } + + override def transformSchema(schema: StructType) = schema + + override def copy(paramMap: ParamMap) = this +}