Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mlm adaptation #287

Merged
merged 180 commits into from
Jun 27, 2022
Merged

Conversation

lintangsutawika
Copy link
Contributor

This is the mlm adaptation part
@thomasw21

Comment on lines +381 to +387
@property
def bos_token_id(self):
raise NotImplementedError("Missing <bos>")

@property
def eos_token_id(self):
raise NotImplementedError("Missing <eos>")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's quite annoying but we use HF's as already so we shouldn't collapse tokens IMO.

)

sample = train_ds[0]
self.assertEqual(len(sample["input_tokens"]) + len(sample["target_tokens"]), args.seq_length)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very basic test ... couldn't think of a more robust test.

Comment on lines +326 to +329
hf_tokenizer_kwargs = {}
if vocab_extra_ids > 0:
# TODO @thomasw21 we might need to concatenate to a pre-existing list?
hf_tokenizer_kwargs["additional_special_tokens"] = [f"<extra_id_{_id}>" for _id in range(vocab_extra_ids)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll want something cleaner here, but since we're not using additional_special_tokens in our tokenizer I'd say it's okay to override that value. I think at some point we'll push another tokenizer on the hub with the additional tokens. cc @SaulLu

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this will override the previous special tokens if they exist right? Does it not work to add them later with self.tokenizer.add_special_tokens?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not very familiar with this part of tokenzier code but I would also have expected it to expand the vocabulary instead of re-using existing extra tokens :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK yes if the embedding is padded later then this is the right way

Copy link
Collaborator

@SaulLu SaulLu Jun 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just checked, whether they are added in the from_pretrained method or with the add_special_tokens method, in both cases the tokens will be added but the value of the additional_special_tokens property will be overwritten.

If we want to keep the previously added tokens in the additional_special_tokens property we need to do:

self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, **hf_tokenizer_kwargs)
new_special_tokens = {
    "additional_special_tokens": tok.additional_special_tokens +  [f"<extra_id_{_id}>" for _id in range(vocab_extra_ids) if f"<extra_id_{_id}>" not in self.tokenizer.additional_special_tokens]
    }
self.tokenizer.add_special_tokens(new_special_tokens)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum let's not over engineer this ... we're not using any right now, I can add a warning saying we're going to overwrite the additional tokens (otherwise I have to switch the logic a bit for no reason).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with a warning, but is there anything wrong with @lucile's solution? Ij amy case don't think this is essential, was just curious, as long as it is documented i don't think we should spend too much time on it.

Copy link
Member

@thomasw21 thomasw21 Jun 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue if that the MLMDataset needs to be able to query the sentinel tokens. Right now I assume that all additional_special_tokens are sentinel tokens. So now we need to build a tokenizer that has specific mlm tokens, I can try and do that just didn't want to do it out of lazyness :D

@thomasw21 thomasw21 marked this pull request as ready for review June 23, 2022 13:10
megatron/data/mlm_dataset.py Outdated Show resolved Hide resolved
megatron/data/mlm_dataset.py Show resolved Hide resolved
spans_start[1:], np.full((1,), len(sample), dtype=np.int32)]
)

sentinel_token_ids = all_sentinel_token_ids[:num_noise_spans]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given num_noise_spans is always the same, maybe slightly faster to store sentinel_token_ids as a class attribute of MLMDataset & feed it as an argument to the func

I wonder if it wouldn't be better to make num_noise_spans probabilistic instead of deterministic

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also have a strong intuition that we should want to change those values. But the idea is to have T5 mlm here and rely on their number.

megatron/data/mlm_dataset.py Outdated Show resolved Hide resolved
Co-authored-by: Niklas Muennighoff <n.muennighoff@gmail.com>
up to num_items
"""
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
# TODO @thomasw21 handle random state correctly, ie synchronized across TP.
Copy link
Collaborator

@TevenLeScao TevenLeScao Jun 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This scares me a bit because TP-random states things are hard to debug but tbh we should just test asap to see if loss goes down at the expected rate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes I need to double check that. I can have a go at it. Have forgotten about this TODO.

Copy link
Collaborator

@TevenLeScao TevenLeScao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Read through the PR and didn't catch anything worrying. Let's just test it ASAP.

@thomasw21
Copy link
Member

@Muennighoff Waiting for your approval.

Copy link
Collaborator

@Muennighoff Muennighoff left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job! Is the plan roughly as follows?:
Merge this -> Finish & Merge MTF (& Figure out plan for multilingual retention) -> Try out MLM+MTF on small bloom model -> Try out Enc-Dec+MLM+MTF on small bloom model -> Try out best option on bloom176B

@thomasw21
Copy link
Member

Not exactly, in the priority order (in case we have idle compute we go to the next item):

  • merge and finish MTF
  • run MTF on small models
  • run MTF on big models (pending proof that MTF works for smaller models)
  • run MLM + MTF on small models
  • run MLM + MTF on big model (pending proof that MLM + MTF works for smaller models)

@thomasw21 thomasw21 merged commit 9d26431 into bigscience-workshop:main Jun 27, 2022
adammoody pushed a commit to adammoody/Megatron-DeepSpeed that referenced this pull request Dec 18, 2023
Modify universal checkpoint parameter patterns based on the specific model
configuration. This commit adds support for llama family of models.

Signed-off-by: Moshe Island <misland@habana.ai>
Co-authored-by: Moshe Island <misland@habana.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants