diff --git a/op_builder/builder.py b/op_builder/builder.py index 9747261b5d66..dda6472baccb 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -700,14 +700,25 @@ def cxx_args(self): import torch args = [] if not self.build_for_cpu: + args += super().cxx_args() + if not self.is_rocm_pytorch(): CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64") + CUDA_LIB = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib") else: CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib") + CUDA_LIB = None - args += super().cxx_args() args += [ f'-L{CUDA_LIB64}', + ] + + if CUDA_LIB is not None: + args += [ + f'-L{CUDA_LIB}', + ] + + args += [ '-lcudart', '-lcublas', '-g',