Skip to content

Commit

Permalink
Integrating ONNX runtime (ORT) in Spark NLP 5.0.0 🎉 (#13857)
Browse files Browse the repository at this point in the history
* Add ONNX Runtime to the dependencies

* Add both CPU and GPU coordinates for onnxruntime

* Implement OnnxSerializeModel

* Implement OnnxWrapper

* Update error message for loading external models

* Add support for ONNX to BertEmbeddings annotator

* Add support for ONNX to BERT backend

* Add support for ONNX to DeBERTa

* Implement ONNX in DeBERTa backend

* Adapt Bert For sentence embeddings with the new backend

* Update unit test for BERT (temp)

* Update unit test for DeBERTa (temp)

* Update onnxruntime and google cloud dependencies

* Seems Apple Silicon and Aarch64 are supported in onnxruntime

* Cleaning up

* Remove bad merge

* Update BERT unit test

* Add fix me to the try

* Making withSafeOnnxModelLoader thread safe

* update onnxruntime

* Revert back to normal unit tests for now [ski ptest]

* Added ADT for ModelEngine (#13862)

Co-authored-by: Stefano Lori <s.lori@izicap.com>

* Optimize ONNX on CPU

* refactor

* Add ONNX support to DistilBERT

* Add support for ONNX in RoBERTa

* Fix the bad serialization on write

* Fix using the wrong object

---------

Co-authored-by: Stefano Lori <wolliq@users.noreply.github.com>
Co-authored-by: Stefano Lori <s.lori@izicap.com>
  • Loading branch information
3 people authored Jul 1, 2023
1 parent ae688ab commit c2dd80d
Show file tree
Hide file tree
Showing 66 changed files with 1,199 additions and 455 deletions.
11 changes: 11 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ val tensorflowDependencies: Seq[sbt.ModuleID] =
else
Seq(tensorflowCPU)

val onnxDependencies: Seq[sbt.ModuleID] =
if (is_gpu.equals("true"))
Seq(onnxGPU)
else if (is_silicon.equals("true"))
Seq(onnxCPU)
else if (is_aarch64.equals("true"))
Seq(onnxCPU)
else
Seq(onnxCPU)

lazy val mavenProps = settingKey[Unit]("workaround for Maven properties")

lazy val root = (project in file("."))
Expand All @@ -175,6 +185,7 @@ lazy val root = (project in file("."))
testDependencies ++
utilDependencies ++
tensorflowDependencies ++
onnxDependencies ++
typedDependencyParserDependencies,
// TODO potentially improve this?
mavenProps := {
Expand Down
5 changes: 4 additions & 1 deletion project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ object Dependencies {
val tensorflowM1 = "com.johnsnowlabs.nlp" %% "tensorflow-m1" % tensorflowVersion
val tensorflowLinuxAarch64 = "com.johnsnowlabs.nlp" %% "tensorflow-aarch64" % tensorflowVersion

val gcpStorageVersion = "2.16.0"
val onnxRuntimeVersion = "1.15.0"
val onnxCPU = "com.microsoft.onnxruntime" % "onnxruntime" % onnxRuntimeVersion
val onnxGPU = "com.microsoft.onnxruntime" % "onnxruntime_gpu" % onnxRuntimeVersion
val gcpStorageVersion = "2.20.1"
val gcpStorage = "com.google.cloud" % "google-cloud-storage" % gcpStorageVersion

/** ------- Dependencies end ------- */
Expand Down
280 changes: 189 additions & 91 deletions src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.OnnxTensor
import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.ModelArch
import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}

Expand All @@ -35,6 +37,8 @@ import scala.collection.JavaConverters._
*
* @param tensorflowWrapper
* Bert Model wrapper with TensorFlow Wrapper
* @param onnxWrapper
* Bert Model wrapper with ONNX Wrapper
* @param sentenceStartTokenId
* Id of sentence start Token
* @param sentenceEndTokenId
Expand All @@ -47,7 +51,8 @@ import scala.collection.JavaConverters._
* Source: [[https://github.com/google-research/bert]]
*/
private[johnsnowlabs] class Bert(
val tensorflowWrapper: TensorflowWrapper,
val tensorflowWrapper: Option[TensorflowWrapper],
val onnxWrapper: Option[OnnxWrapper],
sentenceStartTokenId: Int,
sentenceEndTokenId: Int,
configProtoBytes: Option[Array[Byte]] = None,
Expand All @@ -57,6 +62,10 @@ private[johnsnowlabs] class Bert(
extends Serializable {

val _tfBertSignatures: Map[String, String] = signatures.getOrElse(ModelSignatureManager.apply())
val detectedEngine: String =
if (tensorflowWrapper.isDefined) TensorFlow.name
else if (onnxWrapper.isDefined) ONNX.name
else TensorFlow.name

private def sessionWarmup(): Unit = {
val dummyInput =
Expand All @@ -74,51 +83,99 @@ private[johnsnowlabs] class Bert(
sessionWarmup()

def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {

val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max
val batchLength = batch.length

val tensors = new TensorResources()

val (tokenTensors, maskTensors, segmentTensors) =
PrepareEmbeddings.prepareBatchTensorsWithSegment(
tensors = tensors,
batch = batch,
maxSentenceLength = maxSentenceLength,
batchLength = batchLength)

val runner = tensorflowWrapper
.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
.runner

runner
.feed(
_tfBertSignatures.getOrElse(
ModelSignatureConstants.InputIdsV1.key,
"missing_input_id_key"),
tokenTensors)
.feed(
_tfBertSignatures
.getOrElse(ModelSignatureConstants.AttentionMaskV1.key, "missing_input_mask_key"),
maskTensors)
.feed(
_tfBertSignatures
.getOrElse(ModelSignatureConstants.TokenTypeIdsV1.key, "missing_segment_ids_key"),
segmentTensors)
.fetch(_tfBertSignatures
.getOrElse(ModelSignatureConstants.LastHiddenStateV1.key, "missing_sequence_output_key"))

val outs = runner.run().asScala
val embeddings = TensorResources.extractFloats(outs.head)
val embeddings = detectedEngine match {

case ONNX.name =>
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val segmentTensors =
OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray)

val inputs =
Map(
"input_ids" -> tokenTensors,
"attention_mask" -> maskTensors,
"token_type_ids" -> segmentTensors).asJava

// TODO: A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled.
try {
val results = runner.run(inputs)
try {
val embeddings = results
.get("last_hidden_state")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()
segmentTensors.close()
// runner.close()
// env.close()
//
embeddings
} finally if (results != null) results.close()
}
case _ =>
val tensors = new TensorResources()

val (tokenTensors, maskTensors, segmentTensors) =
PrepareEmbeddings.prepareBatchTensorsWithSegment(
tensors,
batch,
maxSentenceLength,
batchLength)

val runner = tensorflowWrapper.get
.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
.runner

runner
.feed(
_tfBertSignatures.getOrElse(
ModelSignatureConstants.InputIdsV1.key,
"missing_input_id_key"),
tokenTensors)
.feed(
_tfBertSignatures
.getOrElse(ModelSignatureConstants.AttentionMaskV1.key, "missing_input_mask_key"),
maskTensors)
.feed(
_tfBertSignatures
.getOrElse(ModelSignatureConstants.TokenTypeIdsV1.key, "missing_segment_ids_key"),
segmentTensors)
.fetch(
_tfBertSignatures
.getOrElse(
ModelSignatureConstants.LastHiddenStateV1.key,
"missing_sequence_output_key"))

val outs = runner.run().asScala
val embeddings = TensorResources.extractFloats(outs.head)

tokenTensors.close()
maskTensors.close()
segmentTensors.close()
tensors.clearSession(outs)
tensors.clearTensors()

embeddings

tokenTensors.close()
maskTensors.close()
segmentTensors.close()
tensors.clearSession(outs)
tensors.clearTensors()
}

PrepareEmbeddings.prepareBatchWordEmbeddings(
batch,
Expand All @@ -133,48 +190,91 @@ private[johnsnowlabs] class Bert(
val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max
val batchLength = batch.length

val tensors = new TensorResources()

val (tokenTensors, maskTensors, segmentTensors) =
PrepareEmbeddings.prepareBatchTensorsWithSegment(
tensors = tensors,
batch = batch,
maxSentenceLength = maxSentenceLength,
batchLength = batchLength)

val runner = tensorflowWrapper
.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
.runner

runner
.feed(
_tfBertSignatures.getOrElse(
ModelSignatureConstants.InputIdsV1.key,
"missing_input_id_key"),
tokenTensors)
.feed(
_tfBertSignatures
.getOrElse(ModelSignatureConstants.AttentionMaskV1.key, "missing_input_mask_key"),
maskTensors)
.feed(
_tfBertSignatures
.getOrElse(ModelSignatureConstants.TokenTypeIdsV1.key, "missing_segment_ids_key"),
segmentTensors)
.fetch(_tfBertSignatures
.getOrElse(ModelSignatureConstants.PoolerOutput.key, "missing_pooled_output_key"))

val outs = runner.run().asScala
val embeddings = TensorResources.extractFloats(outs.head)

tokenTensors.close()
maskTensors.close()
segmentTensors.close()
tensors.clearSession(outs)
tensors.clearTensors()
val embeddings = detectedEngine match {
case ONNX.name =>
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val segmentTensors =
OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray)

val inputs =
Map(
"input_ids" -> tokenTensors,
"attention_mask" -> maskTensors,
"token_type_ids" -> segmentTensors).asJava

try {
val results = runner.run(inputs)
try {
val embeddings = results
.get("pooler_output")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()
segmentTensors.close()
// runner.close()
// env.close()
//
embeddings
} finally if (results != null) results.close()
}
case _ =>
val tensors = new TensorResources()

val (tokenTensors, maskTensors, segmentTensors) =
PrepareEmbeddings.prepareBatchTensorsWithSegment(
tensors,
batch,
maxSentenceLength,
batchLength)

val runner = tensorflowWrapper.get
.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
.runner

runner
.feed(
_tfBertSignatures.getOrElse(
ModelSignatureConstants.InputIdsV1.key,
"missing_input_id_key"),
tokenTensors)
.feed(
_tfBertSignatures
.getOrElse(ModelSignatureConstants.AttentionMaskV1.key, "missing_input_mask_key"),
maskTensors)
.feed(
_tfBertSignatures
.getOrElse(ModelSignatureConstants.TokenTypeIdsV1.key, "missing_segment_ids_key"),
segmentTensors)
.fetch(_tfBertSignatures
.getOrElse(ModelSignatureConstants.PoolerOutput.key, "missing_pooled_output_key"))

val outs = runner.run().asScala
val embeddings = TensorResources.extractFloats(outs.head)

tokenTensors.close()
maskTensors.close()
segmentTensors.close()
tensors.clearSession(outs)
tensors.clearTensors()

embeddings

}
val dim = embeddings.length / batchLength
embeddings.grouped(dim).toArray

Expand All @@ -200,17 +300,17 @@ private[johnsnowlabs] class Bert(
segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0L))
}

val runner = tensorflowWrapper
val tokenTensors = tensors.createLongBufferTensor(shape, tokenBuffers)
val maskTensors = tensors.createLongBufferTensor(shape, maskBuffers)
val segmentTensors = tensors.createLongBufferTensor(shape, segmentBuffers)

val runner = tensorflowWrapper.get
.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
.runner

val tokenTensors = tensors.createLongBufferTensor(shape, tokenBuffers)
val maskTensors = tensors.createLongBufferTensor(shape, maskBuffers)
val segmentTensors = tensors.createLongBufferTensor(shape, segmentBuffers)

runner
.feed(
_tfBertSignatures.getOrElse(
Expand Down Expand Up @@ -257,7 +357,6 @@ private[johnsnowlabs] class Bert(
maxSentenceLength,
sentenceStartTokenId,
sentenceEndTokenId)

val vectors = tag(encoded)

/*Combine tokens and calculated embeddings*/
Expand Down Expand Up @@ -324,7 +423,6 @@ private[johnsnowlabs] class Bert(
maxSentenceLength,
sentenceStartTokenId,
sentenceEndTokenId)

val embeddings = if (isLong) {
tagSequenceSBert(encoded)
} else {
Expand Down
Loading

0 comments on commit c2dd80d

Please sign in to comment.