From c09baa00c7dcded5fa7d6c58cfac8a4adeed5035 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Thu, 3 Feb 2022 11:19:56 -0800 Subject: [PATCH] [sentencepiece] Fixes #1149, fix NPE bug Change-Id: Ib86efb38def0bcb21862c5493683395071eb3592 --- .../ai/djl/sentencepiece/SpTokenizer.java | 8 ++--- .../ai/djl/sentencepiece/SpTokenizerTest.java | 34 +++++++++++++++---- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpTokenizer.java b/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpTokenizer.java index 613fcced47f..37b39bba0e7 100644 --- a/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpTokenizer.java +++ b/extensions/sentencepiece/src/main/java/ai/djl/sentencepiece/SpTokenizer.java @@ -93,13 +93,13 @@ private void loadModel(Path modelPath, String prefix) throws IOException { "Model path doesn't exist: " + modelPath.toAbsolutePath()); } Path modelDir = modelPath.toAbsolutePath(); + if (prefix == null || prefix.isEmpty()) { + prefix = modelDir.toFile().getName(); + } Path modelFile = findModelFile(modelDir, prefix); if (modelFile == null) { // TODO: support proto and IOStream model - modelFile = findModelFile(modelDir, modelDir.toFile().getName()); - if (modelFile == null) { - throw new FileNotFoundException("No .model found in : " + modelPath); - } + throw new FileNotFoundException("No .model found in : " + modelPath); } String modelFilePath = modelFile.toString(); diff --git a/extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java b/extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java index 155290e371f..9838db19422 100644 --- a/extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java +++ b/extensions/sentencepiece/src/test/java/ai/djl/sentencepiece/SpTokenizerTest.java @@ -29,17 +29,17 @@ public class SpTokenizerTest { @BeforeTest public void downloadModel() throws IOException { - Path modelFile = Paths.get("build/test/models/sententpiece_test_model.model"); + Path modelFile = Paths.get("build/test/sp_model/sp_model.model"); if (Files.notExists(modelFile)) { DownloadUtils.download( "https://resources.djl.ai/test-models/sententpiece_test_model.model", - "build/test/models/sententpiece_test_model.model"); + "build/test/sp_model/sp_model.model"); } } @Test public void testLoadFromBytes() throws IOException { - Path modelPath = Paths.get("build/test/models/sententpiece_test_model.model"); + Path modelPath = Paths.get("build/test/sp_model/sp_model.model"); byte[] bytes = Files.readAllBytes(modelPath); try (SpTokenizer tokenizer = new SpTokenizer(bytes)) { String original = "Hello World"; @@ -55,7 +55,7 @@ public void testLoadFromBytes() throws IOException { public void testTokenize() throws IOException { TestRequirements.notWindows(); - Path modelPath = Paths.get("build/test/models/sententpiece_test_model.model"); + Path modelPath = Paths.get("build/test/sp_model"); try (SpTokenizer tokenizer = new SpTokenizer(modelPath)) { String original = "Hello World"; List tokens = tokenizer.tokenize(original); @@ -71,7 +71,7 @@ public void testTokenize() throws IOException { public void testUtf16Tokenize() throws IOException { TestRequirements.notWindows(); - Path modelPath = Paths.get("build/test/models/sententpiece_test_model.model"); + Path modelPath = Paths.get("build/test/sp_model/sp_model.model"); try (SpTokenizer tokenizer = new SpTokenizer(modelPath)) { String original = "\uD83D\uDC4B\uD83D\uDC4B"; List tokens = tokenizer.tokenize(original); @@ -84,8 +84,8 @@ public void testUtf16Tokenize() throws IOException { public void testEncodeDecode() throws IOException { TestRequirements.notWindows(); - Path modelPath = Paths.get("build/test/models"); - String prefix = "sententpiece_test_model"; + Path modelPath = Paths.get("build/test/sp_model"); + String prefix = "sp_model"; try (SpTokenizer tokenizer = new SpTokenizer(modelPath, prefix)) { String original = "Hello World"; SpProcessor processor = tokenizer.getProcessor(); @@ -96,4 +96,24 @@ public void testEncodeDecode() throws IOException { Assert.assertEquals(recovered, original); } } + + @Test + public void testModelNotFound() throws IOException { + TestRequirements.notWindows(); + + Assert.assertThrows( + () -> { + new SpTokenizer(Paths.get("build/test/non-exists")); + }); + + Assert.assertThrows( + () -> { + new SpTokenizer(Paths.get("build/test/sp_model"), "non-exists.model"); + }); + + Assert.assertThrows( + () -> { + new SpTokenizer(Paths.get("build/test/sp_model"), "non-exists"); + }); + } }