diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 48bb359713977..7c8d0143b80c3 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -151,7 +151,7 @@ def training_step(self, batch, batch_idx): """ - +import copy import inspect from abc import ABC, abstractmethod import warnings @@ -586,7 +586,7 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens): gpu_id = 0 if isinstance(self.data_parallel_device_ids, list): gpu_id = self.data_parallel_device_ids[0] - batch = self.transfer_batch_to_gpu(batch.copy(), gpu_id) + batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id) args[0] = batch output = self.model.training_step(*args)