Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tokenizers] Add batch decoding methods for tokenizers #2154

Merged
merged 1 commit into from
Nov 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions extensions/tokenizers/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,47 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_
ret
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_batchDecode(
env: JNIEnv,
_: JObject,
handle: jlong,
batch_ids: jobjectArray,
skip_special_tokens: jboolean,
) -> jobjectArray {
let tokenizer = cast_handle::<Tokenizer>(handle);
let batch_len = env.get_array_length(batch_ids).unwrap();
let mut batch_decode_input: Vec<Vec<u32>> = Vec::new();
for i in 0..batch_len {
let item = env.get_object_array_element(batch_ids, i).unwrap();
let sequence_ids = env
.get_long_array_elements(*item, ReleaseMode::NoCopyBack)
.unwrap();
let sequence_ids_ptr = sequence_ids.as_ptr();
let sequence_len = sequence_ids.size().unwrap() as usize;
let mut decode_ids: Vec<u32> = Vec::new();
for i in 0..sequence_len {
unsafe {
let val = sequence_ids_ptr.add(i);
decode_ids.push(*val as u32);
}
}
batch_decode_input.push(decode_ids);
}
let decoding: Vec<String> = tokenizer
.decode_batch(batch_decode_input, skip_special_tokens == JNI_TRUE)
.unwrap();
let ret: jobjectArray = env
.new_object_array(batch_len, "java/lang/String", JObject::null())
.unwrap();
for (i, decode) in decoding.iter().enumerate() {
let item: JString = env.new_string(&decode).unwrap();
env.set_object_array_element(ret, i as jsize, item)
.unwrap();
}
ret
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTruncationStrategy(
env: JNIEnv,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,27 @@ public String decode(long[] ids) {
return decode(ids, !addSpecialTokens);
}

/**
* Returns the decoded Strings from the input batch ids.
*
* @param batchIds the batch of id sequences to decode
* @param skipSpecialTokens whether to remove special tokens in the decoding
* @return the decoded Strings from the input batch ids
*/
public String[] batchDecode(long[][] batchIds, boolean skipSpecialTokens) {
return TokenizersLibrary.LIB.batchDecode(getHandle(), batchIds, skipSpecialTokens);
}

/**
* Returns the decoded Strings from the input batch ids.
*
* @param batchIds the batch of id sequences to decode
* @return the decoded Strings from the input batch ids
*/
public String[] batchDecode(long[][] batchIds) {
return batchDecode(batchIds, !addSpecialTokens);
}

/**
* Creates a builder to build a {@code HuggingFaceTokenizer}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ public native long encodeDual(

public native long[] batchEncode(long tokenizer, String[] inputs, boolean addSpecialTokens);

public native String[] batchDecode(long tokenizer, long[][] batchIds, boolean addSpecialTokens);

public native void deleteEncoding(long encoding);

public native long[] getTokenIds(long encoding);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,4 +414,48 @@ public void testSpecialTokenHandling() throws IOException {
}
}
}

@Test
public void testBatchProcessing() throws IOException {
String[] inputs =
new String[] {
"Hello there friend", "How are you today", "Good weather I'd say", "I am Happy!"
};
String[] outputsWithSpecialTokens =
new String[] {
"[CLS] Hello there friend [SEP]",
"[CLS] How are you today [SEP]",
"[CLS] Good weather I ' d say [SEP]",
"[CLS] I am Happy! [SEP]"
};
String[] outputsWithoutSpecialTokens =
new String[] {
"Hello there friend",
"How are you today",
"Good weather I ' d say",
"I am Happy!"
};
try (HuggingFaceTokenizer tokenizer =
HuggingFaceTokenizer.builder().optTokenizerName("bert-base-cased").build()) {

// default tokenizer with special tokens included
Encoding[] encodings = tokenizer.batchEncode(inputs);
long[][] batchIds =
Arrays.stream(encodings).map(Encoding::getIds).toArray(long[][]::new);
String[] outputs = tokenizer.batchDecode(batchIds);
Assert.assertEquals(outputs, outputsWithSpecialTokens);

// encode with special tokens, decode with special tokens
encodings = tokenizer.batchEncode(inputs, true);
batchIds = Arrays.stream(encodings).map(Encoding::getIds).toArray(long[][]::new);
outputs = tokenizer.batchDecode(batchIds, false);
Assert.assertEquals(outputs, outputsWithSpecialTokens);

// encode without special tokens, decode without special tokens
encodings = tokenizer.batchEncode(inputs, false);
batchIds = Arrays.stream(encodings).map(Encoding::getIds).toArray(long[][]::new);
outputs = tokenizer.batchDecode(batchIds, true);
Assert.assertEquals(outputs, outputsWithoutSpecialTokens);
}
}
}