diff --git a/src/main/scala/com/johnsnowlabs/util/ZipArchiveUtil.scala b/src/main/scala/com/johnsnowlabs/util/ZipArchiveUtil.scala index e8cca00d22b533..d37a6ce90e7a11 100644 --- a/src/main/scala/com/johnsnowlabs/util/ZipArchiveUtil.scala +++ b/src/main/scala/com/johnsnowlabs/util/ZipArchiveUtil.scala @@ -23,52 +23,47 @@ import java.util.zip.{ZipEntry, ZipFile, ZipOutputStream} import scala.collection.JavaConverters._ object ZipArchiveUtil { - - private def listFiles(file: File, outputFilename: String): List[String] = { + // Recursively lists all files in a given directory, returning a list of their absolute paths + private[util] def listFilesRecursive(file: File): List[File] = { file match { - case file if file.isFile => - if (file.getName != outputFilename) - List(file.getAbsoluteFile.toString) - else - List() + case file if file.isFile => List(new File(file.getAbsoluteFile.toString)) case file if file.isDirectory => val fList = file.list // Add all files in current dir to list and recur on subdirs - fList.foldLeft(List[String]())((pList: List[String], path: String) => - pList ++ listFiles(new File(file, path), outputFilename)) + fList.foldLeft(List[File]())((pList: List[File], path: String) => + pList ++ listFilesRecursive(new File(file, path))) case _ => throw new IOException("Bad path. No file or directory found.") } } - private def addFileToZipEntry( - filename: String, - parentPath: String, - filePathsCount: Int): ZipEntry = { - if (filePathsCount <= 1) - new ZipEntry(new File(filename).getName) - else { - // use relative path to avoid adding absolute path directories - val relative = new File(parentPath).toURI.relativize(new File(filename).toURI).getPath + private[util] def addFileToZipEntry( + filename: File, + parentPath: File, + useRelativePath: Boolean = false): ZipEntry = { + if (!useRelativePath) // use absolute path + new ZipEntry(filename.getName) + else { // use relative path + val relative = parentPath.toURI.relativize(filename.toURI).getPath new ZipEntry(relative) } } - private def createZip( - filePaths: List[String], - outputFilename: String, - parentPath: String): Unit = { + private[util] def createZip( + filePaths: List[File], + outputFilePath: File, + parentPath: File): Unit = { val Buffer = 2 * 1024 val data = new Array[Byte](Buffer) try { - val zipFileOS = new FileOutputStream(outputFilename) + val zipFileOS = new FileOutputStream(outputFilePath) val zip = new ZipOutputStream(zipFileOS) zip.setLevel(0) - filePaths.foreach((name: String) => { - val zipEntry = addFileToZipEntry(name, parentPath, filePaths.size) + filePaths.foreach((file: File) => { + val zipEntry = addFileToZipEntry(file, parentPath, filePaths.size > 1) // add zip entry to output stream zip.putNextEntry(new ZipEntry(zipEntry)) - val in = new BufferedInputStream(new FileInputStream(name), Buffer) + val in = new BufferedInputStream(new FileInputStream(file), Buffer) var b = in.read(data, 0, Buffer) while (b != -1) { zip.write(data, 0, b) @@ -86,10 +81,36 @@ object ZipArchiveUtil { } } - def zip(fileName: String, outputFileName: String): Unit = { - val file = new File(fileName) - val filePaths = listFiles(file, outputFileName) - createZip(filePaths, outputFileName, fileName) + private[util] def zipFile(soureFile: File, outputFilePath: File): Unit = { + createZip(List(soureFile.getAbsoluteFile), outputFilePath, null) + } + + private[util] def zipDir(sourceDir: File, outputFilePath: File): Unit = { + val filePaths = listFilesRecursive(sourceDir) + createZip(filePaths, outputFilePath, sourceDir) + } + + def zip(sourcePath: String, outputFilePath: String): Unit = { + val sourceFile = new File(sourcePath) + val outputFile = new File(outputFilePath) + if (sourceFile.equals(outputFile)) + throw new IllegalArgumentException("source path cannot be identical to target path") + + if (!outputFile.getParentFile().exists) + throw new IOException("the parent directory of output file doesn't exist") + + if (!sourceFile.exists()) + throw new IOException("zip source path must exsit") + + if (outputFile.exists()) + throw new IOException("zip target file exsits") + + if (sourceFile.isDirectory()) + zipDir(sourceFile, outputFile) + else if (sourceFile.isFile()) + zipFile(sourceFile, outputFile) + else + throw new IllegalArgumentException("only folder and file input are valid") } def unzip(file: File, destDirPath: Option[String] = None): String = { diff --git a/src/test/resources/onnx/models/dummy_model.onnx b/src/test/resources/onnx/models/dummy_model.onnx new file mode 100644 index 00000000000000..d1ba43acc9c529 Binary files /dev/null and b/src/test/resources/onnx/models/dummy_model.onnx differ diff --git a/src/test/scala/com/johnsnowlabs/ml/onnx/OnnxWrapperTestSpec.scala b/src/test/scala/com/johnsnowlabs/ml/onnx/OnnxWrapperTestSpec.scala new file mode 100644 index 00000000000000..2c3c6ede10d4fc --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/ml/onnx/OnnxWrapperTestSpec.scala @@ -0,0 +1,82 @@ +/* + * Copyright 2017-2022 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.ml.onnx + +import com.johnsnowlabs.tags.FastTest +import org.scalatest.flatspec.AnyFlatSpec +import java.nio.file.{Files, Paths, Path} +import java.io.File +import com.johnsnowlabs.util.FileHelper +import org.scalatest.BeforeAndAfter +import java.util.UUID + +class OnnxWrapperTestSpec extends AnyFlatSpec with BeforeAndAfter { + /* + * Dummy model was created with the following python script + """ + import torch + import torch.nn as nn + import torch.onnx + + # Define a simple neural network model + class DummyModel(nn.Module): + def __init__(self): + super(DummyModel, self).__init__() + self.linear = nn.Linear(in_features=10, out_features=5) + + def forward(self, x): + return self.linear(x) + + # Create the model and dummy input + model = DummyModel() + dummy_input = torch.randn(1, 10) # batch size of 1, 10 features + + # Export the model to ONNX format + torch.onnx.export(model, dummy_input, "dummy_model.onnx", verbose=True) + """ + * + */ + private val modelPath: String = "src/test/resources/onnx/models/dummy_model.onnx" + + private val tmpDirPath: String = UUID.randomUUID().toString.takeRight(12) + "_onnx" + var tmpFolder: String = _ + + before { + tmpFolder = Files + .createDirectory(Paths.get(tmpDirPath)) + .toAbsolutePath + .toString + } + + after { + FileHelper.delete(tmpFolder) + } + + "a dummy onnx wrapper" should "get session correctly" taggedAs FastTest in { + val modelBytes: Array[Byte] = Files.readAllBytes(Paths.get(modelPath)) + val dummyOnnxWrapper = new OnnxWrapper(modelBytes) + dummyOnnxWrapper.getSession() + } + + "a dummy onnx wrapper" should "saveToFile correctly" taggedAs FastTest in { + val modelBytes: Array[Byte] = Files.readAllBytes(Paths.get(modelPath)) + val dummyOnnxWrapper = new OnnxWrapper(modelBytes) + dummyOnnxWrapper.saveToFile(Paths.get(tmpFolder, "modelFromTest.zip").toString) + // verify file existence + assert(new File(tmpFolder, "modelFromTest.zip").exists()) + } +} diff --git a/src/test/scala/com/johnsnowlabs/util/ZipArchiveUtilTestSpec.scala b/src/test/scala/com/johnsnowlabs/util/ZipArchiveUtilTestSpec.scala new file mode 100644 index 00000000000000..90c23c74c53a0a --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/util/ZipArchiveUtilTestSpec.scala @@ -0,0 +1,164 @@ +/* + * Copyright 2017-2022 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.util + +import com.johnsnowlabs.tags.FastTest +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers._ +import java.nio.file.{Files, Paths, Path} +import java.io.File +import com.johnsnowlabs.util.FileHelper +import org.scalatest.BeforeAndAfter +import java.util.UUID + +class ZipArchiveUtilTestSpec extends AnyFlatSpec with BeforeAndAfter { + private val tmpDirPath: String = UUID.randomUUID().toString.takeRight(12) + "_onnx" + var tmpFolder: String = _ + + before { + // create temp dir for testing + tmpFolder = Files + .createDirectory(Paths.get(tmpDirPath)) + .toAbsolutePath + .toString + + // create files and dirs for recusive testing + new File(tmpFolder, "fileA").createNewFile() + Files.createDirectory(Paths.get(tmpDirPath, "dir")) + Files.createFile(Paths.get(tmpFolder, "dir", "fileA")) + Files.createFile(Paths.get(tmpFolder, "dir", "fileB")) + Files.createDirectory(Paths.get(tmpDirPath, "dir", "dir2")) + Files.createFile(Paths.get(tmpDirPath, "dir", "dir2", "fileC")) + } + + after { + // delete the temp directory + FileHelper.delete(tmpFolder) + } + + "listFilesRecursive" should "throw exception if the file doesn't exist" taggedAs FastTest in { + val isIOException = + try { + ZipArchiveUtil.listFilesRecursive(new File("a")) + false + } catch { + case e: java.io.IOException => true + case _: Throwable => false + } + assert(isIOException) + } + + "listFilesRecursive" should "return a single item list if give a file" taggedAs FastTest in { + val list = ZipArchiveUtil.listFilesRecursive(new File(tmpFolder, "fileA")) + assert(list.length == 1) + assert(list.head.equals(new File(tmpFolder, "fileA"))) + } + + "listFilesRecursive" should "return a single item list if give a file within folder" taggedAs FastTest in { + val list = ZipArchiveUtil.listFilesRecursive(new File(tmpFolder, "dir/fileA")) + assert(list.length == 1) + assert(list.head.equals(new File(tmpFolder, "dir/fileA"))) + } + + "listFilesRecursive" should "return a list with 3 items if give the dir folder" taggedAs FastTest in { + val list = ZipArchiveUtil.listFilesRecursive(new File(tmpFolder, "dir")) + assert(list.length == 3) + + list.toSet should contain theSameElementsAs Set( + new File(tmpFolder, "dir/dir2/fileC"), + new File(tmpFolder, "dir/fileA"), + new File(tmpFolder, "dir/fileB")) + } + + "addFileToZipEntry" should "return zip entry with absolute setting" taggedAs FastTest in { + val zipEntry = ZipArchiveUtil.addFileToZipEntry(new File("fileA"), null, false) + assert(zipEntry.getName == "fileA") + } + + "addFileToZipEntry" should "return zip entry with relative setting" taggedAs FastTest in { + val zipEntry = ZipArchiveUtil.addFileToZipEntry(new File("dir/fileA"), new File("dir"), true) + assert(zipEntry.getName == "fileA") + } + + "addFileToZipEntry" should "return zip entry full path with absolute setting" taggedAs FastTest in { + val zipEntry = ZipArchiveUtil.addFileToZipEntry(new File("dir/fileA"), new File("dir"), false) + assert(zipEntry.getName == "fileA") + } + + "createZip" should "create zip for a single file" taggedAs FastTest in { + ZipArchiveUtil.createZip( + List(new File(tmpFolder, "dir/fileA")), + new File(tmpFolder, "targetA.zip"), + null) + + assert(new File(tmpFolder, "targetA.zip").exists()) + } + + "createZip" should "create zip" taggedAs FastTest in { + ZipArchiveUtil.createZip( + List( + new File(Paths.get(tmpFolder, "dir", "fileA").toString), + new File(Paths.get(tmpFolder, "dir", "fileB").toString)), + new File(tmpFolder, "targetDir.zip"), + new File(Paths.get(tmpFolder).toString)) + + assert(new File(tmpFolder, "targetDir.zip").exists()) + } + + "zipFile" should "zip a single file" taggedAs FastTest in { + ZipArchiveUtil.zipFile( + new File(Paths.get(tmpFolder, "dir", "fileA").toString), + new File(tmpFolder, "targetA.zip")) + assert(new File(tmpFolder, "targetA.zip").exists()) + } + + "zipDir" should "zip a directory" taggedAs FastTest in { + ZipArchiveUtil.zipDir( + new File(Paths.get(tmpFolder, "dir").toString), + new File(tmpFolder, "targetDir.zip")) + assert(new File(tmpFolder, "targetDir.zip").exists()) + } + + "zip" should "zip a single file with String input" taggedAs FastTest in { + ZipArchiveUtil.zip( + Paths.get(tmpFolder, "dir", "fileA").toString, + Paths.get(tmpFolder, "targetA.zip").toString) + assert(new File(tmpFolder, "targetA.zip").exists()) + } + + "zip" should "zip a dir with String input" taggedAs FastTest in { + ZipArchiveUtil.zip( + Paths.get(tmpFolder, "dir").toString, + Paths.get(tmpFolder, "targetDir.zip").toString) + assert(new File(tmpFolder, "targetDir.zip").exists()) + } + + "zip" should "throw exception if the folder not exist since we are not responsible to create folders" taggedAs FastTest in { + val isIOExceptinoCaught = + try { + ZipArchiveUtil.zip( + Paths.get(tmpFolder, "dir").toString, + Paths.get(tmpFolder, "otherdir/targetDir.zip").toString) + false + } catch { + case e: java.io.IOException => true + case _: Throwable => false + } + + assert(isIOExceptinoCaught) + } +}