Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARKNLP-1037] Adding addFile changes to to replace broadcast in all ONNX based annotators #14236

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/main/scala/com/johnsnowlabs/ml/ai/M2M100.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ import com.johnsnowlabs.ml.onnx.OnnxWrapper.EncoderDecoderWithoutPastWrappers
import com.johnsnowlabs.ml.onnx.TensorResources.implicits._
import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper
import com.johnsnowlabs.nlp.Annotation

import scala.collection.JavaConverters._
import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT
import org.tensorflow.{Session, Tensor}

import scala.collection.JavaConverters._

private[johnsnowlabs] class M2M100(
val onnxWrappers: EncoderDecoderWithoutPastWrappers,
val spp: SentencePieceWrapper,
Expand Down
5 changes: 4 additions & 1 deletion src/main/scala/com/johnsnowlabs/ml/ai/Whisper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,10 @@ private[johnsnowlabs] class Whisper(
case TensorFlow.name =>
val session =
tensorflowWrapper.get
.getTFSessionWithSignature(configProtoBytes, savedSignatures = signatures)
.getTFSessionWithSignature(
configProtoBytes,
savedSignatures = signatures,
initAllTables = false)

val encodedBatchFeatures: Tensor =
encode(featuresBatch, Some(session), None).asInstanceOf[Tensor]
Expand Down
132 changes: 66 additions & 66 deletions src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package com.johnsnowlabs.ml.onnx

import ai.onnxruntime.OrtSession.SessionOptions
import com.johnsnowlabs.util.FileHelper
import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.sql.SparkSession
Expand All @@ -32,11 +30,10 @@ trait WriteOnnxModel {
path: String,
spark: SparkSession,
onnxWrappersWithNames: Seq[(OnnxWrapper, String)],
suffix: String,
dataFileSuffix: String = "_data"): Unit = {
suffix: String): Unit = {

val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
val fileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)

// 1. Create tmp folder
val tmpFolder = Files
Expand All @@ -51,15 +48,16 @@ trait WriteOnnxModel {
onnxWrapper.saveToFile(onnxFile)

// 3. Copy to dest folder
fs.copyFromLocalFile(new Path(onnxFile), new Path(path))
fileSystem.copyFromLocalFile(new Path(onnxFile), new Path(path))

// 4. check if there is a onnx_data file
if (onnxWrapper.onnxModelPath.isDefined) {
val onnxDataFile = new Path(onnxWrapper.onnxModelPath.get + dataFileSuffix)
if (fs.exists(onnxDataFile)) {
fs.copyFromLocalFile(onnxDataFile, new Path(path))
if (onnxWrapper.dataFileDirectory.isDefined) {
val onnxDataFile = new Path(onnxWrapper.dataFileDirectory.get)
if (fileSystem.exists(onnxDataFile)) {
fileSystem.copyFromLocalFile(onnxDataFile, new Path(path))
}
}

}

// 4. Remove tmp folder
Expand All @@ -74,7 +72,6 @@ trait WriteOnnxModel {
fileName: String): Unit = {
writeOnnxModels(path, spark, Seq((onnxWrapper, fileName)), suffix)
}

}

trait ReadOnnxModel {
Expand All @@ -86,38 +83,61 @@ trait ReadOnnxModel {
suffix: String,
zipped: Boolean = true,
useBundle: Boolean = false,
sessionOptions: Option[SessionOptions] = None,
dataFileSuffix: String = "_data"): OnnxWrapper = {
modelName: Option[String] = None,
tmpFolder: Option[String] = None,
dataFilePostfix: Option[String] = None): OnnxWrapper = {

// 1. Copy to local tmp dir
val localModelFile = if (modelName.isDefined) modelName.get else onnxFile
val srcPath = new Path(path, localModelFile)
val fileSystem = getFileSystem(path, spark)
val localTmpFolder = if (tmpFolder.isDefined) tmpFolder.get else createTmpDirectory(suffix)
fileSystem.copyToLocalFile(srcPath, new Path(localTmpFolder))

// 2. Copy onnx_data file if exists
val fsPath = new Path(path, localModelFile).toString

val onnxDataFile: Option[String] = if (modelName.isDefined && dataFilePostfix.isDefined) {
Some(fsPath.replaceAll(modelName.get, s"${suffix}_${modelName.get}${dataFilePostfix.get}"))
} else None

if (onnxDataFile.isDefined) {
val onnxDataFilePath = new Path(onnxDataFile.get)
if (fileSystem.exists(onnxDataFilePath)) {
fileSystem.copyToLocalFile(onnxDataFilePath, new Path(localTmpFolder))
}
}

// 3. Read ONNX state
val onnxFileTmpPath = new Path(localTmpFolder, localModelFile).toString
val onnxWrapper =
OnnxWrapper.read(
spark,
onnxFileTmpPath,
zipped = zipped,
useBundle = useBundle,
modelName = if (modelName.isDefined) modelName.get else onnxFile,
onnxFileSuffix = Some(suffix))

onnxWrapper

}

private def getFileSystem(path: String, sparkSession: SparkSession): FileSystem = {
val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
val fileSystem = FileSystem.get(uri, sparkSession.sparkContext.hadoopConfiguration)
fileSystem
}

private def createTmpDirectory(suffix: String): String = {

// 1. Create tmp directory
val tmpFolder = Files
.createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix)
.createTempDirectory(s"${UUID.randomUUID().toString.takeRight(12)}_$suffix")
.toAbsolutePath
.toString

// 2. Copy to local dir
fs.copyToLocalFile(new Path(path, onnxFile), new Path(tmpFolder))

val localPath = new Path(tmpFolder, onnxFile).toString

val fsPath = new Path(path, onnxFile)

// 3. Copy onnx_data file if exists
val onnxDataFile = new Path(fsPath + dataFileSuffix)

if (fs.exists(onnxDataFile)) {
fs.copyToLocalFile(onnxDataFile, new Path(tmpFolder))
}
// 4. Read ONNX state
val onnxWrapper = OnnxWrapper.read(localPath, zipped = zipped, useBundle = useBundle)

// 5. Remove tmp folder
FileHelper.delete(tmpFolder)

onnxWrapper
tmpFolder
}

def readOnnxModels(
Expand All @@ -127,43 +147,23 @@ trait ReadOnnxModel {
suffix: String,
zipped: Boolean = true,
useBundle: Boolean = false,
dataFileSuffix: String = "_data"): Map[String, OnnxWrapper] = {
dataFilePostfix: String = "_data"): Map[String, OnnxWrapper] = {

val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)

// 1. Create tmp directory
val tmpFolder = Files
.createTempDirectory(UUID.randomUUID().toString.takeRight(12) + suffix)
.toAbsolutePath
.toString
val tmpFolder = Some(createTmpDirectory(suffix))

val wrappers = (modelNames map { modelName: String =>
// 2. Copy to local dir
val localModelFile = modelName
fs.copyToLocalFile(new Path(path, localModelFile), new Path(tmpFolder))

val localPath = new Path(tmpFolder, localModelFile).toString

val fsPath = new Path(path, localModelFile).toString

// 3. Copy onnx_data file if exists
val onnxDataFile = new Path(fsPath + dataFileSuffix)

if (fs.exists(onnxDataFile)) {
fs.copyToLocalFile(onnxDataFile, new Path(tmpFolder))
}

// 4. Read ONNX state
val onnxWrapper =
OnnxWrapper.read(localPath, zipped = zipped, useBundle = useBundle, modelName = modelName)

val onnxWrapper = readOnnxModel(
path,
spark,
suffix,
zipped,
useBundle,
Some(modelName),
tmpFolder,
Option(dataFilePostfix))
(modelName, onnxWrapper)
}).toMap

// 4. Remove tmp folder
FileHelper.delete(tmpFolder)

wrappers
}

Expand Down
78 changes: 30 additions & 48 deletions src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ import ai.onnxruntime.OrtSession.SessionOptions.{ExecutionMode, OptLevel}
import ai.onnxruntime.providers.OrtCUDAProviderOptions
import ai.onnxruntime.{OrtEnvironment, OrtSession}
import com.johnsnowlabs.util.{ConfigHelper, FileHelper, ZipArchiveUtil}
import org.apache.commons.io.FileUtils
import org.apache.spark.SparkFiles
import org.apache.spark.sql.SparkSession
import org.slf4j.{Logger, LoggerFactory}
import org.apache.hadoop.fs.{FileSystem, Path}

import java.io._
import java.nio.file.{Files, Paths}
import java.util.UUID
import scala.util.{Failure, Success, Try}

class OnnxWrapper(var onnxModel: Array[Byte], var onnxModelPath: Option[String] = None)
class OnnxWrapper(var modelFileName: Option[String] = None, var dataFileDirectory: Option[String])
extends Serializable {

/** For Deserialization */
Expand All @@ -43,10 +44,15 @@ class OnnxWrapper(var onnxModel: Array[Byte], var onnxModelPath: Option[String]

def getSession(onnxSessionOptions: Map[String, String]): (OrtSession, OrtEnvironment) =
this.synchronized {
// TODO: After testing it works remove the Map.empty
if (ortSession == null && ortEnv == null) {
val modelFilePath = if (modelFileName.isDefined) {
SparkFiles.get(modelFileName.get)
} else {
throw new UnsupportedOperationException("modelFileName not defined")
}

val (session, env) =
OnnxWrapper.withSafeOnnxModelLoader(onnxModel, onnxSessionOptions, onnxModelPath)
OnnxWrapper.withSafeOnnxModelLoader(onnxSessionOptions, Some(modelFilePath))
ortEnv = env
ortSession = session
}
Expand All @@ -60,17 +66,11 @@ class OnnxWrapper(var onnxModel: Array[Byte], var onnxModelPath: Option[String]
.toAbsolutePath
.toString

// 2. Save onnx model
val fileName = Paths.get(file).getFileName.toString
val onnxFile = Paths
.get(tmpFolder, fileName)
.toString

FileUtils.writeByteArrayToFile(new File(onnxFile), onnxModel)
// 4. Zip folder
if (zip) ZipArchiveUtil.zip(tmpFolder, file)
val tmpModelFilePath = SparkFiles.get(modelFileName.get)
// 2. Zip folder
if (zip) ZipArchiveUtil.zip(tmpModelFilePath, file)

// 5. Remove tmp directory
// 3. Remove tmp directory
FileHelper.delete(tmpFolder)
}

Expand All @@ -82,7 +82,6 @@ object OnnxWrapper {

// TODO: make sure this.synchronized is needed or it's not a bottleneck
private def withSafeOnnxModelLoader(
onnxModel: Array[Byte],
sessionOptions: Map[String, String],
onnxModelPath: Option[String] = None): (OrtSession, OrtEnvironment) =
this.synchronized {
Expand All @@ -96,19 +95,18 @@ object OnnxWrapper {
val session = env.createSession(onnxModelPath.get, sessionOptionsObject)
(session, env)
} else {
val session = env.createSession(onnxModel, sessionOptionsObject)
(session, env)
throw new UnsupportedOperationException("onnxModelPath not defined")
}
}

// TODO: the parts related to onnx_data should be refactored once we support addFile()
def read(
sparkSession: SparkSession,
modelPath: String,
zipped: Boolean = true,
useBundle: Boolean = false,
modelName: String = "model",
dataFileSuffix: String = "_data"): OnnxWrapper = {

dataFileSuffix: Option[String] = Some("_data"),
onnxFileSuffix: Option[String] = None): OnnxWrapper = {
// 1. Create tmp folder
val tmpFolder = Files
.createTempDirectory(UUID.randomUUID().toString.takeRight(12) + "_onnx")
Expand All @@ -118,11 +116,10 @@ object OnnxWrapper {
// 2. Unpack archive
val folder =
if (zipped)
ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder))
ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder), onnxFileSuffix)
else
modelPath

val sessionOptions = new OnnxSession().getSessionOptions
val onnxFile =
if (useBundle) Paths.get(modelPath, s"$modelName.onnx").toString
else Paths.get(folder, new File(folder).list().head).toString
Expand All @@ -134,38 +131,23 @@ object OnnxWrapper {
val parentDir = if (zipped) Paths.get(modelPath).getParent.toString else modelPath

val onnxDataFileExist: Boolean = {
onnxDataFile = Paths.get(parentDir, modelName + dataFileSuffix).toFile
onnxDataFile.exists()
if (onnxFileSuffix.isDefined && dataFileSuffix.isDefined) {
val onnxDataFilePath = s"${onnxFileSuffix.get}_$modelName${dataFileSuffix.get}"
onnxDataFile = Paths.get(parentDir, onnxDataFilePath).toFile
onnxDataFile.exists()
} else false
}

if (onnxDataFileExist) {
val onnxDataFileTmp =
Paths.get(tmpFolder, modelName + dataFileSuffix).toFile
FileUtils.copyFile(onnxDataFile, onnxDataFileTmp)
sparkSession.sparkContext.addFile(onnxDataFile.toString)
}

val modelFile = new File(onnxFile)
val modelBytes = FileUtils.readFileToByteArray(modelFile)
var session: OrtSession = null
var env: OrtEnvironment = null
if (onnxDataFileExist) {
val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions, Some(onnxFile))
session = _session
env = _env
} else {
val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions, None)
session = _session
env = _env
sparkSession.sparkContext.addFile(onnxFile)

}
// 4. Remove tmp folder
FileHelper.delete(tmpFolder)
val onnxFileName = Some(new File(onnxFile).getName)
val dataFileDirectory = if (onnxDataFileExist) Some(onnxDataFile.toString) else None
val onnxWrapper = new OnnxWrapper(onnxFileName, dataFileDirectory)

val onnxWrapper =
if (onnxDataFileExist) new OnnxWrapper(modelBytes, Option(onnxFile))
else new OnnxWrapper(modelBytes)
onnxWrapper.ortSession = session
onnxWrapper.ortEnv = env
onnxWrapper
}

Expand Down
Loading
Loading