Skip to content

Commit

Permalink
Added Openvino support
Browse files Browse the repository at this point in the history
  • Loading branch information
prabod committed Jul 22, 2024
1 parent d9aac67 commit a9155e5
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 46 deletions.
149 changes: 116 additions & 33 deletions src/main/scala/com/johnsnowlabs/ml/ai/CPM.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,29 @@ import com.johnsnowlabs.ml.ai.util.Generation.{Generate, GenerationConfig}
import com.johnsnowlabs.ml.onnx.OnnxSession
import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers
import com.johnsnowlabs.ml.onnx.TensorResources.implicits._
import com.johnsnowlabs.ml.openvino.OpenvinoWrapper
import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper
import com.johnsnowlabs.ml.util.{ONNX, Openvino, TensorFlow}
import com.johnsnowlabs.nlp.Annotation

import scala.collection.JavaConverters._
import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT
import org.intel.openvino.InferRequest
import org.tensorflow.{Session, Tensor}

private[johnsnowlabs] class CPM(
val onnxWrappers: DecoderWrappers,
val onnxWrappers: Option[DecoderWrappers],
val openvinoWrapper: Option[OpenvinoWrapper],
val spp: SentencePieceWrapper,
generationConfig: GenerationConfig)
extends Serializable
with Generate {

private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions
val detectedEngine: String =
if (onnxWrappers.isDefined) ONNX.name
else if (openvinoWrapper.isDefined) Openvino.name
else ONNX.name

private val GenerationConfig(
bosTokenId: Int,
Expand Down Expand Up @@ -71,7 +79,7 @@ private[johnsnowlabs] class CPM(
def encode(sentences: Seq[Annotation]): Seq[Array[Int]] = {
sentences.map(s => {
val sentWithTask = s.result
spp.getSppModel.encodeAsIds(sentWithTask)
Array(bosTokenId) ++ spp.getSppModel.encodeAsIds(sentWithTask)
})
}

Expand All @@ -88,8 +96,8 @@ private[johnsnowlabs] class CPM(
randomSeed: Option[Long],
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int,
maxInputLength: Int): Array[Array[Int]] = {
val (encoderSession, env) = onnxWrappers.decoder.getSession(onnxSessionOptions)
maxInputLength: Int,
stopTokenIds: Array[Int]): Array[Array[Int]] = {
val ignoreTokenIdsInt = ignoreTokenIds
val expandedDecoderInputsVals = batch.toArray
val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray
Expand All @@ -109,19 +117,23 @@ private[johnsnowlabs] class CPM(
effectiveBatch_mult = 1
}

expandedDecoderInputsVals.zipWithIndex.foreach { case (input, i) =>
expandedDecoderInputsVals(i) = Array(bosTokenId) ++ input
val (decoderEncoderStateTensors, encoderAttentionMaskTensors, session) =
detectedEngine match {
case ONNX.name =>
// dummy tensors for decoder encode state and attention mask
val (encoderSession, env) = onnxWrappers.get.decoder.getSession(onnxSessionOptions)
(
Right(OnnxTensor.createTensor(env, Array(0))),
Right(OnnxTensor.createTensor(env, Array(1))),
Right((env, encoderSession)))
case Openvino.name =>
// not needed
(null, null, null)
}
val ovInferRequest: Option[InferRequest] = detectedEngine match {
case ONNX.name => None
case Openvino.name => Some(openvinoWrapper.get.getCompiledModel().create_infer_request())
}
// Run the prompt through the decoder and get the past
// val decoderOutputs =
// generateGreedyOnnx(
// expandedDecoderInputsVals.toArray,
// (encoderSession, env),
// maxOutputLength)

// dummy tensors for decoder encode state and attention mask
val decoderEncoderStateTensors = Right(OnnxTensor.createTensor(env, Array(0)))
val encoderAttentionMaskTensors = Right(OnnxTensor.createTensor(env, Array(1)))

// output with beam search
val modelOutputs = generate(
Expand All @@ -144,8 +156,10 @@ private[johnsnowlabs] class CPM(
this.paddingTokenId,
randomSeed,
ignoreTokenIdsInt,
Right((env, encoderSession)),
applySoftmax = false)
session,
applySoftmax = true,
ovInferRequest = ovInferRequest,
stopTokenIds = stopTokenIds)

// decoderOutputs
modelOutputs
Expand All @@ -165,7 +179,8 @@ private[johnsnowlabs] class CPM(
randomSeed: Option[Long] = None,
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int,
maxInputLength: Int): Seq[Annotation] = {
maxInputLength: Int,
stopTokenIds: Array[Int]): Seq[Annotation] = {

val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch =>
val batchSP = encode(batch)
Expand All @@ -182,7 +197,8 @@ private[johnsnowlabs] class CPM(
randomSeed,
ignoreTokenIds,
beamSize,
maxInputLength)
maxInputLength,
stopTokenIds)

decode(spIds)

Expand Down Expand Up @@ -237,21 +253,88 @@ private[johnsnowlabs] class CPM(
decoderEncoderStateTensors: Either[Tensor, OnnxTensor],
encoderAttentionMaskTensors: Either[Tensor, OnnxTensor],
maxLength: Int,
session: Either[Session, (OrtEnvironment, OrtSession)]): Array[Array[Float]] = {
session: Either[Session, (OrtEnvironment, OrtSession)],
ovInferRequest: Option[InferRequest]): Array[Array[Float]] = {

session.fold(
tfSession => {
detectedEngine match {
case TensorFlow.name =>
// not implemented yet
Array()
},
onnxSession => {
val (env, decoderSession) = onnxSession
case ONNX.name =>
val (env, decoderSession) = session.right.get
val decoderOutputs =
getDecoderOutputs(decoderInputIds.toArray, onnxSession = (decoderSession, env))
decoderOutputs
})
case Openvino.name =>
val decoderOutputs =
getDecoderOutputsOv(
encoderInputIds.toArray,
decoderInputIds.toArray,
ovInferRequest.get)
decoderOutputs
}
}

private def getDecoderOutputsOv(
encoderInputIds: Array[Array[Int]],
decoderInputIds: Array[Array[Int]],
inferRequest: InferRequest): (Array[Array[Float]]) = {
val (inputIdsLong, inputPositionIDsLong): (Array[Long], Array[Long]) =
if (encoderInputIds.head.length == decoderInputIds.head.length) {
// First pass
val inpIdsLong = decoderInputIds.flatMap { tokenIds => tokenIds.map(_.toLong) }
val posIdsLong = decoderInputIds.flatMap { tokenIds =>
tokenIds.zipWithIndex.map { case (_, i) =>
i.toLong
}
}
(inpIdsLong, posIdsLong)
} else {
// Subsequent passes
val inpIdsLong = decoderInputIds.map { tokenIds => tokenIds.last.toLong }
val posIdsLong = decoderInputIds.map { tokenIds =>
tokenIds.zipWithIndex.map { case (_, i) =>
i.toLong
}.last
}
(inpIdsLong, posIdsLong)
}
val attentionMask: Array[Long] =
decoderInputIds.flatMap { tokenIds => tokenIds.map(_ => 1L) }

val batchSize: Int = decoderInputIds.length
val beamIdx: Array[Int] = new Array[Int](batchSize)
val shape: Array[Int] = Array(batchSize, inputIdsLong.length / batchSize)

val inputIdsLongTensor: org.intel.openvino.Tensor =
new org.intel.openvino.Tensor(shape, inputIdsLong)
val decoderAttentionMask: org.intel.openvino.Tensor =
new org.intel.openvino.Tensor(Array(batchSize, decoderInputIds.head.length), attentionMask)
val decoderPositionIDs: org.intel.openvino.Tensor =
new org.intel.openvino.Tensor(shape, inputPositionIDsLong)
val beamIdxTensor: org.intel.openvino.Tensor =
new org.intel.openvino.Tensor(Array(batchSize), beamIdx)

inferRequest.set_tensor("input_ids", inputIdsLongTensor)
inferRequest.set_tensor("attention_mask", decoderAttentionMask)
inferRequest.set_tensor("position_ids", decoderPositionIDs)
inferRequest.set_tensor("beam_idx", beamIdxTensor)

inferRequest.infer()

val result = inferRequest.get_tensor("logits")
val logitsRaw = result.data()

val sequenceLength = inputIdsLong.length / batchSize
val decoderOutputs = (0 until batchSize).map(i => {
logitsRaw
.slice(
i * sequenceLength * vocabSize + (sequenceLength - 1) * vocabSize,
i * sequenceLength * vocabSize + sequenceLength * vocabSize)
})
decoderOutputs.toArray
}

private def getDecoderOutputs(
inputIds: Array[Array[Int]],
onnxSession: (OrtSession, OrtEnvironment)): (Array[Array[Float]]) = {
Expand Down Expand Up @@ -283,12 +366,12 @@ private[johnsnowlabs] class CPM(
val sequenceLength = inputIds.head.length
val batchSize = inputIds.length

// val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput)
// inputIdsLongTensor.close()
// decoderPositionIDs.close()
// decoderAttentionMask.close()
// val batchLogits = logits.grouped(vocabSize).toArray
// batchLogits
// val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput)
// inputIdsLongTensor.close()
// decoderPositionIDs.close()
// decoderAttentionMask.close()
// val batchLogits = logits.grouped(vocabSize).toArray
// batchLogits

val logitsRaw = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput)
val decoderOutputs = (0 until batchSize).map(i => {
Expand Down
Loading

0 comments on commit a9155e5

Please sign in to comment.