forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompute_at.cc
622 lines (577 loc) · 27.2 KB
/
compute_at.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"
namespace tvm {
namespace tir {
using support::NDIntSet;
/******** Error Classes ********/
/*!
* \brief An error raised when not all required blocks are under the given loop.
* \tparam is_consumer Indicates if all the required blocks are consumers or producers
*/
template <bool is_consumer>
class NotAllRequiredBlocksAreVisitedError : public ScheduleError {
public:
explicit NotAllRequiredBlocksAreVisitedError(IRModule mod, int num_not_visited,
const Array<StmtSRef>& required)
: mod_(mod), num_not_visited_(num_not_visited) {
required_.reserve(required.size());
for (const StmtSRef& block_sref : required) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
required_.push_back(GetRef<Block>(block));
}
}
String FastErrorString() const final {
return "ScheduleError: Not all required blocks are under the loop scope";
}
String DetailRenderTemplate() const final {
String relation = is_consumer ? "consumer(s)" : "producer(s)";
std::ostringstream os;
os << "The primitive requires all the " << relation
<< " of the given block to be present under the target loop. However, there are "
<< num_not_visited_ << " " << relation << " not satisfying the constraint. List of the "
<< relation << ":";
for (int i = 0, n = required_.size(); i < n; ++i) {
os << "{" << i << "}";
}
return os.str();
}
IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final {
return {required_.begin(), required_.end()};
}
private:
IRModule mod_;
int num_not_visited_;
Array<Block> required_;
};
/*!
* \brief An error raised when the given block is not in the same block scope as the given loop,
* or the given loop is the ancestor of the given block.
*/
class NotInSameScopeError : public ScheduleError {
public:
static void CheckAndBindLoopDomain(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, const StmtSRef& scope_root_sref,
arith::Analyzer* analyzer) {
for (const StmtSRefNode* p = loop_sref.get();; p = p->parent) {
if (const ForNode* loop = p->StmtAs<ForNode>()) {
analyzer->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
} else if (p != scope_root_sref.get()) {
throw NotInSameScopeError(self->mod, block_sref, loop_sref);
} else {
break;
}
}
for (const StmtSRefNode* p = block_sref->parent; p != scope_root_sref.get(); p = p->parent) {
if (p == loop_sref.get()) {
throw NotInSameScopeError(self->mod, block_sref, loop_sref);
}
}
}
String FastErrorString() const final {
return "ScheduleError: Expected the block and loop to be under the same block scope, and loop "
"not to be the ancestor of block";
}
String DetailRenderTemplate() const final {
return "ScheduleError: Expected the block {0} and loop {1} to be under the same block scope, "
"and loop not to be the ancestor of block";
}
IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {block_, loop_}; }
private:
explicit NotInSameScopeError(IRModule mod, const StmtSRef& block_sref, const StmtSRef& loop_sref)
: mod_(mod),
block_(GetRef<Block>(block_sref->StmtAs<BlockNode>())),
loop_(GetRef<For>(loop_sref->StmtAs<ForNode>())) {}
IRModule mod_;
Block block_;
For loop_;
};
/******** Helper Functions/Classes ********/
/*!
* \brief Find a point where the block can be inserted under the loop
* \tparam require_all_producers_visited Requires all producer blocks to be present under the loop
* \tparam require_all_consumers_visited Requires all consumer blocks to be present under the loop
* \param self The schedule state
* \param subtrees The subtrees under the loop, among which the insertion points are sought
* \param producer_srefs The producer blocks
* \param consumer_srefs The consumer blocks
* \param block2realize A cache that maps a block to its realize
* \return The last position the new block can be inserted onto, and the
* producer-consumer-relationship is still satisfied.
* \throws ScheduleError if there is no such insertion point found
*/
template <bool require_all_producers_visited, bool require_all_consumers_visited>
int FindInsertionPoint(
const ScheduleState& self, const Array<Stmt>& subtrees, const Array<StmtSRef>& producer_srefs,
const Array<StmtSRef>& consumer_srefs,
std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize) {
ProducerConsumerSplit split =
ProducerConsumerSplit::Find(self, subtrees, producer_srefs, consumer_srefs, block2realize);
// Step 1. Check if all the producers are visited in the subtrees, if required to
if (require_all_producers_visited) {
int num_producers = producer_srefs.size();
if (split.n_producers_visited < num_producers) {
throw NotAllRequiredBlocksAreVisitedError<false>(
self->mod, num_producers - split.n_producers_visited, producer_srefs);
}
}
// Step 2. Check if all the consumers are visited in the subtrees, if required to
if (require_all_consumers_visited) {
int num_consumers = consumer_srefs.size();
if (split.n_consumers_visited < num_consumers) {
throw NotAllRequiredBlocksAreVisitedError<true>(
self->mod, num_consumers - split.n_consumers_visited, consumer_srefs);
}
}
// Step 3. Check if there is at least one index of the position can be inserted into
// The valid indices are: (last_producer_position, first_consumer_position]
ICHECK(split.last_producer_position < split.first_consumer_position);
// Step 4. Return the last valid insertion point
return split.first_consumer_position;
}
/*!
* \brief A helper to reconstruct the block scope where the given block is moved under the given
* loop, and the given block's induced loop nest is regenerated to satisfy the required region.
*/
class ScopeReconstructor : private StmtMutator {
public:
explicit ScopeReconstructor(Block scope_root, Block block, For loop)
: scope_root_(scope_root), block_(block), loop_(loop) {}
using StmtMutator::operator();
/*!
* \brief Create the loop nest on top of the block, induced by the given block var's domain
* \param insert_position The position among the subtrees where the block and its induced loop
* nest is inserted
* \param iter_doms The domain of each block var
* \param preserve_unit_loops Whether to generate unit loops where the loop extent is 1
*/
void MakeNewLoop(int insert_position, std::vector<Range> iter_doms, bool preserve_unit_loops) {
int n_iters = iter_doms.size();
Array<Var> loop_vars;
Array<PrimExpr> loop_extents;
Array<PrimExpr> iter_values;
loop_vars.reserve(n_iters);
loop_extents.reserve(n_iters);
iter_values.reserve(n_iters);
for (int i = 0; i < n_iters; ++i) {
const Range& iter_dom = iter_doms[i];
if (preserve_unit_loops || !is_one(iter_dom->extent)) {
Var var("ax" + std::to_string(loop_vars.size()), DataType::Int(32));
loop_vars.push_back(var);
loop_extents.push_back(iter_dom->extent);
iter_values.push_back(iter_dom->min + var);
} else {
iter_values.push_back(iter_dom->min);
}
}
this->new_block_realize_ =
BlockRealize(std::move(iter_values), const_true(), std::move(block_));
Stmt new_subtree = this->new_block_realize_;
for (int i = static_cast<int>(loop_vars.size()) - 1; i >= 0; --i) {
const Var& loop_var = loop_vars[i];
const PrimExpr& loop_extent = loop_extents[i];
new_subtree = For(/*loop_var=*/loop_var,
/*min=*/Integer(0),
/*extent=*/loop_extent,
/*ForKind=*/ForKind::kSerial,
/*body=*/std::move(new_subtree));
}
Array<Stmt> subtrees = AsArray(loop_->body);
subtrees.insert(subtrees.begin() + insert_position, std::move(new_subtree));
ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop_.get());
new_loop->body = SeqStmt(std::move(subtrees));
this->new_loop_ = For(std::move(new_loop));
}
private:
Stmt VisitStmt_(const BlockNode* block) final {
if (block != scope_root_.get()) {
return GetRef<Block>(block);
}
if (block == rm_src_stmt_.get()) {
block = TVM_TYPE_AS(block, rm_tgt_stmt_, BlockNode);
}
return StmtMutator::VisitStmt_(block);
}
Stmt VisitStmt_(const ForNode* loop) final {
if (loop == rm_src_stmt_.get()) {
loop = TVM_TYPE_AS(loop, rm_tgt_stmt_, ForNode);
}
if (loop == loop_.get()) {
return new_loop_;
}
return StmtMutator::VisitStmt_(loop);
}
public:
/*! \brief The root block of the block scope */
Block scope_root_;
/*! \brief The given block to be moved */
Block block_;
/*! \brief The given loop the block and its loop nest to be put under */
For loop_;
/*! \brief The new loop to replace the original loop */
For new_loop_{nullptr};
/*! \brief The new block realize to the moved block */
BlockRealize new_block_realize_{nullptr};
/*! \brief The plan to remove the given block by replacing this loop/block in the AST */
Stmt rm_src_stmt_{nullptr};
/*! \brief The plan to remove the given block by replacing to this loop/block in the AST */
Stmt rm_tgt_stmt_{nullptr};
};
/*!
* \brief Calculate a list of accessed buffer regions under a path of loops
* \tparam relax_storage_scope Whether to relax beyond the path according to the storage and
* execution scope
* \param binding The block binding, used to unbind the buffer regions
* \param buffer_regions The buffer regions to be calculated
* \param relax_path_low_inclusive The lowest point in the loop path, inclusive
* \param relax_path_high_exclusive The highest point in the loop path, exclusive
* \param relaxed Where the calculation result is stored
*/
template <bool relax_storage_scope>
void RelaxBufferRegions(const Map<Var, PrimExpr>& binding,
const Array<BufferRegion>& buffer_regions,
const StmtSRef& relax_path_low_inclusive,
const StmtSRef& relax_path_high_exclusive,
std::unordered_map<const BufferNode*, std::vector<NDIntSet>>* relaxed) {
runtime::StorageScope global_scope{runtime::StorageRank::kGlobal, ""};
// We cache the variable domains
runtime::StorageRank previous_rank = runtime::StorageRank::kGlobal;
Optional<Map<Var, arith::IntSet>> var_dom = NullOpt;
// Enumerate every buffer region
for (const BufferRegion& buffer_region : buffer_regions) {
const Buffer& buffer = buffer_region->buffer;
const Array<Range>& region = buffer_region->region;
// Skip the buffer regions we are not interested in
auto it = relaxed->find(buffer.get());
if (it == relaxed->end()) {
continue;
}
std::vector<NDIntSet>& relaxed_regions = it->second;
// Check and update the cached `var_dom`
runtime::StorageScope scope =
relax_storage_scope ? runtime::StorageScope::Create(buffer.scope()) : global_scope;
runtime::StorageRank rank = scope.rank;
if (rank != previous_rank || !var_dom.defined()) {
previous_rank = rank;
var_dom = AsIntSet(LoopDomainOfSRefTreePath(
/*low_inclusive=*/relax_path_low_inclusive,
/*high_exclusive=*/relax_path_high_exclusive,
/*extra_relax_scope=*/scope));
}
// Relax the region
Array<arith::IntSet> relaxed_region =
arith::EvalSet(Substitute(region, binding), var_dom.value());
relaxed_regions.push_back({relaxed_region.begin(), relaxed_region.end()});
}
}
/*!
* \brief Calculate the iteration domain of a provided integer set to fully cover the required
* domain
* \param provided The provided integer set to cover the required domain
* \param required The required domain to be covered
* \param iter_doms The result iteration domains to be updated
* \param analyzer The arithmetic analyzer
*/
void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& required,
std::unordered_map<const VarNode*, std::vector<arith::IntSet>>* iter_doms,
arith::Analyzer* analyzer) {
PrimExpr provided_min = analyzer->Simplify(provided.min());
PrimExpr provided_extent = analyzer->Simplify(provided.max() - provided_min + 1);
PrimExpr required_min = analyzer->Simplify(required.min());
PrimExpr required_extent = analyzer->Simplify(required.max() - required_min + 1);
PrimExpr dom_min{nullptr}, dom_extent{nullptr};
Var dom_var{ObjectPtr<VarNode>{nullptr}};
arith::PVar<Var> p_v;
arith::PVar<PrimExpr> p_e;
if ((p_v * p_e).Match(provided_min) || (p_e * p_v).Match(provided_min)) {
PrimExpr e = p_e.Eval();
dom_var = p_v.Eval();
dom_min = floordiv(required_min, e);
dom_extent = analyzer->Simplify((required_extent + e - 1) / e);
} else if (analyzer->CanProveEqual(provided_extent, 1) && p_v.Match(provided_min)) {
dom_var = p_v.Eval();
dom_min = required_min;
dom_extent = required_extent;
} else {
ICHECK(false) << "ValueError: BufferRegion pattern match failed";
}
auto it = iter_doms->find(dom_var.get());
if (it != iter_doms->end()) {
std::vector<arith::IntSet>& doms = it->second;
doms.push_back(arith::IntSet::FromMinExtent(dom_min, dom_extent));
} else {
ICHECK(analyzer->CanProveEqual(provided_min, required_min));
ICHECK(analyzer->CanProveEqual(provided_extent, required_extent));
}
}
/*!
* \brief Calculate the domain of block vars to cover the required region
* \param iter_vars The list of block vars to cover the required region
* \param provided_regions The region provided by one iteration instance of the block vars
* \param required_regions The region required to be covered
* \param analyzer The arithmetic analyzer
* \return A list of iteration domain corresponding to the given list of block vars
*/
std::vector<Range> CalculateBlockVarDomain(
const Array<IterVar>& iter_vars,
std::unordered_map<const BufferNode*, std::vector<NDIntSet>> provided_regions,
std::unordered_map<const BufferNode*, std::vector<NDIntSet>> required_regions,
arith::Analyzer* analyzer) {
int n_iters = iter_vars.size();
// Step 1. Construct the mapping from block var to their iteration domain (initialized to empty)
std::unordered_map<const VarNode*, std::vector<arith::IntSet>> iter_doms;
iter_doms.reserve(n_iters);
for (const IterVar& iter_var : iter_vars) {
iter_doms[iter_var->var.get()] = {};
}
// Step 2. For each buffer, update the domain according to the provided and required regions
for (const auto& kv : provided_regions) {
const BufferNode* buffer = kv.first;
const std::vector<NDIntSet>& many_provided_regions = kv.second;
// Calculate `provided_region` and `required_region`
auto it = required_regions.find(buffer);
if (it == required_regions.end() || it->second.empty()) {
continue;
}
NDIntSet required_region = support::NDIntSetUnion(it->second);
NDIntSet provided_region = support::NDIntSetUnion(many_provided_regions);
ICHECK_EQ(provided_region.size(), buffer->shape.size());
ICHECK_EQ(required_region.size(), buffer->shape.size());
// For each dimension, update the iteration domain
int ndim = buffer->shape.size();
for (int i = 0; i < ndim; ++i) {
arith::IntSet provided = provided_region[i];
arith::IntSet required = required_region[i];
required = arith::Intersect(
{std::move(required), arith::IntSet::FromMinExtent(Integer(0), buffer->shape[i])});
UpdateBlockVarDomain(provided, required, &iter_doms, analyzer);
}
}
// Union the iter var domains, put them in the same order of block vars, and return
std::vector<Range> result;
result.reserve(n_iters);
for (const IterVar& iter_var : iter_vars) {
const std::vector<arith::IntSet>& doms = iter_doms.at(iter_var->var.get());
arith::IntSet dom = arith::IntSet::FromRange(iter_var->dom);
if (!doms.empty()) {
dom = arith::Intersect({std::move(dom), arith::Union(doms)});
}
PrimExpr min = analyzer->Simplify(dom.min());
PrimExpr extent = analyzer->Simplify(dom.max() - min + 1);
result.push_back(Range::FromMinExtent(min, extent));
}
return result;
}
/*!
* \brief Calculate the provided region of the given block by one single of its execution instance,
* as well as the required buffer regions relaxed to the given loop
* \tparam is_compute_at Indicates if the operation is compute-at or reverse-compute-at
* \param block The given block that provides buffer regions
* \param loop_sref The given loop under which the block is going to be moved to
* \param block2realize Maps a block to its corresponding BlockRealize
* \param producer_srefs The producers of the given block
* \param consumer_srefs The consumers of the given block
* \param provided_regions The calculated regions provided by the block
* \param required_regions The calculated regions required by its consumers (in compute-at) or
* producers (in reverse-compute-at)
*/
template <bool is_compute_at>
void CalculateProvidedRequiredRegions(
const BlockNode* block, const StmtSRef& loop_sref,
std::unordered_map<const BlockNode*, const BlockRealizeNode*> block2realize,
Array<StmtSRef> producer_srefs, Array<StmtSRef> consumer_srefs,
std::unordered_map<const BufferNode*, std::vector<NDIntSet>>* provided_regions,
std::unordered_map<const BufferNode*, std::vector<NDIntSet>>* required_regions) {
// Step 1. Calculate the region provided by a single execution instance of `block`
const Array<BufferRegion>& provided_buffers = is_compute_at ? block->writes : block->reads;
provided_regions->reserve(provided_buffers.size());
required_regions->reserve(provided_buffers.size());
for (const BufferRegion& provided_buffer_region : provided_buffers) {
const BufferNode* buffer = provided_buffer_region->buffer.get();
const Array<Range>& region = provided_buffer_region->region;
(*provided_regions)[buffer].push_back(support::NDIntSetFromRegion(region));
(*required_regions)[buffer].clear();
}
// Step 2. Calculate the region required by dependent blocks under `loop`
for (const StmtSRef& required_block_sref : is_compute_at ? consumer_srefs : producer_srefs) {
const BlockNode* required_block = TVM_SREF_TO_BLOCK(required_block, required_block_sref);
ICHECK(block2realize.count(required_block));
RelaxBufferRegions</*relax_storage_scope=*/is_compute_at>(
/*binding=*/GetBindings(GetRef<BlockRealize>(block2realize.at(required_block))),
/*buffer_regions=*/is_compute_at ? required_block->reads : required_block->writes,
/*relax_path_low_inclusive=*/GetRef<StmtSRef>(required_block_sref->parent),
/*relax_path_high_exclusive=*/loop_sref, /*relaxed=*/required_regions);
}
}
/******** Main Implementation ********/
template <bool is_compute_at>
void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops,
arith::Analyzer* analyzer, bool check_only = false) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
// Step 1. Bunch of checks
// Check condition 1) : scope stage pipeline
StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/true);
Block scope_root = GetRef<Block>(scope_root_sref->StmtAs<BlockNode>());
BlockScope scope = self->GetBlockScope(scope_root_sref);
Array<StmtSRef> producer_srefs = GetProducers(block_sref, scope);
Array<StmtSRef> consumer_srefs = GetConsumers(block_sref, scope);
// Check condition 2) : `block` is a complete or reduction block
CheckCompleteOrReductionBlock(self, block_sref, scope_root_sref);
// Check condition 3): `block` and `loop` are under the same scope,
// and `loop` is not the ancestor of `block`
NotInSameScopeError::CheckAndBindLoopDomain(self, block_sref, loop_sref, scope_root_sref,
analyzer);
// Check condition 4): `block` is not an output block
if (is_compute_at) {
CheckNotOutputBlock(self, block_sref, scope_root_sref);
}
// Step 2. Plan for the removal of `block`
ScopeReconstructor reconstructor(scope_root, GetRef<Block>(block), GetRef<For>(loop));
LeafBlockRemovalPlan(self, block_sref, &reconstructor.rm_src_stmt_, &reconstructor.rm_tgt_stmt_);
// Step 3. Find the insertion point under `loop`
// Check condition 5): all the required block are under the given loop
std::unordered_map<const BlockNode*, const BlockRealizeNode*> block2realize;
block2realize.reserve(self->block_info.size());
int insert_position = FindInsertionPoint<!is_compute_at, is_compute_at>(
/*self=*/self,
/*subtrees=*/AsArray(loop->body),
/*producer_srefs=*/producer_srefs,
/*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize);
// Step 4. Calculate the region provided by a single execution instance of `block`,
// as well as the region required by dependent blocks under `loop`.
// Here is the definition of `provide` and `require`:
// - In compute-at, `provide` means `produce`, and `require` means `consume`
// - In reverse-compute-at, `provide` means `consume`, and `require` means `produce`
std::unordered_map<const BufferNode*, std::vector<NDIntSet>> provided_regions;
std::unordered_map<const BufferNode*, std::vector<NDIntSet>> required_regions;
CalculateProvidedRequiredRegions<is_compute_at>(
/*block=*/block, /*loop_sref=*/loop_sref, /*block2realize=*/std::move(block2realize),
/*producer_srefs=*/std::move(producer_srefs),
/*consumer_srefs=*/std::move(consumer_srefs),
/*provided_regions=*/&provided_regions, /*required_regions=*/&required_regions);
// Step 5. Calculate the iteration domain for each block var
std::vector<Range> iter_doms =
CalculateBlockVarDomain(/*iter_vars=*/block->iter_vars,
/*provided_regions=*/std::move(provided_regions),
/*required_regions=*/std::move(required_regions),
/*analyzer=*/analyzer);
// Step 6. Create the new scope according to the iteration domain
reconstructor.MakeNewLoop(/*insert_position=*/insert_position, /*iter_doms=*/std::move(iter_doms),
/*preserve_unit_loops=*/preserve_unit_loops);
Block new_scope_root = Downcast<Block>(reconstructor(scope_root));
// Step 7. Do the actual replacement
if (check_only) {
return;
}
self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}});
// Step 8. Update the cached flags
BlockInfo& block_info = self->block_info[block_sref];
block_info.affine_binding = IsAffineBinding(
/*realize=*/reconstructor.new_block_realize_,
/*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef<StmtSRef>(block_sref->parent)),
/*analyzer=*/analyzer);
}
void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
arith::Analyzer analyzer;
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops,
&analyzer);
}
void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
arith::Analyzer analyzer;
ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops,
&analyzer);
}
bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
arith::Analyzer analyzer;
try {
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops,
&analyzer, true);
} catch (const tvm::runtime::Error& e) {
return false;
}
return true;
}
bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops) {
arith::Analyzer analyzer;
try {
ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops,
&analyzer, true);
} catch (const tvm::runtime::Error& e) {
return false;
}
return true;
}
/******** InstructionKind Registration ********/
struct ComputeAtTraits : public UnpackedInstTraits<ComputeAtTraits> {
static constexpr const char* kName = "ComputeAt";
static constexpr bool kIsPure = false;
private:
static constexpr size_t kNumInputs = 2;
static constexpr size_t kNumAttrs = 1;
static constexpr size_t kNumDecisions = 0;
static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv,
Bool preserve_unit_loops) {
return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool());
}
static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv,
Bool preserve_unit_loops) {
PythonAPICall py("compute_at");
py.Input("block", block_rv);
py.Input("loop", loop_rv);
py.Input("preserve_unit_loops", preserve_unit_loops.operator bool());
return py.Str();
}
template <typename>
friend struct ::tvm::tir::UnpackedInstTraits;
};
struct ReverseComputeAtTraits : public UnpackedInstTraits<ReverseComputeAtTraits> {
static constexpr const char* kName = "ReverseComputeAt";
static constexpr bool kIsPure = false;
private:
static constexpr size_t kNumInputs = 2;
static constexpr size_t kNumAttrs = 1;
static constexpr size_t kNumDecisions = 0;
static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv,
Bool preserve_unit_loops) {
return sch->ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool());
}
static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv,
Bool preserve_unit_loops) {
PythonAPICall py("reverse_compute_at");
py.Input("block", block_rv);
py.Input("loop", loop_rv);
py.Input("preserve_unit_loops", preserve_unit_loops.operator bool());
return py.Str();
}
template <typename>
friend struct ::tvm::tir::UnpackedInstTraits;
};
TVM_REGISTER_INST_KIND_TRAITS(ComputeAtTraits);
TVM_REGISTER_INST_KIND_TRAITS(ReverseComputeAtTraits);
} // namespace tir
} // namespace tvm