From 96c7abbcc6c7efe01b330b91de7a2d1d1afb8169 Mon Sep 17 00:00:00 2001 From: Danilo Burbano Date: Wed, 7 Aug 2024 08:15:00 -0500 Subject: [PATCH] [SPARKNLP-855] Introducing AlbertForZeroShotClassification --- .../annotator/classifier_dl/__init__.py | 3 +- .../albert_for_zero_shot_classification.py | 211 +++++++++ python/sparknlp/internal/__init__.py | 9 + ...lbert_for_zero_shot_classification_test.py | 60 +++ .../ml/ai/AlbertClassification.scala | 42 +- .../com/johnsnowlabs/nlp/annotator.scala | 7 + .../dl/AlbertForZeroShotClassification.scala | 402 ++++++++++++++++++ ...ertForZeroShotClassificationTestSpec.scala | 66 +++ 8 files changed, 796 insertions(+), 4 deletions(-) create mode 100644 python/sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py create mode 100644 python/test/annotator/classifier_dl/albert_for_zero_shot_classification_test.py create mode 100644 src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForZeroShotClassification.scala create mode 100644 src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForZeroShotClassificationTestSpec.scala diff --git a/python/sparknlp/annotator/classifier_dl/__init__.py b/python/sparknlp/annotator/classifier_dl/__init__.py index 2ecbe83018301c..ed3e7160892349 100644 --- a/python/sparknlp/annotator/classifier_dl/__init__.py +++ b/python/sparknlp/annotator/classifier_dl/__init__.py @@ -51,4 +51,5 @@ from sparknlp.annotator.classifier_dl.deberta_for_zero_shot_classification import * from sparknlp.annotator.classifier_dl.mpnet_for_sequence_classification import * from sparknlp.annotator.classifier_dl.mpnet_for_question_answering import * -from sparknlp.annotator.classifier_dl.mpnet_for_token_classification import * \ No newline at end of file +from sparknlp.annotator.classifier_dl.mpnet_for_token_classification import * +from sparknlp.annotator.classifier_dl.albert_for_zero_shot_classification import * \ No newline at end of file diff --git a/python/sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py b/python/sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py new file mode 100644 index 00000000000000..4ea47dab83166e --- /dev/null +++ b/python/sparknlp/annotator/classifier_dl/albert_for_zero_shot_classification.py @@ -0,0 +1,211 @@ +# Copyright 2017-2024 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains classes for AlbertForZeroShotClassification.""" + +from sparknlp.common import * + + +class AlbertForZeroShotClassification(AnnotatorModel, + HasCaseSensitiveProperties, + HasBatchedAnnotate, + HasClassifierActivationProperties, + HasCandidateLabelsProperties, + HasEngine, + HasMaxSentenceLengthLimit): + """AlbertForZeroShotClassification using a `ModelForSequenceClassification` trained on NLI (natural language + inference) tasks. Equivalent of `DistilBertForSequenceClassification` models, but these models don't require a hardcoded + number of potential classes, they can be chosen at runtime. It usually means it's slower but it is much more + flexible. + + Note that the model will loop through all provided labels. So the more labels you have, the + longer this process will take. + + Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis + pair and passed to the pretrained model. + + Pretrained models can be loaded with :meth:`.pretrained` of the companion + object: + + >>> sequenceClassifier = AlbertForZeroShotClassification.pretrained() \\ + ... .setInputCols(["token", "document"]) \\ + ... .setOutputCol("label") + + The default model is ``"albert_base_zero_shot_classifier_onnx"``, if no name is + provided. + + For available pretrained models please see the `Models Hub + `__. + + To see which models are compatible and how to import them see + `Import Transformers into Spark NLP 🚀 + `_. + + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``DOCUMENT, TOKEN`` ``CATEGORY`` + ====================== ====================== + + Parameters + ---------- + batchSize + Batch size. Large values allows faster processing but requires more + memory, by default 8 + caseSensitive + Whether to ignore case in tokens for embeddings matching, by default + True + configProtoBytes + ConfigProto from tensorflow, serialized into byte array. + maxSentenceLength + Max sentence length to process, by default 128 + coalesceSentences + Instead of 1 class per sentence (if inputCols is `sentence`) output 1 + class per document by averaging probabilities in all sentences, by + default False + activation + Whether to calculate logits via Softmax or Sigmoid, by default + `"softmax"`. + + Examples + -------- + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> documentAssembler = DocumentAssembler() \\ + ... .setInputCol("text") \\ + ... .setOutputCol("document") + >>> tokenizer = Tokenizer() \\ + ... .setInputCols(["document"]) \\ + ... .setOutputCol("token") + >>> sequenceClassifier = AlbertForZeroShotClassification.pretrained() \\ + ... .setInputCols(["token", "document"]) \\ + ... .setOutputCol("label") \\ + ... .setCaseSensitive(True) + >>> pipeline = Pipeline().setStages([ + ... documentAssembler, + ... tokenizer, + ... sequenceClassifier + ... ]) + >>> data = spark.createDataFrame([["I have a problem with my iphone that needs to be resolved asap!!"]]).toDF("text") + >>> result = pipeline.fit(data).transform(data) + >>> result.select("label.result").show(truncate=False) + +---------+ + |result | + +---------+ + |[urgent] | + +---------+ + """ + name = "AlbertForZeroShotClassification" + + inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.TOKEN] + + outputAnnotatorType = AnnotatorType.CATEGORY + + configProtoBytes = Param(Params._dummy(), + "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()", + TypeConverters.toListInt) + + coalesceSentences = Param(Params._dummy(), "coalesceSentences", + "Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document by averaging probabilities in all sentences.", + TypeConverters.toBoolean) + + def getClasses(self): + """ + Returns labels used to train this model + """ + return self._call_java("getClasses") + + def setConfigProtoBytes(self, b): + """Sets configProto from tensorflow, serialized into byte array. + + Parameters + ---------- + b : List[int] + ConfigProto from tensorflow, serialized into byte array + """ + return self._set(configProtoBytes=b) + + def setCoalesceSentences(self, value): + """Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document by averaging + probabilities in all sentences. Due to max sequence length limit in almost all transformer models such as Bart + (512 tokens), this parameter helps to feed all the sentences into the model and averaging all the probabilities + for the entire document instead of probabilities per sentence. (Default: true) + + Parameters + ---------- + value : bool + If the output of all sentences will be averaged to one output + """ + return self._set(coalesceSentences=value) + + @keyword_only + def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.AlbertForZeroShotClassification", + java_model=None): + super(AlbertForZeroShotClassification, self).__init__( + classname=classname, + java_model=java_model + ) + self._setDefault( + batchSize=8, + maxSentenceLength=128, + caseSensitive=True, + coalesceSentences=False, + activation="softmax" + ) + + @staticmethod + def loadSavedModel(folder, spark_session): + """Loads a locally saved model. + + Parameters + ---------- + folder : str + Folder of the saved model + spark_session : pyspark.sql.SparkSession + The current SparkSession + + Returns + ------- + AlbertForZeroShotClassification + The restored model + """ + from sparknlp.internal import _AlbertForZeroShotClassificationLoader + jModel = _AlbertForZeroShotClassificationLoader(folder, spark_session._jsparkSession)._java_obj + return AlbertForZeroShotClassification(java_model=jModel) + + @staticmethod + def pretrained(name="albert_zero_shot_classifier_onnx", lang="en", remote_loc=None): + """Downloads and loads a pretrained model. + + Parameters + ---------- + name : str, optional + Name of the pretrained model, by default + "albert_zero_shot_classifier_onnx" + lang : str, optional + Language of the pretrained model, by default "en" + remote_loc : str, optional + Optional remote address of the resource, by default None. Will use + Spark NLPs repositories otherwise. + + Returns + ------- + BartForZeroShotClassification + The restored model + """ + from sparknlp.pretrained import ResourceDownloader + return ResourceDownloader.downloadModel(AlbertForZeroShotClassification, name, lang, remote_loc) \ No newline at end of file diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index deeff9c5189f52..56bf0ef2397e9f 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -58,6 +58,15 @@ def __init__(self, path, jspark): ) +class _AlbertForZeroShotClassificationLoader(ExtendedJavaWrapper): + def __init__(self, path, jspark): + super(_AlbertForZeroShotClassificationLoader, self).__init__( + "com.johnsnowlabs.nlp.annotators.classifier.dl.AlbertForZeroShotClassification.loadSavedModel", + path, + jspark, + ) + + class _BertLoader(ExtendedJavaWrapper): def __init__(self, path, jspark, use_openvino=False): super(_BertLoader, self).__init__( diff --git a/python/test/annotator/classifier_dl/albert_for_zero_shot_classification_test.py b/python/test/annotator/classifier_dl/albert_for_zero_shot_classification_test.py new file mode 100644 index 00000000000000..7afda4131c4470 --- /dev/null +++ b/python/test/annotator/classifier_dl/albert_for_zero_shot_classification_test.py @@ -0,0 +1,60 @@ +# Copyright 2017-2024 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest + +from sparknlp.annotator import * +from sparknlp.base import * +from test.annotator.common.has_max_sentence_length_test import HasMaxSentenceLengthTests +from test.util import SparkContextForTest + + +@pytest.mark.slow +class AlbertForZeroShotClassificationTestSpec(unittest.TestCase, HasMaxSentenceLengthTests): + def setUp(self): + self.text = "I have a problem with my iphone that needs to be resolved asap!!" + self.data = SparkContextForTest.spark \ + .createDataFrame([[self.text]]).toDF("text") + self.candidate_labels = ["urgent", "mobile", "technology"] + + self.tested_annotator = AlbertForZeroShotClassification \ + .pretrained()\ + .setInputCols(["document", "token"]) \ + .setOutputCol("multi_class") \ + .setCandidateLabels(self.candidate_labels) + + def test_run(self): + document_assembler = DocumentAssembler() \ + .setInputCol("text") \ + .setOutputCol("document") + + tokenizer = Tokenizer().setInputCols("document").setOutputCol("token") + + doc_classifier = self.tested_annotator + + pipeline = Pipeline(stages=[ + document_assembler, + tokenizer, + doc_classifier + ]) + + model = pipeline.fit(self.data) + model.transform(self.data).show() + + light_pipeline = LightPipeline(model) + annotations_result = light_pipeline.fullAnnotate(self.text) + multi_class_result = annotations_result[0]["multi_class"][0].result + self.assertIn(multi_class_result, self.candidate_labels) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala index f1483553faac5d..23487fcbaebdc5 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala @@ -23,9 +23,10 @@ import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignat import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ +import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.BasicTokenizer import com.johnsnowlabs.nlp.{ActivationFunction, Annotation} -import org.tensorflow.ndarray.buffer.IntDataBuffer import org.slf4j.{Logger, LoggerFactory} +import org.tensorflow.ndarray.buffer.IntDataBuffer import scala.collection.JavaConverters._ @@ -88,7 +89,19 @@ private[johnsnowlabs] class AlbertClassification( def tokenizeSeqString( candidateLabels: Seq[String], maxSeqLength: Int, - caseSensitive: Boolean): Seq[WordpieceTokenizedSentence] = ??? + caseSensitive: Boolean): Seq[WordpieceTokenizedSentence] = { + val basicTokenizer = new BasicTokenizer(caseSensitive) + val encoder = + new SentencepieceEncoder(spp, caseSensitive, sentencePieceDelimiterId, pieceIdOffset = 1) + + val labelsToSentences = candidateLabels.map { s => Sentence(s, 0, s.length - 1, 0) } + + labelsToSentences.map(label => { + val tokens = basicTokenizer.tokenize(label) + val wordpieceTokens = tokens.flatMap(token => encoder.encode(token)).take(maxSeqLength) + WordpieceTokenizedSentence(wordpieceTokens) + }) + } def tokenizeDocument( docs: Seq[Annotation], @@ -262,7 +275,30 @@ private[johnsnowlabs] class AlbertClassification( batch: Seq[Array[Int]], entailmentId: Int, contradictionId: Int, - activation: String): Array[Array[Float]] = ??? + activation: String): Array[Array[Float]] = { + + val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max + val paddedBatch = batch.map(arr => padArrayWithZeros(arr, maxSentenceLength)) + val batchLength = paddedBatch.length + + val rawScores = detectedEngine match { + case TensorFlow.name => getRawScoresWithTF(paddedBatch, maxSentenceLength) + case ONNX.name => getRawScoresWithOnnx(paddedBatch, maxSentenceLength, sequence = true) + } + + val dim = rawScores.length / batchLength + rawScores + .grouped(dim) + .toArray + } + + private def padArrayWithZeros(arr: Array[Int], maxLength: Int): Array[Int] = { + if (arr.length >= maxLength) { + arr + } else { + arr ++ Array.fill(maxLength - arr.length)(0) + } + } def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = { val batchLength = batch.length diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotator.scala b/src/main/scala/com/johnsnowlabs/nlp/annotator.scala index 60655bda2809ec..172b2b03812b4e 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotator.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotator.scala @@ -432,6 +432,13 @@ package object annotator { extends ReadablePretrainedAlbertForTokenModel with ReadAlbertForTokenDLModel + type AlbertForZeroShotClassification = + com.johnsnowlabs.nlp.annotators.classifier.dl.AlbertForZeroShotClassification + + object AlbertForZeroShotClassification + extends ReadablePretrainedAlbertForZeroShotModel + with ReadAlbertForZeroShotDLModel + type XlnetForTokenClassification = com.johnsnowlabs.nlp.annotators.classifier.dl.XlnetForTokenClassification diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForZeroShotClassification.scala new file mode 100644 index 00000000000000..6cb650237a604d --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForZeroShotClassification.scala @@ -0,0 +1,402 @@ +/* + * Copyright 2017-2024 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.classifier.dl + +import com.johnsnowlabs.ml.ai.AlbertClassification +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} +import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ + ReadSentencePieceModel, + SentencePieceWrapper, + WriteSentencePieceModel +} +import com.johnsnowlabs.ml.tensorflow.{ + ReadTensorflowModel, + TensorflowWrapper, + WriteTensorflowModel +} +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadSentencePieceAsset, + loadTextAsset, + modelSanityCheck, + notSupportedEngineError +} +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} +import com.johnsnowlabs.nlp._ +import com.johnsnowlabs.nlp.annotators.common.{SentenceSplit, TokenizedWithSentence} +import com.johnsnowlabs.nlp.serialization.MapFeature +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param.{BooleanParam, IntArrayParam, IntParam} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.SparkSession + +class AlbertForZeroShotClassification(override val uid: String) + extends AnnotatorModel[AlbertForZeroShotClassification] + with HasBatchedAnnotate[AlbertForZeroShotClassification] + with WriteTensorflowModel + with WriteOnnxModel + with WriteSentencePieceModel + with HasCaseSensitiveProperties + with HasClassifierActivationProperties + with HasEngine + with HasCandidateLabelsProperties { + + /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator + * type + */ + def this() = this(Identifiable.randomUID("ALBERT_FOR_ZERO_SHOT_CLASSIFICATION")) + + /** Input Annotator Types: DOCUMENT, TOKEN + * + * @group anno + */ + override val inputAnnotatorTypes: Array[AnnotatorType] = + Array(AnnotatorType.DOCUMENT, AnnotatorType.TOKEN) + + /** Output Annotator Types: CATEGORY + * + * @group anno + */ + override val outputAnnotatorType: AnnotatorType = AnnotatorType.CATEGORY + + /** Labels used to decode predicted IDs back to string tags + * + * @group param + */ + val labels: MapFeature[String, Int] = new MapFeature(this, "labels").setProtected() + + /** @group setParam */ + def setLabels(value: Map[String, Int]): this.type = { + if (get(labels).isEmpty) + set(labels, value) + this + } + + /** Returns labels used to train this model */ + def getClasses: Array[String] = { + $$(labels).keys.toArray + } + + /** Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document + * by averaging probabilities in all sentences (Default: `false`). + * + * Due to max sequence length limit in almost all transformer models such as DeBerta (512 + * tokens), this parameter helps feeding all the sentences into the model and averaging all the + * probabilities for the entire document instead of probabilities per sentence. + * + * @group param + */ + val coalesceSentences = new BooleanParam( + this, + "coalesceSentences", + "If sets to true the output of all sentences will be averaged to one output instead of one output per sentence. Defaults to false.") + + /** @group setParam */ + def setCoalesceSentences(value: Boolean): this.type = set(coalesceSentences, value) + + /** @group getParam */ + def getCoalesceSentences: Boolean = $(coalesceSentences) + + /** ConfigProto from tensorflow, serialized into byte array. Get with + * `config_proto.SerializeToString()` + * + * @group param + */ + val configProtoBytes = new IntArrayParam( + this, + "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()") + + /** @group setParam */ + def setConfigProtoBytes(bytes: Array[Int]): AlbertForZeroShotClassification.this.type = + set(this.configProtoBytes, bytes) + + /** @group getParam */ + def getConfigProtoBytes: Option[Array[Byte]] = get(this.configProtoBytes).map(_.map(_.toByte)) + + /** Max sentence length to process (Default: `128`) + * + * @group param + */ + val maxSentenceLength = + new IntParam(this, "maxSentenceLength", "Max sentence length to process") + + /** @group setParam */ + def setMaxSentenceLength(value: Int): this.type = { + require( + value <= 512, + "DeBerta models do not support sequences longer than 512 because of trainable positional embeddings.") + require(value >= 1, "The maxSentenceLength must be at least 1") + set(maxSentenceLength, value) + this + } + + /** @group getParam */ + def getMaxSentenceLength: Int = $(maxSentenceLength) + + /** It contains TF model signatures for the laded saved model + * + * @group param + */ + val signatures = + new MapFeature[String, String](model = this, name = "signatures").setProtected() + + /** @group setParam */ + def setSignatures(value: Map[String, String]): this.type = { + set(signatures, value) + this + } + + /** @group getParam */ + def getSignatures: Option[Map[String, String]] = get(this.signatures) + + private var _model: Option[Broadcast[AlbertClassification]] = None + + /** @group setParam */ + def setModelIfNotSet( + spark: SparkSession, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper], + spp: SentencePieceWrapper): AlbertForZeroShotClassification = { + if (_model.isEmpty) { + _model = Some( + spark.sparkContext.broadcast( + new AlbertClassification( + tensorflowWrapper, + onnxWrapper, + spp, + configProtoBytes = getConfigProtoBytes, + tags = $$(labels), + signatures = getSignatures))) + } + + this + } + + /** @group getParam */ + def getModelIfNotSet: AlbertClassification = _model.get.value + + /** Whether to lowercase tokens or not (Default: `true`). + * + * @group setParam + */ + override def setCaseSensitive(value: Boolean): this.type = { + set(this.caseSensitive, value) + } + + setDefault( + batchSize -> 8, + maxSentenceLength -> 128, + caseSensitive -> true, + coalesceSentences -> false) + + /** takes a document and annotations and produces new annotations of this annotator's annotation + * type + * + * @param batchedAnnotations + * Annotations in batches that correspond to inputAnnotationCols generated by previous + * annotators if any + * @return + * any number of annotations processed for every batch of input annotations. Not necessary + * one to one relationship + * + * IMPORTANT: !MUST! return sequences of equal lengths !! IMPORTANT: !MUST! return sentences + * that belong to the same original row !! (challenging) + */ + override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = { + batchedAnnotations.map(annotations => { + val sentences = SentenceSplit.unpack(annotations).toArray + val tokenizedSentences = TokenizedWithSentence.unpack(annotations).toArray + + if (tokenizedSentences.nonEmpty) { + getModelIfNotSet.predictSequenceWithZeroShot( + tokenizedSentences, + sentences, + $(candidateLabels), + $(entailmentIdParam), + $(contradictionIdParam), + $(batchSize), + $(maxSentenceLength), + $(caseSensitive), + $(coalesceSentences), + $$(labels), + getActivation) + + } else { + Seq.empty[Annotation] + } + }) + } + + override def onWrite(path: String, spark: SparkSession): Unit = { + super.onWrite(path, spark) + val suffix = "_albert_classification" + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + suffix, + AlbertForSequenceClassification.tfFile) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + AlbertForSequenceClassification.onnxFile) + } + + writeSentencePieceModel( + path, + spark, + getModelIfNotSet.spp, + "_albert", + AlbertForSequenceClassification.sppFile) + } + +} + +trait ReadablePretrainedAlbertForZeroShotModel + extends ParamsAndFeaturesReadable[AlbertForZeroShotClassification] + with HasPretrained[AlbertForZeroShotClassification] { + override val defaultModelName: Some[String] = Some("albert_zero_shot_classifier_onnx") + override val defaultLang: String = "en" + + /** Java compliant-overrides */ + override def pretrained(): AlbertForZeroShotClassification = super.pretrained() + + override def pretrained(name: String): AlbertForZeroShotClassification = + super.pretrained(name) + + override def pretrained(name: String, lang: String): AlbertForZeroShotClassification = + super.pretrained(name, lang) + + override def pretrained( + name: String, + lang: String, + remoteLoc: String): AlbertForZeroShotClassification = + super.pretrained(name, lang, remoteLoc) +} + +trait ReadAlbertForZeroShotDLModel + extends ReadTensorflowModel + with ReadOnnxModel + with ReadSentencePieceModel { + this: ParamsAndFeaturesReadable[AlbertForZeroShotClassification] => + + override val tfFile: String = "albert_classification_tensorflow" + override val onnxFile: String = "albert_classification_onnx" + override val sppFile: String = "albert_spp" + + def readModel( + instance: AlbertForZeroShotClassification, + path: String, + spark: SparkSession): Unit = { + + val spp = readSentencePieceModel(path, spark, "_albert_spp", sppFile) + + instance.getEngine match { + case TensorFlow.name => + val tfWrapper = readTensorflowModel(path, spark, "_albert_classification_tf") + instance.setModelIfNotSet(spark, Some(tfWrapper), None, spp) + case ONNX.name => + val onnxWrapper = + readOnnxModel( + path, + spark, + "albert_zero_classification_onnx", + zipped = true, + useBundle = false, + None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) + case _ => + throw new Exception(notSupportedEngineError) + } + } + + addReader(readModel) + + def loadSavedModel(modelPath: String, spark: SparkSession): AlbertForZeroShotClassification = { + + val (localModelPath, detectedEngine) = modelSanityCheck(modelPath) + + val spModel = loadSentencePieceAsset(localModelPath, "spiece.model") + val labels = loadTextAsset(localModelPath, "labels.txt").zipWithIndex.toMap + + val entailmentIds = labels.filter(x => x._1.toLowerCase().startsWith("entail")).values.toArray + val contradictionIds = + labels.filter(x => x._1.toLowerCase().startsWith("contradict")).values.toArray + + require( + entailmentIds.length == 1 && contradictionIds.length == 1, + s"""This annotator supports classifiers trained on NLI datasets. You must have only at least 2 or maximum 3 labels in your dataset: + + example with 3 labels: 'contradict', 'neutral', 'entailment' + example with 2 labels: 'contradict', 'entailment' + + You can modify assets/labels.txt file to match the above format. + + Current labels: ${labels.keys.mkString(", ")} + """) + + val annotatorModel = new AlbertForZeroShotClassification() + .setLabels(labels) + .setCandidateLabels(labels.keys.toArray) + + /* set the entailment id */ + annotatorModel.set(annotatorModel.entailmentIdParam, entailmentIds.head) + /* set the contradiction id */ + annotatorModel.set(annotatorModel.contradictionIdParam, contradictionIds.head) + /* set the engine */ + annotatorModel.set(annotatorModel.engine, detectedEngine) + + detectedEngine match { + case TensorFlow.name => + val (wrapper, signatures) = + TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) + + val _signatures = signatures match { + case Some(s) => s + case None => throw new Exception("Cannot load signature definitions from model!") + } + + /** the order of setSignatures is important if we use getSignatures inside + * setModelIfNotSet + */ + annotatorModel + .setSignatures(_signatures) + .setModelIfNotSet(spark, Some(wrapper), None, spModel) + case ONNX.name => + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) + annotatorModel.setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) + case _ => + throw new Exception(notSupportedEngineError) + } + + annotatorModel + } +} + +/** This is the companion object of [[AlbertForZeroShotClassification]]. Please refer to that + * class for the documentation. + */ +object AlbertForZeroShotClassification + extends ReadablePretrainedAlbertForZeroShotModel + with ReadAlbertForZeroShotDLModel diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForZeroShotClassificationTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForZeroShotClassificationTestSpec.scala new file mode 100644 index 00000000000000..06980c5e36eae2 --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForZeroShotClassificationTestSpec.scala @@ -0,0 +1,66 @@ +/* + * Copyright 2017-2024 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.classifier.dl + +import com.johnsnowlabs.nlp.annotators.Tokenizer +import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.tags.SlowTest +import org.apache.spark.ml.Pipeline +import org.apache.spark.sql.functions.explode +import org.scalatest.flatspec.AnyFlatSpec + +class AlbertForZeroShotClassificationTestSpec extends AnyFlatSpec { + + "AlbertForZeroShotClassification" should "correctly load custom ONNX model" taggedAs SlowTest in { + import ResourceHelper.spark.implicits._ + + val dataDf = + Seq("I have a problem with my iphone that needs to be resolved asap!!").toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val tokenizer = new Tokenizer() + .setInputCols(Array("document")) + .setOutputCol("token") + + val zeroShotClassifier = AlbertForZeroShotClassification + .pretrained() + .setInputCols(Array("document", "token")) + .setOutputCol("multi_class") + .setCaseSensitive(true) + .setCoalesceSentences(true) + .setCandidateLabels(Array("urgent", "mobile", "technology")) + + val pipeline = new Pipeline().setStages(Array(document, tokenizer, zeroShotClassifier)) + + val pipelineModel = pipeline.fit(dataDf) + val pipelineDF = pipelineModel.transform(dataDf) + + pipelineDF.select("multi_class").show(false) + val totalDocs = pipelineDF.select(explode($"document.result")).count.toInt + val totalLabels = pipelineDF.select(explode($"multi_class.result")).count.toInt + + println(s"total tokens: $totalDocs") + println(s"total labels: $totalLabels") + + assert(totalDocs == totalLabels) + } + +}