Skip to content

Commit b6e9785

Browse files
authored
Merge pull request #5457 from Courtesy-Xs/ly_add_implementation_for_launch_config
add implementatino for GetGPULaunchConfig1D
2 parents f366a5e + 5724b9e commit b6e9785

File tree

7 files changed

+110
-100
lines changed

7 files changed

+110
-100
lines changed

extensions/csrc/common/dev_info_mgr.h

-20
This file was deleted.

extensions/csrc/common/target.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class Target {
105105
static Target DefaultAscendTarget();
106106

107107
static Target DefaultCUDATarget() {
108-
return Target(OS::Linux, Arch::CUDA, BitLen::k64);
108+
return Target(OS::Linux, Arch::NVGPU, BitLen::k64);
109109
}
110110

111111
friend std::ostream& operator<<(std::ostream& os, const Target& target);

extensions/csrc/cuda/activation_kernel.cu

+10-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "../common/micros.h"
66
#include "../common/mp_type_traits.h"
7+
#include "utils/gpu_launch_config.h"
78

89
template<typename T>
910
__device__ __forceinline__ T silu_kernel(const T& x) {
@@ -36,21 +37,28 @@ __global__ void act_and_mul_kernel(
3637
// silu(x[:half_1stdim]) * (x[half_1stdim:])
3738
torch::Tensor silu_and_mul(const torch::Tensor& ins)
3839
{
40+
// Note(LiuYang): According to torch doc, vec() may cost a lot, but I did't find a better api
41+
// to manipulate ins_shape which is IntArrayRef
3942
auto ins_shape = ins.sizes().vec();
4043

4144
ins_shape[0] = ins_shape[0]/2;
4245
if (ins_shape[0] == 1) {
4346
ins_shape.erase(ins_shape.begin());
4447
}
4548
auto outs = torch::zeros(ins_shape,ins.options());
46-
auto outs_shape = ins.sizes().vec();
4749

4850
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
4951

5052
// Note(Liuyang): numel of ins must be divisible by 2
5153
int64_t numel = ((torch::numel(ins)) >> 1);
5254

53-
// TODO(LiuYang): Maybe we need to implement a function to get launch config
55+
// Note(LiuYang): For better performance for special case of which input is [2, 64, 11008], now
56+
// I comment this part code,because it also cost a little time to calculate a better config
57+
// colossalAI::cuda::utils::NVGPUDevInfo dev_info(0);
58+
// auto config = colossalAI::cuda::utils::GetGPULaunchConfig1D(dev_info,numel,1);
59+
// dim3 grid = config.grid;
60+
// dim3 block = config.block;
61+
5462
dim3 grid((numel+255)/256);
5563
dim3 block(256);
5664

extensions/csrc/cuda/utils/gpu_launch_config.h

+59-17
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,74 @@
33
#include <cuda.h>
44
#include <cuda_runtime.h>
55

6+
#include "nvgpu_dev_info.h"
7+
68
namespace colossalAI {
79
namespace cuda {
810
namespace utils {
911

10-
GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size);
12+
struct GPULaunchConfig {
13+
dim3 block{1, 1, 1};
14+
dim3 grid{1, 1, 1};
15+
};
16+
17+
static GPULaunchConfig GetGPULaunchConfig1D(const NVGPUDevInfo& dev_info,
18+
int64_t numel, int64_t vec_size) {
19+
const int64_t max_threads_per_block = dev_info.GetMaxThreadsPerBlock();
20+
const int64_t max_blocks_per_grid = dev_info.GetMaxGridDims()[0];
21+
const int64_t kMinimumSize = 64;
22+
const int64_t kMaximumSize = 512;
23+
int64_t active_threads = (numel + vec_size - 1) / vec_size;
24+
int64_t sm_num = dev_info.GetMultiProcessorCount();
25+
26+
// Note(LiuYang): expected threads should be in [64, 128, 256, 512] generally
27+
int64_t expected_threads_per_block = kMaximumSize;
1128

12-
// TODO(LiuYang): to be implemented
13-
GPULaunchConfig GPUGetGPULaunchConfig2D(int64_t numel, int vec_size);
29+
auto RoundUpToPowerOfTwo = [](int64_t x) {
30+
bool is_power_of_two = false;
31+
int64_t ret = 1;
32+
int64_t y = x;
33+
while (y > 0) {
34+
is_power_of_two = ((ret ^ x) == 0);
35+
y = (x >> 1);
36+
ret = (ret << 1);
37+
if (y > 0) is_power_of_two = false;
38+
}
39+
if (is_power_of_two) return x;
40+
return ret;
41+
};
1442

15-
// TODO(LiuYang): to be implemented
16-
GPULaunchConfig GPUGetGPULaunchConfig3D(int64_t numel, int vec_size);
43+
if ((active_threads / (sm_num << 1)) < max_threads_per_block) {
44+
expected_threads_per_block =
45+
RoundUpToPowerOfTwo(active_threads / (sm_num << 1));
46+
} else if ((active_threads / (sm_num << 2)) < max_threads_per_block) {
47+
expected_threads_per_block =
48+
RoundUpToPowerOfTwo(active_threads / (sm_num << 2));
49+
}
1750

18-
class GPULaunchConfig {
19-
public:
20-
GPULaunchConfig(){};
21-
GPULaunchConfig(const dim3& block, const dim3& grid)
22-
: block_(block), grid_(grid) {}
23-
friend GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size);
51+
expected_threads_per_block =
52+
std::max(expected_threads_per_block, kMinimumSize);
53+
int64_t expect_block_per_grid =
54+
((active_threads + expected_threads_per_block - 1) /
55+
expected_threads_per_block);
2456

25-
protected:
26-
void set_block(const dim3& dim) { block_ = dim; }
27-
void set_grid(const dim3& dim) { grid_ = dim; }
57+
if (expect_block_per_grid > max_blocks_per_grid) {
58+
expect_block_per_grid = max_blocks_per_grid;
59+
expected_threads_per_block =
60+
(active_threads + expect_block_per_grid - 1) / expect_block_per_grid;
61+
if (expected_threads_per_block > max_threads_per_block)
62+
throw std::invalid_argument(
63+
"Threads required for current input exceed for current GPU!");
64+
expected_threads_per_block =
65+
RoundUpToPowerOfTwo(expected_threads_per_block);
66+
expect_block_per_grid = ((active_threads + expected_threads_per_block - 1) /
67+
expected_threads_per_block);
68+
}
2869

29-
private:
30-
dim3 block_(1, 1, 1);
31-
dim3 grid_(1, 1, 1);
70+
GPULaunchConfig config;
71+
config.block.x = expected_threads_per_block;
72+
config.grid.x = expect_block_per_grid;
73+
return config;
3274
}
3375

3476
} // namespace utils

extensions/csrc/cuda/utils/micros.h

+8-6
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
#include <cuda.h>
44
#include <cuda_runtime.h>
55

6-
#define CUDA_CHECK(func) \
7-
{ \
8-
auto status = func; \
9-
if (status != cudaSuccess) { \
10-
LOG(FATAL) << "CUDA Error : " << cudaGetErrorString(status); \
11-
} \
6+
#include <exception>
7+
8+
#define CUDA_CHECK(func) \
9+
{ \
10+
auto status = func; \
11+
if (status != cudaSuccess) { \
12+
throw std::runtime_error(cudaGetErrorString(status)); \
13+
} \
1214
}

extensions/csrc/cuda/utils/nvgpu_dev_info.cc

-45
This file was deleted.

extensions/csrc/cuda/utils/nvgpu_dev_info.h

+32-9
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include <vector>
99

1010
#include "micros.h"
11-
#include "target.h"
1211

1312
namespace colossalAI {
1413
namespace cuda {
@@ -17,19 +16,43 @@ namespace utils {
1716
class NVGPUDevInfo {
1817
public:
1918
explicit NVGPUDevInfo(int device_num) : device_num_(device_num) {
20-
CUDA_CALL(cudaGetDeviceProperties(prop_, device));
19+
CUDA_CHECK(cudaGetDeviceProperties(&prop_, device_num));
2120
}
2221

23-
std::array<int, 3> GetMaxGridDims() const;
24-
std::array<int, 3> GetMaxBlockDims() const;
25-
std::array<int, 2> GetCapability() const;
26-
int GetMultiProcessorCount() const;
27-
int GetMaxThreadsPerMultiProcessor() const;
28-
int GetMaxThreadsPerBlock() const;
22+
std::array<int, 3> GetMaxGridDims() const {
23+
std::array<int, 3> ret;
24+
ret[0] = prop_.maxGridSize[0];
25+
ret[1] = prop_.maxGridSize[1];
26+
ret[2] = prop_.maxGridSize[2];
27+
return ret;
28+
}
29+
30+
std::array<int, 3> GetMaxBlockDims() const {
31+
std::array<int, 3> ret;
32+
ret[0] = prop_.maxThreadsDim[0];
33+
ret[1] = prop_.maxThreadsDim[1];
34+
ret[2] = prop_.maxThreadsDim[2];
35+
return ret;
36+
}
37+
38+
std::array<int, 2> GetCapability() const {
39+
std::array<int, 2> ret;
40+
ret[0] = prop_.major;
41+
ret[1] = prop_.minor;
42+
return ret;
43+
}
44+
45+
int GetMultiProcessorCount() const { return prop_.multiProcessorCount; }
46+
47+
int GetMaxThreadsPerMultiProcessor() const {
48+
return prop_.maxThreadsPerMultiProcessor;
49+
}
50+
51+
int GetMaxThreadsPerBlock() const { return prop_.maxThreadsPerBlock; }
2952

3053
private:
3154
int device_num_;
32-
cudaDeviceProp* prop_;
55+
cudaDeviceProp prop_;
3356
};
3457

3558
} // namespace utils

0 commit comments

Comments
 (0)