diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 87d7fa5faecac..941025b36c0ac 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,6 +21,7 @@ import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer +import numpy from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -78,6 +79,7 @@ def distributed_sampler_kwargs(self): def setup(self, model): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" # pass in a state q smp = mp.get_context("spawn")