Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve zip util code and add tests for both ZipArchiveUtil ane OnnxW… #14056

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 51 additions & 30 deletions src/main/scala/com/johnsnowlabs/util/ZipArchiveUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,52 +23,47 @@ import java.util.zip.{ZipEntry, ZipFile, ZipOutputStream}
import scala.collection.JavaConverters._

object ZipArchiveUtil {

private def listFiles(file: File, outputFilename: String): List[String] = {
// Recursively lists all files in a given directory, returning a list of their absolute paths
private[util] def listFilesRecursive(file: File): List[File] = {
file match {
case file if file.isFile =>
if (file.getName != outputFilename)
List(file.getAbsoluteFile.toString)
else
List()
case file if file.isFile => List(new File(file.getAbsoluteFile.toString))
case file if file.isDirectory =>
val fList = file.list
// Add all files in current dir to list and recur on subdirs
fList.foldLeft(List[String]())((pList: List[String], path: String) =>
pList ++ listFiles(new File(file, path), outputFilename))
fList.foldLeft(List[File]())((pList: List[File], path: String) =>
pList ++ listFilesRecursive(new File(file, path)))
case _ => throw new IOException("Bad path. No file or directory found.")
}
}

private def addFileToZipEntry(
filename: String,
parentPath: String,
filePathsCount: Int): ZipEntry = {
if (filePathsCount <= 1)
new ZipEntry(new File(filename).getName)
else {
// use relative path to avoid adding absolute path directories
val relative = new File(parentPath).toURI.relativize(new File(filename).toURI).getPath
private[util] def addFileToZipEntry(
filename: File,
parentPath: File,
useRelativePath: Boolean = false): ZipEntry = {
if (!useRelativePath) // use absolute path
new ZipEntry(filename.getName)
else { // use relative path
val relative = parentPath.toURI.relativize(filename.toURI).getPath
new ZipEntry(relative)
}
}

private def createZip(
filePaths: List[String],
outputFilename: String,
parentPath: String): Unit = {
private[util] def createZip(
filePaths: List[File],
outputFilePath: File,
parentPath: File): Unit = {

val Buffer = 2 * 1024
val data = new Array[Byte](Buffer)
try {
val zipFileOS = new FileOutputStream(outputFilename)
val zipFileOS = new FileOutputStream(outputFilePath)
val zip = new ZipOutputStream(zipFileOS)
zip.setLevel(0)
filePaths.foreach((name: String) => {
val zipEntry = addFileToZipEntry(name, parentPath, filePaths.size)
filePaths.foreach((file: File) => {
val zipEntry = addFileToZipEntry(file, parentPath, filePaths.size > 1)
// add zip entry to output stream
zip.putNextEntry(new ZipEntry(zipEntry))
val in = new BufferedInputStream(new FileInputStream(name), Buffer)
val in = new BufferedInputStream(new FileInputStream(file), Buffer)
var b = in.read(data, 0, Buffer)
while (b != -1) {
zip.write(data, 0, b)
Expand All @@ -86,10 +81,36 @@ object ZipArchiveUtil {
}
}

def zip(fileName: String, outputFileName: String): Unit = {
val file = new File(fileName)
val filePaths = listFiles(file, outputFileName)
createZip(filePaths, outputFileName, fileName)
private[util] def zipFile(soureFile: File, outputFilePath: File): Unit = {
createZip(List(soureFile.getAbsoluteFile), outputFilePath, null)
}

private[util] def zipDir(sourceDir: File, outputFilePath: File): Unit = {
val filePaths = listFilesRecursive(sourceDir)
createZip(filePaths, outputFilePath, sourceDir)
}

def zip(sourcePath: String, outputFilePath: String): Unit = {
val sourceFile = new File(sourcePath)
val outputFile = new File(outputFilePath)
if (sourceFile.equals(outputFile))
throw new IllegalArgumentException("source path cannot be identical to target path")

if (!outputFile.getParentFile().exists)
throw new IOException("the parent directory of output file doesn't exist")

if (!sourceFile.exists())
throw new IOException("zip source path must exsit")

if (outputFile.exists())
throw new IOException("zip target file exsits")

if (sourceFile.isDirectory())
zipDir(sourceFile, outputFile)
else if (sourceFile.isFile())
zipFile(sourceFile, outputFile)
else
throw new IllegalArgumentException("only folder and file input are valid")
}

def unzip(file: File, destDirPath: Option[String] = None): String = {
Expand Down
Binary file added src/test/resources/onnx/models/dummy_model.onnx
Binary file not shown.
82 changes: 82 additions & 0 deletions src/test/scala/com/johnsnowlabs/ml/onnx/OnnxWrapperTestSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright 2017-2022 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.johnsnowlabs.ml.onnx

import com.johnsnowlabs.tags.FastTest
import org.scalatest.flatspec.AnyFlatSpec
import java.nio.file.{Files, Paths, Path}
import java.io.File
import com.johnsnowlabs.util.FileHelper
import org.scalatest.BeforeAndAfter
import java.util.UUID

class OnnxWrapperTestSpec extends AnyFlatSpec with BeforeAndAfter {
/*
* Dummy model was created with the following python script
"""
import torch
import torch.nn as nn
import torch.onnx

# Define a simple neural network model
class DummyModel(nn.Module):
def __init__(self):
super(DummyModel, self).__init__()
self.linear = nn.Linear(in_features=10, out_features=5)

def forward(self, x):
return self.linear(x)

# Create the model and dummy input
model = DummyModel()
dummy_input = torch.randn(1, 10) # batch size of 1, 10 features

# Export the model to ONNX format
torch.onnx.export(model, dummy_input, "dummy_model.onnx", verbose=True)
"""
*
*/
private val modelPath: String = "src/test/resources/onnx/models/dummy_model.onnx"

private val tmpDirPath: String = UUID.randomUUID().toString.takeRight(12) + "_onnx"
var tmpFolder: String = _

before {
tmpFolder = Files
.createDirectory(Paths.get(tmpDirPath))
.toAbsolutePath
.toString
}

after {
FileHelper.delete(tmpFolder)
}

"a dummy onnx wrapper" should "get session correctly" taggedAs FastTest in {
val modelBytes: Array[Byte] = Files.readAllBytes(Paths.get(modelPath))
val dummyOnnxWrapper = new OnnxWrapper(modelBytes)
dummyOnnxWrapper.getSession()
}

"a dummy onnx wrapper" should "saveToFile correctly" taggedAs FastTest in {
val modelBytes: Array[Byte] = Files.readAllBytes(Paths.get(modelPath))
val dummyOnnxWrapper = new OnnxWrapper(modelBytes)
dummyOnnxWrapper.saveToFile(Paths.get(tmpFolder, "modelFromTest.zip").toString)
// verify file existence
assert(new File(tmpFolder, "modelFromTest.zip").exists())
}
}
164 changes: 164 additions & 0 deletions src/test/scala/com/johnsnowlabs/util/ZipArchiveUtilTestSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* Copyright 2017-2022 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.johnsnowlabs.util

import com.johnsnowlabs.tags.FastTest
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers._
import java.nio.file.{Files, Paths, Path}
import java.io.File
import com.johnsnowlabs.util.FileHelper
import org.scalatest.BeforeAndAfter
import java.util.UUID

class ZipArchiveUtilTestSpec extends AnyFlatSpec with BeforeAndAfter {
private val tmpDirPath: String = UUID.randomUUID().toString.takeRight(12) + "_onnx"
var tmpFolder: String = _

before {
// create temp dir for testing
tmpFolder = Files
.createDirectory(Paths.get(tmpDirPath))
.toAbsolutePath
.toString

// create files and dirs for recusive testing
new File(tmpFolder, "fileA").createNewFile()
Files.createDirectory(Paths.get(tmpDirPath, "dir"))
Files.createFile(Paths.get(tmpFolder, "dir", "fileA"))
Files.createFile(Paths.get(tmpFolder, "dir", "fileB"))
Files.createDirectory(Paths.get(tmpDirPath, "dir", "dir2"))
Files.createFile(Paths.get(tmpDirPath, "dir", "dir2", "fileC"))
}

after {
// delete the temp directory
FileHelper.delete(tmpFolder)
}

"listFilesRecursive" should "throw exception if the file doesn't exist" taggedAs FastTest in {
val isIOException =
try {
ZipArchiveUtil.listFilesRecursive(new File("a"))
false
} catch {
case e: java.io.IOException => true
case _: Throwable => false
}
assert(isIOException)
}

"listFilesRecursive" should "return a single item list if give a file" taggedAs FastTest in {
val list = ZipArchiveUtil.listFilesRecursive(new File(tmpFolder, "fileA"))
assert(list.length == 1)
assert(list.head.equals(new File(tmpFolder, "fileA")))
}

"listFilesRecursive" should "return a single item list if give a file within folder" taggedAs FastTest in {
val list = ZipArchiveUtil.listFilesRecursive(new File(tmpFolder, "dir/fileA"))
assert(list.length == 1)
assert(list.head.equals(new File(tmpFolder, "dir/fileA")))
}

"listFilesRecursive" should "return a list with 3 items if give the dir folder" taggedAs FastTest in {
val list = ZipArchiveUtil.listFilesRecursive(new File(tmpFolder, "dir"))
assert(list.length == 3)

list.toSet should contain theSameElementsAs Set(
new File(tmpFolder, "dir/dir2/fileC"),
new File(tmpFolder, "dir/fileA"),
new File(tmpFolder, "dir/fileB"))
}

"addFileToZipEntry" should "return zip entry with absolute setting" taggedAs FastTest in {
val zipEntry = ZipArchiveUtil.addFileToZipEntry(new File("fileA"), null, false)
assert(zipEntry.getName == "fileA")
}

"addFileToZipEntry" should "return zip entry with relative setting" taggedAs FastTest in {
val zipEntry = ZipArchiveUtil.addFileToZipEntry(new File("dir/fileA"), new File("dir"), true)
assert(zipEntry.getName == "fileA")
}

"addFileToZipEntry" should "return zip entry full path with absolute setting" taggedAs FastTest in {
val zipEntry = ZipArchiveUtil.addFileToZipEntry(new File("dir/fileA"), new File("dir"), false)
assert(zipEntry.getName == "fileA")
}

"createZip" should "create zip for a single file" taggedAs FastTest in {
ZipArchiveUtil.createZip(
List(new File(tmpFolder, "dir/fileA")),
new File(tmpFolder, "targetA.zip"),
null)

assert(new File(tmpFolder, "targetA.zip").exists())
}

"createZip" should "create zip" taggedAs FastTest in {
ZipArchiveUtil.createZip(
List(
new File(Paths.get(tmpFolder, "dir", "fileA").toString),
new File(Paths.get(tmpFolder, "dir", "fileB").toString)),
new File(tmpFolder, "targetDir.zip"),
new File(Paths.get(tmpFolder).toString))

assert(new File(tmpFolder, "targetDir.zip").exists())
}

"zipFile" should "zip a single file" taggedAs FastTest in {
ZipArchiveUtil.zipFile(
new File(Paths.get(tmpFolder, "dir", "fileA").toString),
new File(tmpFolder, "targetA.zip"))
assert(new File(tmpFolder, "targetA.zip").exists())
}

"zipDir" should "zip a directory" taggedAs FastTest in {
ZipArchiveUtil.zipDir(
new File(Paths.get(tmpFolder, "dir").toString),
new File(tmpFolder, "targetDir.zip"))
assert(new File(tmpFolder, "targetDir.zip").exists())
}

"zip" should "zip a single file with String input" taggedAs FastTest in {
ZipArchiveUtil.zip(
Paths.get(tmpFolder, "dir", "fileA").toString,
Paths.get(tmpFolder, "targetA.zip").toString)
assert(new File(tmpFolder, "targetA.zip").exists())
}

"zip" should "zip a dir with String input" taggedAs FastTest in {
ZipArchiveUtil.zip(
Paths.get(tmpFolder, "dir").toString,
Paths.get(tmpFolder, "targetDir.zip").toString)
assert(new File(tmpFolder, "targetDir.zip").exists())
}

"zip" should "throw exception if the folder not exist since we are not responsible to create folders" taggedAs FastTest in {
val isIOExceptinoCaught =
try {
ZipArchiveUtil.zip(
Paths.get(tmpFolder, "dir").toString,
Paths.get(tmpFolder, "otherdir/targetDir.zip").toString)
false
} catch {
case e: java.io.IOException => true
case _: Throwable => false
}

assert(isIOExceptinoCaught)
}
}
Loading