Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bytelevel normalizer to fix decode when adding tokens to BPE #1555

Merged
merged 9 commits into from
Jul 15, 2024

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Jun 18, 2024

This revert the previous breaking change.

Also add a new ByteLevel normalizer, which replaces the ByteLevel pre_tokenizer.
Checked that we can add chines / Cyrillic tokens which are properly encoded and decoder.

Fixes #1392

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines 850 to 863
let tokens = ids
.iter()
.filter_map(|id| {
self.added_vocabulary
.id_to_token(*id, &self.model)
.filter(|token| {
!skip_special_tokens || !self.added_vocabulary.is_special_token(token)
})
})
.collect::<Vec<_>>();

if let Some(decoder) = &self.decoder {
decoder.decode(tokens)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted to what we originally had

@ArthurZucker
Copy link
Collaborator Author

The test passes locally !

tokenizers/src/pre_tokenizers/byte_level.rs Outdated Show resolved Hide resolved
tokenizers/src/tokenizer/added_vocabulary.rs Outdated Show resolved Hide resolved
tokenizers/src/tokenizer/added_vocabulary.rs Outdated Show resolved Hide resolved
tokenizers/src/tokenizer/pre_tokenizer.rs Outdated Show resolved Hide resolved
Comment on lines +33 to +34
/// Strip the normalized string inplace
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why doesn't this live in NormalizedString like so:

impl NormalizedString {
  fn normalize_byte_level(&mut self) -> Result<()> { // ... }
}

?
Feels a bit weird to have an empty stateless struct for just a function, but may be due to the structure of tokenizers / python.

tokenizers/src/models/bpe/trainer.rs Outdated Show resolved Hide resolved
@ArthurZucker ArthurZucker changed the title Fix decode Add bytelevel normalizer to fix decode when adding tokens to BPE Jul 15, 2024
@ArthurZucker ArthurZucker merged commit 4ea2f23 into main Jul 15, 2024
13 checks passed
@ArthurZucker ArthurZucker deleted the fix-decode branch July 15, 2024 10:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

added_tokens with bytemap charaters in ByteLevel could not be decoded correctly
3 participants