diff --git a/dl_translate/_translation_model.py b/dl_translate/_translation_model.py index 1e65b16..8803436 100644 --- a/dl_translate/_translation_model.py +++ b/dl_translate/_translation_model.py @@ -177,8 +177,11 @@ def translate( data_loader = torch.utils.data.DataLoader(text, batch_size=batch_size) output_text = [] + tqdm_iterator = data_loader + if verbose is True: + tqdm_iterator = tqdm with torch.no_grad(): - for batch in tqdm(data_loader, disable=not verbose): + for batch in tqdm_iterator: encoded = self._tokenizer(batch, return_tensors="pt", padding=True) encoded.to(self.device)