Skip to content

Commit

Permalink
SPARKNLP-867 Refactors getActivation to take multilabel into account (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
danilojsl authored Jul 18, 2023
1 parent c07a4a7 commit 31e5bde
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.Bert
maxSentenceLength=128,
caseSensitive=True,
coalesceSentences=False,
activation="softmax"
activation="softmax",
multilabel=False
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ trait HasClassifierActivationProperties extends ParamsAndFeaturesWritable {
"multilabel",
"Whether or not the result should be multi-class (the sum of all probabilities is 1.0) or multi-label (each label has a probability between 0.0 to 1.0). Default is False i.e. multi-class")

setDefault(activation -> ActivationFunction.softmax, threshold -> 0.5f, multilabel -> false)

/** @group getParam */
def getActivation: String = $(activation)
def getActivation: String = {
val activation =
if ($(multilabel)) ActivationFunction.sigmoid else ActivationFunction.softmax

if ($(multilabel)) activation else $(this.activation)
}

/** @group setParam */
def setActivation(value: String): this.type = {
Expand Down Expand Up @@ -94,6 +97,8 @@ trait HasClassifierActivationProperties extends ParamsAndFeaturesWritable {
set(this.multilabel, value)
}

setDefault(activation -> ActivationFunction.softmax, threshold -> 0.5f, multilabel -> false)

}

object ActivationFunction {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ class BertForZeroShotClassification(override val uid: String)
$(caseSensitive),
$(coalesceSentences),
$$(labels),
$(activation))
getActivation)

} else {
Seq.empty[Annotation]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ class DistilBertForZeroShotClassification(override val uid: String)
$(caseSensitive),
$(coalesceSentences),
$$(labels),
$(activation))
getActivation)

} else {
Seq.empty[Annotation]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ class RoBertaForZeroShotClassification(override val uid: String)
$(caseSensitive),
$(coalesceSentences),
$$(labels),
$(activation))
getActivation)

} else {
Seq.empty[Annotation]
Expand Down

0 comments on commit 31e5bde

Please sign in to comment.