Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

812 implement de berta for zero shot classification annotator #14151

Conversation

ahmedlone127
Copy link
Contributor

adding DeBertaForZeroShotClassification ( note needs to be reviewed to see why predictions( Very different from huggingface predictions and largely incorrect ) are not being made correctly , same as BartForZeroShotClassifications

@maziyarpanahi maziyarpanahi self-assigned this Jan 27, 2024
@maziyarpanahi maziyarpanahi added new-feature Introducing a new feature new model DON'T MERGE Do not merge this PR labels Jan 27, 2024
@maziyarpanahi
Copy link
Member

@DevinTDHa Could you please have a look and see why the predictions are different from HuggingFace?

@DevinTDHa
Copy link
Member

DevinTDHa commented Jan 27, 2024

Hi @maziyarpanahi @ahmedlone127, I encountered this issue before and propose a solution:

Underlying Cause

I believe the issue could be part of the following. Note that this might be specific to implementations that I looked at, namely RoBertaForQA and MPNetForQA. It might differ for different models.

  1. Attention mask tensors are wrongly encoded. They should be an array of ones, but instead we are creating it based on the token ids (token 0 could be a valid ID and not a special token, depending on the model)
    batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)
  2. Logits are not pre-processed before calculating probabilities with softmax. In the original implementations, special tokens and question tokens are not included in the start/end probabilities. This is done by setting the log-probability for these tokens to -10000 which is 0 probability.
    def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = {
    val batchLength = batch.length
    val (startLogits, endLogits) = detectedEngine match {
    case ONNX.name => computeLogitsWithOnnx(batch)
    case _ => computeLogitsWithTF(batch)
    }
    val endDim = endLogits.length / batchLength
    val endScores: Array[Array[Float]] =
    endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray
    val startDim = startLogits.length / batchLength
    val startScores: Array[Array[Float]] =
    startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray
    (startScores, endScores)
    }
  3. Final Score is not correctly calculated. it should be the likelihood (startScore * endScore) instead of the mean.
    "score" -> ((startIndex._1 + endIndex._1) / 2).toString)))

Fix

I fixed these issues for RobertaForQA and the upcoming MPNetForQA by overriding the base functions in the specific implementation file (RobertaClassification.scala and MPNetForClassification.scala).

They are included in the respective PR: #14147

I chose not to fix the XXXForClassification.scala file, as I don't know if the other annotators are also affected that use it. It could be model specific.

So going forward, I'd suggest looking at my implementation and checking if the scores are better than before. The implementation is here:

override def predictSpan(
documents: Seq[Annotation],
maxSentenceLength: Int,
caseSensitive: Boolean,
mergeTokenStrategy: String = MergeTokenStrategy.vocab,
engine: String = TensorFlow.name): Seq[Annotation] = {
val questionAnnot = Seq(documents.head)
val contextAnnot = documents.drop(1)
val wordPieceTokenizedQuestion =
tokenizeDocument(questionAnnot, maxSentenceLength, caseSensitive)
val wordPieceTokenizedContext =
tokenizeDocument(contextAnnot, maxSentenceLength, caseSensitive)
val contextLength = wordPieceTokenizedContext.head.tokens.length
val questionLength = wordPieceTokenizedQuestion.head.tokens.length
val encodedInput =
encodeSequence(wordPieceTokenizedQuestion, wordPieceTokenizedContext, maxSentenceLength)
val (rawStartLogits, rawEndLogits) = tagSpan(encodedInput)
val (startScores, endScores) =
processLogits(rawStartLogits.head, rawEndLogits.head, questionLength, contextLength)
// Drop BOS token from valid results
val startIndex = startScores.zipWithIndex.drop(1).maxBy(_._1)
val endIndex = endScores.zipWithIndex.drop(1).maxBy(_._1)
val offsetStartIndex = 3 // 3 added special tokens
val offsetEndIndex = offsetStartIndex - 1
val allTokenPieces =
wordPieceTokenizedQuestion.head.tokens ++ wordPieceTokenizedContext.flatMap(x => x.tokens)
val decodedAnswer =
allTokenPieces.slice(startIndex._2 - offsetStartIndex, endIndex._2 - offsetEndIndex)
val content =
mergeTokenStrategy match {
case MergeTokenStrategy.vocab =>
decodedAnswer.filter(_.isWordStart).map(x => x.token).mkString(" ")
case MergeTokenStrategy.sentencePiece =>
val token = ""
decodedAnswer
.map(x =>
if (x.isWordStart) " " + token + x.token
else token + x.token)
.mkString("")
.trim
}
val totalScore = startIndex._1 * endIndex._1

Just as a note, the tolerance for the scores of my implementation are not that low (0.01 float tolerance). Close enough I suppose but I believe there might be some numerical issues for the implementation in general.

It can be tested like so:

implicit val tolerantEq = TolerantNumerics.tolerantFloatEquality(1e-2f)

@maziyarpanahi
Copy link
Member

  • Thanks @DevinTDHa appreciate sharing your experience
  • @ahmedlone127 could you please apply the changes and re-test?

@ahmedlone127
Copy link
Contributor Author

Thanks for the detailed explanation @DevinTDHa !

I will implement these steps and get back to you @maziyarpanahi

@ahmedlone127
Copy link
Contributor Author

Hello @maziyarpanahi , I took another look at the source code and found that by changing the offset value for the tokenizations (from 1 to 0 ) the predictions from the model became a lot more accurate.
image

@maziyarpanahi maziyarpanahi changed the base branch from master to release/530-release-candidate February 6, 2024 12:12
@maziyarpanahi maziyarpanahi merged commit 6566239 into release/530-release-candidate Feb 6, 2024
4 checks passed
@coveralls
Copy link

Pull Request Test Coverage Report for Build 7799480957

Warning: This coverage report may be inaccurate.

We've detected an issue with your CI configuration that might affect the accuracy of this pull request's coverage report.
To ensure accuracy in future PRs, please see these guidelines.
A quick fix for this PR: rebase it; your next report should be accurate.

  • -1 of 1 (0.0%) changed or added relevant line in 1 file are covered.
  • 7 unchanged lines in 5 files lost coverage.
  • Overall coverage decreased (-0.01%) to 62.756%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala 0 1 0.0%
Files with Coverage Reduction New Missed Lines %
src/main/scala/com/johnsnowlabs/nlp/pretrained/S3ResourceDownloader.scala 1 49.12%
src/main/scala/com/johnsnowlabs/util/Benchmark.scala 1 70.59%
src/main/scala/com/johnsnowlabs/util/ZipArchiveUtil.scala 1 87.67%
src/main/scala/com/johnsnowlabs/nlp/util/io/ResourceHelper.scala 2 47.02%
src/main/scala/com/johnsnowlabs/storage/StorageHelper.scala 2 59.18%
Totals Coverage Status
Change from base Build 7799297990: -0.01%
Covered Lines: 8954
Relevant Lines: 14268

💛 - Coveralls

maziyarpanahi added a commit that referenced this pull request Feb 27, 2024
…date

* fixed all sbt warnings

* remove file system url prefix (#14132)

* SPARKNLP-942: MPNet Classifiers (#14147)

* SPARKNLP-942: MPNetForSequenceClassification

* SPARKNLP-942: MPNetForQuestionAnswering

* SPARKNLP-942: MPNet Classifiers Documentation

* Restore RobertaforQA bugfix

* adding import notebook + changing default model + adding onnx support (#14158)

* Sparknlp 876: Introducing LLAMA2  (#14148)

* introducing LLAMA2

* Added option to read model from model path to onnx wrapper

* Added option to read model from model path to onnx wrapper

* updated text description

* LLAMA2 python API

* added method to save onnx_data

* added position ids

* - updated Generate.scala to accept onnx tensors
- added beam search support for LLAMA2

* updated max input length

* updated python default params
changed test to slow test

* fixed serialization bug

* Doc sim rank as retriever (#14149)

* Added retrieval interface to the doc sim rank approach

* Added Python interface as retriever in doc sim ranker

---------

Co-authored-by: Stefano Lori <s.lori@izicap.com>

* 812 implement de berta for zero shot classification annotator (#14151)

* adding code

* adding notebook for import

---------

Co-authored-by: Maziyar Panahi <maziyar.panahi@iscpif.fr>

* Add notebook for fine tuning sbert (#14152)

* [SPARKNLP-986] Fixing optional input col validations (#14153)

* [SPARKNLP-984] Fixing Deberta notebooks URIs (#14154)

* SparkNLP 933: Introducing M2M100 : multilingual translation model (#14155)

* introducing LLAMA2

* Added option to read model from model path to onnx wrapper

* Added option to read model from model path to onnx wrapper

* updated text description

* LLAMA2 python API

* added method to save onnx_data

* added position ids

* - updated Generate.scala to accept onnx tensors
- added beam search support for LLAMA2

* updated max input length

* updated python default params
changed test to slow test

* fixed serialization bug

* Added Scala code for M2M100

* Documentation for scala code

* Python API for M2M100

* added more tests for scala

* added tests for python

* added pretrained

* rewording

* fixed serialization bug

* fixed serialization bug

---------

Co-authored-by: Maziyar Panahi <maziyar.panahi@iscpif.fr>

* SPARKNLP-985: Add flexible naming for onnx_data (#14165)

Some annotators might have different naming schemes
for their files. Added a parameter to control this.

* Add LLAMA2Transformer and M2M100Transformer to annotator

* Add LLAMA2Transformer and M2M100Transformer to ResourceDownloader

* bump version to 5.3.0 [skip test]

* SPARKNLP-999: Fix remote model loading for some onnx models

* used filesystem to check for the onnx_data file (#14169)

* [SPARKNLP-940] Adding changes to correctly copy cluster index storage… (#14167)

* [SPARKNLP-940] Adding changes to correctly copy cluster index storage when defined

* [SPARKNLP-940] Moving local mode control to its right place

* [SPARKNLP-940] Refactoring sentToCLuster method

* [SPARKNLP-988] Updating EntityRuler documentation (#14168)

* [SPARKNLP-940] Adding changes to support storage temp directory (cluster_tmp_dir)

* SPARKNLP-1000: Disable init_all_tables for GPT2 (#14177)

Fixes `java.lang.IllegalArgumentException: No Operation named [init_all_tables] in the Graph` when model needs to be deserialized.
The deserialization is skipped when the modelis already loaded (so it will only appear on the worker nodes and not the driver)

GPT2 does not contain tables and so does not require this command.

* fixes python documentation (#14172)

* revert MarianTransformer.scala

* revert HasBatchedAnnotate.scala

* revert Preprocessor.scala

* Revert ViTClassifier.scala

* disable hard exception

* Replace hard exception with soft logs (#14179)

This reverts commit eb91fde.

* move the example from root to examples/ [skip test]

* Cleanup some code [skip test]

* Update onnxruntime to 1.17.0 [skip test]

* Fix M2M100 default model's name [skip test]

* Update docs [run doc]

* Update Scala and Python APIs

---------

Co-authored-by: ahmedlone127 <ahmedlone127@gmail.com>
Co-authored-by: Jiamao Zheng <jiamaozheng@users.noreply.github.com>
Co-authored-by: Devin Ha <33089471+DevinTDHa@users.noreply.github.com>
Co-authored-by: Prabod Rathnayaka <prabod@rathnayaka.me>
Co-authored-by: Stefano Lori <wolliq@users.noreply.github.com>
Co-authored-by: Stefano Lori <s.lori@izicap.com>
Co-authored-by: Danilo Burbano <37355249+danilojsl@users.noreply.github.com>
Co-authored-by: Devin Ha <t.ha@tu-berlin.de>
Co-authored-by: Danilo Burbano <danilo@johnsnowlabs.com>
Co-authored-by: github-actions <action@github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
DON'T MERGE Do not merge this PR new model new-feature Introducing a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants