From c2dd80db4b88aa5a582dd36fd7297ae94431ad94 Mon Sep 17 00:00:00 2001 From: Maziyar Panahi Date: Sat, 1 Jul 2023 15:09:47 +0200 Subject: [PATCH] =?UTF-8?q?Integrating=20ONNX=20runtime=20(ORT)=20in=20Spa?= =?UTF-8?q?rk=20NLP=205.0.0=20=F0=9F=8E=89=20(#13857)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add ONNX Runtime to the dependencies * Add both CPU and GPU coordinates for onnxruntime * Implement OnnxSerializeModel * Implement OnnxWrapper * Update error message for loading external models * Add support for ONNX to BertEmbeddings annotator * Add support for ONNX to BERT backend * Add support for ONNX to DeBERTa * Implement ONNX in DeBERTa backend * Adapt Bert For sentence embeddings with the new backend * Update unit test for BERT (temp) * Update unit test for DeBERTa (temp) * Update onnxruntime and google cloud dependencies * Seems Apple Silicon and Aarch64 are supported in onnxruntime * Cleaning up * Remove bad merge * Update BERT unit test * Add fix me to the try * Making withSafeOnnxModelLoader thread safe * update onnxruntime * Revert back to normal unit tests for now [ski ptest] * Added ADT for ModelEngine (#13862) Co-authored-by: Stefano Lori * Optimize ONNX on CPU * refactor * Add ONNX support to DistilBERT * Add support for ONNX in RoBERTa * Fix the bad serialization on write * Fix using the wrong object --------- Co-authored-by: Stefano Lori Co-authored-by: Stefano Lori --- build.sbt | 11 + project/Dependencies.scala | 5 +- .../scala/com/johnsnowlabs/ml/ai/Bert.scala | 280 ++++++++++++------ .../com/johnsnowlabs/ml/ai/DeBerta.scala | 143 ++++++--- .../com/johnsnowlabs/ml/ai/DistilBert.scala | 123 +++++--- .../com/johnsnowlabs/ml/ai/RoBerta.scala | 124 +++++--- .../ml/onnx/OnnxSerializeModel.scala | 98 ++++++ .../johnsnowlabs/ml/onnx/OnnxWrapper.scala | 162 ++++++++++ .../ml/util/LoadExternalModel.scala | 28 +- .../johnsnowlabs/ml/util/ModelEngine.scala | 31 +- .../com/johnsnowlabs/nlp/HasEngine.scala | 4 +- .../nlp/annotators/audio/HubertForCTC.scala | 4 +- .../nlp/annotators/audio/Wav2Vec2ForCTC.scala | 4 +- .../dl/AlbertForQuestionAnswering.scala | 4 +- .../dl/AlbertForSequenceClassification.scala | 4 +- .../dl/AlbertForTokenClassification.scala | 4 +- .../dl/BertForQuestionAnswering.scala | 4 +- .../dl/BertForSequenceClassification.scala | 4 +- .../dl/BertForTokenClassification.scala | 4 +- .../dl/BertForZeroShotClassification.scala | 8 +- .../dl/CamemBertForQuestionAnswering.scala | 4 +- .../CamemBertForSequenceClassification.scala | 4 +- .../dl/CamemBertForTokenClassification.scala | 4 +- .../dl/DeBertaForQuestionAnswering.scala | 4 +- .../dl/DeBertaForSequenceClassification.scala | 4 +- .../dl/DeBertaForTokenClassification.scala | 4 +- .../dl/DistilBertForQuestionAnswering.scala | 4 +- .../DistilBertForSequenceClassification.scala | 4 +- .../dl/DistilBertForTokenClassification.scala | 4 +- .../DistilBertForZeroShotClassification.scala | 4 +- .../dl/LongformerForQuestionAnswering.scala | 4 +- .../LongformerForSequenceClassification.scala | 4 +- .../dl/LongformerForTokenClassification.scala | 4 +- .../dl/RoBertaForQuestionAnswering.scala | 4 +- .../dl/RoBertaForSequenceClassification.scala | 4 +- .../dl/RoBertaForTokenClassification.scala | 4 +- .../dl/RoBertaForZeroShotClassification.scala | 4 +- .../dl/TapasForQuestionAnswering.scala | 4 +- .../dl/XlmRoBertaForQuestionAnswering.scala | 4 +- .../XlmRoBertaForSequenceClassification.scala | 4 +- .../dl/XlmRoBertaForTokenClassification.scala | 4 +- .../dl/XlnetForSequenceClassification.scala | 4 +- .../dl/XlnetForTokenClassification.scala | 4 +- .../annotators/coref/SpanBertCorefModel.scala | 4 +- .../cv/ConvNextForImageClassification.scala | 4 +- .../cv/SwinForImageClassification.scala | 4 +- .../cv/ViTForImageClassification.scala | 6 +- .../annotators/ld/dl/LanguageDetectorDL.scala | 4 +- .../annotators/seq2seq/BartTransformer.scala | 4 +- .../annotators/seq2seq/GPT2Transformer.scala | 4 +- .../seq2seq/MarianTransformer.scala | 4 +- .../annotators/seq2seq/T5Transformer.scala | 4 +- .../nlp/embeddings/AlbertEmbeddings.scala | 4 +- .../nlp/embeddings/BertEmbeddings.scala | 91 ++++-- .../embeddings/BertSentenceEmbeddings.scala | 79 +++-- .../nlp/embeddings/CamemBertEmbeddings.scala | 4 +- .../nlp/embeddings/DeBertaEmbeddings.scala | 88 ++++-- .../nlp/embeddings/DistilBertEmbeddings.scala | 93 ++++-- .../nlp/embeddings/ElmoEmbeddings.scala | 4 +- .../nlp/embeddings/LongformerEmbeddings.scala | 15 +- .../nlp/embeddings/RoBertaEmbeddings.scala | 65 +++- .../RoBertaSentenceEmbeddings.scala | 16 +- .../embeddings/UniversalSentenceEncoder.scala | 4 +- .../nlp/embeddings/XlmRoBertaEmbeddings.scala | 4 +- .../XlmRoBertaSentenceEmbeddings.scala | 4 +- .../nlp/embeddings/XlnetEmbeddings.scala | 4 +- 66 files changed, 1199 insertions(+), 455 deletions(-) create mode 100644 src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala create mode 100644 src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala diff --git a/build.sbt b/build.sbt index 7821eecfcc413b..c9e37ecd4a699e 100644 --- a/build.sbt +++ b/build.sbt @@ -165,6 +165,16 @@ val tensorflowDependencies: Seq[sbt.ModuleID] = else Seq(tensorflowCPU) +val onnxDependencies: Seq[sbt.ModuleID] = + if (is_gpu.equals("true")) + Seq(onnxGPU) + else if (is_silicon.equals("true")) + Seq(onnxCPU) + else if (is_aarch64.equals("true")) + Seq(onnxCPU) + else + Seq(onnxCPU) + lazy val mavenProps = settingKey[Unit]("workaround for Maven properties") lazy val root = (project in file(".")) @@ -175,6 +185,7 @@ lazy val root = (project in file(".")) testDependencies ++ utilDependencies ++ tensorflowDependencies ++ + onnxDependencies ++ typedDependencyParserDependencies, // TODO potentially improve this? mavenProps := { diff --git a/project/Dependencies.scala b/project/Dependencies.scala index f36d7f528d3c54..99b725d3b51c76 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -108,7 +108,10 @@ object Dependencies { val tensorflowM1 = "com.johnsnowlabs.nlp" %% "tensorflow-m1" % tensorflowVersion val tensorflowLinuxAarch64 = "com.johnsnowlabs.nlp" %% "tensorflow-aarch64" % tensorflowVersion - val gcpStorageVersion = "2.16.0" + val onnxRuntimeVersion = "1.15.0" + val onnxCPU = "com.microsoft.onnxruntime" % "onnxruntime" % onnxRuntimeVersion + val onnxGPU = "com.microsoft.onnxruntime" % "onnxruntime_gpu" % onnxRuntimeVersion + val gcpStorageVersion = "2.20.1" val gcpStorage = "com.google.cloud" % "google-cloud-storage" % gcpStorageVersion /** ------- Dependencies end ------- */ diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala index ed704f8c3ef476..c291b9f23c549d 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala @@ -16,10 +16,12 @@ package com.johnsnowlabs.ml.ai +import ai.onnxruntime.OnnxTensor import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings +import com.johnsnowlabs.ml.onnx.OnnxWrapper import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} -import com.johnsnowlabs.ml.util.ModelArch +import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} @@ -35,6 +37,8 @@ import scala.collection.JavaConverters._ * * @param tensorflowWrapper * Bert Model wrapper with TensorFlow Wrapper + * @param onnxWrapper + * Bert Model wrapper with ONNX Wrapper * @param sentenceStartTokenId * Id of sentence start Token * @param sentenceEndTokenId @@ -47,7 +51,8 @@ import scala.collection.JavaConverters._ * Source: [[https://github.com/google-research/bert]] */ private[johnsnowlabs] class Bert( - val tensorflowWrapper: TensorflowWrapper, + val tensorflowWrapper: Option[TensorflowWrapper], + val onnxWrapper: Option[OnnxWrapper], sentenceStartTokenId: Int, sentenceEndTokenId: Int, configProtoBytes: Option[Array[Byte]] = None, @@ -57,6 +62,10 @@ private[johnsnowlabs] class Bert( extends Serializable { val _tfBertSignatures: Map[String, String] = signatures.getOrElse(ModelSignatureManager.apply()) + val detectedEngine: String = + if (tensorflowWrapper.isDefined) TensorFlow.name + else if (onnxWrapper.isDefined) ONNX.name + else TensorFlow.name private def sessionWarmup(): Unit = { val dummyInput = @@ -74,51 +83,99 @@ private[johnsnowlabs] class Bert( sessionWarmup() def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { - val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max val batchLength = batch.length - val tensors = new TensorResources() - - val (tokenTensors, maskTensors, segmentTensors) = - PrepareEmbeddings.prepareBatchTensorsWithSegment( - tensors = tensors, - batch = batch, - maxSentenceLength = maxSentenceLength, - batchLength = batchLength) - - val runner = tensorflowWrapper - .getTFSessionWithSignature( - configProtoBytes = configProtoBytes, - savedSignatures = signatures, - initAllTables = false) - .runner - - runner - .feed( - _tfBertSignatures.getOrElse( - ModelSignatureConstants.InputIdsV1.key, - "missing_input_id_key"), - tokenTensors) - .feed( - _tfBertSignatures - .getOrElse(ModelSignatureConstants.AttentionMaskV1.key, "missing_input_mask_key"), - maskTensors) - .feed( - _tfBertSignatures - .getOrElse(ModelSignatureConstants.TokenTypeIdsV1.key, "missing_segment_ids_key"), - segmentTensors) - .fetch(_tfBertSignatures - .getOrElse(ModelSignatureConstants.LastHiddenStateV1.key, "missing_sequence_output_key")) - - val outs = runner.run().asScala - val embeddings = TensorResources.extractFloats(outs.head) + val embeddings = detectedEngine match { + + case ONNX.name => + // [nb of encoded sentences , maxSentenceLength] + val (runner, env) = onnxWrapper.get.getSession() + + val tokenTensors = + OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) + val maskTensors = + OnnxTensor.createTensor( + env, + batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray) + + val segmentTensors = + OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray) + + val inputs = + Map( + "input_ids" -> tokenTensors, + "attention_mask" -> maskTensors, + "token_type_ids" -> segmentTensors).asJava + + // TODO: A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled. + try { + val results = runner.run(inputs) + try { + val embeddings = results + .get("last_hidden_state") + .get() + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + tokenTensors.close() + maskTensors.close() + segmentTensors.close() + // runner.close() + // env.close() + // + embeddings + } finally if (results != null) results.close() + } + case _ => + val tensors = new TensorResources() + + val (tokenTensors, maskTensors, segmentTensors) = + PrepareEmbeddings.prepareBatchTensorsWithSegment( + tensors, + batch, + maxSentenceLength, + batchLength) + + val runner = tensorflowWrapper.get + .getTFSessionWithSignature( + configProtoBytes = configProtoBytes, + savedSignatures = signatures, + initAllTables = false) + .runner + + runner + .feed( + _tfBertSignatures.getOrElse( + ModelSignatureConstants.InputIdsV1.key, + "missing_input_id_key"), + tokenTensors) + .feed( + _tfBertSignatures + .getOrElse(ModelSignatureConstants.AttentionMaskV1.key, "missing_input_mask_key"), + maskTensors) + .feed( + _tfBertSignatures + .getOrElse(ModelSignatureConstants.TokenTypeIdsV1.key, "missing_segment_ids_key"), + segmentTensors) + .fetch( + _tfBertSignatures + .getOrElse( + ModelSignatureConstants.LastHiddenStateV1.key, + "missing_sequence_output_key")) + + val outs = runner.run().asScala + val embeddings = TensorResources.extractFloats(outs.head) + + tokenTensors.close() + maskTensors.close() + segmentTensors.close() + tensors.clearSession(outs) + tensors.clearTensors() + + embeddings - tokenTensors.close() - maskTensors.close() - segmentTensors.close() - tensors.clearSession(outs) - tensors.clearTensors() + } PrepareEmbeddings.prepareBatchWordEmbeddings( batch, @@ -133,48 +190,91 @@ private[johnsnowlabs] class Bert( val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max val batchLength = batch.length - val tensors = new TensorResources() - - val (tokenTensors, maskTensors, segmentTensors) = - PrepareEmbeddings.prepareBatchTensorsWithSegment( - tensors = tensors, - batch = batch, - maxSentenceLength = maxSentenceLength, - batchLength = batchLength) - - val runner = tensorflowWrapper - .getTFSessionWithSignature( - configProtoBytes = configProtoBytes, - savedSignatures = signatures, - initAllTables = false) - .runner - - runner - .feed( - _tfBertSignatures.getOrElse( - ModelSignatureConstants.InputIdsV1.key, - "missing_input_id_key"), - tokenTensors) - .feed( - _tfBertSignatures - .getOrElse(ModelSignatureConstants.AttentionMaskV1.key, "missing_input_mask_key"), - maskTensors) - .feed( - _tfBertSignatures - .getOrElse(ModelSignatureConstants.TokenTypeIdsV1.key, "missing_segment_ids_key"), - segmentTensors) - .fetch(_tfBertSignatures - .getOrElse(ModelSignatureConstants.PoolerOutput.key, "missing_pooled_output_key")) - - val outs = runner.run().asScala - val embeddings = TensorResources.extractFloats(outs.head) - - tokenTensors.close() - maskTensors.close() - segmentTensors.close() - tensors.clearSession(outs) - tensors.clearTensors() + val embeddings = detectedEngine match { + case ONNX.name => + // [nb of encoded sentences , maxSentenceLength] + val (runner, env) = onnxWrapper.get.getSession() + + val tokenTensors = + OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) + val maskTensors = + OnnxTensor.createTensor( + env, + batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray) + + val segmentTensors = + OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray) + + val inputs = + Map( + "input_ids" -> tokenTensors, + "attention_mask" -> maskTensors, + "token_type_ids" -> segmentTensors).asJava + + try { + val results = runner.run(inputs) + try { + val embeddings = results + .get("pooler_output") + .get() + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + tokenTensors.close() + maskTensors.close() + segmentTensors.close() + // runner.close() + // env.close() + // + embeddings + } finally if (results != null) results.close() + } + case _ => + val tensors = new TensorResources() + + val (tokenTensors, maskTensors, segmentTensors) = + PrepareEmbeddings.prepareBatchTensorsWithSegment( + tensors, + batch, + maxSentenceLength, + batchLength) + + val runner = tensorflowWrapper.get + .getTFSessionWithSignature( + configProtoBytes = configProtoBytes, + savedSignatures = signatures, + initAllTables = false) + .runner + + runner + .feed( + _tfBertSignatures.getOrElse( + ModelSignatureConstants.InputIdsV1.key, + "missing_input_id_key"), + tokenTensors) + .feed( + _tfBertSignatures + .getOrElse(ModelSignatureConstants.AttentionMaskV1.key, "missing_input_mask_key"), + maskTensors) + .feed( + _tfBertSignatures + .getOrElse(ModelSignatureConstants.TokenTypeIdsV1.key, "missing_segment_ids_key"), + segmentTensors) + .fetch(_tfBertSignatures + .getOrElse(ModelSignatureConstants.PoolerOutput.key, "missing_pooled_output_key")) + + val outs = runner.run().asScala + val embeddings = TensorResources.extractFloats(outs.head) + + tokenTensors.close() + maskTensors.close() + segmentTensors.close() + tensors.clearSession(outs) + tensors.clearTensors() + + embeddings + } val dim = embeddings.length / batchLength embeddings.grouped(dim).toArray @@ -200,17 +300,17 @@ private[johnsnowlabs] class Bert( segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0L)) } - val runner = tensorflowWrapper + val tokenTensors = tensors.createLongBufferTensor(shape, tokenBuffers) + val maskTensors = tensors.createLongBufferTensor(shape, maskBuffers) + val segmentTensors = tensors.createLongBufferTensor(shape, segmentBuffers) + + val runner = tensorflowWrapper.get .getTFSessionWithSignature( configProtoBytes = configProtoBytes, savedSignatures = signatures, initAllTables = false) .runner - val tokenTensors = tensors.createLongBufferTensor(shape, tokenBuffers) - val maskTensors = tensors.createLongBufferTensor(shape, maskBuffers) - val segmentTensors = tensors.createLongBufferTensor(shape, segmentBuffers) - runner .feed( _tfBertSignatures.getOrElse( @@ -257,7 +357,6 @@ private[johnsnowlabs] class Bert( maxSentenceLength, sentenceStartTokenId, sentenceEndTokenId) - val vectors = tag(encoded) /*Combine tokens and calculated embeddings*/ @@ -324,7 +423,6 @@ private[johnsnowlabs] class Bert( maxSentenceLength, sentenceStartTokenId, sentenceEndTokenId) - val embeddings = if (isLong) { tagSequenceSBert(encoded) } else { diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala b/src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala index 2f2638c5acd65e..bbf4ac83b1862b 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala @@ -16,10 +16,13 @@ package com.johnsnowlabs.ml.ai +import ai.onnxruntime.OnnxTensor import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings -import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder} +import com.johnsnowlabs.ml.onnx.OnnxWrapper +import com.johnsnowlabs.ml.tensorflow.sentencepiece._ import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ import scala.collection.JavaConverters._ @@ -34,7 +37,8 @@ import scala.collection.JavaConverters._ * Configuration for TensorFlow session */ class DeBerta( - val tensorflowWrapper: TensorflowWrapper, + val tensorflowWrapper: Option[TensorflowWrapper], + val onnxWrapper: Option[OnnxWrapper], val spp: SentencePieceWrapper, batchSize: Int, configProtoBytes: Option[Array[Byte]] = None, @@ -44,6 +48,11 @@ class DeBerta( val _tfDeBertaSignatures: Map[String, String] = signatures.getOrElse(ModelSignatureManager.apply()) + val detectedEngine: String = + if (tensorflowWrapper.isDefined) TensorFlow.name + else if (onnxWrapper.isDefined) ONNX.name + else TensorFlow.name + // keys representing the input and output tensors of the DeBERTa model private val SentenceStartTokenId = spp.getSppModel.pieceToId("[CLS]") private val SentenceEndTokenId = spp.getSppModel.pieceToId("[SEP]") @@ -51,52 +60,96 @@ class DeBerta( private val SentencePieceDelimiterId = spp.getSppModel.pieceToId("▁") def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { - + /* Actual size of each sentence to skip padding in the TF model */ val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max val batchLength = batch.length - val tensors = new TensorResources() - - val (tokenTensors, maskTensors, segmentTensors) = - PrepareEmbeddings.prepareBatchTensorsWithSegment( - tensors = tensors, - batch = batch, - maxSentenceLength = maxSentenceLength, - batchLength = batchLength, - sentencePadTokenId = SentencePadTokenId) - - val runner = tensorflowWrapper - .getTFSessionWithSignature( - configProtoBytes = configProtoBytes, - savedSignatures = signatures, - initAllTables = false) - .runner - - runner - .feed( - _tfDeBertaSignatures.getOrElse( - ModelSignatureConstants.InputIds.key, - "missing_input_id_key"), - tokenTensors) - .feed( - _tfDeBertaSignatures - .getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"), - maskTensors) - .feed( - _tfDeBertaSignatures - .getOrElse(ModelSignatureConstants.TokenTypeIds.key, "missing_segment_ids_key"), - segmentTensors) - .fetch(_tfDeBertaSignatures - .getOrElse(ModelSignatureConstants.LastHiddenState.key, "missing_sequence_output_key")) - - val outs = runner.run().asScala - val embeddings = TensorResources.extractFloats(outs.head) - - tokenTensors.close() - maskTensors.close() - segmentTensors.close() - tensors.clearSession(outs) - tensors.clearTensors() + val embeddings = detectedEngine match { + + case ONNX.name => + // [nb of encoded sentences , maxSentenceLength] + val (runner, env) = onnxWrapper.get.getSession() + + val tokenTensors = + OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) + val maskTensors = + OnnxTensor.createTensor( + env, + batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray) + + val segmentTensors = + OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray) + + val inputs = + Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava + + try { + val results = runner.run(inputs) + try { + val embeddings = results + .get("last_hidden_state") + .get() + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + tokenTensors.close() + maskTensors.close() + segmentTensors.close() + // runner.close() + // env.close() + // + embeddings + } finally if (results != null) results.close() + } + case _ => + val tensors = new TensorResources() + + val (tokenTensors, maskTensors, segmentTensors) = + PrepareEmbeddings.prepareBatchTensorsWithSegment( + tensors, + batch, + maxSentenceLength, + batchLength) + + val runner = tensorflowWrapper.get + .getTFSessionWithSignature( + configProtoBytes = configProtoBytes, + savedSignatures = signatures, + initAllTables = false) + .runner + + runner + .feed( + _tfDeBertaSignatures.getOrElse( + ModelSignatureConstants.InputIds.key, + "missing_input_id_key"), + tokenTensors) + .feed( + _tfDeBertaSignatures + .getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"), + maskTensors) + .feed( + _tfDeBertaSignatures + .getOrElse(ModelSignatureConstants.TokenTypeIds.key, "missing_segment_ids_key"), + segmentTensors) + .fetch( + _tfDeBertaSignatures + .getOrElse( + ModelSignatureConstants.LastHiddenState.key, + "missing_sequence_output_key")) + + val outs = runner.run().asScala + val embeddings = TensorResources.extractFloats(outs.head) + + tokenTensors.close() + maskTensors.close() + segmentTensors.close() + tensors.clearSession(outs) + tensors.clearTensors() + + embeddings + + } PrepareEmbeddings.prepareBatchWordEmbeddings( batch, diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/DistilBert.scala b/src/main/scala/com/johnsnowlabs/ml/ai/DistilBert.scala index 3e0d9a022a52cf..afa6a3b8bb29d5 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/DistilBert.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/DistilBert.scala @@ -16,10 +16,12 @@ package com.johnsnowlabs.ml.ai +import ai.onnxruntime.OnnxTensor import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings +import com.johnsnowlabs.ml.onnx.OnnxWrapper import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} -import com.johnsnowlabs.ml.util.ModelArch +import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} @@ -66,7 +68,8 @@ import scala.collection.JavaConverters._ * Configuration for TensorFlow session */ private[johnsnowlabs] class DistilBert( - val tensorflowWrapper: TensorflowWrapper, + val tensorflowWrapper: Option[TensorflowWrapper], + val onnxWrapper: Option[OnnxWrapper], sentenceStartTokenId: Int, sentenceEndTokenId: Int, configProtoBytes: Option[Array[Byte]] = None, @@ -75,6 +78,10 @@ private[johnsnowlabs] class DistilBert( extends Serializable { val _tfBertSignatures: Map[String, String] = signatures.getOrElse(ModelSignatureManager.apply()) + val detectedEngine: String = + if (tensorflowWrapper.isDefined) TensorFlow.name + else if (onnxWrapper.isDefined) ONNX.name + else TensorFlow.name private def sessionWarmup(): Unit = { val dummyInput = @@ -93,46 +100,88 @@ private[johnsnowlabs] class DistilBert( val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max val batchLength = batch.length - val tensors = new TensorResources() - - val (tokenTensors, maskTensors) = - PrepareEmbeddings.prepareBatchTensors( - tensors = tensors, - batch = batch, - maxSentenceLength = maxSentenceLength, - batchLength = batchLength) - - val runner = tensorflowWrapper - .getTFSessionWithSignature( - configProtoBytes = configProtoBytes, - savedSignatures = signatures, - initAllTables = false) - .runner - - runner - .feed( - _tfBertSignatures.getOrElse(ModelSignatureConstants.InputIds.key, "missing_input_id_key"), - tokenTensors) - .feed( - _tfBertSignatures - .getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"), - maskTensors) - .fetch(_tfBertSignatures - .getOrElse(ModelSignatureConstants.LastHiddenState.key, "missing_sequence_output_key")) - - val outs = runner.run().asScala - val embeddings = TensorResources.extractFloats(outs.head) - - tokenTensors.close() - maskTensors.close() - tensors.clearSession(outs) - tensors.clearTensors() + val embeddings = detectedEngine match { + case ONNX.name => + // [nb of encoded sentences , maxSentenceLength] + val (runner, env) = onnxWrapper.get.getSession() + + val tokenTensors = + OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) + val maskTensors = + OnnxTensor.createTensor( + env, + batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray) + + val inputs = + Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava + + // TODO: A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled. + try { + val results = runner.run(inputs) + try { + val embeddings = results + .get("last_hidden_state") + .get() + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + tokenTensors.close() + maskTensors.close() + + embeddings + } finally if (results != null) results.close() + } + case _ => + val tensors = new TensorResources() + + val (tokenTensors, maskTensors, segmentTensors) = + PrepareEmbeddings.prepareBatchTensorsWithSegment( + tensors, + batch, + maxSentenceLength, + batchLength) + + val runner = tensorflowWrapper.get + .getTFSessionWithSignature( + configProtoBytes = configProtoBytes, + savedSignatures = signatures, + initAllTables = false) + .runner + + runner + .feed( + _tfBertSignatures.getOrElse( + ModelSignatureConstants.InputIds.key, + "missing_input_id_key"), + tokenTensors) + .feed( + _tfBertSignatures + .getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"), + maskTensors) + .fetch( + _tfBertSignatures + .getOrElse( + ModelSignatureConstants.LastHiddenState.key, + "missing_sequence_output_key")) + + val outs = runner.run().asScala + val embeddings = TensorResources.extractFloats(outs.head) + + tokenTensors.close() + maskTensors.close() + segmentTensors.close() + tensors.clearSession(outs) + tensors.clearTensors() + + embeddings + } PrepareEmbeddings.prepareBatchWordEmbeddings( batch, embeddings, maxSentenceLength, batchLength) + } /** @param batch @@ -154,7 +203,7 @@ private[johnsnowlabs] class DistilBert( maxSentenceLength = maxSentenceLength, batchLength = batchLength) - val runner = tensorflowWrapper + val runner = tensorflowWrapper.get .getTFSessionWithSignature( configProtoBytes = configProtoBytes, savedSignatures = signatures, diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/RoBerta.scala b/src/main/scala/com/johnsnowlabs/ml/ai/RoBerta.scala index b5d0f7c8c51560..1e903ff0d4a345 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/RoBerta.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/RoBerta.scala @@ -16,10 +16,12 @@ package com.johnsnowlabs.ml.ai +import ai.onnxruntime.OnnxTensor import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings +import com.johnsnowlabs.ml.onnx.OnnxWrapper import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} -import com.johnsnowlabs.ml.util.ModelArch +import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} @@ -39,7 +41,8 @@ import scala.collection.JavaConverters._ * Model's inputs and output(s) signatures */ private[johnsnowlabs] class RoBerta( - val tensorflowWrapper: TensorflowWrapper, + val tensorflowWrapper: Option[TensorflowWrapper], + val onnxWrapper: Option[OnnxWrapper], sentenceStartTokenId: Int, sentenceEndTokenId: Int, padTokenId: Int, @@ -50,6 +53,10 @@ private[johnsnowlabs] class RoBerta( val _tfRoBertaSignatures: Map[String, String] = signatures.getOrElse(ModelSignatureManager.apply()) + val detectedEngine: String = + if (tensorflowWrapper.isDefined) TensorFlow.name + else if (onnxWrapper.isDefined) ONNX.name + else TensorFlow.name private def sessionWarmup(): Unit = { val dummyInput = @@ -68,42 +75,81 @@ private[johnsnowlabs] class RoBerta( val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max val batchLength = batch.length - val tensors = new TensorResources() - - val (tokenTensors, maskTensors) = - PrepareEmbeddings.prepareBatchTensors( - tensors = tensors, - batch = batch, - maxSentenceLength = maxSentenceLength, - batchLength = batchLength, - sentencePadTokenId = padTokenId) - - val runner = tensorflowWrapper - .getTFSessionWithSignature( - configProtoBytes = configProtoBytes, - savedSignatures = signatures, - initAllTables = false) - .runner - - runner - .feed( - _tfRoBertaSignatures - .getOrElse(ModelSignatureConstants.InputIds.key, "missing_input_id_key"), - tokenTensors) - .feed( - _tfRoBertaSignatures - .getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"), - maskTensors) - .fetch(_tfRoBertaSignatures - .getOrElse(ModelSignatureConstants.LastHiddenState.key, "missing_sequence_output_key")) - - val outs = runner.run().asScala - val embeddings = TensorResources.extractFloats(outs.head) - - tokenTensors.close() - maskTensors.close() - tensors.clearSession(outs) - tensors.clearTensors() + val embeddings = detectedEngine match { + + case ONNX.name => + // [nb of encoded sentences , maxSentenceLength] + val (runner, env) = onnxWrapper.get.getSession() + + val tokenTensors = + OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) + val maskTensors = + OnnxTensor.createTensor( + env, + batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray) + + val inputs = + Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava + + // TODO: A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled. + try { + val results = runner.run(inputs) + try { + val embeddings = results + .get("last_hidden_state") + .get() + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + tokenTensors.close() + maskTensors.close() + embeddings + + } finally if (results != null) results.close() + } + case _ => + val tensors = new TensorResources() + + val (tokenTensors, maskTensors) = + PrepareEmbeddings.prepareBatchTensors( + tensors = tensors, + batch = batch, + maxSentenceLength = maxSentenceLength, + batchLength = batchLength, + sentencePadTokenId = padTokenId) + + val runner = tensorflowWrapper.get + .getTFSessionWithSignature( + configProtoBytes = configProtoBytes, + savedSignatures = signatures, + initAllTables = false) + .runner + + runner + .feed( + _tfRoBertaSignatures + .getOrElse(ModelSignatureConstants.InputIds.key, "missing_input_id_key"), + tokenTensors) + .feed( + _tfRoBertaSignatures + .getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"), + maskTensors) + .fetch( + _tfRoBertaSignatures + .getOrElse( + ModelSignatureConstants.LastHiddenState.key, + "missing_sequence_output_key")) + + val outs = runner.run().asScala + val embeddings = TensorResources.extractFloats(outs.head) + + tokenTensors.close() + maskTensors.close() + tensors.clearSession(outs) + tensors.clearTensors() + + embeddings + } PrepareEmbeddings.prepareBatchWordEmbeddings( batch, @@ -133,7 +179,7 @@ private[johnsnowlabs] class RoBerta( batchLength = batchLength, sentencePadTokenId = padTokenId) - val runner = tensorflowWrapper + val runner = tensorflowWrapper.get .getTFSessionWithSignature( configProtoBytes = configProtoBytes, savedSignatures = signatures, diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala new file mode 100644 index 00000000000000..b6acafabbb8a3a --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala @@ -0,0 +1,98 @@ +/* + * Copyright 2017-2022 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.ml.onnx + +import ai.onnxruntime.OrtSession.SessionOptions +import com.johnsnowlabs.util.FileHelper +import org.apache.commons.io.FileUtils +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.sql.SparkSession + +import java.io.File +import java.nio.file.{Files, Paths} +import java.util.UUID + +trait WriteOnnxModel { + + def writeOnnxModel( + path: String, + spark: SparkSession, + onnxWrapper: OnnxWrapper, + suffix: String, + fileName: String): Unit = { + val uri = new java.net.URI(path.replaceAllLiterally("\\", "/")) + val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration) + + // 1. Create tmp folder + val tmpFolder = Files + .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix) + .toAbsolutePath + .toString + + val onnxFile = Paths.get(tmpFolder, fileName).toString + + // 2. Save Tensorflow state + onnxWrapper.saveToFile(onnxFile) + + // 3. Copy to dest folder + fs.copyFromLocalFile(new Path(onnxFile), new Path(path)) + + // 4. Remove tmp folder + FileUtils.deleteDirectory(new File(tmpFolder)) + } + +} + +trait ReadOnnxModel { + val onnxFile: String + + def readOnnxModel( + path: String, + spark: SparkSession, + suffix: String, + zipped: Boolean = true, + useBundle: Boolean = false, + sessionOptions: Option[SessionOptions] = None): OnnxWrapper = { + + val uri = new java.net.URI(path.replaceAllLiterally("\\", "/")) + val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration) + + // 1. Create tmp directory + val tmpFolder = Files + .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix) + .toAbsolutePath + .toString + + // 2. Copy to local dir + fs.copyToLocalFile(new Path(path, onnxFile), new Path(tmpFolder)) + + val localPath = new Path(tmpFolder, onnxFile).toString + + // 3. Read ONNX state + val onnxWrapper = OnnxWrapper.read( + localPath, + zipped = zipped, + useBundle = useBundle, + sessionOptions = sessionOptions) + + // 4. Remove tmp folder + FileHelper.delete(tmpFolder) + + onnxWrapper + } + +} diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala new file mode 100644 index 00000000000000..3e755615521afe --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala @@ -0,0 +1,162 @@ +/* + * Copyright 2017-2022 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.ml.onnx + +import ai.onnxruntime.OrtSession.SessionOptions +import ai.onnxruntime.OrtSession.SessionOptions.{ExecutionMode, OptLevel} +import ai.onnxruntime.providers.OrtCUDAProviderOptions +import ai.onnxruntime.{OrtEnvironment, OrtSession} +import com.johnsnowlabs.util.{FileHelper, ZipArchiveUtil} +import org.apache.commons.io.FileUtils +import org.slf4j.{Logger, LoggerFactory} + +import java.io._ +import java.nio.file.{Files, Paths} +import java.util.UUID + +class OnnxWrapper(var onnxModel: Array[Byte]) extends Serializable { + + /** For Deserialization */ + def this() = { + this(null) + } + + // Important for serialization on none-kyro serializers + @transient private var m_session: OrtSession = _ + @transient private var m_env: OrtEnvironment = _ + @transient private val logger = LoggerFactory.getLogger("OnnxWrapper") + + def getSession(sessionOptions: Option[SessionOptions] = None): (OrtSession, OrtEnvironment) = + this.synchronized { + if (m_session == null && m_env == null) { + val (session, env) = OnnxWrapper.withSafeOnnxModelLoader(onnxModel, sessionOptions) + m_env = env + m_session = session + } + (m_session, m_env) + } + + def saveToFile(file: String): Unit = { + // 1. Create tmp director + val tmpFolder = Files + .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_onnx") + .toAbsolutePath + .toString + + // 2. Save onnx model + val onnxFile = Paths.get(tmpFolder, file).toString + FileUtils.writeByteArrayToFile(new File(onnxFile), onnxModel) + + // 4. Zip folder + ZipArchiveUtil.zip(tmpFolder, file) + + // 5. Remove tmp directory + FileHelper.delete(tmpFolder) + } + +} + +/** Companion object */ +object OnnxWrapper { + private[OnnxWrapper] val logger: Logger = LoggerFactory.getLogger("OnnxWrapper") + + // TODO: make sure this.synchronized is needed or it's not a bottleneck + private def withSafeOnnxModelLoader( + onnxModel: Array[Byte], + sessionOptions: Option[SessionOptions] = None): (OrtSession, OrtEnvironment) = + this.synchronized { + val env = OrtEnvironment.getEnvironment() + + val opts = + if (sessionOptions.isDefined) sessionOptions.get else new OrtSession.SessionOptions() + + val providers = OrtEnvironment.getAvailableProviders + + if (providers.toArray.map(x => x.toString).contains("CUDA")) { + logger.info("using CUDA") + // it seems there is no easy way to use multiple GPUs + // at least not without using multiple threads + // TODO: add support for multiple GPUs + // TODO: allow user to specify which GPU to use + val gpuDeviceId = 0 // The GPU device ID to execute on + val cudaOpts = new OrtCUDAProviderOptions(gpuDeviceId) + // TODO: incorporate other cuda-related configs + // cudaOpts.add("gpu_mem_limit", "" + (512 * 1024 * 1024)) + // sessOptions.addCUDA(gpuDeviceId) + opts.addCUDA(cudaOpts) + } else { + logger.info("using CPUs") + // TODO: the following configs can be tested for performance + // However, so far, they seem to be slower than the ones used + // opts.setIntraOpNumThreads(Runtime.getRuntime.availableProcessors()) + // opts.setMemoryPatternOptimization(true) + // opts.setCPUArenaAllocator(false) + opts.setIntraOpNumThreads(6) + opts.setOptimizationLevel(OptLevel.ALL_OPT) + opts.setExecutionMode(ExecutionMode.SEQUENTIAL) + } + + val session = env.createSession(onnxModel, opts) + (session, env) + } + + def read( + modelPath: String, + zipped: Boolean = true, + useBundle: Boolean = false, + modelName: String = "model", + sessionOptions: Option[SessionOptions] = None): OnnxWrapper = { + + // 1. Create tmp folder + val tmpFolder = Files + .createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_onnx") + .toAbsolutePath + .toString + + // 2. Unpack archive + val folder = + if (zipped) + ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder)) + else + modelPath + + // TODO: simplify this logic of useBundle + val (session, env, modelBytes) = + if (useBundle) { + val onnxFile = Paths.get(modelPath, s"$modelName.onnx").toString + val modelFile = new File(onnxFile) + val modelBytes = FileUtils.readFileToByteArray(modelFile) + val (session, env) = withSafeOnnxModelLoader(modelBytes, sessionOptions) + (session, env, modelBytes) + } else { + val modelFile = new File(folder).list().head + val fullPath = Paths.get(folder, modelFile).toFile + val modelBytes = FileUtils.readFileToByteArray(fullPath) + val (session, env) = withSafeOnnxModelLoader(modelBytes, sessionOptions) + (session, env, modelBytes) + } + + // 4. Remove tmp folder + FileHelper.delete(tmpFolder) + + val onnxWrapper = new OnnxWrapper(modelBytes) + onnxWrapper.m_session = session + onnxWrapper.m_env = env + onnxWrapper + } + +} diff --git a/src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala b/src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala index 8e1f737be3a512..58aff6825f0408 100644 --- a/src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala +++ b/src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala @@ -37,12 +37,26 @@ object LoadExternalModel { | ├── variables.data-00000-of-00001 | └── variables.index | + |A typical imported ONNX model has the following structure: + | + |├── assets/ + | ├── your-assets-are-here (vocab, sp model, labels, etc.) + |├── model.onnx + | + |A typical imported ONNX model for Seq2Seq has the following structure: + | + |├── assets/ + | ├── your-assets-are-here (vocab, sp model, labels, etc.) + |├── encoder_model.onnx + |├── decoder_model.onnx + |├── decoder_with_past_model.onnx (not used in this release) + | |Please make sure you follow provided notebooks to import external models into Spark NLP: |https://github.com/JohnSnowLabs/spark-nlp/discussions/5669""".stripMargin } def isTensorFlowModel(modelPath: String): Boolean = { - val tfSavedModel = new File(modelPath, ModelEngine.tensorflowModelName) + val tfSavedModel = new File(modelPath, TensorFlow.modelName) tfSavedModel.exists() } @@ -50,11 +64,11 @@ object LoadExternalModel { def isOnnxModel(modelPath: String, isEncoderDecoder: Boolean = false): Boolean = { if (isEncoderDecoder) { - val onnxEncoderModel = new File(modelPath, ModelEngine.onnxEncoderModel) - val onnxDecoderModel = new File(modelPath, ModelEngine.onnxDecoderModel) + val onnxEncoderModel = new File(modelPath, ONNX.encoderModel) + val onnxDecoderModel = new File(modelPath, ONNX.decoderModel) onnxEncoderModel.exists() && onnxDecoderModel.exists() } else { - val onnxModel = new File(modelPath, ModelEngine.onnxModelName) + val onnxModel = new File(modelPath, ONNX.modelName) onnxModel.exists() } @@ -80,12 +94,12 @@ object LoadExternalModel { val onnxModelExist = isOnnxModel(modelPath, isEncoderDecoder) if (tfSavedModelExist) { - ModelEngine.tensorflow + TensorFlow.name } else if (onnxModelExist) { - ModelEngine.onnx + ONNX.name } else { require(tfSavedModelExist || onnxModelExist, notSupportedEngineError) - ModelEngine.unk + Unknown.name } } diff --git a/src/main/scala/com/johnsnowlabs/ml/util/ModelEngine.scala b/src/main/scala/com/johnsnowlabs/ml/util/ModelEngine.scala index 9e8b93e9991219..061a42e7caa930 100644 --- a/src/main/scala/com/johnsnowlabs/ml/util/ModelEngine.scala +++ b/src/main/scala/com/johnsnowlabs/ml/util/ModelEngine.scala @@ -1,5 +1,5 @@ /* - * Copyright 2017-2022 John Snow Labs + * Copyright 2017-2023 John Snow Labs * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,24 @@ package com.johnsnowlabs.ml.util -object ModelEngine { - val tensorflow = "tensorflow" - val tensorflowModelName = "saved_model.pb" - val onnx = "onnx" - val onnxModelName = "model.onnx" - val onnxEncoderModel = "encoder_model.onnx" - val onnxDecoderModel = "decoder_model.onnx" - val onnxDecoderWithPastModel = "decoder_with_past_model.onnx" - val unk = "unk" +sealed trait ModelEngine + +final case object TensorFlow extends ModelEngine { + val name = "tensorflow" + val modelName = "saved_model.pb" +} +final case object PyTorch extends ModelEngine { + val name = "pytorch" +} + +final case object ONNX extends ModelEngine { + val name = "onnx" + val modelName = "model.onnx" + val encoderModel = "encoder_model.onnx" + val decoderModel = "decoder_model.onnx" + val decoderWithPastModel = "decoder_with_past_model.onnx" +} + +final case object Unknown extends ModelEngine { + val name = "unk" } diff --git a/src/main/scala/com/johnsnowlabs/nlp/HasEngine.scala b/src/main/scala/com/johnsnowlabs/nlp/HasEngine.scala index 541d50b34afee5..39870b3073ce12 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/HasEngine.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/HasEngine.scala @@ -16,7 +16,7 @@ package com.johnsnowlabs.nlp -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import org.apache.spark.ml.param.Param trait HasEngine extends ParamsAndFeaturesWritable { @@ -27,7 +27,7 @@ trait HasEngine extends ParamsAndFeaturesWritable { */ val engine = new Param[String](this, "engine", "Deep Learning engine used for this model") - setDefault(engine, ModelEngine.tensorflow) + setDefault(engine, TensorFlow.name) /** @group getParam */ def getEngine: String = $(engine) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/HubertForCTC.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/HubertForCTC.scala index 989b3ee0634d30..520ffd0bda69ab 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/HubertForCTC.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/HubertForCTC.scala @@ -22,7 +22,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.audio.feature_extractor.Preprocessor import org.apache.spark.ml.util.Identifiable @@ -213,7 +213,7 @@ trait ReadHubertForAudioDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/Wav2Vec2ForCTC.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/Wav2Vec2ForCTC.scala index 4e51a3812f1a25..5927de36c587e9 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/Wav2Vec2ForCTC.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/Wav2Vec2ForCTC.scala @@ -27,7 +27,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp.AnnotatorType.{AUDIO, DOCUMENT} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.audio.feature_extractor.Preprocessor @@ -340,7 +340,7 @@ trait ReadWav2Vec2ForAudioDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala index 1d2026e7f1bd3a..217fbc6ca25947 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -317,7 +317,7 @@ trait ReadAlbertForQuestionAnsweringDLModel annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala index 8e110c8460ec5a..f0d61bcaade650 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -372,7 +372,7 @@ trait ReadAlbertForSequenceDLModel extends ReadTensorflowModel with ReadSentence annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala index 4abbb18a6307f1..89e61223d63097 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.{ModelEngine, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -343,7 +343,7 @@ trait ReadAlbertForTokenDLModel extends ReadTensorflowModel with ReadSentencePie annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala index e8d17348c1b968..d48b40dcb65c08 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.{ModelEngine, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -325,7 +325,7 @@ trait ReadBertForQuestionAnsweringDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala index d873915c1e412e..ff0bb3aeb4676a 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala @@ -24,7 +24,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.{ModelEngine, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -383,7 +383,7 @@ trait ReadBertForSequenceDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala index c9062cd3d99b83..0c287de7d2cd64 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -346,7 +346,7 @@ trait ReadBertForTokenDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala index b0149ab660d9c8..6c6ddc35140d1a 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala @@ -19,23 +19,19 @@ package com.johnsnowlabs.nlp.annotators.classifier.dl import com.johnsnowlabs.ml.ai.BertClassification import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.util.LoadExternalModel.{ - loadSentencePieceAsset, loadTextAsset, modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature -import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs, ResourceHelper} import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.param.{BooleanParam, IntArrayParam, IntParam} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.SparkSession -import java.io.File - /** BertForZeroShotClassification using a `ModelForSequenceClassification` trained on NLI (natural * language inference) tasks. Equivalent of `BertForSequenceClassification` models, but these * models don't require a hardcoded number of potential classes, they can be chosen at runtime. @@ -421,7 +417,7 @@ trait ReadBertForZeroShotDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala index 784003488a1a83..e55e6adf4b6cb5 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -323,7 +323,7 @@ trait ReadCamemBertForQADLModel extends ReadTensorflowModel with ReadSentencePie annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala index d96d8e59318e1e..9519af01f8a7ac 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -378,7 +378,7 @@ trait ReadCamemBertForSequenceDLModel extends ReadTensorflowModel with ReadSente annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala index 7b440341739223..275cd4bba61238 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -345,7 +345,7 @@ trait ReadCamemBertForTokenDLModel extends ReadTensorflowModel with ReadSentence annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala index 06e9c955d1f0f6..600b85da999a6d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -323,7 +323,7 @@ trait ReadDeBertaForQuestionAnsweringDLModel annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala index dae903e43e14df..0f025ebca7c367 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -374,7 +374,7 @@ trait ReadDeBertaForSequenceDLModel extends ReadTensorflowModel with ReadSentenc annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala index b09d9d5298bca9..81b3fdff7def4b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -343,7 +343,7 @@ trait ReadDeBertaForTokenDLModel extends ReadTensorflowModel with ReadSentencePi annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala index e950099b9e82fc..be3709d19b6279 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -328,7 +328,7 @@ trait ReadDistilBertForQuestionAnsweringDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala index 4c2699cf848e28..aee25f66d01640 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -383,7 +383,7 @@ trait ReadDistilBertForSequenceDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala index 53690d311104e2..20616a8303e7fc 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -351,7 +351,7 @@ trait ReadDistilBertForTokenDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForZeroShotClassification.scala index 27f34509c867aa..b1afba431726d2 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForZeroShotClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForZeroShotClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -423,7 +423,7 @@ trait ReadDistilBertForZeroShotDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForQuestionAnswering.scala index f9cdbeaf323127..453b8ac7e2cb17 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForQuestionAnswering.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -349,7 +349,7 @@ trait ReadLongformerForQuestionAnsweringDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForSequenceClassification.scala index e6c91330eaf371..6dd293f033515c 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForSequenceClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -403,7 +403,7 @@ trait ReadLongformerForSequenceDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForTokenClassification.scala index 1957ccb20b00cb..176fea3d1e19f2 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForTokenClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -371,7 +371,7 @@ trait ReadLongformerForTokenDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala index 27212384881b1b..35bd006fc4a3e5 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.{ModelEngine, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -347,7 +347,7 @@ trait ReadRoBertaForQuestionAnsweringDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala index f3c76c0b88f915..5e4b268af48f0d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -398,7 +398,7 @@ trait ReadRoBertaForSequenceDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala index 65c14a953c5b05..742306621bd376 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -369,7 +369,7 @@ trait ReadRoBertaForTokenDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForZeroShotClassification.scala index ff24acd94a7894..60041627854e15 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForZeroShotClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForZeroShotClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -441,7 +441,7 @@ trait ReadRoBertaForZeroShotDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/TapasForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/TapasForQuestionAnswering.scala index b9a1e253525054..22b3760cb8fa69 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/TapasForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/TapasForQuestionAnswering.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp.base.TableAssembler import com.johnsnowlabs.nlp.{Annotation, AnnotatorType, HasPretrained, ParamsAndFeaturesReadable} import org.apache.spark.broadcast.Broadcast @@ -265,7 +265,7 @@ trait ReadTapasForQuestionAnsweringDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala index a42fef9c880aea..01920477d5a672 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -323,7 +323,7 @@ trait ReadXlmRoBertaForQuestionAnsweringDLModel annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala index eada6953d7bca1..add55d9270b8be 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -375,7 +375,7 @@ trait ReadXlmRoBertaForSequenceDLModel extends ReadTensorflowModel with ReadSent annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala index 38a379d9ae529a..ded252b097d481 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -345,7 +345,7 @@ trait ReadXlmRoBertaForTokenDLModel extends ReadTensorflowModel with ReadSentenc annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForSequenceClassification.scala index 593e9d51e37a6f..b9e786c4a869fb 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForSequenceClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -371,7 +371,7 @@ trait ReadXlnetForSequenceDLModel extends ReadTensorflowModel with ReadSentenceP annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForTokenClassification.scala index 3f9e9f54df57e8..43b1e4dcd46103 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForTokenClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -343,7 +343,7 @@ trait ReadXlnetForTokenDLModel extends ReadTensorflowModel with ReadSentencePiec annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/coref/SpanBertCorefModel.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/coref/SpanBertCorefModel.scala index 1b097a76813d03..bb48cbc9d7c575 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/coref/SpanBertCorefModel.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/coref/SpanBertCorefModel.scala @@ -26,7 +26,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} @@ -449,7 +449,7 @@ trait ReadSpanBertCorefTensorflowModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ConvNextForImageClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ConvNextForImageClassification.scala index b87de63bae8e58..3af9710b12ff9b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ConvNextForImageClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ConvNextForImageClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.cv.feature_extractor.Preprocessor import org.apache.spark.broadcast.Broadcast @@ -353,7 +353,7 @@ trait ReadConvNextForImageDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/SwinForImageClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/SwinForImageClassification.scala index 4341fa23cd4bd6..344200c0c2e501 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/SwinForImageClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/SwinForImageClassification.scala @@ -22,7 +22,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.cv.feature_extractor.Preprocessor import org.apache.spark.ml.param.{BooleanParam, DoubleParam} @@ -334,7 +334,7 @@ trait ReadSwinForImageDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ViTForImageClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ViTForImageClassification.scala index e786739b6ac718..985fbc041251a5 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ViTForImageClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ViTForImageClassification.scala @@ -27,7 +27,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp.AnnotatorType.{CATEGORY, IMAGE} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.cv.feature_extractor.Preprocessor @@ -39,8 +39,6 @@ import org.apache.spark.sql.SparkSession import org.json4s._ import org.json4s.jackson.JsonMethods._ -import java.io.File - /** Vision Transformer (ViT) for image classification. * * ViT is a transformer based alternative to the convolutional neural networks usually used for @@ -384,7 +382,7 @@ trait ReadViTForImageDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/ld/dl/LanguageDetectorDL.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/ld/dl/LanguageDetectorDL.scala index 79a14a4b6098e5..5e05d06d6c2bec 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/ld/dl/LanguageDetectorDL.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/ld/dl/LanguageDetectorDL.scala @@ -22,7 +22,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -341,7 +341,7 @@ trait ReadLanguageDetectorDLTensorflowModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, _) = TensorflowWrapper.read( localModelPath, diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTransformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTransformer.scala index 66d97181a86e38..aa72c0e4738fbe 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTransformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTransformer.scala @@ -27,7 +27,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -596,7 +596,7 @@ trait ReadBartTransformerDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read( localModelPath, diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/GPT2Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/GPT2Transformer.scala index 246c6e8a6f10ac..29d76fcd0dea17 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/GPT2Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/GPT2Transformer.scala @@ -27,7 +27,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.{BpeTokenizer, Gpt2Tokenizer} @@ -544,7 +544,7 @@ trait ReadGPT2TransformerDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, _) = TensorflowWrapper.read( localModelPath, diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala index e15729d3f05a42..ce18cf3ad4f8bd 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -458,7 +458,7 @@ trait ReadMarianMTDLModel extends ReadTensorflowModel with ReadSentencePieceMode annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read( localModelPath, zipped = false, diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/T5Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/T5Transformer.scala index ac70e675f3df97..edc691a236191c 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/T5Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/T5Transformer.scala @@ -32,7 +32,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -537,7 +537,7 @@ trait ReadT5TransformerDLModel extends ReadTensorflowModel with ReadSentencePiec val spModel = loadSentencePieceAsset(localModelPath, "spiece.model") detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read( localModelPath, zipped = false, diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala index 1f4c0e8a2923b3..c8da89256f2b4c 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -396,7 +396,7 @@ trait ReadAlbertDLModel extends ReadTensorflowModel with ReadSentencePieceModel annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala index fb47275a624d3e..3717c09bf3b9bc 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala @@ -17,13 +17,14 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.ml.ai.Bert +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.util.LoadExternalModel.{ loadTextAsset, modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} @@ -157,13 +158,30 @@ class BertEmbeddings(override val uid: String) extends AnnotatorModel[BertEmbeddings] with HasBatchedAnnotate[BertEmbeddings] with WriteTensorflowModel + with WriteOnnxModel with HasEmbeddingsProperties with HasStorageRef with HasCaseSensitiveProperties with HasEngine { + /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator + * type + */ def this() = this(Identifiable.randomUID("BERT_EMBEDDINGS")) + /** Input Annotator Types: DOCUMENT, TOKEN + * + * @group anno + */ + override val inputAnnotatorTypes: Array[String] = + Array(AnnotatorType.DOCUMENT, AnnotatorType.TOKEN) + + /** Output Annotator Types: WORD_EMBEDDINGS + * + * @group anno + */ + override val outputAnnotatorType: AnnotatorType = AnnotatorType.WORD_EMBEDDINGS + /** @group setParam */ def sentenceStartTokenId: Int = { $$(vocabulary)("[CLS]") @@ -241,12 +259,14 @@ class BertEmbeddings(override val uid: String) /** @group setParam */ def setModelIfNotSet( spark: SparkSession, - tensorflowWrapper: TensorflowWrapper): BertEmbeddings = { + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper]): BertEmbeddings = { if (_model.isEmpty) { _model = Some( spark.sparkContext.broadcast( new Bert( tensorflowWrapper, + onnxWrapper, sentenceStartTokenId, sentenceEndTokenId, configProtoBytes = getConfigProtoBytes, @@ -257,7 +277,6 @@ class BertEmbeddings(override val uid: String) this } - /** @group getParam */ def getModelIfNotSet: Bert = _model.get.value /** Set Embeddings dimensions for the BERT model Only possible to set this when the first time @@ -354,22 +373,30 @@ class BertEmbeddings(override val uid: String) wrapEmbeddingsMetadata(dataset.col(getOutputCol), $(dimension), Some($(storageRef)))) } - /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator - * type - */ - override val inputAnnotatorTypes: Array[String] = - Array(AnnotatorType.DOCUMENT, AnnotatorType.TOKEN) - override val outputAnnotatorType: AnnotatorType = AnnotatorType.WORD_EMBEDDINGS - override def onWrite(path: String, spark: SparkSession): Unit = { super.onWrite(path, spark) - writeTensorflowModelV2( - path, - spark, - getModelIfNotSet.tensorflowWrapper, - "_bert", - BertEmbeddings.tfFile, - configProtoBytes = getConfigProtoBytes) + val suffix = "_bert" + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + suffix, + BertEmbeddings.tfFile, + configProtoBytes = getConfigProtoBytes) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + BertEmbeddings.onnxFile) + + case _ => + throw new Exception(notSupportedEngineError) + } } } @@ -391,15 +418,27 @@ trait ReadablePretrainedBertModel super.pretrained(name, lang, remoteLoc) } -trait ReadBertDLModel extends ReadTensorflowModel { +trait ReadBertDLModel extends ReadTensorflowModel with ReadOnnxModel { this: ParamsAndFeaturesReadable[BertEmbeddings] => override val tfFile: String = "bert_tensorflow" + override val onnxFile: String = "bert_onnx" def readModel(instance: BertEmbeddings, path: String, spark: SparkSession): Unit = { - val tf = readTensorflowModel(path, spark, "_bert_tf", initAllTables = false) - instance.setModelIfNotSet(spark, tf) + instance.getEngine match { + case TensorFlow.name => + val tfWrapper = readTensorflowModel(path, spark, "_bert_tf", initAllTables = false) + instance.setModelIfNotSet(spark, Some(tfWrapper), None) + + case ONNX.name => { + val onnxWrapper = + readOnnxModel(path, spark, "_bert_onnx", zipped = true, useBundle = false, None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) + } + case _ => + throw new Exception(notSupportedEngineError) + } } addReader(readModel) @@ -417,8 +456,8 @@ trait ReadBertDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => - val (wrapper, signatures) = + case TensorFlow.name => + val (tfWrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) val _signatures = signatures match { @@ -431,7 +470,12 @@ trait ReadBertDLModel extends ReadTensorflowModel { */ annotatorModel .setSignatures(_signatures) - .setModelIfNotSet(spark, wrapper) + .setModelIfNotSet(spark, Some(tfWrapper), None) + + case ONNX.name => + val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper)) case _ => throw new Exception(notSupportedEngineError) @@ -439,6 +483,7 @@ trait ReadBertDLModel extends ReadTensorflowModel { annotatorModel } + } /** This is the companion object of [[BertEmbeddings]]. Please refer to that class for the diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala index 6e7af48bd113b0..c2e36695688a38 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala @@ -17,13 +17,14 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.ml.ai.Bert +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.util.LoadExternalModel.{ loadTextAsset, modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} @@ -122,7 +123,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession} * }}} * * @see - * [[BertEmbeddings]] for token-level embeddings + * [[BertSentenceEmbeddings]] for sentence-level embeddings * @see * [[com.johnsnowlabs.nlp.annotators.classifier.dl.BertForSequenceClassification BertForSequenceClassification]] * for embeddings with a sequence classification layer on top @@ -152,6 +153,7 @@ class BertSentenceEmbeddings(override val uid: String) extends AnnotatorModel[BertSentenceEmbeddings] with HasBatchedAnnotate[BertSentenceEmbeddings] with WriteTensorflowModel + with WriteOnnxModel with HasEmbeddingsProperties with HasStorageRef with HasCaseSensitiveProperties @@ -302,13 +304,17 @@ class BertSentenceEmbeddings(override val uid: String) def getModelIfNotSet: Bert = _model.get.value /** @group setParam */ - def setModelIfNotSet(spark: SparkSession, tensorflow: TensorflowWrapper): this.type = { + def setModelIfNotSet( + spark: SparkSession, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper]): this.type = { if (_model.isEmpty) { _model = Some( spark.sparkContext.broadcast( new Bert( - tensorflow, + tensorflowWrapper, + onnxWrapper, sentenceStartTokenId, sentenceEndTokenId, configProtoBytes = getConfigProtoBytes, @@ -391,13 +397,28 @@ class BertSentenceEmbeddings(override val uid: String) override def onWrite(path: String, spark: SparkSession): Unit = { super.onWrite(path, spark) - writeTensorflowModelV2( - path, - spark, - getModelIfNotSet.tensorflowWrapper, - "_bert_sentence", - BertSentenceEmbeddings.tfFile, - configProtoBytes = getConfigProtoBytes) + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + "_bert_sentence", + BertSentenceEmbeddings.tfFile, + configProtoBytes = getConfigProtoBytes) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + "_bert_sentence", + BertSentenceEmbeddings.onnxFile) + + case _ => + throw new Exception(notSupportedEngineError) + } + } } @@ -419,15 +440,34 @@ trait ReadablePretrainedBertSentenceModel super.pretrained(name, lang, remoteLoc) } -trait ReadBertSentenceDLModel extends ReadTensorflowModel { +trait ReadBertSentenceDLModel extends ReadTensorflowModel with ReadOnnxModel { this: ParamsAndFeaturesReadable[BertSentenceEmbeddings] => override val tfFile: String = "bert_sentence_tensorflow" + override val onnxFile: String = "bert_sentence_onnx" def readModel(instance: BertSentenceEmbeddings, path: String, spark: SparkSession): Unit = { - val tf = readTensorflowModel(path, spark, "_bert_sentence_tf", initAllTables = false) - instance.setModelIfNotSet(spark, tf) + instance.getEngine match { + case TensorFlow.name => + val tfWrapper = + readTensorflowModel(path, spark, "_bert_sentence_tf", initAllTables = false) + instance.setModelIfNotSet(spark, Some(tfWrapper), None) + + case ONNX.name => { + val onnxWrapper = + readOnnxModel( + path, + spark, + "_bert_sentence_onnx", + zipped = true, + useBundle = false, + None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) + } + case _ => + throw new Exception(notSupportedEngineError) + } } addReader(readModel) @@ -445,8 +485,8 @@ trait ReadBertSentenceDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => - val (wrapper, signatures) = + case TensorFlow.name => + val (tfWrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) val _signatures = signatures match { @@ -459,7 +499,12 @@ trait ReadBertSentenceDLModel extends ReadTensorflowModel { */ annotatorModel .setSignatures(_signatures) - .setModelIfNotSet(spark, wrapper) + .setModelIfNotSet(spark, Some(tfWrapper), None) + + case ONNX.name => + val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper)) case _ => throw new Exception(notSupportedEngineError) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala index 12c2b4d1edaef0..914d9b87b91449 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala @@ -12,7 +12,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -362,7 +362,7 @@ trait ReadCamemBertDLModel extends ReadTensorflowModel with ReadSentencePieceMod annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala index e502484f1f1521..56f57238e3a84e 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala @@ -17,6 +17,7 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.ml.ai.DeBerta +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ ReadSentencePieceModel, @@ -28,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.{ModelEngine, ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -160,6 +161,7 @@ class DeBertaEmbeddings(override val uid: String) extends AnnotatorModel[DeBertaEmbeddings] with HasBatchedAnnotate[DeBertaEmbeddings] with WriteTensorflowModel + with WriteOnnxModel with WriteSentencePieceModel with HasEmbeddingsProperties with HasStorageRef @@ -247,7 +249,8 @@ class DeBertaEmbeddings(override val uid: String) /** @group setParam */ def setModelIfNotSet( spark: SparkSession, - tensorflowWrapper: TensorflowWrapper, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper], spp: SentencePieceWrapper): DeBertaEmbeddings = { if (_model.isEmpty) { @@ -255,6 +258,7 @@ class DeBertaEmbeddings(override val uid: String) spark.sparkContext.broadcast( new DeBerta( tensorflowWrapper, + onnxWrapper, spp, batchSize = $(batchSize), configProtoBytes = getConfigProtoBytes, @@ -308,30 +312,41 @@ class DeBertaEmbeddings(override val uid: String) }) } - override def onWrite(path: String, spark: SparkSession): Unit = { - super.onWrite(path, spark) - writeTensorflowModelV2( - path, - spark, - getModelIfNotSet.tensorflowWrapper, - "_deberta", - DeBertaEmbeddings.tfFile, - configProtoBytes = getConfigProtoBytes) - writeSentencePieceModel( - path, - spark, - getModelIfNotSet.spp, - "_deberta", - DeBertaEmbeddings.sppFile) - - } - override protected def afterAnnotate(dataset: DataFrame): DataFrame = { dataset.withColumn( getOutputCol, wrapEmbeddingsMetadata(dataset.col(getOutputCol), $(dimension), Some($(storageRef)))) } + override def onWrite(path: String, spark: SparkSession): Unit = { + super.onWrite(path, spark) + val suffix = "_deberta" + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + suffix, + DeBertaEmbeddings.tfFile, + configProtoBytes = getConfigProtoBytes) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + DeBertaEmbeddings.onnxFile) + + case _ => + throw new Exception(notSupportedEngineError) + } + + writeSentencePieceModel(path, spark, getModelIfNotSet.spp, suffix, DeBertaEmbeddings.sppFile) + + } + } trait ReadablePretrainedDeBertaModel @@ -351,16 +366,32 @@ trait ReadablePretrainedDeBertaModel super.pretrained(name, lang, remoteLoc) } -trait ReadDeBertaDLModel extends ReadTensorflowModel with ReadSentencePieceModel { +trait ReadDeBertaDLModel + extends ReadTensorflowModel + with ReadSentencePieceModel + with ReadOnnxModel { this: ParamsAndFeaturesReadable[DeBertaEmbeddings] => override val tfFile: String = "deberta_tensorflow" + override val onnxFile: String = "deberta_onnx" override val sppFile: String = "deberta_spp" def readModel(instance: DeBertaEmbeddings, path: String, spark: SparkSession): Unit = { - val tf = readTensorflowModel(path, spark, "_deberta_tf", initAllTables = false) val spp = readSentencePieceModel(path, spark, "_deberta_spp", sppFile) - instance.setModelIfNotSet(spark, tf, spp) + + instance.getEngine match { + case TensorFlow.name => + val tfWrapper = readTensorflowModel(path, spark, "_deberta_tf", initAllTables = false) + instance.setModelIfNotSet(spark, Some(tfWrapper), None, spp) + + case ONNX.name => { + val onnxWrapper = + readOnnxModel(path, spark, "_deberta_onnx", zipped = true, useBundle = false, None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) + } + case _ => + throw new Exception(notSupportedEngineError) + } } addReader(readModel) @@ -377,8 +408,8 @@ trait ReadDeBertaDLModel extends ReadTensorflowModel with ReadSentencePieceModel annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => - val (wrapper, signatures) = + case TensorFlow.name => + val (tfWrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) val _signatures = signatures match { @@ -391,7 +422,12 @@ trait ReadDeBertaDLModel extends ReadTensorflowModel with ReadSentencePieceModel */ annotatorModel .setSignatures(_signatures) - .setModelIfNotSet(spark, wrapper, spModel) + .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) + + case ONNX.name => + val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) case _ => throw new Exception(notSupportedEngineError) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala index 8bcfbd578a2343..d28ce903c48eb0 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala @@ -17,13 +17,14 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.ml.ai.DistilBert +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.util.LoadExternalModel.{ loadTextAsset, modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} @@ -161,6 +162,7 @@ class DistilBertEmbeddings(override val uid: String) extends AnnotatorModel[DistilBertEmbeddings] with HasBatchedAnnotate[DistilBertEmbeddings] with WriteTensorflowModel + with WriteOnnxModel with HasEmbeddingsProperties with HasStorageRef with HasCaseSensitiveProperties @@ -171,6 +173,19 @@ class DistilBertEmbeddings(override val uid: String) */ def this() = this(Identifiable.randomUID("DISTILBERT_EMBEDDINGS")) + /** Input Annotator Types: DOCUMENT. TOKEN + * + * @group param + */ + override val inputAnnotatorTypes: Array[String] = + Array(AnnotatorType.DOCUMENT, AnnotatorType.TOKEN) + + /** Output Annotator Types: WORD_EMBEDDINGS + * + * @group param + */ + override val outputAnnotatorType: AnnotatorType = AnnotatorType.WORD_EMBEDDINGS + def sentenceStartTokenId: Int = { $$(vocabulary)("[CLS]") } @@ -246,12 +261,14 @@ class DistilBertEmbeddings(override val uid: String) /** @group setParam */ def setModelIfNotSet( spark: SparkSession, - tensorflowWrapper: TensorflowWrapper): DistilBertEmbeddings = { + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper]): DistilBertEmbeddings = { if (_model.isEmpty) { _model = Some( spark.sparkContext.broadcast( new DistilBert( tensorflowWrapper, + onnxWrapper, sentenceStartTokenId, sentenceEndTokenId, configProtoBytes = getConfigProtoBytes, @@ -357,28 +374,31 @@ class DistilBertEmbeddings(override val uid: String) wrapEmbeddingsMetadata(dataset.col(getOutputCol), $(dimension), Some($(storageRef)))) } - /** Input Annotator Types: DOCUMENT. TOKEN - * - * @group param - */ - override val inputAnnotatorTypes: Array[String] = - Array(AnnotatorType.DOCUMENT, AnnotatorType.TOKEN) - - /** Output Annotator Types: WORD_EMBEDDINGS - * - * @group param - */ - override val outputAnnotatorType: AnnotatorType = AnnotatorType.WORD_EMBEDDINGS - override def onWrite(path: String, spark: SparkSession): Unit = { super.onWrite(path, spark) - writeTensorflowModelV2( - path, - spark, - getModelIfNotSet.tensorflowWrapper, - "_distilbert", - DistilBertEmbeddings.tfFile, - configProtoBytes = getConfigProtoBytes) + val suffix = "_distilbert" + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + suffix, + DistilBertEmbeddings.tfFile, + configProtoBytes = getConfigProtoBytes) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + DistilBertEmbeddings.onnxFile) + + case _ => + throw new Exception(notSupportedEngineError) + } + } } @@ -400,15 +420,27 @@ trait ReadablePretrainedDistilBertModel super.pretrained(name, lang, remoteLoc) } -trait ReadDistilBertDLModel extends ReadTensorflowModel { +trait ReadDistilBertDLModel extends ReadTensorflowModel with ReadOnnxModel { this: ParamsAndFeaturesReadable[DistilBertEmbeddings] => override val tfFile: String = "distilbert_tensorflow" + override val onnxFile: String = "bert_onnx" def readModel(instance: DistilBertEmbeddings, path: String, spark: SparkSession): Unit = { - val tf = readTensorflowModel(path, spark, "_distilbert_tf", initAllTables = false) - instance.setModelIfNotSet(spark, tf) + instance.getEngine match { + case TensorFlow.name => + val tfWrapper = readTensorflowModel(path, spark, "_distilbert_tf", initAllTables = false) + instance.setModelIfNotSet(spark, Some(tfWrapper), None) + + case ONNX.name => { + val onnxWrapper = + readOnnxModel(path, spark, "_distilbert_onnx", zipped = true, useBundle = false, None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) + } + case _ => + throw new Exception(notSupportedEngineError) + } } addReader(readModel) @@ -426,8 +458,8 @@ trait ReadDistilBertDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => - val (wrapper, signatures) = + case TensorFlow.name => + val (tfWrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) val _signatures = signatures match { @@ -440,7 +472,12 @@ trait ReadDistilBertDLModel extends ReadTensorflowModel { */ annotatorModel .setSignatures(_signatures) - .setModelIfNotSet(spark, wrapper) + .setModelIfNotSet(spark, Some(tfWrapper), None) + + case ONNX.name => + val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper)) case _ => throw new Exception(notSupportedEngineError) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/ElmoEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/ElmoEmbeddings.scala index 647061442198c1..7f12ffc4c89a4d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/ElmoEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/ElmoEmbeddings.scala @@ -19,7 +19,7 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.ml.ai.Elmo import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.util.LoadExternalModel.{modelSanityCheck, notSupportedEngineError} -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.storage.HasStorageRef @@ -363,7 +363,7 @@ trait ReadElmoDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, _) = TensorflowWrapper.read( localModelPath, zipped = false, diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala index 7984096e2ca163..a42c1b334d9d0b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala @@ -17,13 +17,14 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.ml.ai.RoBerta +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.util.LoadExternalModel.{ loadTextAsset, modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.BpeTokenizer @@ -247,12 +248,14 @@ class LongformerEmbeddings(override val uid: String) /** @group setParam */ def setModelIfNotSet( spark: SparkSession, - tensorflowWrapper: TensorflowWrapper): LongformerEmbeddings = { + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper]): LongformerEmbeddings = { if (_model.isEmpty) { _model = Some( spark.sparkContext.broadcast( new RoBerta( tensorflowWrapper, + onnxWrapper, sentenceStartTokenId, sentenceEndTokenId, padTokenId, @@ -381,7 +384,7 @@ class LongformerEmbeddings(override val uid: String) writeTensorflowModelV2( path, spark, - getModelIfNotSet.tensorflowWrapper, + getModelIfNotSet.tensorflowWrapper.get, "_longformer", LongformerEmbeddings.tfFile, configProtoBytes = getConfigProtoBytes) @@ -414,7 +417,7 @@ trait ReadLongformerDLModel extends ReadTensorflowModel { def readModel(instance: LongformerEmbeddings, path: String, spark: SparkSession): Unit = { val tf = readTensorflowModel(path, spark, "_longformer_tf", initAllTables = false) - instance.setModelIfNotSet(spark, tf) + instance.setModelIfNotSet(spark, Some(tf), None) } addReader(readModel) @@ -440,7 +443,7 @@ trait ReadLongformerDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) @@ -454,7 +457,7 @@ trait ReadLongformerDLModel extends ReadTensorflowModel { */ annotatorModel .setSignatures(_signatures) - .setModelIfNotSet(spark, wrapper) + .setModelIfNotSet(spark, Some(wrapper), None) case _ => throw new Exception(notSupportedEngineError) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala index dae1369d440a2e..02c06bca1b4e77 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala @@ -17,13 +17,14 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.ml.ai.RoBerta +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.util.LoadExternalModel.{ loadTextAsset, modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.BpeTokenizer @@ -162,6 +163,7 @@ class RoBertaEmbeddings(override val uid: String) extends AnnotatorModel[RoBertaEmbeddings] with HasBatchedAnnotate[RoBertaEmbeddings] with WriteTensorflowModel + with WriteOnnxModel with HasEmbeddingsProperties with HasStorageRef with HasCaseSensitiveProperties @@ -260,12 +262,14 @@ class RoBertaEmbeddings(override val uid: String) /** @group setParam */ def setModelIfNotSet( spark: SparkSession, - tensorflowWrapper: TensorflowWrapper): RoBertaEmbeddings = { + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper]): RoBertaEmbeddings = { if (_model.isEmpty) { _model = Some( spark.sparkContext.broadcast( new RoBerta( tensorflowWrapper, + onnxWrapper, sentenceStartTokenId, sentenceEndTokenId, padTokenId, @@ -391,15 +395,29 @@ class RoBertaEmbeddings(override val uid: String) override def onWrite(path: String, spark: SparkSession): Unit = { super.onWrite(path, spark) - writeTensorflowModelV2( - path, - spark, - getModelIfNotSet.tensorflowWrapper, - "_roberta", - RoBertaEmbeddings.tfFile, - configProtoBytes = getConfigProtoBytes) - } + val suffix = "_roberta" + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + suffix, + RoBertaEmbeddings.tfFile, + configProtoBytes = getConfigProtoBytes) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + RoBertaEmbeddings.onnxFile) + case _ => + throw new Exception(notSupportedEngineError) + } + } } trait ReadablePretrainedRobertaModel @@ -419,15 +437,27 @@ trait ReadablePretrainedRobertaModel super.pretrained(name, lang, remoteLoc) } -trait ReadRobertaDLModel extends ReadTensorflowModel { +trait ReadRobertaDLModel extends ReadTensorflowModel with ReadOnnxModel { this: ParamsAndFeaturesReadable[RoBertaEmbeddings] => override val tfFile: String = "roberta_tensorflow" + override val onnxFile: String = "roberta_onnx" def readModel(instance: RoBertaEmbeddings, path: String, spark: SparkSession): Unit = { - val tf = readTensorflowModel(path, spark, "_roberta_tf", initAllTables = false) - instance.setModelIfNotSet(spark, tf) + instance.getEngine match { + case TensorFlow.name => + val tfWrapper = readTensorflowModel(path, spark, "_roberta_tf", initAllTables = false) + instance.setModelIfNotSet(spark, Some(tfWrapper), None) + + case ONNX.name => { + val onnxWrapper = + readOnnxModel(path, spark, "_roberta_onnx", zipped = true, useBundle = false, None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) + } + case _ => + throw new Exception(notSupportedEngineError) + } } addReader(readModel) @@ -453,7 +483,7 @@ trait ReadRobertaDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) @@ -467,7 +497,12 @@ trait ReadRobertaDLModel extends ReadTensorflowModel { */ annotatorModel .setSignatures(_signatures) - .setModelIfNotSet(spark, wrapper) + .setModelIfNotSet(spark, Some(wrapper), None) + + case ONNX.name => + val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper)) case _ => throw new Exception(notSupportedEngineError) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaSentenceEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaSentenceEmbeddings.scala index 9cff4ea74fbab5..c41c6c91ab2da9 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaSentenceEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaSentenceEmbeddings.scala @@ -17,13 +17,14 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.ml.ai.RoBerta +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.util.LoadExternalModel.{ loadTextAsset, modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.BpeTokenizer @@ -257,12 +258,15 @@ class RoBertaSentenceEmbeddings(override val uid: String) /** @group setParam */ def setModelIfNotSet( spark: SparkSession, - tensorflowWrapper: TensorflowWrapper): RoBertaSentenceEmbeddings = { + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper]): RoBertaSentenceEmbeddings = { if (_model.isEmpty) { _model = Some( spark.sparkContext.broadcast( + new RoBerta( tensorflowWrapper, + onnxWrapper, sentenceStartTokenId, sentenceEndTokenId, padTokenId, @@ -368,7 +372,7 @@ class RoBertaSentenceEmbeddings(override val uid: String) writeTensorflowModelV2( path, spark, - getModelIfNotSet.tensorflowWrapper, + getModelIfNotSet.tensorflowWrapper.get, "_roberta", RoBertaSentenceEmbeddings.tfFile, configProtoBytes = getConfigProtoBytes) @@ -403,7 +407,7 @@ trait ReadRobertaSentenceDLModel extends ReadTensorflowModel { def readModel(instance: RoBertaSentenceEmbeddings, path: String, spark: SparkSession): Unit = { val tf = readTensorflowModel(path, spark, "_roberta_tf", initAllTables = false) - instance.setModelIfNotSet(spark, tf) + instance.setModelIfNotSet(spark, Some(tf), None) } addReader(readModel) @@ -429,7 +433,7 @@ trait ReadRobertaSentenceDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) @@ -443,7 +447,7 @@ trait ReadRobertaSentenceDLModel extends ReadTensorflowModel { */ annotatorModel .setSignatures(_signatures) - .setModelIfNotSet(spark, wrapper) + .setModelIfNotSet(spark, Some(wrapper), None) case _ => throw new Exception(notSupportedEngineError) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/UniversalSentenceEncoder.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/UniversalSentenceEncoder.scala index 2cc0712a615e1d..89bce984f4c388 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/UniversalSentenceEncoder.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/UniversalSentenceEncoder.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.tensorflow.{ WriteTensorflowModel } import com.johnsnowlabs.ml.util.LoadExternalModel.{modelSanityCheck, notSupportedEngineError} -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, SENTENCE_EMBEDDINGS} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common.SentenceSplit @@ -349,7 +349,7 @@ trait ReadUSEDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val wrapper = TensorflowWrapper.readWithSP( localModelPath, diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala index 76cbc656235e4e..107da32535a946 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -394,7 +394,7 @@ trait ReadXlmRobertaDLModel extends ReadTensorflowModel with ReadSentencePieceMo annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaSentenceEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaSentenceEmbeddings.scala index f81836a86e1a75..07df2844768290 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaSentenceEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaSentenceEmbeddings.scala @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, ModelEngine, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -383,7 +383,7 @@ trait ReadXlmRobertaSentenceDLModel extends ReadTensorflowModel with ReadSentenc annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlnetEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlnetEmbeddings.scala index a5bfac1d55159b..86d2c8b3e1cf97 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlnetEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlnetEmbeddings.scala @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.{ModelEngine, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -393,7 +393,7 @@ trait ReadXlnetDLModel extends ReadTensorflowModel with ReadSentencePieceModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)