Skip to content

Commit

Permalink
Add num_dataloader_workers arg to dreambooth script (#1107)
Browse files Browse the repository at this point in the history
This is especially important for Windows users, who may have to set the
number of workers to 0.
  • Loading branch information
lukaskuhn-lku authored Nov 10, 2023
1 parent 3af469e commit 49ddefa
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
3 changes: 3 additions & 0 deletions docs/source/task_guides/dreambooth_lora.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ accelerate launch train_dreambooth.py \
--output_dir=$OUTPUT_DIR \
--train_text_encoder \
--with_prior_preservation --prior_loss_weight=1.0 \
--num_dataloader_workers=1 \
--instance_prompt="a photo of sks dog" \
--class_prompt="a photo of dog" \
--resolution=512 \
Expand All @@ -101,6 +102,8 @@ accelerate launch train_dreambooth.py \
--max_train_steps=800
```

If you are running this script on Windows, you may need to set the `--num_dataloader_workers` to 0.

## Inference with a single adapter

To run inference with the fine-tuned model, first specify the base model with which the fine-tuned LoRA weights will be combined:
Expand Down
6 changes: 5 additions & 1 deletion examples/lora_dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ def parse_args(input_args=None):
help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True",
)

parser.add_argument(
"--num_dataloader_workers", type=int, default=1, help="Num of workers for the training dataloader."
)

parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
Expand Down Expand Up @@ -799,7 +803,7 @@ def main(args):
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=1,
num_workers=args.num_dataloader_workers,
)

# Scheduler and math around the number of training steps.
Expand Down

0 comments on commit 49ddefa

Please sign in to comment.