From f01c27f65a217d5d087f09534d7b022083db4ed5 Mon Sep 17 00:00:00 2001 From: Rahul Batra Date: Wed, 31 Jan 2024 16:51:59 +0000 Subject: [PATCH] [ROCm]: Add ROCm command buffer support for triton kernel --- jaxlib/gpu/triton_kernels.cc | 8 ++++---- jaxlib/gpu/vendor.h | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index 88d451a81812..f57e69b168f9 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -524,16 +524,16 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const { gpustreamCaptureStatus_t capture_status; GPU_RETURN_IF_ERROR(gpuStreamIsCapturing(stream, &capture_status)); - bool is_capturing = capture_status == CU_STREAM_CAPTURE_STATUS_ACTIVE; + bool is_capturing = capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE; - gpustreamCaptureMode_t capture_mode = CU_STREAM_CAPTURE_MODE_RELAXED; + gpustreamCaptureMode_t capture_mode = GPU_STREAM_CAPTURE_MODE_RELAXED; gpuStream_t autotune_stream = stream; if (is_capturing) { + GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode)); // Need a side stream so as not to interfere with graph capture. - GPU_RETURN_IF_ERROR( - gpuStreamCreate(&autotune_stream, CU_STREAM_NON_BLOCKING)); + GPU_RETURN_IF_ERROR(gpuStreamCreate(&autotune_stream, GPU_STREAM_NON_BLOCKING)); } // If an input aliases with an output, it will get overwritten during the diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 62ef6c4cdfa9..fde0e4bb92b2 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -254,6 +254,10 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT CUSPARSE_SPARSETODENSE_ALG_DEFAULT #define GPUSPARSE_STATUS_SUCCESS CUSPARSE_STATUS_SUCCESS +#define GPU_STREAM_CAPTURE_STATUS_ACTIVE CU_STREAM_CAPTURE_STATUS_ACTIVE +#define GPU_STREAM_CAPTURE_MODE_RELAXED CU_STREAM_CAPTURE_MODE_RELAXED +#define GPU_STREAM_NON_BLOCKING CU_STREAM_NON_BLOCKING + #define gpuCtxGetDevice cuCtxGetDevice #define gpuCtxPopCurrent cuCtxPopCurrent #define gpuCtxPushCurrent cuCtxPushCurrent @@ -332,6 +336,8 @@ typedef hipsolverFillMode_t gpusolverFillMode_t; typedef hipblasHandle_t gpublasHandle_t; typedef hipblasStatus_t gpublasStatus_t; typedef hipCtx_t gpuContext_t; +typedef hipStreamCaptureMode gpustreamCaptureMode_t; +typedef hipStreamCaptureStatus gpustreamCaptureStatus_t; typedef hipDataType gpuDataType; typedef hipDevice_t gpuDevice_t; typedef hipDeviceptr_t gpuDevicePtr_t; @@ -494,6 +500,10 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT HIPSPARSE_SPARSETODENSE_ALG_DEFAULT #define GPUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS +#define GPU_STREAM_CAPTURE_STATUS_ACTIVE hipStreamCaptureStatusActive +#define GPU_STREAM_CAPTURE_MODE_RELAXED hipStreamCaptureModeRelaxed +#define GPU_STREAM_NON_BLOCKING hipStreamNonBlocking + #define gpuGetLastError hipGetLastError #define gpuGetErrorString hipGetErrorString #define gpuMemcpyAsync hipMemcpyAsync @@ -526,6 +536,10 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpuMemcpyDtoHAsync hipMemcpyDtoHAsync #define gpuMemcpyHtoDAsync hipMemcpyHtoDAsync #define gpuMemsetD8Async hipMemsetD8Async +#define gpuThreadExchangeStreamCaptureMode hipThreadExchangeStreamCaptureMode +#define gpuStreamCreate hipStreamCreateWithFlags +#define gpuStreamDestroy hipStreamDestroy +#define gpuStreamIsCapturing hipStreamIsCapturing #define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR \ hipDeviceAttributeComputeCapabilityMajor