Skip to content

Commit

Permalink
[libc][math][c23] Add f16sqrtf C23 math function (llvm#95251)
Browse files Browse the repository at this point in the history
Part of llvm#95250.
  • Loading branch information
overmighty authored and EthanLuisMcDonough committed Aug 13, 2024
1 parent df3727d commit 2243a4c
Show file tree
Hide file tree
Showing 29 changed files with 352 additions and 123 deletions.
1 change: 1 addition & 0 deletions libc/config/linux/aarch64/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
libc.src.math.canonicalizef16
libc.src.math.ceilf16
libc.src.math.copysignf16
libc.src.math.f16sqrtf
libc.src.math.fabsf16
libc.src.math.fdimf16
libc.src.math.floorf16
Expand Down
1 change: 1 addition & 0 deletions libc/config/linux/x86_64/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
libc.src.math.canonicalizef16
libc.src.math.ceilf16
libc.src.math.copysignf16
libc.src.math.f16sqrtf
libc.src.math.fabsf16
libc.src.math.fdimf16
libc.src.math.floorf16
Expand Down
2 changes: 2 additions & 0 deletions libc/docs/math/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ Higher Math Functions
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| fma | |check| | |check| | | | | 7.12.13.1 | F.10.10.1 |
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| f16sqrt | |check| | | | N/A | | 7.12.14.6 | F.10.11 |
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| fsqrt | N/A | | | N/A | | 7.12.14.6 | F.10.11 |
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| hypot | |check| | |check| | | | | 7.12.7.4 | F.10.4.4 |
Expand Down
2 changes: 2 additions & 0 deletions libc/spec/stdc.td
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,8 @@ def StdC : StandardSpec<"stdc"> {
GuardedFunctionSpec<"totalorderf16", RetValSpec<IntType>, [ArgSpec<Float16Ptr>, ArgSpec<Float16Ptr>], "LIBC_TYPES_HAS_FLOAT16">,

GuardedFunctionSpec<"totalordermagf16", RetValSpec<IntType>, [ArgSpec<Float16Ptr>, ArgSpec<Float16Ptr>], "LIBC_TYPES_HAS_FLOAT16">,

GuardedFunctionSpec<"f16sqrtf", RetValSpec<Float16Type>, [ArgSpec<FloatType>], "LIBC_TYPES_HAS_FLOAT16">,
]
>;

Expand Down
1 change: 1 addition & 0 deletions libc/src/__support/FPUtil/generic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_header_library(
sqrt.h
sqrt_80_bit_long_double.h
DEPENDS
libc.hdr.fenv_macros
libc.src.__support.common
libc.src.__support.CPP.bit
libc.src.__support.CPP.type_traits
Expand Down
128 changes: 99 additions & 29 deletions libc/src/__support/FPUtil/generic/sqrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "src/__support/common.h"
#include "src/__support/uint128.h"

#include "hdr/fenv_macros.h"

namespace LIBC_NAMESPACE {
namespace fputil {

Expand Down Expand Up @@ -64,40 +66,50 @@ LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) {

// Correctly rounded IEEE 754 SQRT for all rounding modes.
// Shift-and-add algorithm.
template <typename T>
LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {

if constexpr (internal::SpecialLongDouble<T>::VALUE) {
template <typename OutType, typename InType>
LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
cpp::is_floating_point_v<InType> &&
sizeof(OutType) <= sizeof(InType),
OutType>
sqrt(InType x) {
if constexpr (internal::SpecialLongDouble<OutType>::VALUE &&
internal::SpecialLongDouble<InType>::VALUE) {
// Special 80-bit long double.
return x86::sqrt(x);
} else {
// IEEE floating points formats.
using FPBits_t = typename fputil::FPBits<T>;
using StorageType = typename FPBits_t::StorageType;
constexpr StorageType ONE = StorageType(1) << FPBits_t::FRACTION_LEN;
constexpr auto FLT_NAN = FPBits_t::quiet_nan().get_val();

FPBits_t bits(x);

if (bits == FPBits_t::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
using OutFPBits = typename fputil::FPBits<OutType>;
using OutStorageType = typename OutFPBits::StorageType;
using InFPBits = typename fputil::FPBits<InType>;
using InStorageType = typename InFPBits::StorageType;
constexpr InStorageType ONE = InStorageType(1) << InFPBits::FRACTION_LEN;
constexpr auto FLT_NAN = OutFPBits::quiet_nan().get_val();
constexpr int EXTRA_FRACTION_LEN =
InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
constexpr InStorageType EXTRA_FRACTION_MASK =
(InStorageType(1) << EXTRA_FRACTION_LEN) - 1;

InFPBits bits(x);

if (bits == InFPBits::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
// sqrt(+Inf) = +Inf
// sqrt(+0) = +0
// sqrt(-0) = -0
// sqrt(NaN) = NaN
// sqrt(-NaN) = -NaN
return x;
return static_cast<OutType>(x);
} else if (bits.is_neg()) {
// sqrt(-Inf) = NaN
// sqrt(-x) = NaN
return FLT_NAN;
} else {
int x_exp = bits.get_exponent();
StorageType x_mant = bits.get_mantissa();
InStorageType x_mant = bits.get_mantissa();

// Step 1a: Normalize denormal input and append hidden bit to the mantissa
if (bits.is_subnormal()) {
++x_exp; // let x_exp be the correct exponent of ONE bit.
internal::normalize<T>(x_exp, x_mant);
internal::normalize<InType>(x_exp, x_mant);
} else {
x_mant |= ONE;
}
Expand All @@ -120,47 +132,105 @@ LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
// So the nth digit y_n of the mantissa of sqrt(x) can be found by:
// y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
// 0 otherwise.
StorageType y = ONE;
StorageType r = x_mant - ONE;
InStorageType y = ONE;
InStorageType r = x_mant - ONE;

for (StorageType current_bit = ONE >> 1; current_bit; current_bit >>= 1) {
for (InStorageType current_bit = ONE >> 1; current_bit;
current_bit >>= 1) {
r <<= 1;
StorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
InStorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
if (r >= tmp) {
r -= tmp;
y += current_bit;
}
}

// We compute one more iteration in order to round correctly.
bool lsb = static_cast<bool>(y & 1); // Least significant bit
bool rb = false; // Round bit
bool lsb = (y & (InStorageType(1) << EXTRA_FRACTION_LEN)) !=
0; // Least significant bit
bool rb = false; // Round bit
r <<= 2;
StorageType tmp = (y << 2) + 1;
InStorageType tmp = (y << 2) + 1;
if (r >= tmp) {
r -= tmp;
rb = true;
}

bool sticky = false;

if constexpr (EXTRA_FRACTION_LEN > 0) {
sticky = rb || (y & EXTRA_FRACTION_MASK) != 0;
rb = (y & (InStorageType(1) << (EXTRA_FRACTION_LEN - 1))) != 0;
}

// Remove hidden bit and append the exponent field.
x_exp = ((x_exp >> 1) + FPBits_t::EXP_BIAS);
x_exp = ((x_exp >> 1) + OutFPBits::EXP_BIAS);

OutStorageType y_out = static_cast<OutStorageType>(
((y - ONE) >> EXTRA_FRACTION_LEN) |
(static_cast<OutStorageType>(x_exp) << OutFPBits::FRACTION_LEN));

if constexpr (EXTRA_FRACTION_LEN > 0) {
if (x_exp >= OutFPBits::MAX_BIASED_EXPONENT) {
switch (quick_get_round()) {
case FE_TONEAREST:
case FE_UPWARD:
return OutFPBits::inf().get_val();
default:
return OutFPBits::max_normal().get_val();
}
}

if (x_exp <
-OutFPBits::EXP_BIAS - OutFPBits::SIG_LEN + EXTRA_FRACTION_LEN) {
switch (quick_get_round()) {
case FE_UPWARD:
return OutFPBits::min_subnormal().get_val();
default:
return OutType(0.0);
}
}

y = (y - ONE) |
(static_cast<StorageType>(x_exp) << FPBits_t::FRACTION_LEN);
if (x_exp <= 0) {
int underflow_extra_fraction_len = EXTRA_FRACTION_LEN - x_exp + 1;
InStorageType underflow_extra_fraction_mask =
(InStorageType(1) << underflow_extra_fraction_len) - 1;

rb = (y & (InStorageType(1) << (underflow_extra_fraction_len - 1))) !=
0;
OutStorageType subnormal_mant =
static_cast<OutStorageType>(y >> underflow_extra_fraction_len);
lsb = (subnormal_mant & 1) != 0;
sticky = sticky || (y & underflow_extra_fraction_mask) != 0;

switch (quick_get_round()) {
case FE_TONEAREST:
if (rb && (lsb || sticky))
++subnormal_mant;
break;
case FE_UPWARD:
if (rb || sticky)
++subnormal_mant;
break;
}

return cpp::bit_cast<OutType>(subnormal_mant);
}
}

switch (quick_get_round()) {
case FE_TONEAREST:
// Round to nearest, ties to even
if (rb && (lsb || (r != 0)))
++y;
++y_out;
break;
case FE_UPWARD:
if (rb || (r != 0))
++y;
if (rb || (r != 0) || sticky)
++y_out;
break;
}

return cpp::bit_cast<T>(y);
return cpp::bit_cast<OutType>(y_out);
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions libc/src/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ add_math_entrypoint_object(exp10f)
add_math_entrypoint_object(expm1)
add_math_entrypoint_object(expm1f)

add_math_entrypoint_object(f16sqrtf)

add_math_entrypoint_object(fabs)
add_math_entrypoint_object(fabsf)
add_math_entrypoint_object(fabsl)
Expand Down
20 changes: 20 additions & 0 deletions libc/src/math/f16sqrtf.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===-- Implementation header for f16sqrtf ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_SRC_MATH_F16SQRTF_H
#define LLVM_LIBC_SRC_MATH_F16SQRTF_H

#include "src/__support/macros/properties/types.h"

namespace LIBC_NAMESPACE {

float16 f16sqrtf(float x);

} // namespace LIBC_NAMESPACE

#endif // LLVM_LIBC_SRC_MATH_F16SQRTF_H
13 changes: 13 additions & 0 deletions libc/src/math/generic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3601,3 +3601,16 @@ add_entrypoint_object(
COMPILE_OPTIONS
-O3
)

add_entrypoint_object(
f16sqrtf
SRCS
f16sqrtf.cpp
HDRS
../f16sqrtf.h
DEPENDS
libc.src.__support.macros.properties.types
libc.src.__support.FPUtil.sqrt
COMPILE_OPTIONS
-O3
)
2 changes: 1 addition & 1 deletion libc/src/math/generic/acosf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ LLVM_LIBC_FUNCTION(float, acosf, (float x)) {
xbits.set_sign(Sign::POS);
double xd = static_cast<double>(xbits.get_val());
double u = fputil::multiply_add(-0.5, xd, 0.5);
double cv = 2 * fputil::sqrt(u);
double cv = 2 * fputil::sqrt<double>(u);

double r3 = asin_eval(u);
double r = fputil::multiply_add(cv * u, r3, cv);
Expand Down
4 changes: 2 additions & 2 deletions libc/src/math/generic/acoshf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ LLVM_LIBC_FUNCTION(float, acoshf, (float x)) {

double x_d = static_cast<double>(x);
// acosh(x) = log(x + sqrt(x^2 - 1))
return static_cast<float>(
log_eval(x_d + fputil::sqrt(fputil::multiply_add(x_d, x_d, -1.0))));
return static_cast<float>(log_eval(
x_d + fputil::sqrt<double>(fputil::multiply_add(x_d, x_d, -1.0))));
}

} // namespace LIBC_NAMESPACE
2 changes: 1 addition & 1 deletion libc/src/math/generic/asinf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ LLVM_LIBC_FUNCTION(float, asinf, (float x)) {
double sign = SIGN[x_sign];
double xd = static_cast<double>(xbits.get_val());
double u = fputil::multiply_add(-0.5, xd, 0.5);
double c1 = sign * (-2 * fputil::sqrt(u));
double c1 = sign * (-2 * fputil::sqrt<double>(u));
double c2 = fputil::multiply_add(sign, M_MATH_PI_2, c1);
double c3 = c1 * u;

Expand Down
6 changes: 3 additions & 3 deletions libc/src/math/generic/asinhf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ LLVM_LIBC_FUNCTION(float, asinhf, (float x)) {

// asinh(x) = log(x + sqrt(x^2 + 1))
return static_cast<float>(
x_sign *
log_eval(fputil::multiply_add(
x_d, x_sign, fputil::sqrt(fputil::multiply_add(x_d, x_d, 1.0)))));
x_sign * log_eval(fputil::multiply_add(
x_d, x_sign,
fputil::sqrt<double>(fputil::multiply_add(x_d, x_d, 1.0)))));
}

} // namespace LIBC_NAMESPACE
19 changes: 19 additions & 0 deletions libc/src/math/generic/f16sqrtf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===-- Implementation of f16sqrtf function -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "src/math/f16sqrtf.h"
#include "src/__support/FPUtil/sqrt.h"
#include "src/__support/common.h"

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(float16, f16sqrtf, (float x)) {
return fputil::sqrt<float16>(x);
}

} // namespace LIBC_NAMESPACE
2 changes: 1 addition & 1 deletion libc/src/math/generic/hypotf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ LLVM_LIBC_FUNCTION(float, hypotf, (float x, float y)) {
double err = (x_sq >= y_sq) ? (sum_sq - x_sq) - y_sq : (sum_sq - y_sq) - x_sq;

// Take sqrt in double precision.
DoubleBits result(fputil::sqrt(sum_sq));
DoubleBits result(fputil::sqrt<double>(sum_sq));

if (!DoubleBits(sum_sq).is_inf_or_nan()) {
// Correct rounding.
Expand Down
2 changes: 1 addition & 1 deletion libc/src/math/generic/powf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ LLVM_LIBC_FUNCTION(float, powf, (float x, float y)) {
switch (y_u) {
case 0x3f00'0000: // y = 0.5f
// pow(x, 1/2) = sqrt(x)
return fputil::sqrt(x);
return fputil::sqrt<float>(x);
case 0x3f80'0000: // y = 1.0f
return x;
case 0x4000'0000: // y = 2.0f
Expand Down
2 changes: 1 addition & 1 deletion libc/src/math/generic/sqrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt(x); }
LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt<double>(x); }

} // namespace LIBC_NAMESPACE
2 changes: 1 addition & 1 deletion libc/src/math/generic/sqrtf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) { return fputil::sqrt(x); }
LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) { return fputil::sqrt<float>(x); }

} // namespace LIBC_NAMESPACE
4 changes: 3 additions & 1 deletion libc/src/math/generic/sqrtf128.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) { return fputil::sqrt(x); }
LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
return fputil::sqrt<float128>(x);
}

} // namespace LIBC_NAMESPACE
2 changes: 1 addition & 1 deletion libc/src/math/generic/sqrtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(long double, sqrtl, (long double x)) {
return fputil::sqrt(x);
return fputil::sqrt<long double>(x);
}

} // namespace LIBC_NAMESPACE
Loading

0 comments on commit 2243a4c

Please sign in to comment.