From 31bfdf14e59f1e6508bcd7531f2ffeef90850fc3 Mon Sep 17 00:00:00 2001 From: Siddharth Venkatesan Date: Fri, 11 Nov 2022 17:04:54 -0800 Subject: [PATCH] Add batch decoding methods for tokenizers --- extensions/tokenizers/rust/src/lib.rs | 41 +++++++++++++++++ .../tokenizers/HuggingFaceTokenizer.java | 21 +++++++++ .../tokenizers/jni/TokenizersLibrary.java | 2 + .../tokenizers/HuggingFaceTokenizerTest.java | 44 +++++++++++++++++++ 4 files changed, 108 insertions(+) diff --git a/extensions/tokenizers/rust/src/lib.rs b/extensions/tokenizers/rust/src/lib.rs index 92038b2226d..0182327b828 100644 --- a/extensions/tokenizers/rust/src/lib.rs +++ b/extensions/tokenizers/rust/src/lib.rs @@ -428,6 +428,47 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ ret } +#[no_mangle] +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_batchDecode( + env: JNIEnv, + _: JObject, + handle: jlong, + batch_ids: jobjectArray, + skip_special_tokens: jboolean, +) -> jobjectArray { + let tokenizer = cast_handle::(handle); + let batch_len = env.get_array_length(batch_ids).unwrap(); + let mut batch_decode_input: Vec> = Vec::new(); + for i in 0..batch_len { + let item = env.get_object_array_element(batch_ids, i).unwrap(); + let sequence_ids = env + .get_long_array_elements(*item, ReleaseMode::NoCopyBack) + .unwrap(); + let sequence_ids_ptr = sequence_ids.as_ptr(); + let sequence_len = sequence_ids.size().unwrap() as usize; + let mut decode_ids: Vec = Vec::new(); + for i in 0..sequence_len { + unsafe { + let val = sequence_ids_ptr.add(i); + decode_ids.push(*val as u32); + } + } + batch_decode_input.push(decode_ids); + } + let decoding: Vec = tokenizer + .decode_batch(batch_decode_input, skip_special_tokens == JNI_TRUE) + .unwrap(); + let ret: jobjectArray = env + .new_object_array(batch_len, "java/lang/String", JObject::null()) + .unwrap(); + for (i, decode) in decoding.iter().enumerate() { + let item: JString = env.new_string(&decode).unwrap(); + env.set_object_array_element(ret, i as jsize, item) + .unwrap(); + } + ret +} + #[no_mangle] pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTruncationStrategy( env: JNIEnv, diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java index fd05caea6a0..8881292b820 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java @@ -339,6 +339,27 @@ public String decode(long[] ids) { return decode(ids, !addSpecialTokens); } + /** + * Returns the decoded Strings from the input batch ids. + * + * @param batchIds the batch of id sequences to decode + * @param skipSpecialTokens whether to remove special tokens in the decoding + * @return the decoded Strings from the input batch ids + */ + public String[] batchDecode(long[][] batchIds, boolean skipSpecialTokens) { + return TokenizersLibrary.LIB.batchDecode(getHandle(), batchIds, skipSpecialTokens); + } + + /** + * Returns the decoded Strings from the input batch ids. + * + * @param batchIds the batch of id sequences to decode + * @return the decoded Strings from the input batch ids + */ + public String[] batchDecode(long[][] batchIds) { + return batchDecode(batchIds, !addSpecialTokens); + } + /** * Creates a builder to build a {@code HuggingFaceTokenizer}. * diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java index 0e775d66c5b..b37947d78c2 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java @@ -35,6 +35,8 @@ public native long encodeDual( public native long[] batchEncode(long tokenizer, String[] inputs, boolean addSpecialTokens); + public native String[] batchDecode(long tokenizer, long[][] batchIds, boolean addSpecialTokens); + public native void deleteEncoding(long encoding); public native long[] getTokenIds(long encoding); diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java index 2413e24e1a1..033e4afa4d3 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java @@ -414,4 +414,48 @@ public void testSpecialTokenHandling() throws IOException { } } } + + @Test + public void testBatchProcessing() throws IOException { + String[] inputs = + new String[] { + "Hello there friend", "How are you today", "Good weather I'd say", "I am Happy!" + }; + String[] outputsWithSpecialTokens = + new String[] { + "[CLS] Hello there friend [SEP]", + "[CLS] How are you today [SEP]", + "[CLS] Good weather I ' d say [SEP]", + "[CLS] I am Happy! [SEP]" + }; + String[] outputsWithoutSpecialTokens = + new String[] { + "Hello there friend", + "How are you today", + "Good weather I ' d say", + "I am Happy!" + }; + try (HuggingFaceTokenizer tokenizer = + HuggingFaceTokenizer.builder().optTokenizerName("bert-base-cased").build()) { + + // default tokenizer with special tokens included + Encoding[] encodings = tokenizer.batchEncode(inputs); + long[][] batchIds = + Arrays.stream(encodings).map(Encoding::getIds).toArray(long[][]::new); + String[] outputs = tokenizer.batchDecode(batchIds); + Assert.assertEquals(outputs, outputsWithSpecialTokens); + + // encode with special tokens, decode with special tokens + encodings = tokenizer.batchEncode(inputs, true); + batchIds = Arrays.stream(encodings).map(Encoding::getIds).toArray(long[][]::new); + outputs = tokenizer.batchDecode(batchIds, false); + Assert.assertEquals(outputs, outputsWithSpecialTokens); + + // encode without special tokens, decode without special tokens + encodings = tokenizer.batchEncode(inputs, false); + batchIds = Arrays.stream(encodings).map(Encoding::getIds).toArray(long[][]::new); + outputs = tokenizer.batchDecode(batchIds, true); + Assert.assertEquals(outputs, outputsWithoutSpecialTokens); + } + } }