Skip to content

Commit 5724b9e

Browse files
committed
add some comments
1 parent 388e043 commit 5724b9e

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

extensions/csrc/cuda/activation_kernel.cu

+11-6
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,30 @@ __global__ void act_and_mul_kernel(
3737
// silu(x[:half_1stdim]) * (x[half_1stdim:])
3838
torch::Tensor silu_and_mul(const torch::Tensor& ins)
3939
{
40+
// Note(LiuYang): According to torch doc, vec() may cost a lot, but I did't find a better api
41+
// to manipulate ins_shape which is IntArrayRef
4042
auto ins_shape = ins.sizes().vec();
4143

4244
ins_shape[0] = ins_shape[0]/2;
4345
if (ins_shape[0] == 1) {
4446
ins_shape.erase(ins_shape.begin());
4547
}
4648
auto outs = torch::zeros(ins_shape,ins.options());
47-
auto outs_shape = ins.sizes().vec();
4849

4950
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
5051

5152
// Note(Liuyang): numel of ins must be divisible by 2
5253
int64_t numel = ((torch::numel(ins)) >> 1);
5354

54-
// TODO(LiuYang): Maybe we need to implement a function to get launch config
55-
colossalAI::cuda::utils::NVGPUDevInfo dev_info(0);
56-
auto config = colossalAI::cuda::utils::GetGPULaunchConfig1D(dev_info,numel,1);
57-
dim3 grid = config.grid;
58-
dim3 block = config.block;
55+
// Note(LiuYang): For better performance for special case of which input is [2, 64, 11008], now
56+
// I comment this part code,because it also cost a little time to calculate a better config
57+
// colossalAI::cuda::utils::NVGPUDevInfo dev_info(0);
58+
// auto config = colossalAI::cuda::utils::GetGPULaunchConfig1D(dev_info,numel,1);
59+
// dim3 grid = config.grid;
60+
// dim3 block = config.block;
61+
62+
dim3 grid((numel+255)/256);
63+
dim3 block(256);
5964

6065
DISPATCH_FLOAT_HALF_AND_BFLOAT(
6166
ins.scalar_type(),

0 commit comments

Comments
 (0)