diff --git a/deepmd/utils/batch_size.py b/deepmd/utils/batch_size.py index 2b3117d849..fe876a65a5 100644 --- a/deepmd/utils/batch_size.py +++ b/deepmd/utils/batch_size.py @@ -7,8 +7,12 @@ ) import numpy as np +from packaging.version import ( + Version, +) from deepmd.env import ( + TF_VERSION, tf, ) from deepmd.utils.errors import ( @@ -59,7 +63,10 @@ def __init__(self, initial_batch_size: int = 1024, factor: float = 2.0) -> None: self.minimal_not_working_batch_size = self.maximum_working_batch_size + 1 else: self.maximum_working_batch_size = initial_batch_size - if tf.test.is_gpu_available(): + if ( + Version(TF_VERSION) >= Version("1.14") + and tf.config.experimental.get_visible_devices("GPU") + ) or tf.test.is_gpu_available(): self.minimal_not_working_batch_size = 2**31 else: self.minimal_not_working_batch_size = (