From 31e5bdee1e6a5284d1a359d1f41414acfac8feba Mon Sep 17 00:00:00 2001 From: Danilo Burbano <37355249+danilojsl@users.noreply.github.com> Date: Tue, 18 Jul 2023 06:18:18 -0500 Subject: [PATCH] SPARKNLP-867 Refactors getActivation to take multilabel into account (#13888) --- .../bert_for_zero_shot_classification.py | 3 ++- .../nlp/HasClassifierActivationProperties.scala | 11 ++++++++--- .../classifier/dl/BertForZeroShotClassification.scala | 2 +- .../dl/DistilBertForZeroShotClassification.scala | 2 +- .../dl/RoBertaForZeroShotClassification.scala | 2 +- 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/python/sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py b/python/sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py index 742dcb6cc829f4..24787abc59d7ce 100755 --- a/python/sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py +++ b/python/sparknlp/annotator/classifier_dl/bert_for_zero_shot_classification.py @@ -164,7 +164,8 @@ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.Bert maxSentenceLength=128, caseSensitive=True, coalesceSentences=False, - activation="softmax" + activation="softmax", + multilabel=False ) @staticmethod diff --git a/src/main/scala/com/johnsnowlabs/nlp/HasClassifierActivationProperties.scala b/src/main/scala/com/johnsnowlabs/nlp/HasClassifierActivationProperties.scala index 9d92e7a02482e6..2d88385091bc9e 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/HasClassifierActivationProperties.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/HasClassifierActivationProperties.scala @@ -52,10 +52,13 @@ trait HasClassifierActivationProperties extends ParamsAndFeaturesWritable { "multilabel", "Whether or not the result should be multi-class (the sum of all probabilities is 1.0) or multi-label (each label has a probability between 0.0 to 1.0). Default is False i.e. multi-class") - setDefault(activation -> ActivationFunction.softmax, threshold -> 0.5f, multilabel -> false) - /** @group getParam */ - def getActivation: String = $(activation) + def getActivation: String = { + val activation = + if ($(multilabel)) ActivationFunction.sigmoid else ActivationFunction.softmax + + if ($(multilabel)) activation else $(this.activation) + } /** @group setParam */ def setActivation(value: String): this.type = { @@ -94,6 +97,8 @@ trait HasClassifierActivationProperties extends ParamsAndFeaturesWritable { set(this.multilabel, value) } + setDefault(activation -> ActivationFunction.softmax, threshold -> 0.5f, multilabel -> false) + } object ActivationFunction { diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala index 6c6ddc35140d1a..0cc57e366a1301 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala @@ -325,7 +325,7 @@ class BertForZeroShotClassification(override val uid: String) $(caseSensitive), $(coalesceSentences), $$(labels), - $(activation)) + getActivation) } else { Seq.empty[Annotation] diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForZeroShotClassification.scala index b1afba431726d2..0726cf35b5ca7f 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForZeroShotClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForZeroShotClassification.scala @@ -325,7 +325,7 @@ class DistilBertForZeroShotClassification(override val uid: String) $(caseSensitive), $(coalesceSentences), $$(labels), - $(activation)) + getActivation) } else { Seq.empty[Annotation] diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForZeroShotClassification.scala index 60041627854e15..e66e8b59627804 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForZeroShotClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForZeroShotClassification.scala @@ -339,7 +339,7 @@ class RoBertaForZeroShotClassification(override val uid: String) $(caseSensitive), $(coalesceSentences), $$(labels), - $(activation)) + getActivation) } else { Seq.empty[Annotation]