Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add T5 model #145

Merged
merged 34 commits into from
Jan 9, 2023
Merged

Add T5 model #145

merged 34 commits into from
Jan 9, 2023

Conversation

PhungVanDuy
Copy link
Collaborator

@PhungVanDuy PhungVanDuy commented Dec 21, 2022

@PhungVanDuy PhungVanDuy changed the title Add T5 model [WIP] Add T5 model Dec 21, 2022
@PhungVanDuy PhungVanDuy marked this pull request as draft December 21, 2022 02:06
Copy link
Collaborator

@Dahoas Dahoas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I see you're adding a bunch of new tasks in your pr (which is great!) but they should probably be separated out into other prs if possible
  • Do you have a wandb you can share?
  • I would suggest not freezing anything first (on a very small model with a single gpu) to make sure the algo is right

@@ -0,0 +1,22 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll probably want to put this under configs when finished

@@ -0,0 +1,22 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here (move under configs)

examples/reward_model.py Outdated Show resolved Hide resolved
@@ -0,0 +1,110 @@
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually we'll want to put this dataset onto huggingface

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I used this dataset before they were public to hf, but that was for the RLHF blog post not for this PR.

@LouisCastricato LouisCastricato mentioned this pull request Dec 31, 2022
@PhungVanDuy PhungVanDuy marked this pull request as ready for review January 2, 2023 02:34
@PhungVanDuy PhungVanDuy changed the title [WIP] Add T5 model Add T5 model Jan 2, 2023
@PhungVanDuy
Copy link
Collaborator Author

PhungVanDuy commented Jan 2, 2023

@Dahoas @LouisCastricato , this is an example FlanT5 for the CNN-Dailymail dataset but other charts quite weird. Please check when you have time. https://wandb.ai/pvduy/trlx/runs/8q3skf8p

@PhungVanDuy PhungVanDuy marked this pull request as draft January 2, 2023 03:43
@PhungVanDuy PhungVanDuy marked this pull request as ready for review January 2, 2023 03:55
@jon-tow jon-tow added this to the v0.4.0 milestone Jan 2, 2023
Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks awesome! I've left some feedback to be addressed. Very excited about this 👍

trlx/data/configs.py Outdated Show resolved Hide resolved
trlx/trainer/accelerate_base_trainer.py Outdated Show resolved Hide resolved
trlx/utils/modeling.py Outdated Show resolved Hide resolved
trlx/orchestrator/ppo_orchestrator.py Outdated Show resolved Hide resolved
trlx/trainer/nn/ppo_models.py Outdated Show resolved Hide resolved
trlx/trainer/accelerate_base_trainer.py Outdated Show resolved Hide resolved
trlx/pipeline/offline_pipeline.py Outdated Show resolved Hide resolved
examples/ppo_config_cnn_daily.yml Outdated Show resolved Hide resolved
trlx/orchestrator/ppo_orchestrator.py Outdated Show resolved Hide resolved
@PhungVanDuy
Copy link
Collaborator Author

Looks awesome! I've left some feedback to be addressed. Very excited about this 👍

Thank you for your great comment, I will follow up and fix that.

@PhungVanDuy
Copy link
Collaborator Author

PhungVanDuy commented Jan 7, 2023

Fixed PPO for T5 (https://wandb.ai/pvduy/trlx/runs/1n31fb6a). The fix for GPT-J still running on the OpenAI summarization dataset to check. Please review this @reciprocated @LouisCastricato @Dahoas

examples/t5_sentiment_train_lm.py Outdated Show resolved Hide resolved
trlx/trainer/accelerate_base_trainer.py Outdated Show resolved Hide resolved
examples/ppo_sentiments.py Show resolved Hide resolved

meteor = evaluate.load("meteor")

if __name__ == "__main__":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This this is an example we probably should have lots of comments.

@@ -40,6 +40,7 @@ def __init__(

if not hasattr(self.trainer.model, "frozen_head"):
self.ref_model = self.trainer.get_arch(self.trainer.config)
self.ref_model.to(self.trainer.accelerator.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we verified this works? I recall accelerate freezing up if I started putting multiple models on gpu(though this could've just been the sentiment pipeline we were using for sentiments task)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

@PhungVanDuy PhungVanDuy Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do agree with the larger model we should distribute multiple models on multiple gpus, but for this, I think we should keep it on GPU rather than CPU, they are super slow.

@PhungVanDuy
Copy link
Collaborator Author

@PhungVanDuy If the best-of-n sampling is not required for T5 support, I think it's best to create a separate PR for it to allow for proper testing (I found some basic issues in the previous commits). Then we could better review that without further complicating this PR.

You are right @jon-tow , I am removing from this PR we can consider merging this PR today, some bugs with the current main branch should be fixed by this PR. cc @LouisCastricato

Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving some tiny final comments and change requests 🙏

trlx/data/configs.py Show resolved Hide resolved
trlx/utils/modeling.py Outdated Show resolved Hide resolved
trlx/utils/modeling.py Outdated Show resolved Hide resolved
trlx/trainer/nn/ppo_models.py Show resolved Hide resolved
rs[-1] = scores[ix]
rs = rewards[ix]
if len(rs) == 0:
rs = torch.tensor([0.0])
Copy link
Collaborator

@maxreciprocate maxreciprocate Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we penalize empty responses? Also do you know how it's possible to have those, except for when max_new_tokens == 0 🤔

Copy link
Collaborator Author

@PhungVanDuy PhungVanDuy Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a complete exception case, but I got it a few times when I ran PPO sentiments. @jon-tow you also faced with this right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah Jon has said that he also experienced it. I wonder if the case of it is unknown it may be a symptom of some other bug elsewhere

Copy link
Collaborator

@jon-tow jon-tow Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reciprocated the only case I can think of that'd lead to an empty response is when len(query) is larger than the generate method's min_length arg, which defaults to 10, and the model so happens to output the eos_token on its first sample. (Note that with causal models the min_length constraint includes the length of the context (query) meaning it won't actually have an effect on the generations if the min condition is already met by the context size).

In such cases, I'm okay with penalizing empty responses as they're uninformative - so long as this is not a bug lol

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think it is a serious issue. We can just throw an error if this min_length thing comes up. I've never seen this in practice when I set min length correctly. (Perhaps we should add an extra parameter called min_new_length...? We should upstream to HF transformers though)

@LouisCastricato LouisCastricato merged commit 0c5246f into CarperAI:main Jan 9, 2023
Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving some notes for myself to address in the future.

Comment on lines -108 to -110
all_tokens, attention_mask, position_ids = self.trainer.get_model_inputs(
query_tensors.to(response_tensors.device), response_tensors
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This removed the need for:

def get_model_inputs(

We need to remove it if unused before it becomes stale.

Comment on lines +211 to +214
logprobs = logprobs.cpu()
ref_logprobs = ref_logprobs.cpu()
query_tensors = query_tensors.cpu()
response_tensors = response_tensors.cpu()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these lines - these vars are already put on cpu on the lines right before the if-statement

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants