Skip to content

Commit

Permalink
Add Spark extension example
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Nov 24, 2022
1 parent d92d33e commit 9beeac7
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 132 deletions.
170 changes: 118 additions & 52 deletions apache-spark/notebook/Image_Classification_Spark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,25 @@
"outputs": [],
"source": [
"import java.util\n",
"import ai.djl.inference.Predictor\n",
"import ai.djl.Model\n",
"import ai.djl.modality.Classifications\n",
"import ai.djl.repository.zoo.{Criteria, ZooModel}\n",
"import ai.djl.modality.cv.transform.{ Resize, ToTensor}\n",
"import ai.djl.ndarray.types.{DataType, Shape}\n",
"import ai.djl.ndarray.{NDList, NDManager}\n",
"import ai.djl.repository.zoo.{Criteria, ModelZoo, ZooModel}\n",
"import ai.djl.training.util.ProgressBar\n",
"import ai.djl.translate.{Batchifier, Pipeline, Translator, TranslatorContext}\n",
"import ai.djl.util.Utils\n",
"import java.io.Serializable\n",
"import org.apache.spark.ml.image.ImageSchema\n",
"import org.apache.spark.ml.param.{Param, ParamMap}\n",
"import org.apache.spark.ml.Transformer\n",
"import org.apache.spark.ml.util.Identifiable\n",
"import org.apache.spark.sql.functions.col\n",
"import org.apache.spark.sql.{Encoders, Row, NotebookSparkSession}\n",
"import org.apache.spark.sql.{DataFrame, Dataset, Encoders, Row, NotebookSparkSession}\n",
"import org.apache.spark.sql.types.StructType\n",
"import org.apache.log4j.{Level, Logger}\n",
"Logger.getLogger(\"org\").setLevel(Level.OFF) // avoid too much message popping out\n",
"Logger.getLogger(\"ai\").setLevel(Level.OFF) // avoid too much message popping out"
Expand All @@ -82,38 +89,38 @@
"metadata": {},
"outputs": [],
"source": [
" // Translator: a class used to do preprocessing and post processing\n",
" class MyTranslator extends Translator[Row, Classifications] {\n",
"\n",
" private var classes: java.util.List[String] = new util.ArrayList[String]()\n",
" private val pipeline: Pipeline = new Pipeline()\n",
" .add(new Resize(224, 224))\n",
" .add(new ToTensor())\n",
"\n",
" override def prepare(manager: NDManager, model: Model): Unit = {\n",
" classes = Utils.readLines(model.getArtifact(\"synset.txt\").openStream())\n",
" }\n",
"\n",
" override def processInput(ctx: TranslatorContext, row: Row): NDList = {\n",
"\n",
" val height = ImageSchema.getHeight(row)\n",
" val width = ImageSchema.getWidth(row)\n",
" val channel = ImageSchema.getNChannels(row)\n",
" var image = ctx.getNDManager.create(ImageSchema.getData(row), new Shape(height, width, channel)).toType(DataType.UINT8, true)\n",
" // BGR to RGB\n",
" image = image.flip(2)\n",
" pipeline.transform(new NDList(image))\n",
" }\n",
"// Translator: a class used to do preprocessing and post processing\n",
"@SerialVersionUID(234567891L)\n",
"class SparkImageClassificationTranslator extends Translator[Row, Classifications] with Serializable {\n",
"\n",
" // Deal with the output.,NDList contains output result, usually one or more NDArray(s).\n",
" override def processOutput(ctx: TranslatorContext, list: NDList): Classifications = {\n",
" var probabilitiesNd = list.singletonOrThrow\n",
" probabilitiesNd = probabilitiesNd.softmax(0)\n",
" new Classifications(classes, probabilitiesNd)\n",
" }\n",
" private var classes: java.util.List[String] = new util.ArrayList[String]()\n",
" private lazy val pipeline: Pipeline = new Pipeline()\n",
" .add(new Resize(224, 224))\n",
" .add(new ToTensor())\n",
"\n",
" override def prepare(manager: NDManager, model: Model): Unit = {\n",
" classes = Utils.readLines(model.getArtifact(\"synset.txt\").openStream())\n",
" }\n",
"\n",
" override def processInput(ctx: TranslatorContext, row: Row): NDList = {\n",
" val height = ImageSchema.getHeight(row)\n",
" val width = ImageSchema.getWidth(row)\n",
" val channel = ImageSchema.getNChannels(row)\n",
" var image = ctx.getNDManager.create(ImageSchema.getData(row), new Shape(height, width, channel)).toType(DataType.UINT8, true)\n",
" // BGR to RGB\n",
" image = image.flip(2)\n",
" pipeline.transform(new NDList(image))\n",
" }\n",
"\n",
" // Deal with the output.,NDList contains output result, usually one or more NDArray(s).\n",
" override def processOutput(ctx: TranslatorContext, list: NDList): Classifications = {\n",
" var probabilitiesNd = list.singletonOrThrow\n",
" probabilitiesNd = probabilitiesNd.softmax(0)\n",
" new Classifications(classes, probabilitiesNd)\n",
" }\n",
"\n",
" override def getBatchifier: Batchifier = Batchifier.STACK\n",
" }"
" override def getBatchifier: Batchifier = Batchifier.STACK\n",
"}"
]
},
{
Expand All @@ -129,30 +136,90 @@
"If you are using MXNet as the backend engine, plase uncomment the mxnet model url."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Start Spark application\n",
"\n",
"We can create a `NotebookSparkSession` through the Almond Spark plugin. It will internally apply all necessary jars to each of the worker node."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"val modelUrl = \"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18\"\n",
"// val modelUrl = \"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/mxnet_resnet18.zip?model_name=resnet18_v1\"\n",
"lazy val criteria = Criteria.builder\n",
" .setTypes(classOf[Row], classOf[Classifications])\n",
" .optModelUrls(modelUrl)\n",
" .optTranslator(new MyTranslator())\n",
" .optProgress(new ProgressBar)\n",
" .build()\n",
"lazy val model = ModelZoo.loadModel(criteria)"
"@SerialVersionUID(123456789L)\n",
"class SparkModel(var url: String, val translator: Translator[Row, Classifications]) extends Serializable {\n",
"\n",
" private var model: ZooModel[Row, Classifications] = null\n",
"\n",
" def newPredictor(): Predictor[Row, Classifications] = {\n",
" if (model == null) {\n",
" val criteria = Criteria.builder\n",
" .setTypes(classOf[Row], classOf[Classifications])\n",
" .optModelUrls(url)\n",
" .optTranslator(translator)\n",
" .optProgress(new ProgressBar)\n",
" .build\n",
" model = ModelZoo.loadModel(criteria)\n",
" }\n",
" model.newPredictor\n",
" }\n",
"}\n",
"\n",
"@SerialVersionUID(123456789L)\n",
"object SparkModel {\n",
" def load(url: String, translator: Translator[Row, Classifications]) = new SparkModel(url, translator)\n",
"}"
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## Start Spark application\n",
"class SparkPredictor(override val uid: String) extends Transformer {\n",
"\n",
"We can create a `NotebookSparkSession` through the Almond Spark plugin. It will internally apply all necessary jars to each of the worker node."
" def this() = this(Identifiable.randomUID(\"SparkPredictor\"))\n",
"\n",
" final val inputCol = new Param[String](this, \"inputCol\", \"The input column\")\n",
" final val outputCol = new Param[String](this, \"outputCol\", \"The output column\")\n",
" final val modelUrl = new Param[String](this, \"modelUrl\", \"The model URL\")\n",
" final val translator = new Param[Translator[Row, Classifications]](this, \"translator\", \"The translator\")\n",
"\n",
" def setInputCol(value: String): this.type = set(inputCol, value)\n",
"\n",
" def setOutputCol(value: String): this.type = set(outputCol, value)\n",
"\n",
" def setModelUrl(value: String): this.type = set(modelUrl, value)\n",
"\n",
" def setTranslator(value: Translator[Row, Classifications]): this.type = set(translator, value)\n",
"\n",
" def predict(dataset: Dataset[_]): DataFrame = {\n",
" transform(dataset)\n",
" }\n",
"\n",
" override def transform(dataset: Dataset[_]): DataFrame = {\n",
" val outputSchema = transformSchema(dataset.schema)\n",
" val model = SparkModel.load($(modelUrl), $(translator))\n",
" val outputDf = dataset.select($(inputCol)).mapPartitions(partition => {\n",
" val predictor = model.newPredictor()\n",
" partition.map(row => {\n",
" // image data stored as HWC format\n",
" predictor.predict(row).toString\n",
" })\n",
" })(Encoders.STRING)\n",
" outputDf.select($(outputCol))\n",
" }\n",
"\n",
" override def transformSchema(schema: StructType) = schema\n",
"\n",
" override def copy(paramMap: ParamMap) = this\n",
"}"
]
},
{
Expand Down Expand Up @@ -199,14 +266,13 @@
"metadata": {},
"outputs": [],
"source": [
"val result = df.select(col(\"image.*\")).mapPartitions(partition => {\n",
" val predictor = model.newPredictor()\n",
" partition.map(row => {\n",
" // image data stored as HWC format\n",
" predictor.predict(row).toString\n",
" })\n",
"})(Encoders.STRING)\n",
"println(result.collect().mkString(\"\\n\"))"
"val predictor = new SparkPredictor()\n",
" .setInputCol(\"image.*\")\n",
" .setOutputCol(\"value\")\n",
" .setModelUrl(\"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18\")\n",
" .setTranslator(new SparkImageClassificationTranslator())\n",
"val outputDf = predictor.predict(df)\n",
"println(outputDf.collect().mkString(\"\\n\"))"
]
}
],
Expand All @@ -222,7 +288,7 @@
"mimetype": "text/x-scala",
"name": "scala",
"nbconvert_exporter": "script",
"version": "2.12.12"
"version": "2.12.11"
}
},
"nbformat": 4,
Expand Down
30 changes: 22 additions & 8 deletions apache-spark/spark3.0/image-classification/build.gradle
Original file line number Diff line number Diff line change
@@ -1,32 +1,46 @@
plugins {
id 'scala'
id 'application'
id 'com.github.johnrengelman.shadow' version '7.0.0'
}

group "com.example"
version "1.0-SNAPSHOT"

repositories {
mavenCentral()
mavenLocal()
}

dependencies {
implementation platform("ai.djl:bom:0.12.0")
implementation "org.apache.spark:spark-sql_2.12:3.0.1"
implementation "org.apache.spark:spark-mllib_2.12:3.0.1"
implementation "ai.djl:api"


runtimeOnly "ai.djl.pytorch:pytorch-model-zoo"
runtimeOnly "ai.djl.pytorch:pytorch-native-auto"
implementation "org.apache.spark:spark-core_2.12:${spark_version}"
implementation "org.apache.spark:spark-sql_2.12:${spark_version}"
implementation "org.apache.spark:spark-mllib_2.12:${spark_version}"
implementation "ai.djl:api:${djl_version}"
implementation "ai.djl.spark:spark:0.20.0-SNAPSHOT"

runtimeOnly "ai.djl.pytorch:pytorch-engine:${djl_version}"
runtimeOnly "ai.djl.pytorch:pytorch-native-cpu-precxx11:1.12.1"
}

compileScala {
scalaCompileOptions.setAdditionalParameters(["-target:jvm-1.8"])
}

application {
sourceCompatibility = JavaVersion.VERSION_1_8
targetCompatibility = JavaVersion.VERSION_1_8

getMainClass().set(System.getProperty("main", "com.examples.ImageClassificationExample"))
}

shadowJar {
zip64 true
mergeServiceFiles()
exclude "META-INF/*.SF"
exclude 'META-INF/*.DSA'
exclude 'META-INF/*.RSA'
exclude "LICENSE*"
}

tasks.distTar.enabled = false
1 change: 1 addition & 0 deletions apache-spark/spark3.0/image-classification/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ scalacOptions += "-target:jvm-1.8"

resolvers += Resolver.jcenterRepo

libraryDependencies += "org.apache.spark" %% "spark-core" % "3.0.1"
libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.0.1"
libraryDependencies += "org.apache.spark" %% "spark-mllib" % "3.0.1"
libraryDependencies += "ai.djl" % "api" % "0.12.0"
Expand Down
1 change: 1 addition & 0 deletions apache-spark/spark3.0/image-classification/gradle
1 change: 1 addition & 0 deletions apache-spark/spark3.0/image-classification/gradlew.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
call ..\..\..\gradlew.bat
Loading

0 comments on commit 9beeac7

Please sign in to comment.