Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed LLAMA generation bug #14320

Merged
merged 2 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ private[johnsnowlabs] class LLAMA2(
randomSeed,
ignoreTokenIdsInt,
session,
applySoftmax = false,
applySoftmax = true,
ovInferRequest = ovInferRequest)

modelOutputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,7 @@ trait Generate {
seededRandom = new scala.util.Random(seed.get)
}
for (i <- 0 until k) {
var rand = scala.util.Random.nextDouble()
if (seed.isDefined) {
rand = new scala.util.Random(seed.get).nextDouble()
}
val rand = seededRandom.nextDouble()
var cumProb = 0.0
var j = 0
while (j < probabilities.length - i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,25 @@ import scala.collection.mutable.ArrayBuffer
class TopKLogitWarper(
val k: Int,
val filterValue: Float = Float.NegativeInfinity,
val minTokensToKeep: Int = 1)
val minTokensToKeep: Int = 100)
extends LogitWarper {
override def call(
inputIds: Seq[Array[Int]],
scores: Array[Array[Float]],
currentLength: Int): Array[Array[Float]] = {
var logitsUpd = scores
val logitsShape = Array(scores.length, scores(0).length)
if (k > 0) {
val topKup = k.max(minTokensToKeep).min(logitsShape.last) // Safety check
val logitsUpd = scores.map(_.clone()) // Deep copy of the scores

/** Remove all tokens with a probability less than the last token of the top-k */
if (k > 0) {
val logitsShape = Array(scores.length, scores.head.length)
val effectiveTopK = k.max(minTokensToKeep).min(logitsShape.last) // Safety check

val topKLogits = new ArrayBuffer[Array[Float]]()
for (logits <- scores) {
val topKIndices = getTopKIndices(logits, topKup)
for ((logits, i) <- scores.zipWithIndex) {
val topKIndices = getTopKIndices(logits, effectiveTopK)
val maskedValues = maskNotTopKValues(logits, topKIndices)
topKLogits += maskedValues
logitsUpd(i) = maskedValues
}
topKLogits.toArray
}

logitsUpd
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,46 +21,34 @@ class TopPLogitWarper(val p: Double, val minTokensToKeep: Int = 1) extends Logit
inputIds: Seq[Array[Int]],
scores: Array[Array[Float]],
currentLength: Int): Array[Array[Float]] = {
var scoresUpd = scores
val scoresShape = Array(scores.length, scores(0).length)
if (this.p < 1.0) {
val (sortedscores, sortedIndices) = scores(0).zipWithIndex.sorted.reverse.unzip
val logitsUpd = scores.map(_.clone()) // Deep copy of the scores

val cumulativeProbs = this.scanLeft(this.softmax(sortedscores))(0.0)(_ + _).drop(1)
if (p < 1.0) {
val scoresFiltered = scores.map(_.filterNot(_.isInfinite)) // Filter out infinite values
val scoresShape = Array(scoresFiltered.length, scoresFiltered.head.length)
val topPThreshold = math.ceil(p * scoresShape.last).toInt // Determine top-p threshold

/** Remove tokens with cumulative probability above the threshold (token with 0 are kept) */
var sortedIndicesToRemove =
for (prob <- cumulativeProbs)
yield if (prob > this.p) true else false

if (minTokensToKeep > 1) {

/** Keep at least minTokensToKeep (set to minTokensToKeep-1 because we add the first one
* below)
*/
sortedIndicesToRemove = List.fill(sortedIndicesToRemove.take(minTokensToKeep).length)(
false) ++ sortedIndicesToRemove.drop(minTokensToKeep)
for ((logits, i) <- scores.zipWithIndex) {
val topPIndices = getTopPIndices(logits, topPThreshold)
val maskedValues = maskNotTopPValues(logits, topPIndices)
logitsUpd(i) = maskedValues
}
}

logitsUpd
}

/** Shift the indices to the right to keep also the first token above the threshold */
sortedIndicesToRemove = sortedIndicesToRemove.takeRight(1) ++ sortedIndicesToRemove
.dropRight(1)
sortedIndicesToRemove =
List.fill(sortedIndicesToRemove.take(1).length)(false) ++ sortedIndicesToRemove
.drop(1)
private def getTopPIndices(logits: Array[Float], k: Int): Array[Int] = {
logits.zipWithIndex.sortBy(-_._1).take(k).map(_._2)
}

/** scatter sorted tensors to original indexing */
val indicesToRemove =
this.scatterValuesOnBatchIndices(sortedIndicesToRemove.toList, sortedIndices)
scoresUpd =
for ((nextTokenLogit, indexToRemove) <- scores.zip(
IndexedSeq.fill(scores.length)(indicesToRemove)))
yield setTensorByIndicesToValue(
nextTokenLogit,
indexToRemove.toIndexedSeq,
Float.NegativeInfinity)
private def maskNotTopPValues(logits: Array[Float], topPIndices: Array[Int]): Array[Float] = {
val maskedValues = logits.clone()
for (i <- logits.indices) {
if (!topPIndices.contains(i)) {
maskedValues(i) = Float.NegativeInfinity
}
}
scoresUpd
maskedValues
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,11 @@ class LLAMA2Transformer(override val uid: String)
minOutputLength -> 0,
maxOutputLength -> 20,
doSample -> false,
temperature -> 0.6,
topK -> 50,
temperature -> 0.9,
topK -> 100,
topP -> 0.9,
repetitionPenalty -> 1.0,
noRepeatNgramSize -> 3,
noRepeatNgramSize -> 0,
ignoreTokenIds -> Array(),
batchSize -> 1,
beamSize -> 1,
Expand Down
Loading