-
Notifications
You must be signed in to change notification settings - Fork 128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Spark extension example #272
Conversation
7eb652f
to
3509c8b
Compare
private lazy val criteria = Criteria.builder | ||
.setTypes(classOf[Row], classOf[Classifications]) | ||
.optModelUrls(url) | ||
.optTranslator(new SparkImageClassificationTranslator()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this translator serializable? If not, how do we make sure SparkImageClassificationTranslator being used in each executor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it should be serializable
@SerialVersionUID(123456789L) | ||
class SparkModel(val url : String) extends Serializable { | ||
|
||
private lazy val criteria = Criteria.builder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
given we are passing this by SparkPredictor, this doesn't need to be lazy right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
.optTranslator(new SparkImageClassificationTranslator()) | ||
.optProgress(new ProgressBar) | ||
.build() | ||
private lazy val model = ModelZoo.loadModel(criteria) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same applies here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
final val outputCol = new Param[String](this, "outputCol", "The output column") | ||
final val modelUrl = new Param[String](this, "modelUrl", "The model URL") | ||
|
||
def setInputCol(value: String): this.type = set(inputCol, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use setter? Can we use builder pattern to create it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or Scala can use case class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a common pattern for class that extends Transformer, see https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala#L550-L567
println(result.collect().mkString("\n")) | ||
println(df.select("image.origin", "image.width", "image.height").show(truncate = false)) | ||
|
||
val predictor = new SparkPredictor() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how does customer passing their own translator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed
df82490
to
39ac624
Compare
0d148fd
to
9beeac7
Compare
" image = image.flip(2)\n", | ||
" pipeline.transform(new NDList(image))\n", | ||
" }\n", | ||
"// Translator: a class used to do preprocessing and post processing\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use spark extension in the jupyter notebook?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
|
||
runtimeOnly "ai.djl.pytorch:pytorch-model-zoo" | ||
runtimeOnly "ai.djl.pytorch:pytorch-native-auto" | ||
implementation "org.apache.spark:spark-core_2.12:${spark_version}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
spark extension should have covered spack dependencies
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
implementation "org.apache.spark:spark-core_2.12:${spark_version}" | ||
implementation "org.apache.spark:spark-sql_2.12:${spark_version}" | ||
implementation "org.apache.spark:spark-mllib_2.12:${spark_version}" | ||
implementation "ai.djl:api:${djl_version}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use bom
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
} | ||
|
||
compileScala { | ||
scalaCompileOptions.setAdditionalParameters(["-target:jvm-1.8"]) | ||
} | ||
|
||
application { | ||
sourceCompatibility = JavaVersion.VERSION_1_8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use JDK 11?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the java version on EMR is 8 now.
@@ -1,32 +1,46 @@ | |||
plugins { | |||
id 'scala' | |||
id 'application' | |||
id 'com.github.johnrengelman.shadow' version '7.0.0' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
id 'com.github.johnrengelman.shadow' version '7.0.0' | |
id 'com.github.johnrengelman.shadow' version '7.1.2' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
@@ -7,6 +7,7 @@ scalacOptions += "-target:jvm-1.8" | |||
|
|||
resolvers += Resolver.jcenterRepo | |||
|
|||
libraryDependencies += "org.apache.spark" %% "spark-core" % "3.0.1" | |||
libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.0.1" | |||
libraryDependencies += "org.apache.spark" %% "spark-mllib" % "3.0.1" | |||
libraryDependencies += "ai.djl" % "api" % "0.12.0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we upgrade to latest version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
gradle.properties
Outdated
@@ -14,3 +14,5 @@ systemProp.org.gradle.internal.publish.checksums.insecure=true | |||
commons_cli_version=1.5.0 | |||
log4j_slf4j_version=2.18.0 | |||
rapis_version=22.04.0 | |||
spark_version=3.2.2 | |||
djl_version=0.19.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be 0.20.0-SNAPSHOT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
cb1f42a
to
3e0fd3f
Compare
Spark extension poc:
User can call
val outputDf = transformer.transform(df)
to run inference.