diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 5dae47cddb..15fa297ff0 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1613,14 +1613,24 @@ def get_data_loader( # get samplers sampler = self.get_sampler(config, dataset, num_gpus) - - loader = DataLoader( - dataset, - batch_sampler=sampler, - collate_fn=dataset.collate_fn, - num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, - pin_memory=False, - ) + if sampler is None: + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + collate_fn=dataset.collate_fn, + drop_last=False, # setting this False might cause issues in AMP training. + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + else: + loader = DataLoader( + dataset, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) return loader def get_optimizer(self) -> List: