Skip to content

Commit

Permalink
Reduce compile args to Im2Col kernel and add perf script (#2308)
Browse files Browse the repository at this point in the history
* Reduce compile args to Im2Col kernel and add perf script

* reduce the scope of variable to avoid warning

* more unused variables
  • Loading branch information
JehandadKhan authored and junliume committed Aug 13, 2023
1 parent a440524 commit a8b4c4f
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 23 deletions.
41 changes: 26 additions & 15 deletions src/kernels/MIOpenIm2d2Col.cl
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,20 @@ kernel void Im2d2Col(const int data_size_off,
const int stride_w,
const int dilation_h,
const int dilation_w,
global data_t* col)
global data_t* col,
const int num_ch_per_wg,
const int num_im_blks_x,
const int num_im_blks,
const int tile_sz_x,
const int tile_sz_y)
{
#define THREADS_PER_CH (256 / NUM_CH_PER_WG)
/// NUM_CH_PER_WG {1;4}
/// THREADS_PER_CH {256; 64}
(void)num_ch_per_wg;
(void)num_im_blks_x;
(void)num_im_blks;
(void)tile_sz_x;
(void)tile_sz_y;

#if USE_IM_OFF_GUARD
#define IM_OFF_GUARD(idx) (idx) < data_size_off ? im_off[(idx)] : 0
Expand All @@ -147,18 +156,20 @@ kernel void Im2d2Col(const int data_size_off,
/// c * NUM_IM_BLKS => c * out_w * out_h
index_t gid = get_group_id(0);

#if NUM_IM_BLKS == 1 && STRIDE_GT_1 == 0
#if NUM_IM_BLKS_EQ_1 == 1 && STRIDE_GT_1 == 0
// This does not need to be a division and should be a right shift
const int threads_per_ch = 256 / num_ch_per_wg;

// Load image into LDS
/// max (LOCAL_MEM_SIZE) = 65536
local data_t local_im[LOCAL_MEM_SIZE];

/// witem_ch [0;4)
int witem_ch = lid / THREADS_PER_CH;
int witem_ch = lid / threads_per_ch;

int im_lid = lid;
/// h*w < LOCAL_MEM_SIZE/witem_ch
int gid_stride = NUM_CH_PER_WG * h * w;
int gid_stride = num_ch_per_wg * h * w;
while(im_lid < gid_stride)
{
/// gid = max(1, (c_pack / NUM_CH_PER_WG)) => c
Expand All @@ -176,11 +187,11 @@ kernel void Im2d2Col(const int data_size_off,
/// if (NUM_IM_BLKS == 1) => (out_h < 8 && out_w < 32)
/// => out_hw_stride < 256
int out_hw_stride = out_h * out_w;
if(lid % THREADS_PER_CH < out_hw_stride)
if(lid % threads_per_ch < out_hw_stride)
{
/// lid[0, 255] % THREADS_PER_CH {256; 64} =>
/// max(inner_lid)=255; max(out_x)=max(out_y)=255
int inner_lid = lid % THREADS_PER_CH;
int inner_lid = lid % threads_per_ch;
int out_x = inner_lid % out_w;
int out_y = inner_lid / out_w;

Expand All @@ -191,7 +202,7 @@ kernel void Im2d2Col(const int data_size_off,
/// EXTREME_LARGE==0
/// => wei_h * wei_w * type_size * NUM_CH_PER_WG < max (LOCAL_MEM_SIZE)
/// gid * out_hw_stride * LOCAL_MEM_SIZE => c * 256 * 65536
index_t col_y = ((index_t)gid * NUM_CH_PER_WG + witem_ch) * out_hw_stride * wei_h * wei_w;
index_t col_y = ((index_t)gid * num_ch_per_wg + witem_ch) * out_hw_stride * wei_h * wei_w;

for(int y = 0; y < wei_h; y++)
{
Expand All @@ -216,17 +227,17 @@ kernel void Im2d2Col(const int data_size_off,

local data_t local_im[LOCAL_MEM_SIZE];

int wg_ch = gid / NUM_IM_BLKS;
int wg_ch = gid / num_im_blks;
/// TILE_SZ_X = 32, TILE_SZ_Y = 8;
/// gid = c * NUM_IM_BLKS => im_x = NUM_IM_BLKS*TILE_SZ_X = NUM_IM_BLKS*32
/// = NUM_IM_BLKS*32 = out_w * out_h / 8
int im_x = ((gid % NUM_IM_BLKS) % NUM_IM_BLKS_X) * TILE_SZ_X; /// < out_w
int im_y = ((gid % NUM_IM_BLKS) / NUM_IM_BLKS_X) * TILE_SZ_Y; /// < out_h
int im_x = ((gid % num_im_blks) % num_im_blks_x) * tile_sz_x; /// < out_w
int im_y = ((gid % num_im_blks) / num_im_blks_x) * tile_sz_y; /// < out_h

int out_cols_wg = (im_x + TILE_SZ_X) <= out_w ? TILE_SZ_X : (out_w - im_x); /// < out_w
int out_rows_wg = (im_y + TILE_SZ_Y) <= out_h ? TILE_SZ_Y : (out_h - im_y); /// < out_h
int out_cols_wg = (im_x + tile_sz_x) <= out_w ? tile_sz_x : (out_w - im_x); /// < out_w
int out_rows_wg = (im_y + tile_sz_y) <= out_h ? tile_sz_y : (out_h - im_y); /// < out_h

int im_cols_wg = (TILE_SZ_X - 1) * stride_w + (wei_w - 1) * dilation_w + 1;
int im_cols_wg = (tile_sz_x - 1) * stride_w + (wei_w - 1) * dilation_w + 1;

int inner_lid = lid;

Expand Down Expand Up @@ -265,7 +276,7 @@ kernel void Im2d2Col(const int data_size_off,

index_t col_x = (index_t)(im_y + out_y) * out_w + im_x + out_x; /// out_h * out_w
/// c * out_h * out_w * wei_h * wei_w
index_t col_y = (gid / NUM_IM_BLKS) * out_h * out_w * wei_h * wei_w;
index_t col_y = (gid / num_im_blks) * out_h * out_w * wei_h * wei_w;

for(int y = 0; y < wei_h; y++)
{
Expand Down
18 changes: 10 additions & 8 deletions src/ocl/utilocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,11 @@ float Im2d2ColGPU(const Handle& handle,
}
}

params += " -DNUM_CH_PER_WG=" + std::to_string(num_ch_per_wg);
params += " -DNUM_IM_BLKS_X=" + std::to_string(num_blks_x);
params += " -DNUM_IM_BLKS=" + std::to_string(num_blks);
params += " -DLOCAL_MEM_SIZE=" + std::to_string(local_mem_sz);
params += " -DLOCAL_MEM_SIZE=" +
std::to_string(local_mem_sz); // needs some changes to the kernel launch
params += " -DSTRIDE_GT_1=" + std::to_string(static_cast<int>(stride_h * stride_w > 1));
params += " -DTILE_SZ_X=" + std::to_string(tile_sz_x);
params += " -DTILE_SZ_Y=" + std::to_string(tile_sz_y);
params += " -DUSE_IM_OFF_GUARD=1";
params += " -DNUM_IM_BLKS_EQ_1=" + std::to_string(static_cast<int>(num_blks == 1));
params += " -DUSE_IM_OFF_GUARD=1"; // always one

params += GetDataTypeKernelParams(type);

Expand Down Expand Up @@ -272,7 +269,12 @@ float Im2d2ColGPU(const Handle& handle,
stride_w,
dilation_h,
dilation_w,
col);
col,
num_ch_per_wg,
num_blks_x,
num_blks,
tile_sz_x,
tile_sz_y);
}

return handle.GetKernelTime();
Expand Down
Loading

0 comments on commit a8b4c4f

Please sign in to comment.