From 69b44b35615c66590509b96d04cb22f959fa7c62 Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Sun, 9 Jun 2024 11:49:17 +0000 Subject: [PATCH 1/2] fixed LLAMA generation bug --- .../scala/com/johnsnowlabs/ml/ai/LLAMA2.scala | 2 +- .../ml/ai/util/Generation/Generate.scala | 5 +- .../Logit/LogitWarper/TopKLogitWarper.scala | 20 +++---- .../Logit/LogitWarper/TopPLogitWarper.scala | 58 ++++++++----------- .../seq2seq/LLAMA2Transformer.scala | 2 +- 5 files changed, 35 insertions(+), 52 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala b/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala index 13968ce48cab3a..ed3444a3059ee2 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala @@ -164,7 +164,7 @@ private[johnsnowlabs] class LLAMA2( randomSeed, ignoreTokenIdsInt, session, - applySoftmax = false, + applySoftmax = true, ovInferRequest = ovInferRequest) modelOutputs diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala index b983d8565df3fe..4e4140f7735ab2 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala @@ -392,10 +392,7 @@ trait Generate { seededRandom = new scala.util.Random(seed.get) } for (i <- 0 until k) { - var rand = scala.util.Random.nextDouble() - if (seed.isDefined) { - rand = new scala.util.Random(seed.get).nextDouble() - } + val rand = seededRandom.nextDouble() var cumProb = 0.0 var j = 0 while (j < probabilities.length - i) { diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopKLogitWarper.scala b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopKLogitWarper.scala index 6ec064de2b7333..4d60a0e1684eda 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopKLogitWarper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopKLogitWarper.scala @@ -20,27 +20,25 @@ import scala.collection.mutable.ArrayBuffer class TopKLogitWarper( val k: Int, val filterValue: Float = Float.NegativeInfinity, - val minTokensToKeep: Int = 1) + val minTokensToKeep: Int = 100) extends LogitWarper { override def call( inputIds: Seq[Array[Int]], scores: Array[Array[Float]], currentLength: Int): Array[Array[Float]] = { - var logitsUpd = scores - val logitsShape = Array(scores.length, scores(0).length) - if (k > 0) { - val topKup = k.max(minTokensToKeep).min(logitsShape.last) // Safety check + val logitsUpd = scores.map(_.clone()) // Deep copy of the scores - /** Remove all tokens with a probability less than the last token of the top-k */ + if (k > 0) { + val logitsShape = Array(scores.length, scores.head.length) + val effectiveTopK = k.max(minTokensToKeep).min(logitsShape.last) // Safety check - val topKLogits = new ArrayBuffer[Array[Float]]() - for (logits <- scores) { - val topKIndices = getTopKIndices(logits, topKup) + for ((logits, i) <- scores.zipWithIndex) { + val topKIndices = getTopKIndices(logits, effectiveTopK) val maskedValues = maskNotTopKValues(logits, topKIndices) - topKLogits += maskedValues + logitsUpd(i) = maskedValues } - topKLogits.toArray } + logitsUpd } diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopPLogitWarper.scala b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopPLogitWarper.scala index f96c87c11eefdc..85e0dcf0e2893a 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopPLogitWarper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopPLogitWarper.scala @@ -21,46 +21,34 @@ class TopPLogitWarper(val p: Double, val minTokensToKeep: Int = 1) extends Logit inputIds: Seq[Array[Int]], scores: Array[Array[Float]], currentLength: Int): Array[Array[Float]] = { - var scoresUpd = scores - val scoresShape = Array(scores.length, scores(0).length) - if (this.p < 1.0) { - val (sortedscores, sortedIndices) = scores(0).zipWithIndex.sorted.reverse.unzip + val logitsUpd = scores.map(_.clone()) // Deep copy of the scores - val cumulativeProbs = this.scanLeft(this.softmax(sortedscores))(0.0)(_ + _).drop(1) + if (p < 1.0) { + val scoresFiltered = scores.map(_.filterNot(_.isInfinite)) // Filter out infinite values + val scoresShape = Array(scoresFiltered.length, scoresFiltered.head.length) + val topPThreshold = math.ceil(p * scoresShape.last).toInt // Determine top-p threshold - /** Remove tokens with cumulative probability above the threshold (token with 0 are kept) */ - var sortedIndicesToRemove = - for (prob <- cumulativeProbs) - yield if (prob > this.p) true else false - - if (minTokensToKeep > 1) { - - /** Keep at least minTokensToKeep (set to minTokensToKeep-1 because we add the first one - * below) - */ - sortedIndicesToRemove = List.fill(sortedIndicesToRemove.take(minTokensToKeep).length)( - false) ++ sortedIndicesToRemove.drop(minTokensToKeep) + for ((logits, i) <- scores.zipWithIndex) { + val topPIndices = getTopPIndices(logits, topPThreshold) + val maskedValues = maskNotTopPValues(logits, topPIndices) + logitsUpd(i) = maskedValues } + } + + logitsUpd + } - /** Shift the indices to the right to keep also the first token above the threshold */ - sortedIndicesToRemove = sortedIndicesToRemove.takeRight(1) ++ sortedIndicesToRemove - .dropRight(1) - sortedIndicesToRemove = - List.fill(sortedIndicesToRemove.take(1).length)(false) ++ sortedIndicesToRemove - .drop(1) + private def getTopPIndices(logits: Array[Float], k: Int): Array[Int] = { + logits.zipWithIndex.sortBy(-_._1).take(k).map(_._2) + } - /** scatter sorted tensors to original indexing */ - val indicesToRemove = - this.scatterValuesOnBatchIndices(sortedIndicesToRemove.toList, sortedIndices) - scoresUpd = - for ((nextTokenLogit, indexToRemove) <- scores.zip( - IndexedSeq.fill(scores.length)(indicesToRemove))) - yield setTensorByIndicesToValue( - nextTokenLogit, - indexToRemove.toIndexedSeq, - Float.NegativeInfinity) + private def maskNotTopPValues(logits: Array[Float], topPIndices: Array[Int]): Array[Float] = { + val maskedValues = logits.clone() + for (i <- logits.indices) { + if (!topPIndices.contains(i)) { + maskedValues(i) = Float.NegativeInfinity + } } - scoresUpd + maskedValues } - } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala index 9095b7aacdd617..f1f0fd62c906d2 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala @@ -227,7 +227,7 @@ class LLAMA2Transformer(override val uid: String) minOutputLength -> 0, maxOutputLength -> 20, doSample -> false, - temperature -> 0.6, + temperature -> 0.9, topK -> 50, topP -> 0.9, repetitionPenalty -> 1.0, From bf16ddb254a64071898686b1dc6af928d8a63612 Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Sun, 9 Jun 2024 13:21:13 +0000 Subject: [PATCH 2/2] update params --- .../nlp/annotators/seq2seq/LLAMA2Transformer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala index f1f0fd62c906d2..b9c114ea62de5f 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala @@ -228,10 +228,10 @@ class LLAMA2Transformer(override val uid: String) maxOutputLength -> 20, doSample -> false, temperature -> 0.9, - topK -> 50, + topK -> 100, topP -> 0.9, repetitionPenalty -> 1.0, - noRepeatNgramSize -> 3, + noRepeatNgramSize -> 0, ignoreTokenIds -> Array(), batchSize -> 1, beamSize -> 1,