Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added custom stop token id support #14344

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ private[johnsnowlabs] class LLAMA2(
*/
def encode(sentences: Seq[Annotation]): Seq[Array[Int]] = {
sentences.map(s => {
val sentWithTask = s.result
spp.getSppModel.encodeAsIds(sentWithTask)
val sentWithTask = "_" + s.result
Array(bosTokenId) ++ spp.getSppModel.encodeAsIds(sentWithTask)
})
}

Expand All @@ -97,7 +97,8 @@ private[johnsnowlabs] class LLAMA2(
randomSeed: Option[Long],
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int,
maxInputLength: Int): Array[Array[Int]] = {
maxInputLength: Int,
stopTokenIds: Array[Int]): Array[Array[Int]] = {
val ignoreTokenIdsInt = ignoreTokenIds
val expandedDecoderInputsVals = batch
val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray
Expand Down Expand Up @@ -165,7 +166,8 @@ private[johnsnowlabs] class LLAMA2(
ignoreTokenIdsInt,
session,
applySoftmax = true,
ovInferRequest = ovInferRequest)
ovInferRequest = ovInferRequest,
stopTokenIds = stopTokenIds)

modelOutputs
}
Expand All @@ -184,7 +186,8 @@ private[johnsnowlabs] class LLAMA2(
randomSeed: Option[Long] = None,
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int,
maxInputLength: Int): Seq[Annotation] = {
maxInputLength: Int,
stopTokenIds: Array[Int]): Seq[Annotation] = {

val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch =>
val batchSP = encode(batch)
Expand All @@ -201,7 +204,8 @@ private[johnsnowlabs] class LLAMA2(
randomSeed,
ignoreTokenIds,
beamSize,
maxInputLength)
maxInputLength,
stopTokenIds)

decode(spIds)

Expand Down
18 changes: 11 additions & 7 deletions src/main/scala/com/johnsnowlabs/ml/ai/Mistral.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ private[johnsnowlabs] class Mistral(
*/
def encode(sentences: Seq[Annotation]): Seq[Array[Int]] = {
sentences.map(s => {
val sentWithTask = s.result
spp.getSppModel.encodeAsIds(sentWithTask)
val sentWithTask = "_" + s.result
Array(bosTokenId) ++ spp.getSppModel.encodeAsIds(sentWithTask)
})
}

Expand All @@ -96,7 +96,8 @@ private[johnsnowlabs] class Mistral(
randomSeed: Option[Long],
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int,
maxInputLength: Int): Array[Array[Int]] = {
maxInputLength: Int,
stopTokenIds: Array[Int] = Array()): Array[Array[Int]] = {
val ignoreTokenIdsInt = ignoreTokenIds
val expandedDecoderInputsVals = batch
val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray
Expand Down Expand Up @@ -162,8 +163,9 @@ private[johnsnowlabs] class Mistral(
randomSeed,
ignoreTokenIdsInt,
session,
applySoftmax = false,
ovInferRequest = ovInferRequest)
applySoftmax = true,
ovInferRequest = ovInferRequest,
stopTokenIds = stopTokenIds)

// decoderOutputs
modelOutputs
Expand All @@ -183,7 +185,8 @@ private[johnsnowlabs] class Mistral(
randomSeed: Option[Long] = None,
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int,
maxInputLength: Int): Seq[Annotation] = {
maxInputLength: Int,
stopTokenIds: Array[Int]): Seq[Annotation] = {

val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch =>
val batchSP = encode(batch)
Expand All @@ -200,7 +203,8 @@ private[johnsnowlabs] class Mistral(
randomSeed,
ignoreTokenIds,
beamSize,
maxInputLength)
maxInputLength,
stopTokenIds)

decode(spIds)

Expand Down
12 changes: 8 additions & 4 deletions src/main/scala/com/johnsnowlabs/ml/ai/Phi2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ private[johnsnowlabs] class Phi2(
randomSeed: Option[Long],
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int,
maxInputLength: Int): Array[Array[Int]] = {
maxInputLength: Int,
stopTokenIds: Array[Int]): Array[Array[Int]] = {
val ignoreTokenIdsInt = ignoreTokenIds
val expandedDecoderInputsVals = batch
val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray
Expand Down Expand Up @@ -169,7 +170,8 @@ private[johnsnowlabs] class Phi2(
ignoreTokenIdsInt,
session,
applySoftmax = false,
ovInferRequest = ovInferRequest)
ovInferRequest = ovInferRequest,
stopTokenIds = stopTokenIds)

// decoderOutputs
modelOutputs
Expand All @@ -189,7 +191,8 @@ private[johnsnowlabs] class Phi2(
randomSeed: Option[Long] = None,
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int,
maxInputLength: Int): Seq[Annotation] = {
maxInputLength: Int,
stopTokenIds: Array[Int]): Seq[Annotation] = {

val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch =>
val batchSP = encode(batch)
Expand All @@ -206,7 +209,8 @@ private[johnsnowlabs] class Phi2(
randomSeed,
ignoreTokenIds,
beamSize,
maxInputLength)
maxInputLength,
stopTokenIds)

decode(spIds)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ trait Generate {
ignoreTokenIds: Array[Int] = Array(),
session: Either[Session, (OrtEnvironment, OrtSession)],
applySoftmax: Boolean = true,
ovInferRequest: Option[InferRequest] = None): Array[Array[Int]] = {
ovInferRequest: Option[InferRequest] = None,
stopTokenIds: Array[Int] = Array()): Array[Array[Int]] = {

// TODO: Add support for ignoreTokenIds

Expand All @@ -117,8 +118,8 @@ trait Generate {
noRepeatNgramSize = noRepeatNgramSize,
vocabSize = vocabSize))

logitProcessorList.addProcess(
new MinLengthLogitProcessor(eosTokenId, minOutputLength, vocabSize))
// logitProcessorList.addProcess(
// new MinLengthLogitProcessor(eosTokenId, minOutputLength, vocabSize))

logitProcessorList.addProcess(new TemperatureLogitWarper(temperature))

Expand Down Expand Up @@ -148,7 +149,8 @@ trait Generate {
randomSeed,
session,
applySoftmax,
ovInferRequest)
ovInferRequest,
stopTokenIds)
}

/** Beam Search for text generation
Expand Down Expand Up @@ -193,7 +195,8 @@ trait Generate {
randomSeed: Option[Long],
session: Either[Session, (OrtEnvironment, OrtSession)],
applySoftmax: Boolean,
ovInferRequest: Option[InferRequest] = None): Array[Array[Int]] = {
ovInferRequest: Option[InferRequest] = None,
stopTokenIds: Array[Int] = Array()): Array[Array[Int]] = {
val inputIds = inputIdsVal
val batchSize = beamScorer.getBeamHypothesesSeq.length
val numBeams = beamScorer.getNumBeams
Expand Down Expand Up @@ -227,21 +230,22 @@ trait Generate {
// Optionally Apply log softmax to model outputs
var nextTokenScores =
if (applySoftmax) nextTokenLogits.map(logSoftmax) else nextTokenLogits

// Process the logits by defined logit processors
val nextTokenScoresProcessed =
logitProcessor.process(expandedInputs, nextTokenScores, currentLength)

// Process the logits by defined logit warpers
if (doSample) {
nextTokenScores =
logitProcessor.warp(expandedInputs, nextTokenScoresProcessed, currentLength)
}
// Add previous beam scores to the output
nextTokenScores = nextTokenScoresProcessed.zipWithIndex.map { case (x, ind1) =>
nextTokenScores = nextTokenScores.zipWithIndex.map { case (x, ind1) =>
x.zipWithIndex.map { case (y, _) =>
y + beamScores(ind1)
}
}
// Process the logits by defined logit warpers
if (doSample) {
nextTokenScores = logitProcessor.warp(expandedInputs, nextTokenScores, currentLength)
}

// Reshape next token score to (batchSize, vocabSize * numBeams)
val vocabSize = nextTokenScores.head.length
val reshapedNextTokenScores =
Expand Down Expand Up @@ -290,7 +294,8 @@ trait Generate {
padTokenId,
eosTokenId,
beamIndices,
currentLength)
currentLength,
stopTokenIds)
val newBeamScores = beamOutputs._1.flatMap(_.toList)
val beamNextTokens = beamOutputs._2.flatMap(_.toList)
val beamIdx = beamOutputs._3.flatMap(_.toList)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ class TopKLogitWarper(
}

private def getTopKIndices(logits: Array[Float], k: Int): Array[Int] = {
logits.indices.sortBy(logits(_)).reverse.take(k).toArray
// ignore float.NegativeInfinity values
val topKIndices = new ArrayBuffer[Int]()
val sortedLogits = logits.zipWithIndex.filter(_._1 != filterValue).sortBy(-_._1)
for ((_, i) <- sortedLogits.take(k)) {
topKIndices += i
}
topKIndices.toArray
}

private def maskNotTopKValues(logits: Array[Float], topKIndices: Array[Int]): Array[Float] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,40 @@ class TopPLogitWarper(val p: Double, val minTokensToKeep: Int = 1) extends Logit
val logitsUpd = scores.map(_.clone()) // Deep copy of the scores

if (p < 1.0) {
val scoresFiltered = scores.map(_.filterNot(_.isInfinite)) // Filter out infinite values
val scoresShape = Array(scoresFiltered.length, scoresFiltered.head.length)
val topPThreshold = math.ceil(p * scoresShape.last).toInt // Determine top-p threshold
val scoresFiltered = scores // Filter out infinite values
val scoresSoftmaxed = scoresFiltered.map(softmax) // Softmax the scores

for ((logits, i) <- scores.zipWithIndex) {
val topPIndices = getTopPIndices(logits, topPThreshold)
val maskedValues = maskNotTopPValues(logits, topPIndices)
for ((logits, i) <- scoresSoftmaxed.zipWithIndex) {
val topPIndices = getTopPIndices(logits, p)
// Mask the values that are not in the top-p
val maskedValues = maskNotTopPValues(logitsUpd(i), topPIndices)
logitsUpd(i) = maskedValues
}
}

logitsUpd
}

private def getTopPIndices(logits: Array[Float], k: Int): Array[Int] = {
logits.zipWithIndex.sortBy(-_._1).take(k).map(_._2)
private def getTopPIndices(logits: Array[Float], p: Double): Array[Int] = {
// sort the logits in descending order
var sortedLogits = logits.zipWithIndex.sortBy(-_._1)

// filter out the negative infinity values
sortedLogits = sortedLogits.filter(_._1 > 0.0)

// cumulative sum of the probabilities
val cumSum = sortedLogits.map(_._1).scanLeft(0.0)(_ + _)

// find the index of the last element that is less than p
val lastIdx = cumSum.indexWhere(_ >= p)
// if the last index is less than the minimum tokens to keep, return the top p tokens

if (lastIdx < minTokensToKeep) {
sortedLogits.take(math.ceil(p * logits.length).toInt).map(_._2)
} else {
sortedLogits.take(lastIdx).map(_._2)
}

}

private def maskNotTopPValues(logits: Array[Float], topPIndices: Array[Int]): Array[Float] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ abstract class BeamScorer() {
padTokenId: Int,
eosTokenId: Int,
beamIndices: Seq[Array[Int]],
currentLength: Int): (Array[Array[Float]], Array[Array[Int]], Array[Array[Int]])
currentLength: Int,
stopTokenIds: Array[Int]): (Array[Array[Float]], Array[Array[Int]], Array[Array[Int]])

def finalize(
inputIds: Seq[Array[Int]],
Expand All @@ -40,4 +41,5 @@ abstract class BeamScorer() {
def getBeamHypothesesSeq: Seq[BeamHypotheses]
def getNumBeams: Int
def isDone: Boolean
def getDone: Array[Boolean]
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class BeamSearchScorer(
override def getNumBeams: Int = numBeams
private val done: Array[Boolean] = Array.fill(batchSize)(false)

override def getDone: Array[Boolean] = done

override def process(
inputIds: Seq[Array[Int]],
nextScores: Seq[Array[Float]],
Expand All @@ -51,7 +53,8 @@ class BeamSearchScorer(
padTokenId: Int,
eosTokenId: Int,
beamIndices: Seq[Array[Int]],
currentLength: Int): (Array[Array[Float]], Array[Array[Int]], Array[Array[Int]]) = {
currentLength: Int,
stopTokenIds: Array[Int]): (Array[Array[Float]], Array[Array[Int]], Array[Array[Int]]) = {
// val currentLength = inputIds.length
val batchSize = this.beamHypothesesSeq.length
val nextBeamScores = Array.ofDim[Float](batchSize, this.beamSize)
Expand All @@ -75,7 +78,8 @@ class BeamSearchScorer(
val nextIndex = nextIndices(batchIdx)(beamTokenRank)
val batchBeamIdx = batchIdx * this.beamSize + nextIndex

if (eosTokenId == nextToken) {
// either eos token or stop tokens are found
if (eosTokenId == nextToken || stopTokenIds.contains(nextToken)) {
if (beamTokenRank >= this.beamSize) {
break
}
Expand Down
15 changes: 15 additions & 0 deletions src/main/scala/com/johnsnowlabs/nlp/HasGeneratorProperties.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,4 +222,19 @@ trait HasGeneratorProperties {

/** @group getParam */
def getNReturnSequences: Int = $(nReturnSequences)

/** Stop tokens to terminate the generation
*
* @group param
*/
var stopTokenIds =
new IntArrayParam(this, "stopTokens", "Stop tokens to terminate the generation")

/** @group setParam */
def setStopTokenIds(value: Array[Int]): this.type = {
set(stopTokenIds, value)
}

/** @group getParam */
def getStopTokenIds: Array[Int] = $(stopTokenIds)
}
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ class LLAMA2Transformer(override val uid: String)
ignoreTokenIds -> Array(),
batchSize -> 1,
beamSize -> 1,
maxInputLength -> 4096)
maxInputLength -> 4096,
stopTokenIds -> Array())

/** takes a document and annotations and produces new annotations of this annotator's annotation
* type
Expand Down Expand Up @@ -269,7 +270,8 @@ class LLAMA2Transformer(override val uid: String)
randomSeed = this.randomSeed,
ignoreTokenIds = $(ignoreTokenIds),
beamSize = $(beamSize),
maxInputLength = $(maxInputLength))
maxInputLength = $(maxInputLength),
stopTokenIds = $(stopTokenIds))
} else {
Seq()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ class MistralTransformer(override val uid: String)
ignoreTokenIds -> Array(),
batchSize -> 1,
beamSize -> 1,
maxInputLength -> 4096)
maxInputLength -> 4096,
stopTokenIds -> Array())

/** takes a document and annotations and produces new annotations of this annotator's annotation
* type
Expand Down Expand Up @@ -277,7 +278,8 @@ class MistralTransformer(override val uid: String)
randomSeed = this.randomSeed,
ignoreTokenIds = $(ignoreTokenIds),
beamSize = $(beamSize),
maxInputLength = $(maxInputLength))
maxInputLength = $(maxInputLength),
stopTokenIds = $(stopTokenIds))
} else {
Seq()
}
Expand Down
Loading
Loading