From f1e1ba6c1fa578101a73b1baa585761c39ba8c6b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 5 Mar 2024 23:22:20 -0500 Subject: [PATCH] pt: non-blocking copying training data to device Signed-off-by: Jinzhe Zeng --- deepmd/pt/train/training.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index e5a7632ac4..5889b943bb 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -954,9 +954,11 @@ def get_data(self, is_train=True, task_key="Default"): continue elif not isinstance(batch_data[key], list): if batch_data[key] is not None: - batch_data[key] = batch_data[key].to(DEVICE) + batch_data[key] = batch_data[key].to(DEVICE, non_blocking=True) else: - batch_data[key] = [item.to(DEVICE) for item in batch_data[key]] + batch_data[key] = [ + item.to(DEVICE, non_blocking=True) for item in batch_data[key] + ] # we may need a better way to classify which are inputs and which are labels # now wrapper only supports the following inputs: input_keys = [