Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize TensorPrimitives.Exp #93018

Merged
merged 2 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -322,20 +322,8 @@ public static float Dot(ReadOnlySpan<float> x, ReadOnlySpan<float> y)
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Exp(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Exp(x[i]);
}
}
public static void Exp(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<ExpOperator>(x, destination);

/// <summary>Searches for the index of the largest single-precision floating-point number in the specified tensor.</summary>
/// <param name="x">The tensor, represented as a span.</param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2579,6 +2579,286 @@ public static Vector512<float> Invoke(Vector512<float> x, Vector512<float> y)
#endif
}

private readonly struct ExpOperator : IUnaryOperator
{
// This code is based on `vrs4_expf` from amd/aocl-libm-ose
// Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Licensed under the BSD 3-Clause "New" or "Revised" License
// See THIRD-PARTY-NOTICES.TXT for the full license text

// Implementation Notes:
// 1. Argument Reduction:
// e^x = 2^(x/ln2) --- (1)
//
// Let x/ln(2) = z --- (2)
//
// Let z = n + r , where n is an integer --- (3)
// |r| <= 1/2
//
// From (1), (2) and (3),
// e^x = 2^z
// = 2^(N+r)
// = (2^N)*(2^r) --- (4)
//
// 2. Polynomial Evaluation
// From (4),
// r = z - N
// 2^r = C1 + C2*r + C3*r^2 + C4*r^3 + C5 *r^4 + C6*r^5
//
// 4. Reconstruction
// Thus,
// e^x = (2^N) * (2^r)

private const uint V_ARG_MAX = 0x42AE0000;
private const uint V_MASK = 0x7FFFFFFF;

private const float V_EXPF_MIN = -103.97208f;
private const float V_EXPF_MAX = 88.72284f;

private const double V_EXPF_HUGE = 6755399441055744;
private const double V_TBL_LN2 = 1.4426950408889634;

private const double C1 = 1.0000000754895704;
private const double C2 = 0.6931472254087585;
private const double C3 = 0.2402210737432219;
private const double C4 = 0.05550297297702539;
private const double C5 = 0.009676036358193323;
private const double C6 = 0.001341000536524434;

public static float Invoke(float x) => MathF.Exp(x);

public static Vector128<float> Invoke(Vector128<float> x)
{
// Convert x to double precision
(Vector128<double> xl, Vector128<double> xu) = Vector128.Widen(x);

// x * (64.0 / ln(2))
Vector128<double> v_tbl_ln2 = Vector128.Create(V_TBL_LN2);

Vector128<double> zl = xl * v_tbl_ln2;
Vector128<double> zu = xu * v_tbl_ln2;

Vector128<double> v_expf_huge = Vector128.Create(V_EXPF_HUGE);

Vector128<double> dnl = zl + v_expf_huge;
Vector128<double> dnu = zu + v_expf_huge;

// n = int (z)
Vector128<ulong> nl = dnl.AsUInt64();
Vector128<ulong> nu = dnu.AsUInt64();

// dn = double(n)
dnl -= v_expf_huge;
dnu -= v_expf_huge;

// r = z - dn
Vector128<double> c1 = Vector128.Create(C1);
Vector128<double> c2 = Vector128.Create(C2);
Vector128<double> c3 = Vector128.Create(C3);
Vector128<double> c4 = Vector128.Create(C4);
Vector128<double> c5 = Vector128.Create(C5);
Vector128<double> c6 = Vector128.Create(C6);

Vector128<double> rl = zl - dnl;

Vector128<double> rl2 = rl * rl;
Vector128<double> rl4 = rl2 * rl2;

Vector128<double> polyl = (c4 * rl + c3) * rl2
+ ((c6 * rl + c5) * rl4
+ (c2 * rl + c1));


Vector128<double> ru = zu - dnu;

Vector128<double> ru2 = ru * ru;
Vector128<double> ru4 = ru2 * ru2;

Vector128<double> polyu = (c4 * ru + c3) * ru2
+ ((c6 * ru + c5) * ru4
+ (c2 * ru + c1));

// result = (float)[poly + (n << 52)]
Vector128<float> ret = Vector128.Narrow(
(polyl.AsUInt64() + Vector128.ShiftLeft(nl, 52)).AsDouble(),
(polyu.AsUInt64() + Vector128.ShiftLeft(nu, 52)).AsDouble()
);

// Check if -103 < |x| < 88
if (Vector128.GreaterThanAny(x.AsUInt32() & Vector128.Create(V_MASK), Vector128.Create(V_ARG_MAX)))
{
// (x > V_EXPF_MAX) ? float.PositiveInfinity : x
Vector128<float> infinityMask = Vector128.GreaterThan(x, Vector128.Create(V_EXPF_MAX));

ret = Vector128.ConditionalSelect(
infinityMask,
Vector128.Create(float.PositiveInfinity),
ret
);

// (x < V_EXPF_MIN) ? 0 : x
ret = Vector128.AndNot(ret, Vector128.LessThan(x, Vector128.Create(V_EXPF_MIN)));
}

return ret;
}

public static Vector256<float> Invoke(Vector256<float> x)
{
// Convert x to double precision
(Vector256<double> xl, Vector256<double> xu) = Vector256.Widen(x);

// x * (64.0 / ln(2))
Vector256<double> v_tbl_ln2 = Vector256.Create(V_TBL_LN2);

Vector256<double> zl = xl * v_tbl_ln2;
Vector256<double> zu = xu * v_tbl_ln2;

Vector256<double> v_expf_huge = Vector256.Create(V_EXPF_HUGE);

Vector256<double> dnl = zl + v_expf_huge;
Vector256<double> dnu = zu + v_expf_huge;

// n = int (z)
Vector256<ulong> nl = dnl.AsUInt64();
Vector256<ulong> nu = dnu.AsUInt64();

// dn = double(n)
dnl -= v_expf_huge;
dnu -= v_expf_huge;

// r = z - dn
Vector256<double> c1 = Vector256.Create(C1);
Vector256<double> c2 = Vector256.Create(C2);
Vector256<double> c3 = Vector256.Create(C3);
Vector256<double> c4 = Vector256.Create(C4);
Vector256<double> c5 = Vector256.Create(C5);
Vector256<double> c6 = Vector256.Create(C6);

Vector256<double> rl = zl - dnl;

Vector256<double> rl2 = rl * rl;
Vector256<double> rl4 = rl2 * rl2;

Vector256<double> polyl = (c4 * rl + c3) * rl2
+ ((c6 * rl + c5) * rl4
+ (c2 * rl + c1));


Vector256<double> ru = zu - dnu;

Vector256<double> ru2 = ru * ru;
Vector256<double> ru4 = ru2 * ru2;

Vector256<double> polyu = (c4 * ru + c3) * ru2
+ ((c6 * ru + c5) * ru4
+ (c2 * ru + c1));

// result = (float)[poly + (n << 52)]
Vector256<float> ret = Vector256.Narrow(
(polyl.AsUInt64() + Vector256.ShiftLeft(nl, 52)).AsDouble(),
(polyu.AsUInt64() + Vector256.ShiftLeft(nu, 52)).AsDouble()
);

// Check if -103 < |x| < 88
if (Vector256.GreaterThanAny(x.AsUInt32() & Vector256.Create(V_MASK), Vector256.Create(V_ARG_MAX)))
{
// (x > V_EXPF_MAX) ? float.PositiveInfinity : x
Vector256<float> infinityMask = Vector256.GreaterThan(x, Vector256.Create(V_EXPF_MAX));

ret = Vector256.ConditionalSelect(
infinityMask,
Vector256.Create(float.PositiveInfinity),
ret
);

// (x < V_EXPF_MIN) ? 0 : x
ret = Vector256.AndNot(ret, Vector256.LessThan(x, Vector256.Create(V_EXPF_MIN)));
}

return ret;
}

#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x)
{
// Convert x to double precision
(Vector512<double> xl, Vector512<double> xu) = Vector512.Widen(x);

// x * (64.0 / ln(2))
Vector512<double> v_tbl_ln2 = Vector512.Create(V_TBL_LN2);

Vector512<double> zl = xl * v_tbl_ln2;
Vector512<double> zu = xu * v_tbl_ln2;

Vector512<double> v_expf_huge = Vector512.Create(V_EXPF_HUGE);

Vector512<double> dnl = zl + v_expf_huge;
Vector512<double> dnu = zu + v_expf_huge;

// n = int (z)
Vector512<ulong> nl = dnl.AsUInt64();
Vector512<ulong> nu = dnu.AsUInt64();

// dn = double(n)
dnl -= v_expf_huge;
dnu -= v_expf_huge;

// r = z - dn
Vector512<double> c1 = Vector512.Create(C1);
Vector512<double> c2 = Vector512.Create(C2);
Vector512<double> c3 = Vector512.Create(C3);
Vector512<double> c4 = Vector512.Create(C4);
Vector512<double> c5 = Vector512.Create(C5);
Vector512<double> c6 = Vector512.Create(C6);

Vector512<double> rl = zl - dnl;

Vector512<double> rl2 = rl * rl;
Vector512<double> rl4 = rl2 * rl2;

Vector512<double> polyl = (c4 * rl + c3) * rl2
+ ((c6 * rl + c5) * rl4
+ (c2 * rl + c1));


Vector512<double> ru = zu - dnu;

Vector512<double> ru2 = ru * ru;
Vector512<double> ru4 = ru2 * ru2;

Vector512<double> polyu = (c4 * ru + c3) * ru2
+ ((c6 * ru + c5) * ru4
+ (c2 * ru + c1));

// result = (float)[poly + (n << 52)]
Vector512<float> ret = Vector512.Narrow(
(polyl.AsUInt64() + Vector512.ShiftLeft(nl, 52)).AsDouble(),
(polyu.AsUInt64() + Vector512.ShiftLeft(nu, 52)).AsDouble()
);

// Check if -103 < |x| < 88
if (Vector512.GreaterThanAny(x.AsUInt32() & Vector512.Create(V_MASK), Vector512.Create(V_ARG_MAX)))
{
// (x > V_EXPF_MAX) ? float.PositiveInfinity : x
Vector512<float> infinityMask = Vector512.GreaterThan(x, Vector512.Create(V_EXPF_MAX));

ret = Vector512.ConditionalSelect(
infinityMask,
Vector512.Create(float.PositiveInfinity),
ret
);

// (x < V_EXPF_MIN) ? 0 : x
ret = Vector512.AndNot(ret, Vector512.LessThan(x, Vector512.Create(V_EXPF_MIN)));
}

return ret;
}
#endif
}

private readonly struct LogOperator : IUnaryOperator
{
// This code is based on `vrs4_logf` from amd/aocl-libm-ose
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,19 @@ public Vector<float> Invoke(Vector<float> x, Vector<float> y)
public Vector<float> Invoke(Vector<float> x) => Vector.Abs(x);
}

private readonly struct ExpOperator : IUnaryOperator
{
public bool CanVectorize => false;

public float Invoke(float x) => MathF.Exp(x);

public Vector<float> Invoke(Vector<float> x)
{
// Vectorizing requires shift right support, which is .NET 7 or later
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
throw new NotImplementedException();
}
}

private readonly struct LogOperator : IUnaryOperator
{
public bool CanVectorize => false;
Expand Down