diff --git a/xformers/csrc/swiglu/cuda/dual_gemm_silu_identity_mul.cu b/xformers/csrc/swiglu/cuda/dual_gemm_silu_identity_mul.cu index 5d91807b1a..f746d7cc7a 100644 --- a/xformers/csrc/swiglu/cuda/dual_gemm_silu_identity_mul.cu +++ b/xformers/csrc/swiglu/cuda/dual_gemm_silu_identity_mul.cu @@ -144,7 +144,9 @@ std::tuple dual_gemm_silu_identity_mul_( x.options().dtype(at::ScalarType::Byte)); cutlass::Status status = dual_gemm.can_implement(arguments); TORCH_CHECK( - status == cutlass::Status::kSuccess, "not supported by this kernel"); + status == cutlass::Status::kSuccess, + "`dual_gemm_silu_identity_mul` does not support this input: ", + cutlass::cutlassGetStatusString(status)); status = dual_gemm.initialize(arguments, (uint8_t*)workspace.data_ptr()); TORCH_CHECK(status == cutlass::Status::kSuccess, "kernel initialize failed"); status = dual_gemm(stream); diff --git a/xformers/csrc/swiglu/cuda/gemm_fused_operand_sum.cu b/xformers/csrc/swiglu/cuda/gemm_fused_operand_sum.cu index d9d61ec30e..896e3feba5 100644 --- a/xformers/csrc/swiglu/cuda/gemm_fused_operand_sum.cu +++ b/xformers/csrc/swiglu/cuda/gemm_fused_operand_sum.cu @@ -186,7 +186,9 @@ void gemm_fused_operand_sum_( a.options().dtype(at::ScalarType::Byte)); cutlass::Status status = gemm_op.can_implement(arguments); TORCH_CHECK( - status == cutlass::Status::kSuccess, "not supported by this kernel"); + status == cutlass::Status::kSuccess, + "`gemm_fused_operand_sum` does not support this input: ", + cutlass::cutlassGetStatusString(status)); status = gemm_op.initialize(arguments, (uint8_t*)workspace.data_ptr()); TORCH_CHECK(status == cutlass::Status::kSuccess, "kernel initialize failed"); status = gemm_op(stream);