diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java index a5b2b61168..ce789a15dd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java @@ -11,7 +11,6 @@ import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorFactory; -import ai.djl.util.ZipUtils; import lombok.extern.log4j.Log4j2; import org.apache.commons.io.FileUtils; import org.opensearch.ml.common.FunctionName; @@ -27,9 +26,9 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.Predictable; +import org.opensearch.ml.engine.utils.ZipUtils; import java.io.File; -import java.io.FileInputStream; import java.nio.file.Path; import java.security.AccessController; import java.security.PrivilegedActionException; @@ -185,9 +184,7 @@ private void loadModel(File modelZipFile, String modelId, String modelName, Stri if (pathFile.exists()) { FileUtils.deleteDirectory(pathFile); } - try (FileInputStream fileInputStream = new FileInputStream(modelZipFile)) { - ZipUtils.unzip(fileInputStream, modelPath); - } + ZipUtils.unzip(modelZipFile, modelPath); boolean findModelFile = false; for (File file : pathFile.listFiles()) { String name = file.getName(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ZipUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ZipUtils.java new file mode 100644 index 0000000000..974e397fb9 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/ZipUtils.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.utils; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.util.Enumeration; + +import org.apache.commons.compress.archivers.zip.ZipArchiveEntry; +import org.apache.commons.compress.archivers.zip.ZipFile; +import lombok.extern.log4j.Log4j2; + +/** + * A util class contains zip file related operations. + */ +@Log4j2 +public class ZipUtils { + + /** + * Uncompressed a zip file. + * @param zipFile zip file to be uncompressed + * @param dest the destination path of this uncompress + */ + public static void unzip(File zipFile, Path dest) { + try { + ZipFile unzipFile = new ZipFile(zipFile); + Enumeration en = unzipFile.getEntries(); + ZipArchiveEntry zipEntry; + while (en.hasMoreElements()) { + zipEntry = en.nextElement(); + String name = zipEntry.getName(); + Path file = dest.resolve(name).toAbsolutePath(); + if (!file.normalize().startsWith(dest.toAbsolutePath())) + throw new RuntimeException("Bad zip entry"); + if (zipEntry.isDirectory()) { + Files.createDirectories(file); + } else { + Path parentFile = file.getParent(); + if (parentFile == null) { + throw new AssertionError( + "Parent path should never be null: " + file); + } + Files.createDirectories(parentFile); + InputStream inputStream = unzipFile.getInputStream(zipEntry); + Files.copy(inputStream, file, StandardCopyOption.REPLACE_EXISTING); + inputStream.close(); + } + } + } catch (IOException e) { + throw new IllegalArgumentException("Wrong input file", e); + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ZipUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ZipUtilsTest.java new file mode 100644 index 0000000000..88b958fd36 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ZipUtilsTest.java @@ -0,0 +1,39 @@ +package org.opensearch.ml.engine.utils; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import java.io.File; +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Objects; + +public class ZipUtilsTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Test + public void testEmptyZipFile() throws IOException { + exceptionRule.expect(IllegalArgumentException.class); + Path path = Paths.get("build/empty.zip"); + File file = new File(path.toUri()); + Path output = Paths.get("build/output"); + Files.createDirectories(output); + ZipUtils.unzip(file, output); + } + + @Test + public void testUnzipFile() throws IOException, URISyntaxException { + File testZipFile = new File(Objects.requireNonNull(getClass().getResource("foo.zip")).toURI()); + Path output = Paths.get("build/output"); + Files.createDirectories(output); + ZipUtils.unzip(testZipFile, output); + Path testOutputPath = Paths.get("build/output/foo"); + Assert.assertTrue(Files.exists(testOutputPath)); + } +} diff --git a/ml-algorithms/src/test/resources/org/opensearch/ml/engine/utils/foo.zip b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/utils/foo.zip new file mode 100644 index 0000000000..8ccdd0a2c8 Binary files /dev/null and b/ml-algorithms/src/test/resources/org/opensearch/ml/engine/utils/foo.zip differ