Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SFTTrainer add support for IterableDataset #1890

Closed

Conversation

helloworld1
Copy link

On version trl==0.9.4, SFTTrainer only support Dataset. However it prevent streaming large dataset.
This change added IterableDataset so large dataset can be streamed.
The key is that

dataset = datasets.IterableDataset.from_generator(get_training_data(custom_args.data_path), features=datasets.Features({"prompt": datasets.Value("string")}))

dataset_train, dataset_eval = dataset, dataset

trainer = trl.SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=dataset_train,
        eval_dataset=dataset_eval,
        dataset_text_field="prompt",
        max_seq_length=custom_args.max_seq_length,
        peft_config=peft_config,
        args=training_args
    )

It results into error:

Traceback (most recent call last):
  File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/training_hf_debug.py", line 134, in <module>
    main()
  File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/training_hf_debug.py", line 108, in main
    trainer = trl.SFTTrainer(
  File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/venv-wsl/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py", line 101, in inner_f
    return f(*args, **kwargs)
  File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/venv-wsl/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 362, in __init__
    train_dataset = self._prepare_dataset(
  File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/venv-wsl/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 508, in _prepare_dataset
    return self._prepare_non_packed_dataloader(
  File "/mnt/archroot/root/home/liberty/mp/pytorch-custom/venv-wsl/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 582, in _prepare_non_packed_dataloader
    tokenized_dataset = dataset.map(
TypeError: IterableDataset.map() got an unexpected keyword argument 'num_proc'

After this PR the training and eval can proceed. This fixed #1764

@qgallouedec qgallouedec marked this as a duplicate of #1889 Aug 5, 2024
@qgallouedec qgallouedec closed this Aug 5, 2024
@qgallouedec qgallouedec marked this as a duplicate of #1899 Aug 5, 2024
@qgallouedec
Copy link
Member

Duplicate of #1899

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Using IterableDataset crashed the SFTTrainer
2 participants