-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix model to save in ppov2 #1776
Conversation
currently saving self.backup_model but this should be self.model self.backup_model is only a temp model used to store the policy and value function whereas self.model should have just the policy to save
trl/trainer/ppov2_trainer.py
Outdated
@@ -220,7 +220,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa | |||
self.model = self.accelerator.unwrap_model(self.model).policy # save only the policy | |||
if output_dir is None: | |||
output_dir = self.args.output_dir | |||
state_dict = self.accelerator.get_state_dict(self.backup_model) | |||
state_dict = self.accelerator.get_state_dict(self.model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is prob incorrect. There are two scenarios:
- We call
trainer.save_model
, in which case,if not _internal_call
gets triggered, andself.model
becomes the policy - we call
trainer.push_to_hub
, in which case,push_to_hub
sets theself.model
to be the policy, andsuper().push_to_hub(**kwargs)
callssave_model(..., _internal_call=True)
, and in that caseself.model
is still the policy.
It's a bit unfortunate that the logic is a bit convoluted... 🫠
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But how is self.backup_model
set in the case when _internal_call
. I don't see it being set and I'm getting an error
I looked into it a bit more and don't think there's any need to have a separate |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
You can try with: from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
def main():
config = PPOv2Config(output_dir="tmp")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1b-deduped", padding_side="left")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
value_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-1b-deduped", num_labels=1)
reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-1b-deduped", num_labels=1)
ref_policy = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-1b-deduped")
policy = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-1b-deduped")
raw_datasets = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness")
train_dataset = raw_datasets.select(range(50))
def tokenize(element):
outputs = tokenizer(element["prompt"], padding=False)
return {"input_ids": outputs["input_ids"]}
train_dataset = train_dataset.map(tokenize, batched=True, remove_columns=train_dataset.column_names)
trainer = PPOv2Trainer(
config=config,
tokenizer=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
value_model=value_model,
train_dataset=train_dataset,
)
trainer.save_model(config.output_dir)
trainer.push_to_hub()
if __name__ == "__main__":
main() |
lgtm thanks @mnoukhov! |
currently saving self.backup_model but this should be self.model
self.backup_model is only a temp model used to store the policy and value function whereas self.model should have just the policy to save
@vwxyzjn let me know if I'm off base