Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] Fold NORM2() #66240

Merged
merged 1 commit into from
Sep 18, 2023
Merged

[flang] Fold NORM2() #66240

merged 1 commit into from
Sep 18, 2023

Conversation

klausler
Copy link
Contributor

Fold references to the (relatively new) intrinsic function NORM2 at compilation time when the argument(s) are all constants. (Getting this done right involved some changes to the API of the accumulator function objects used by the DoReduction<> template, which rippled through some other reduction function folding code.)

Fold references to the (relatively new) intrinsic function
NORM2 at compilation time when the argument(s) are all constants.
(Getting this done right involved some changes to the API of the
accumulator function objects used by the DoReduction<> template,
which rippled through some other reduction function folding code.)

Pull request: llvm#66240
@klausler klausler requested a review from jeanPerier September 13, 2023 16:54
@klausler klausler requested a review from a team as a code owner September 13, 2023 16:54
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:semantics labels Sep 13, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 13, 2023

@llvm/pr-subscribers-flang-semantics

Changes Fold references to the (relatively new) intrinsic function NORM2 at compilation time when the argument(s) are all constants. (Getting this done right involved some changes to the API of the accumulator function objects used by the DoReduction<> template, which rippled through some other reduction function folding code.) --

Patch is 20.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66240.diff

5 Files Affected:

  • (modified) flang/lib/Evaluate/fold-integer.cpp (+23-13)
  • (modified) flang/lib/Evaluate/fold-logical.cpp (+1-4)
  • (modified) flang/lib/Evaluate/fold-real.cpp (+77-1)
  • (modified) flang/lib/Evaluate/fold-reduction.h (+118-50)
  • (added) flang/test/Evaluate/fold-norm2.f90 (+29)

<pre>
diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp
index fe38c81d976822d..dedfc20a491cd88 100644
--- a/flang/lib/Evaluate/fold-integer.cpp
+++ b/flang/lib/Evaluate/fold-integer.cpp
@@ -264,6 +264,26 @@ Expr&lt;Type&lt;TypeCategory::Integer, KIND&gt;&gt; UBOUND(FoldingContext &amp;context,
}

// COUNT()
+template &lt;typename T, int MASK_KIND&gt; class CountAccumulator {

  • using MaskT = Type&lt;TypeCategory::Logical, MASK_KIND&gt;;

+public:

  • CountAccumulator(const Constant&lt;MaskT&gt; &amp;mask) : mask_{mask} {}
  • void operator()(Scalar&lt;T&gt; &amp;element, const ConstantSubscripts &amp;at) {
  • if (mask_.At(at).IsTrue()) {
  •  auto incremented{element.AddSigned(Scalar&amp;lt;T&amp;gt;{1})};
    
  •  overflow_ |= incremented.overflow;
    
  •  element = incremented.value;
    
  • }
  • }
  • bool overflow() const { return overflow_; }
  • void Done(Scalar&lt;T&gt; &amp;) const {}

+private:

  • const Constant&lt;MaskT&gt; &amp;mask_;
  • bool overflow_{false};
    +};

template &lt;typename T, int maskKind&gt;
static Expr&lt;T&gt; FoldCount(FoldingContext &amp;context, FunctionRef&lt;T&gt; &amp;&amp;ref) {
using LogicalResult = Type&lt;TypeCategory::Logical, maskKind&gt;;
@@ -274,17 +294,9 @@ static Expr&lt;T&gt; FoldCount(FoldingContext &amp;context, FunctionRef&lt;T&gt; &amp;&amp;ref) {
: Folder&lt;LogicalResult&gt;{context}.Folding(arg[0])}) {
std::optional&lt;int&gt; dim;
if (CheckReductionDIM(dim, context, arg, 1, mask-&gt;Rank())) {

  •  bool overflow{false};
    
  •  auto accumulator{
    
  •      [&amp;amp;mask, &amp;amp;overflow](Scalar&amp;lt;T&amp;gt; &amp;amp;element, const ConstantSubscripts &amp;amp;at) {
    
  •        if (mask-&amp;gt;At(at).IsTrue()) {
    
  •          auto incremented{element.AddSigned(Scalar&amp;lt;T&amp;gt;{1})};
    
  •          overflow |= incremented.overflow;
    
  •          element = incremented.value;
    
  •        }
    
  •      }};
    
  •  CountAccumulator&amp;lt;T, maskKind&amp;gt; accumulator{*mask};
     Constant&amp;lt;T&amp;gt; result{DoReduction&amp;lt;T&amp;gt;(*mask, dim, Scalar&amp;lt;T&amp;gt;{}, accumulator)};
    
  •  if (overflow) {
    
  •  if (accumulator.overflow()) {
       context.messages().Say(
           &amp;quot;Result of intrinsic function COUNT overflows its result type&amp;quot;_warn_en_US);
     }
    

@@ -513,9 +525,7 @@ static Expr&lt;T&gt; FoldBitReduction(FoldingContext &amp;context, FunctionRef&lt;T&gt; &amp;&amp;ref,
if (std::optional&lt;Constant&lt;T&gt;&gt; array{
ProcessReductionArgs&lt;T&gt;(context, ref.arguments(), dim, identity,
/ARRAY=/0, /DIM=/1, /MASK=/2)}) {

  • auto accumulator{[&amp;](Scalar&lt;T&gt; &amp;element, const ConstantSubscripts &amp;at) {
  •  element = (element.*operation)(array-&amp;gt;At(at));
    
  • }};
  • OperationAccumulator&lt;T&gt; accumulator{*array, operation};
    return Expr&lt;T&gt;{DoReduction&lt;T&gt;(*array, dim, identity, accumulator)};
    }
    return Expr&lt;T&gt;{std::move(ref)};
    diff --git a/flang/lib/Evaluate/fold-logical.cpp b/flang/lib/Evaluate/fold-logical.cpp
    index 95335f7f48bbedf..9fc42adf805f468 100644
    --- a/flang/lib/Evaluate/fold-logical.cpp
    +++ b/flang/lib/Evaluate/fold-logical.cpp
    @@ -28,14 +28,11 @@ static Expr&lt;T&gt; FoldAllAnyParity(FoldingContext &amp;context, FunctionRef&lt;T&gt; &amp;&amp;ref,
    Scalar&lt;T&gt; (Scalar&lt;T&gt;::*operation)(const Scalar&lt;T&gt; &amp;) const,
    Scalar&lt;T&gt; identity) {
    static_assert(T::category == TypeCategory::Logical);
  • using Element = Scalar&lt;T&gt;;
    std::optional&lt;int&gt; dim;
    if (std::optional&lt;Constant&lt;T&gt;&gt; array{
    ProcessReductionArgs&lt;T&gt;(context, ref.arguments(), dim, identity,
    /ARRAY(MASK)=/0, /DIM=/1)}) {
  • auto accumulator{[&amp;](Element &amp;element, const ConstantSubscripts &amp;at) {
  •  element = (element.*operation)(array-&amp;gt;At(at));
    
  • }};
  • OperationAccumulator accumulator{*array, operation};
    return Expr&lt;T&gt;{DoReduction&lt;T&gt;(*array, dim, identity, accumulator)};
    }
    return Expr&lt;T&gt;{std::move(ref)};
    diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
    index 671d897ef7b2f82..8e3ab1d8fd30b09 100644
    --- a/flang/lib/Evaluate/fold-real.cpp
    +++ b/flang/lib/Evaluate/fold-real.cpp
    @@ -43,6 +43,80 @@ static Expr&lt;T&gt; FoldTransformationalBessel(
    return Expr&lt;T&gt;{std::move(funcRef)};
    }

+// NORM2
+template &lt;int KIND&gt; class Norm2Accumulator {

  • using T = Type&lt;TypeCategory::Real, KIND&gt;;

+public:

  • Norm2Accumulator(
  •  const Constant&amp;lt;T&amp;gt; &amp;amp;array, const Constant&amp;lt;T&amp;gt; &amp;amp;maxAbs, Rounding rounding)
    
  •  : array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
    
  • void operator()(Scalar&lt;T&gt; &amp;element, const ConstantSubscripts &amp;at) {
  • // Kahan summation of scaled elements
  • auto scale{maxAbs_.At(maxAbsAt_)};
  • if (scale.IsZero()) {
  •  // If maxAbs is zero, so are all elements, and result
    
  •  element = scale;
    
  • } else {
  •  auto item{array_.At(at)};
    
  •  auto scaled{item.Divide(scale).value};
    
  •  auto square{item.Multiply(scaled).value};
    
  •  auto next{square.Add(correction_, rounding_)};
    
  •  overflow_ |= next.flags.test(RealFlag::Overflow);
    
  •  auto sum{element.Add(next.value, rounding_)};
    
  •  overflow_ |= sum.flags.test(RealFlag::Overflow);
    
  •  correction_ = sum.value.Subtract(element, rounding_)
    
  •                    .value.Subtract(next.value, rounding_)
    
  •                    .value;
    
  •  element = sum.value;
    
  • }
  • }
  • bool overflow() const { return overflow_; }
  • void Done(Scalar&lt;T&gt; &amp;result) {
  • auto corrected{result.Add(correction_, rounding_)};
  • overflow_ |= corrected.flags.test(RealFlag::Overflow);
  • correction_ = Scalar&lt;T&gt;{};
  • auto rescaled{corrected.value.Multiply(maxAbs_.At(maxAbsAt_))};
  • maxAbs_.IncrementSubscripts(maxAbsAt_);
  • overflow_ |= rescaled.flags.test(RealFlag::Overflow);
  • result = rescaled.value.SQRT().value;
  • }

+private:

  • const Constant&lt;T&gt; &amp;array_;
  • const Constant&lt;T&gt; &amp;maxAbs_;
  • const Rounding rounding_;
  • bool overflow_{false};
  • Scalar&lt;T&gt; correction_{};
  • ConstantSubscripts maxAbsAt_{maxAbs_.lbounds()};
    +};

+template &lt;int KIND&gt;
+static Expr&lt;Type&lt;TypeCategory::Real, KIND&gt;&gt; FoldNorm2(FoldingContext &amp;context,

  • FunctionRef&lt;Type&lt;TypeCategory::Real, KIND&gt;&gt; &amp;&amp;funcRef) {
  • using T = Type&lt;TypeCategory::Real, KIND&gt;;
  • using Element = typename Constant&lt;T&gt;::Element;
  • std::optional&lt;int&gt; dim;
  • const Element identity{};
  • if (std::optional&lt;Constant&lt;T&gt;&gt; array{
  •      ProcessReductionArgs&amp;lt;T&amp;gt;(context, funcRef.arguments(), dim, identity,
    
  •          /*X=*/0, /*DIM=*/1)}) {
    
  • MaxvalMinvalAccumulator&lt;T, /ABS=/true&gt; maxAbsAccumulator{
  •    RelationalOperator::GT, context, *array};
    
  • Constant&lt;T&gt; maxAbs{
  •    DoReduction&amp;lt;T&amp;gt;(*array, dim, identity, maxAbsAccumulator)};
    
  • Norm2Accumulator norm2Accumulator{
  •    *array, maxAbs, context.targetCharacteristics().roundingMode()};
    
  • Constant&lt;T&gt; result{DoReduction&lt;T&gt;(*array, dim, identity, norm2Accumulator)};
  • if (norm2Accumulator.overflow()) {
  •  context.messages().Say(
    
  •      &amp;quot;NORM2() of REAL(%d) data overflowed&amp;quot;_warn_en_US, KIND);
    
  • }
  • return Expr&lt;T&gt;{std::move(result)};
  • }
  • return Expr&lt;T&gt;{std::move(funcRef)};
    +}

template &lt;int KIND&gt;
Expr&lt;Type&lt;TypeCategory::Real, KIND&gt;&gt; FoldIntrinsicFunction(
FoldingContext &amp;context,
@@ -238,6 +312,8 @@ Expr&lt;Type&lt;TypeCategory::Real, KIND&gt;&gt; FoldIntrinsicFunction(
},
sExpr-&gt;u);
}

  • } else if (name == &quot;norm2&quot;) {
  • return FoldNorm2&lt;T::kind&gt;(context, std::move(funcRef));
    } else if (name == &quot;product&quot;) {
    auto one{Scalar&lt;T&gt;::FromInteger(value::Integer&lt;8&gt;{1}).value};
    return FoldProduct&lt;T&gt;(context, std::move(funcRef), one);
    @@ -354,7 +430,7 @@ Expr&lt;Type&lt;TypeCategory::Real, KIND&gt;&gt; FoldIntrinsicFunction(
    return result.value;
    }));
    }
  • // TODO: dot_product, matmul, norm2
  • // TODO: matmul
    return Expr&lt;T&gt;{std::move(funcRef)};
    }

diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index b76cecffaf1c639..cff7f54c60d91ba 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -6,8 +6,6 @@
//
//===----------------------------------------------------------------------===//

-// TODO: NORM2, PARITY

#ifndef FORTRAN_EVALUATE_FOLD_REDUCTION_H_
#define FORTRAN_EVALUATE_FOLD_REDUCTION_H_

@@ -77,7 +75,8 @@ static Expr&lt;T&gt; FoldDotProduct(
overflow |= next.overflow;
sum = std::move(next.value);
}

  • } else { // T::category == TypeCategory::Real
  • } else {
  •  static_assert(T::category == TypeCategory::Real);
     Expr&amp;lt;T&amp;gt; products{
         Fold(context, Expr&amp;lt;T&amp;gt;{Constant&amp;lt;T&amp;gt;{*va}} * Expr&amp;lt;T&amp;gt;{Constant&amp;lt;T&amp;gt;{*vb}})};
     Constant&amp;lt;T&amp;gt; &amp;amp;cProducts{DEREF(UnwrapConstantValue&amp;lt;T&amp;gt;(products))};
    

@@ -172,7 +171,8 @@ static std::optional&lt;Constant&lt;T&gt;&gt; ProcessReductionArgs(FoldingContext &amp;context,
}

// Generalized reduction to an array of one dimension fewer (w/ DIM=)
-// or to a scalar (w/o DIM=).
+// or to a scalar (w/o DIM=). The ACCUMULATOR type must define
+// operator()(Scalar&lt;T&gt; &amp;, const ConstantSubscripts &amp;) and Done(Scalar&lt;T&gt; &amp;).
template &lt;typename T, typename ACCUMULATOR, typename ARRAY&gt;
static Constant&lt;T&gt; DoReduction(const Constant&lt;ARRAY&gt; &amp;array,
std::optional&lt;int&gt; &amp;dim, const Scalar&lt;T&gt; &amp;identity,
@@ -193,6 +193,7 @@ static Constant&lt;T&gt; DoReduction(const Constant&lt;ARRAY&gt; &amp;array,
for (ConstantSubscript j{0}; j &lt; dimExtent; ++j, ++dimAt) {
accumulator(elements.back(), at);
}

  •  accumulator.Done(elements.back());
    
    }
    } else { // no DIM=, result is scalar
    elements.push_back(identity);
    @@ -200,6 +201,7 @@ static Constant&lt;T&gt; DoReduction(const Constant&lt;ARRAY&gt; &amp;array,
    IncrementSubscripts(at, array.shape())) {
    accumulator(elements.back(), at);
    }
  • accumulator.Done(elements.back());
    }
    if constexpr (T::category == TypeCategory::Character) {
    return {static_cast&lt;ConstantSubscript&gt;(identity.size()),
    @@ -210,58 +212,85 @@ static Constant&lt;T&gt; DoReduction(const Constant&lt;ARRAY&gt; &amp;array,
    }

// MAXVAL &amp; MINVAL
+template &lt;typename T, bool ABS = false&gt; class MaxvalMinvalAccumulator {
+public:

  • MaxvalMinvalAccumulator(
  •  RelationalOperator opr, FoldingContext &amp;amp;context, const Constant&amp;lt;T&amp;gt; &amp;amp;array)
    
  •  : opr_{opr}, context_{context}, array_{array} {};
    
  • void operator()(Scalar&lt;T&gt; &amp;element, const ConstantSubscripts &amp;at) const {
  • auto aAt{array_.At(at)};
  • if constexpr (ABS) {
  •  aAt = aAt.ABS();
    
  • }
  • Expr&lt;LogicalResult&gt; test{PackageRelation(
  •    opr_, Expr&amp;lt;T&amp;gt;{Constant&amp;lt;T&amp;gt;{aAt}}, Expr&amp;lt;T&amp;gt;{Constant&amp;lt;T&amp;gt;{element}})};
    
  • auto folded{GetScalarConstantValue&lt;LogicalResult&gt;(
  •    test.Rewrite(context_, std::move(test)))};
    
  • CHECK(folded.has_value());
  • if (folded-&gt;IsTrue()) {
  •  element = array_.At(at);
    
  • }
  • }
  • void Done(Scalar&lt;T&gt; &amp;) const {}

+private:

  • RelationalOperator opr_;
  • FoldingContext &amp;context_;
  • const Constant&lt;T&gt; &amp;array_;
    +};

template &lt;typename T&gt;
static Expr&lt;T&gt; FoldMaxvalMinval(FoldingContext &amp;context, FunctionRef&lt;T&gt; &amp;&amp;ref,
RelationalOperator opr, const Scalar&lt;T&gt; &amp;identity) {
static_assert(T::category == TypeCategory::Integer ||
T::category == TypeCategory::Real ||
T::category == TypeCategory::Character);

  • using Element = Scalar&lt;T&gt;;
    std::optional&lt;int&gt; dim;
    if (std::optional&lt;Constant&lt;T&gt;&gt; array{
    ProcessReductionArgs&lt;T&gt;(context, ref.arguments(), dim, identity,
    /ARRAY=/0, /DIM=/1, /MASK=/2)}) {
  • auto accumulator{[&amp;](Element &amp;element, const ConstantSubscripts &amp;at) {
  •  Expr&amp;lt;LogicalResult&amp;gt; test{PackageRelation(opr,
    
  •      Expr&amp;lt;T&amp;gt;{Constant&amp;lt;T&amp;gt;{array-&amp;gt;At(at)}}, Expr&amp;lt;T&amp;gt;{Constant&amp;lt;T&amp;gt;{element}})};
    
  •  auto folded{GetScalarConstantValue&amp;lt;LogicalResult&amp;gt;(
    
  •      test.Rewrite(context, std::move(test)))};
    
  •  CHECK(folded.has_value());
    
  •  if (folded-&amp;gt;IsTrue()) {
    
  •    element = array-&amp;gt;At(at);
    
  •  }
    
  • }};
  • MaxvalMinvalAccumulator accumulator{opr, context, *array};
    return Expr&lt;T&gt;{DoReduction&lt;T&gt;(*array, dim, identity, accumulator)};
    }
    return Expr&lt;T&gt;{std::move(ref)};
    }

// PRODUCT
+template &lt;typename T&gt; class ProductAccumulator {
+public:

  • ProductAccumulator(const Constant&lt;T&gt; &amp;array) : array_{array} {}
  • void operator()(Scalar&lt;T&gt; &amp;element, const ConstantSubscripts &amp;at) {
  • if constexpr (T::category == TypeCategory::Integer) {
  •  auto prod{element.MultiplySigned(array_.At(at))};
    
  •  overflow_ |= prod.SignedMultiplicationOverflowed();
    
  •  element = prod.lower;
    
  • } else { // Real &amp; Complex
  •  auto prod{element.Multiply(array_.At(at))};
    
  •  overflow_ |= prod.flags.test(RealFlag::Overflow);
    
  •  element = prod.value;
    
  • }
  • }
  • bool overflow() const { return overflow_; }
  • void Done(Scalar&lt;T&gt; &amp;) const {}

+private:

  • const Constant&lt;T&gt; &amp;array_;
  • bool overflow_{false};
    +};

template &lt;typename T&gt;
static Expr&lt;T&gt; FoldProduct(
FoldingContext &amp;context, FunctionRef&lt;T&gt; &amp;&amp;ref, Scalar&lt;T&gt; identity) {
static_assert(T::category == TypeCategory::Integer ||
T::category == TypeCategory::Real ||
T::category == TypeCategory::Complex);

  • using Element = typename Constant&lt;T&gt;::Element;
    std::optional&lt;int&gt; dim;
    if (std::optional&lt;Constant&lt;T&gt;&gt; array{
    ProcessReductionArgs&lt;T&gt;(context, ref.arguments(), dim, identity,
    /ARRAY=/0, /DIM=/1, /MASK=/2)}) {
  • bool overflow{false};
  • auto accumulator{[&amp;](Element &amp;element, const ConstantSubscripts &amp;at) {
  •  if constexpr (T::category == TypeCategory::Integer) {
    
  •    auto prod{element.MultiplySigned(array-&amp;gt;At(at))};
    
  •    overflow |= prod.SignedMultiplicationOverflowed();
    
  •    element = prod.lower;
    
  •  } else { // Real &amp;amp; Complex
    
  •    auto prod{element.Multiply(array-&amp;gt;At(at))};
    
  •    overflow |= prod.flags.test(RealFlag::Overflow);
    
  •    element = prod.value;
    
  •  }
    
  • }};
  • ProductAccumulator accumulator{*array};
    auto result{Expr&lt;T&gt;{DoReduction&lt;T&gt;(*array, dim, identity, accumulator)}};
  • if (overflow) {
  • if (accumulator.overflow()) {
    context.messages().Say(
    &quot;PRODUCT() of %s data overflowed&quot;_warn_en_US, T::AsFortran());
    }
    @@ -271,6 +300,46 @@ static Expr&lt;T&gt; FoldProduct(
    }

// SUM
+template &lt;typename T&gt; class SumAccumulator {

  • using Element = typename Constant&lt;T&gt;::Element;

+public:

  • SumAccumulator(const Constant&lt;T&gt; &amp;array, Rounding rounding)
  •  : array_{array}, rounding_{rounding} {}
    
  • void operator()(Element &amp;element, const ConstantSubscripts &amp;at) {
  • if constexpr (T::category == TypeCategory::Integer) {
  •  auto sum{element.AddSigned(array_.At(at))};
    
  •  overflow_ |= sum.overflow;
    
  •  element = sum.value;
    
  • } else { // Real &amp; Complex: use Kahan summation
  •  auto next{array_.At(at).Add(correction_, rounding_)};
    
  •  overflow_ |= next.flags.test(RealFlag::Overflow);
    
  •  auto sum{element.Add(next.value, rounding_)};
    
  •  overflow_ |= sum.flags.test(RealFlag::Overflow);
    
  •  // correction = (sum - element) - next; algebraically zero
    
  •  correction_ = sum.value.Subtract(element, rounding_)
    
  •                    .value.Subtract(next.value, rounding_)
    
  •                    .value;
    
  •  element = sum.value;
    
  • }
  • }
  • bool overflow() const { return overflow_; }
  • void Done([[maybe_unused]] Element &amp;element) {
  • if constexpr (T::category != TypeCategory::Integer) {
  •  auto corrected{element.Add(correction_, rounding_)};
    
  •  overflow_ |= corrected.flags.test(RealFlag::Overflow);
    
  •  correction_ = Scalar&amp;lt;T&amp;gt;{};
    
  •  element = corrected.value;
    
  • }
  • }

+private:

  • const Constant&lt;T&gt; &amp;array_;
  • Rounding rounding_;
  • bool overflow_{false};
  • Element correction_{};
    +};

template &lt;typename T&gt;
static Expr&lt;T&gt; FoldSum(FoldingContext &amp;context, FunctionRef&lt;T&gt; &amp;&amp;ref) {
static_assert(T::category == TypeCategory::Integer ||
@@ -278,31 +347,14 @@ static Expr&lt;T&gt; FoldSum(FoldingContext &amp;context, FunctionRef&lt;T&gt; &amp;&amp;ref) {
T::category == TypeCategory::Complex);
using Element = typename Constant&lt;T&gt;::Element;
std::optional&lt;int&gt; dim;

  • Element identity{}, correction{};
  • Element identity{};
    if (std::optional&lt;Constant&lt;T&gt;&gt; array{
    ProcessReductionArgs&lt;T&gt;(context, ref.arguments(), dim, identity,
    /ARRAY=/0, /DIM=/1, /MASK=/2)}) {
  • bool overflow{false};
  • auto accumulator{[&amp;](Element &amp;element, const ConstantSubscripts &amp;at) {
  •  if constexpr (T::category == TypeCategory::Integer) {
    
  •    auto sum{element.AddSigned(array-&amp;gt;At(at))};
    
  •    overflow |= sum.overflow;
    
  •    element = sum.value;
    
  •  } else { // Real &amp;amp; Complex: use Kahan summation
    
  •    const auto &amp;amp;rounding{context.targetCharacteristics().roundingMode()};
    
  •    auto next{array-&amp;gt;At(at).Add(correction, rounding)};
    
  •    overflow |= next.flags.test(RealFlag::Overflow);
    
  •    auto sum{element.Add(next.value, rounding)};
    
  •    overflow |= sum.flags.test(RealFlag::Overflow);
    
  •    // correction = (sum - element) - next; algebraically zero
    
  •    correction = sum.value.Subtract(element, rounding)
    
  •                     .value.Subtract(next.value, rounding)
    
  •                     .value;
    
  •    element = sum.value;
    
  •  }
    
  • }};
  • SumAccumulator accumulator{
  •    *array, context.targetCharacteristics().roundingMode()};
    
    auto result{Expr&lt;T&gt;{DoReduction&lt;T&gt;(*array, dim, identity, accumulator)}};
  • if (overflow) {
  • if (accumulator.overflow()) {
    context.messages().Say(
    &quot;SUM() of %s data overflowed&quot;_warn_en_US, T::AsFortran());
    }
    @@ -311,5 +363,21 @@ static Expr&lt;T&gt; FoldSum(FoldingContext &amp;context, FunctionRef&lt;T&gt; &amp;&amp;ref) {
    return Expr&lt;T&gt;{std::move(ref)};
    }

+// Utility for IALL, IANY, IPARITY, ALL, ANY, &amp; PARITY
+template &lt;typename T&gt; class OperationAccumulator {
+public:

  • OperationAccumulator(const Constant&lt;T&gt; &amp;array,
  •  Scalar&amp;lt;T&amp;gt; (Scalar&amp;lt;T&amp;gt;::*operation)(const Scalar&amp;lt;T&amp;gt; &amp;amp;) const)
    
  •  : array_{array}, operation_{operation} {}
    
  • void operator()(Scalar&lt;T&gt; &amp;element, const ConstantSubscripts &amp;at) {
  • element = (element.*operation_)(array_.At(at));
  • }
  • void Done(Scalar&lt;T&gt; &amp;) const {}

+private:

  • const Constant&lt;T&gt; &amp;array_;
  • Scalar&lt;T&gt; (Scalar&lt;T&gt;::*operation_)(const Scalar&lt;T&gt; &amp;) const;
    +};

} // namespace Fortran::evaluate
#endif // FORTRAN_EVALUATE_FOLD_REDUCTION_H_
diff --git a/flang/test/Evaluate/fold-norm2.f90 b/flang/test/Evaluate/fold-norm2.f90
new file mode 100644
index 000000000000000..30d5289b5a6e33c
--- /dev/null
+++ b/flang/test/Evaluate/fold-norm2.f90
@@ -0,0 +1,29 @@
+! RUN: %python %S/test_folding.py %s %flang_fc1
+! Tests folding of NORM2(), F&#x27;2023 16.9.153
+module m

  • ! Examples from the standard
  • logical, parameter :: test_ex1 = norm2([3.,4.]) == 5.
  • real, parameter :: ex2(2,2) = resha...

@klausler klausler merged commit 39f1860 into llvm:main Sep 18, 2023
@klausler klausler deleted the bug1353 branch September 18, 2023 15:58
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
Fold references to the (relatively new) intrinsic function NORM2 at
compilation time when the argument(s) are all constants. (Getting this
done right involved some changes to the API of the accumulator function
objects used by the DoReduction<> template, which rippled through some
other reduction function folding code.)
zahiraam pushed a commit to tahonermann/llvm-project that referenced this pull request Oct 24, 2023
Fold references to the (relatively new) intrinsic function NORM2 at
compilation time when the argument(s) are all constants. (Getting this
done right involved some changes to the API of the accumulator function
objects used by the DoReduction<> template, which rippled through some
other reduction function folding code.)
zahiraam pushed a commit to tahonermann/llvm-project that referenced this pull request Oct 24, 2023
Fold references to the (relatively new) intrinsic function NORM2 at
compilation time when the argument(s) are all constants. (Getting this
done right involved some changes to the API of the accumulator function
objects used by the DoReduction<> template, which rippled through some
other reduction function folding code.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants