Skip to content

Commit

Permalink
[MetaSchedule] Enhance Conv2d NCHW Winograd Schedule Rules (#12127)
Browse files Browse the repository at this point in the history
* Update winograd schedule rules.

* Remove extra part for setting local storage scope.

* Fix bgemm schedule.

* Add winograd tile size to annotation.

* Finish winograd schedule rules.

* Process add relu.

* Modify to nchw rules.

* Add missing nchw output rules.

* Add winograd conv2d nchw search space test.

* Fix lint.

* Leave consumer of output to autoinline.

* Remove bgemm rules.

* Remove bgemm schedule rule annotation.

* Update unit test.

* Fix test case.
  • Loading branch information
zxybazh authored Aug 3, 2022
1 parent df29e82 commit 46a8498
Show file tree
Hide file tree
Showing 3 changed files with 271 additions and 4 deletions.
10 changes: 6 additions & 4 deletions python/tvm/topi/cuda/conv2d_winograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_
(0, 0, pt, pl),
(0, 0, pb, pr),
name="data_pad",
attrs={"schedule_rule": "None"},
)

r = KW
Expand Down Expand Up @@ -118,7 +117,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_
idxmod(idxdiv(p, nW), nH) * m + eps
][idxmod(p, nW) * m + nu],
name="d",
attrs={"schedule_rule": "None"},
)

# transform data
Expand All @@ -130,7 +128,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_
input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]
),
name="data_pack",
attrs={"schedule_rule": "meta_schedule.winograd_data_pack.cuda"},
attrs={"schedule_rule": "meta_schedule.winograd_data_pack.nchw.cuda"},
)

# do batch gemm
Expand All @@ -152,7 +150,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_
bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]
),
name="inverse",
attrs={"schedule_rule": "meta_schedule.winograd_inverse.cuda"},
attrs={"schedule_rule": "meta_schedule.winograd_inverse.nchw.cuda"},
)

# output
Expand All @@ -163,6 +161,10 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_
],
name="output",
tag="conv2d_nchw_winograd",
attrs={
"schedule_rule": "meta_schedule.winograd_output.nchw.cuda",
"winograd_tile_size": alpha - 3 + 1,
},
)

if isinstance(N, int):
Expand Down
101 changes: 101 additions & 0 deletions src/meta_schedule/schedule_rule/winograd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ inline BlockRV GetOnlyProducer(Schedule sch, BlockRV block) {
return producers[0];
}

inline BlockRV GetOnlyConsumer(Schedule sch, BlockRV block) {
Array<BlockRV> consumers = sch->GetConsumers(block);
ICHECK_EQ(consumers.size(), 1);
return consumers[0];
}

inline LoopRV ScheduleDataPack(Schedule sch, BlockRV block) {
Array<ExprRV> factors{nullptr};
Array<LoopRV> loops = sch->GetLoops(block);
Expand Down Expand Up @@ -74,6 +80,32 @@ inline LoopRV ScheduleDataPack(Schedule sch, BlockRV block) {
return t1[1];
}

inline LoopRV ScheduleDataPackNCHW(Schedule sch, BlockRV block) {
Array<LoopRV> loops = sch->GetLoops(block);
ICHECK_EQ(loops.size(), 6);

if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[0]))) {
if (*i <= 16) {
sch->Unroll(loops[0]);
}
}
if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[1]))) {
if (*i <= 16) {
sch->Unroll(loops[1]);
}
}
sch->Unroll(loops[4]);
sch->Unroll(loops[5]);

Array<ExprRV> factors = sch->SamplePerfectTile(loops[3], /*n=*/2, /*max_innermost_factor=*/64);
Array<LoopRV> split =
sch->Split(loops[3], /*factors=*/{factors[0], factors[1]}, /*preserve_unit_loops=*/true);

LoopRV fused = sch->Fuse({loops[2], split[0]});
sch->Reorder({fused, split[1], loops[0], loops[1]});
return split[1];
}

TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse.llvm")
.set_body_typed([](Schedule sch, BlockRV block) -> Array<Schedule> {
ScheduleDataPack(sch, block);
Expand All @@ -92,6 +124,37 @@ TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.llvm")
return {sch};
});

TVM_REGISTER_GLOBAL("meta_schedule.winograd_output.nchw.cuda")
.set_body_typed([](Schedule sch, BlockRV output) -> Array<Schedule> {
// get loops
Array<LoopRV> loops = sch->GetLoops(output);
ICHECK_EQ(loops.size(), 4);

BlockRV OL{nullptr};

// tile
Optional<PrimExpr> tile_size =
tir::GetAnn<PrimExpr>(sch->GetSRef(output), "winograd_tile_size");
ICHECK(tile_size.defined()) << "Winograd tile size is not defined in block annotation!";
Array<LoopRV> split0 = sch->Split(loops[2], {NullOpt, tile_size.value()});
Array<LoopRV> split1 = sch->Split(loops[3], {NullOpt, tile_size.value()});
sch->Reorder({split0[0], split1[0], split0[1], split1[1]});

// compute_at
BlockRV inverse = GetOnlyProducer(sch, output);
sch->ComputeAt(inverse, /*loop_rv=*/split1[0],
/*preserve_unit_loops=*/true);

// fuse
LoopRV fused = sch->Fuse({loops[0], loops[1], split0[0], split1[0]});

int64_t max_threadblocks = 256;
int64_t max_threads_per_block = 1024;
auto get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024});
BindBlockThreadIdx(sch, output, max_threadblocks, max_threads_per_block, get_factor);
return {sch};
});

TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse.cuda")
.set_body_typed([](Schedule sch, BlockRV block) -> Array<Schedule> {
ScheduleDataPack(sch, block);
Expand All @@ -102,6 +165,26 @@ TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse.cuda")
return {sch};
});

TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse.nchw.cuda")
.set_body_typed([](Schedule sch, BlockRV inverse) -> Array<Schedule> {
sch->SetScope(inverse, /*buffer_index=*/0, /*storage_scope=*/"local");
Array<LoopRV> loops = sch->GetLoops(inverse);
ICHECK_EQ(loops.size(), 6);
if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[2]))) {
if (*i <= 16) {
sch->Unroll(loops[2]);
}
}
if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[3]))) {
if (*i <= 16) {
sch->Unroll(loops[3]);
}
}
sch->Unroll(loops[4]);
sch->Unroll(loops[5]);
return {sch};
});

TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.cuda")
.set_body_typed([](Schedule sch, BlockRV data_pack) -> Array<Schedule> {
BlockRV input_tile = GetOnlyProducer(sch, data_pack);
Expand All @@ -117,5 +200,23 @@ TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.cuda")
return {sch};
});

TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.nchw.cuda")
.set_body_typed([](Schedule sch, BlockRV data_pack) -> Array<Schedule> {
BlockRV input_tile = GetOnlyProducer(sch, data_pack);
BlockRV data_pad = GetOnlyProducer(sch, input_tile);

BlockRV data_l = sch->CacheWrite(data_pack, /*buffer_index=*/0, /*storage_scope=*/"local");
LoopRV loop = ScheduleDataPackNCHW(sch, data_pack);
sch->ReverseComputeAt(data_l, loop, /*preserve_unit_loops=*/true);
sch->ComputeAt(input_tile, /*loop_rv=*/loop, /*preserve_unit_loops=*/true);
sch->ComputeInline(data_pad);

int64_t max_threadblocks = 256;
int64_t max_threads_per_block = 1024;
auto get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024});
BindBlockThreadIdx(sch, data_pack, max_threadblocks, max_threads_per_block, get_factor);
return {sch};
});

} // namespace meta_schedule
} // namespace tvm
Loading

0 comments on commit 46a8498

Please sign in to comment.