-
Notifications
You must be signed in to change notification settings - Fork 595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[VectorDistribution] Add layout analysis for distributing multi-dim reduction (2/4) #18800
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
…pport for the case when the reduction is inside scf.for operation Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
// Returns a boolean flag indicating whether the input value 'val' is a | ||
// vector, determined by checking its rank. | ||
bool isVector(VectorValue val); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This strikes me as an odd helper: you give 'vector' a new meaning without introducing a name. Instead, I'd either flip it and add a helper like isRank0(VectorValue val)
, or just expand the check where you need it.
bool isVector(VectorValue val) { | ||
if (val.getType().getRank() != 0) | ||
return true; | ||
return false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bool isVector(VectorValue val) { | |
if (val.getType().getRank() != 0) | |
return true; | |
return false; | |
} | |
bool isVector(VectorValue val) { | |
return val.getType().getRank() != 0; | |
} |
// Result lattice not has a layout yet. | ||
if (resultLattices.empty()) | ||
return; | ||
|
||
// We do not support multiple results yet. | ||
if (resultLattices.size() != 1) | ||
return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first check is redundant
return; | ||
|
||
for (RegionSuccessor successor : successors) { | ||
if (auto succ = successor.getSuccessor()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use the actual type here instead of auto
? It's not clear based on the RHS
assert(isa<VectorType>(inputType) && | ||
"Scalar broadcast not supported for now."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assertion is obsolete now.
if (isa<VectorType>(replacement.getType()) && | ||
cast<ShapedType>(replacement.getType()).getRank() != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use dyn_cast instead:
if (auto x = dyn_cast<Y>(y)) {
if (x.something() == Z) {
replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced); | ||
} else { | ||
Value accReducedVal = rewriter.create<vector::ExtractOp>( | ||
loc, accReduction, SmallVector<int64_t>{0}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need a vector here, you can do: ArrayRef{int64_t(0)}
if (accVector) { | ||
locallyReduced = dyn_cast<VectorValue>(localReduction.getResult()); | ||
} else { | ||
VectorType vecType = VectorType::get(SmallVector<int64_t>{1}, elemTy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also here, no need to use a vector
if (accVector) { | ||
accElemTy = accVector.getType().getElementType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use getElementTypeOrSelf
. Also below.
bool isSrcVector = (srcVector) && (isVector(srcVector)); | ||
if (!isSrcVector) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only use of isSrcVector
, so it makes sense to inline it
bool isSrcVector = (srcVector) && (isVector(srcVector)); | |
if (!isSrcVector) { | |
if (!srcVector || !isVector(srcVector)) { |
Splitting #18519 into four patches.
Depends #18784
This is the second one, adding the corresponding layout analysis and especially supporting the case where reduction is performed inside scf.for operation.
Also, the relevant tests are added.