Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CPU Adam JIT compilation #5780

Merged
merged 8 commits into from
Jul 31, 2024
23 changes: 15 additions & 8 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,25 +772,32 @@ def libraries_args(self):

class TorchCPUOpBuilder(CUDAOpBuilder):

def get_cuda_lib64_path(self):
import torch
if not self.is_rocm_pytorch():
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64")
if not os.path.exists(CUDA_LIB64):
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib")
else:
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib")
return CUDA_LIB64

def extra_ldflags(self):
if self.build_for_cpu:
return ['-fopenmp']

if not self.is_rocm_pytorch():
return ['-lcurand']
ld_flags = ['-lcurand']
if not self.build_for_cpu:
ld_flags.append(f'-L{self.get_cuda_lib64_path()}')
loadams marked this conversation as resolved.
Show resolved Hide resolved
return ld_flags

return []

def cxx_args(self):
import torch
args = []
if not self.build_for_cpu:
if not self.is_rocm_pytorch():
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64")
if not os.path.exists(CUDA_LIB64):
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib")
else:
CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib")
CUDA_LIB64 = self.get_cuda_lib64_path()

args += super().cxx_args()
args += [
Expand Down
Loading