Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for building CUDA extension on Windows #396

Merged
merged 7 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/benchmark_fp6_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@


def benchmark(m: int, k: int, n: int):
fp6_weight = torch.randint(256, size=(n, k // 4 * 3), dtype=torch.uint8, device="cuda")
fp6_weight = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda")
scales = torch.rand(n, dtype=torch.half, device="cuda") + 0.5
fp6_linear = Fp6LlmLinear(fp6_weight.view(torch.int32), scales)
fp6_linear = Fp6LlmLinear(fp6_weight, scales)

fp16_linear = nn.Linear(k, n, bias=True, dtype=torch.half, device="cuda")
fp16_linear.weight.data = from_tc_float6_e3m2(fp6_weight.view(-1), n, k, dtype=torch.half) * scales[:, None]
fp16_linear.weight.data = from_tc_float6_e3m2(fp6_weight, dtype=torch.half) * scales[:, None]

fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
fp6_output = fp6_linear(fp16_act)
Expand Down
50 changes: 36 additions & 14 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def read_version(file_path="version.txt"):
CUDAExtension,
BuildExtension,
CUDA_HOME,
IS_WINDOWS
)


Expand All @@ -52,20 +53,41 @@ def get_extensions():
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
extension = CUDAExtension if use_cuda else CppExtension

extra_link_args = []
extra_compile_args = {
"cxx": [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
]
}
if debug_mode:
extra_compile_args["cxx"].append("-g")
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])
if not IS_WINDOWS:
extra_link_args = []
extra_compile_args = {
"cxx": [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0",
]
}

if debug_mode:
extra_compile_args["cxx"].append("-g")
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])

else:
extra_link_args = []
extra_compile_args = {
"cxx": [
"/O2" if not debug_mode else "/Od",
"/permissive-"
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0",
]
}

if debug_mode:
extra_compile_args["cxx"].append("/ZI")
extra_compile_args["nvcc"].append("-g")
extra_link_args.append("/DEBUG")

this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
Expand Down
9 changes: 5 additions & 4 deletions torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_matmul.cuh
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_matmul.cuh

#include "configs.h"
#include "utils_gmem.cuh"
Expand Down Expand Up @@ -133,11 +133,12 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
uint32_t* __restrict__ write_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
uint32_t* __restrict__ write_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
// Trible-Buffer for B Tile
half __restrict__ (*read_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
// MODIFICATION NOTE: to support MSVC, half __restrict__ (*read_SPTR ) is changed to below. similarly for read2_SPTR and write_SPTR.
half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
#ifdef PIPELINE_LEVEL_SMEM
half __restrict__ (*read2_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
half (* __restrict__ read2_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
#endif
half __restrict__ (*write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
half (* __restrict__ write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
//
bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter;
// Copying A tile from Global to Register, Bypassing L1, using double-buffer
Expand Down
15 changes: 10 additions & 5 deletions torchao/csrc/cuda/fp6_llm/ptx_mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_mma.cuh
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_mma.cuh

/***************************************************************************
* Copyright 2023 The FLash-LLM Authors. All rights reserved.
Expand All @@ -36,11 +36,14 @@
#include <assert.h>
#include "configs.h"

// MODIFICATION NOTE: to support MSVC
// - uint32_t __restrict__ Reg[][4] is changed to uint32_t (* __restrict__ Reg)[4]
// - half __restrict__ (*read_SPTR) is changed to half (* __restrict__ read_SPTR)
#ifdef PIPELINE_LEVEL_SMEM
template <typename TilingConfig>
__device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4],
half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
int slice_id) {
__device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[4],
half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
int slice_id) {
#ifdef DEBUG_MODE
static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) );
#endif
Expand Down Expand Up @@ -112,8 +115,10 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[
}
#endif

// MODIFICATION NOTE: to support MSVC, the function signature is changed from
// MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b).
__device__ __forceinline__ void
MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b)
MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b)
{
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{ %0, %1, %2, %3},"
Expand Down
8 changes: 5 additions & 3 deletions torchao/csrc/cuda/fp6_llm/utils_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_core.cuh
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_core.cuh

#ifndef UTILS_CORE_CUH
#define UTILS_CORE_CUH
Expand All @@ -35,12 +35,13 @@ __device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], u
}
}

// MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
template <typename TilingConfig>
__device__ __forceinline__ void initialize_mma_slice(uint32_t (*a)[4],
uint32_t (*b)[4],
uint32_t* __restrict__ A1_SPTR_read,
uint32_t* __restrict__ A2_SPTR_read,
half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
uint32_t* RPTR_Scales)
{
// Writing registers
Expand All @@ -53,13 +54,14 @@ __device__ __forceinline__ void initialize_mma_slice(uint32_t (
B_FromSharedToReg<TilingConfig>(b, B_SPTR_read, 0); // Loading B from shared to registers
}

// MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
template <typename TilingConfig>
__device__ __forceinline__ void core_mma_slice(float c[][REG_PER_THREAD_C_TENSOR_16_16],
uint32_t (*a)[4],
uint32_t (*b)[4],
uint32_t* __restrict__ A1_SPTR_read,
uint32_t* __restrict__ A2_SPTR_read,
half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
uint32_t* RPTR_Scales,
int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching
{
Expand Down
13 changes: 7 additions & 6 deletions torchao/csrc/cuda/fp6_llm/utils_gmem.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh

#ifndef UTILS_GMEM_CUH
#define UTILS_GMEM_CUH
Expand Down Expand Up @@ -57,17 +57,18 @@ __device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantSc
for(int i=0; i<2; i++) SPTR_QuantScales[Offset_Shared+i] = GPTR_A_Scales[Offset_Global+i*8];
}

// MODIFICATION NOTE: to support MSVC, half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
/*
* (1) Copying X rows * 64 columns of FP16 values, originally in row major
* (2) Copying 64 rows * X columns of FP16 values, originally in column major
* 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads
*/
template<int MaxNumOfLinesToCopy, int BLOCK_WARPS>
__device__ __forceinline__ void CopyFromGlobalToShared(half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
const half* GlobalPTR,
const int GlobalStride,
const int NumOfLinesLeft, // To support arbitrary N dimensions.
bool Pred = true) {
__device__ __forceinline__ void CopyFromGlobalToShared(half (* __restrict__ SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
const half* GlobalPTR,
const int GlobalStride,
const int NumOfLinesLeft, // To support arbitrary N dimensions.
bool Pred = true) {
// static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time
const int NumOfThreads = BLOCK_WARPS * WARP_SIZE;
const int NumOfGroups = NumOfThreads / 8;
Expand Down
30 changes: 17 additions & 13 deletions torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_parallel_dequant.cuh
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_parallel_dequant.cuh
// To support MSVC, all instances of u_int32_t are changed to uint32_t.

#ifndef UTILS_PARALLELDEQUANT_CUH
#define UTILS_PARALLELDEQUANT_CUH
Expand All @@ -26,7 +27,7 @@
* Outputs: R1, R2
* Note: Simplified Exponent calculation is applied.
*/
__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2) {
__device__ __forceinline__ void FP6_FP16_Cast_4Way(uint32_t *R1, uint32_t *R2) {
*R2 = *R1 & 0x80808080;
*R1 = *R1 >> 2;
*R1 = *R1 & 0x1f1f1f1f;
Expand All @@ -41,7 +42,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2)
* Outputs: R1, R2
* Note: Simplified Exponent calculation is NOT applied.
*/
__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_t *R2) {
__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(uint32_t *R1, uint32_t *R2) {
//*R2 = *R1 & 0x80808080;
*R2 = *R1 & 0xc0c0c0c0;
*R1 = *R1 >> 2;
Expand All @@ -63,7 +64,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_
//*R2 = 0x3c003c00;
}

__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) {
__device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale) {
half* FP16_1 = reinterpret_cast<half*>(&PackedFP16Pair);
half* FP16_2 = FP16_1 + 1;
uint32_t output;
Expand All @@ -73,16 +74,19 @@ __device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Sc
return output;
}

__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4],
u_int32_t __restrict__ *read_RPTR_Frag1,
u_int32_t __restrict__ *read_RPTR_Frag2,
u_int32_t *Scales) {
u_int32_t *OutputRegs = reinterpret_cast<u_int32_t*> (Reg);
u_int32_t *Frag1_PTR = read_RPTR_Frag1;
u_int32_t *Frag2_PTR = read_RPTR_Frag2;
// MODIFICATION NOTE: to support MSVC
// - u_int32_t __restrict__ Reg[][4] is changed to below.
// - u_int32_t __restrict__ *read_RPTR_Frag1 is changed to below. similarly for read_RPTR_Frag2
__device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg)[4],
uint32_t * __restrict__ read_RPTR_Frag1,
uint32_t * __restrict__ read_RPTR_Frag2,
uint32_t * Scales) {
uint32_t *OutputRegs = reinterpret_cast<uint32_t*> (Reg);
uint32_t *Frag1_PTR = read_RPTR_Frag1;
uint32_t *Frag2_PTR = read_RPTR_Frag2;
half *Scale_RPTR = reinterpret_cast<half*>(Scales);
u_int32_t Packed_FP6 = 0;
u_int32_t tmp = 0;
uint32_t Packed_FP6 = 0;
uint32_t tmp = 0;
// Dequantizing 32 FP6, each Loop dequantizing 4 FP6
#pragma unroll(8)
for(int i=0; i<8; i++) {
Expand Down
Loading