Skip to content

Commit

Permalink
[Bugfix] Simplify reduce expression in te.gradient (apache#6611)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu authored and Tushar Dey committed Oct 15, 2020
1 parent d241781 commit 63e95e3
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/te/autodiff/ad_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1172,8 +1172,10 @@ PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const A

new_red = Reduce(red->combiner, source, red->axis, cond, red->value_index, red->init);
new_red = SimplifyReductionDomain(new_red, combined_vranges);
// Update original red pointer for later use.
red = new_red.as<ReduceNode>();
// If the reduction disappears completely then transform the result as a non-reduction
if (!new_red.as<ReduceNode>()) {
if (!red) {
return RemoveJacobianAndLiftNonzeroCondImpl(new_red, axis, vranges);
}

Expand Down

0 comments on commit 63e95e3

Please sign in to comment.