Skip to content

Commit

Permalink
[CI/Build] Suppress divide-by-zero and missing return statement warni…
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrmchlsmth authored Aug 5, 2024
1 parent 8571ac4 commit 6e4852c
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 8 deletions.
8 changes: 8 additions & 0 deletions csrc/attention/dtype_bfloat16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#else
return __bfloat1622float2(val);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
Expand All @@ -102,6 +103,7 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#else
return __bfloat162bfloat162(val);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

// Vector addition.
Expand All @@ -115,6 +117,7 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
return __hadd(a, b);
#endif
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
Expand All @@ -123,6 +126,7 @@ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
#else
return __hadd2(a, b);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
Expand Down Expand Up @@ -170,6 +174,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#else
return __hmul(a, b);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

template <>
Expand All @@ -179,6 +184,7 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#else
return __hmul2(a, b);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

template <>
Expand Down Expand Up @@ -289,6 +295,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
#else
return __hfma2(a, b, c);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
Expand All @@ -298,6 +305,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
#else
return __hfma2(bf162bf162(a), b, c);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
Expand Down
1 change: 1 addition & 0 deletions csrc/quantization/awq/dequantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {

return result;
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

} // namespace awq
Expand Down
5 changes: 3 additions & 2 deletions csrc/quantization/fp8/nvidia/quant_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
__NV_SATFINITE, fp8_type);
return (uint8_t)res;
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}

// float -> fp8
Expand Down Expand Up @@ -508,7 +509,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
}
#endif
assert(false);
return {}; // Squash missing return statement warning
__builtin_unreachable(); // Suppress missing return statement warning
}

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
Expand All @@ -521,7 +522,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
#endif
assert(false);
return {}; // Squash missing return statement warning
__builtin_unreachable(); // Suppress missing return statement warning
}

// The following macro is used to dispatch the conversion function based on
Expand Down
18 changes: 12 additions & 6 deletions csrc/quantization/gptq_marlin/gptq_marlin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1130,12 +1130,12 @@ __global__ void Marlin(
};

auto fetch_zp_to_registers = [&](int k, int full_pipe) {
if constexpr (has_zp) {
// This code does not handle group_blocks == 0,
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert(group_blocks != 0);
// This code does not handle group_blocks == 0,
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert(!has_zp || group_blocks != 0);

if constexpr (has_zp) {
int pipe = full_pipe % stages;

if constexpr (group_blocks == -1) {
Expand All @@ -1161,7 +1161,13 @@ __global__ void Marlin(
cur_k += k_iter_size * (k % b_sh_wr_iters);

int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / group_blocks;
int cur_group_id = 0;

// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop

int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;

Expand Down

0 comments on commit 6e4852c

Please sign in to comment.