From cdb474e80ed143999a1b5d90cd1b27e8a1ce654f Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Mon, 26 Aug 2024 13:35:42 -0600 Subject: [PATCH] slicing fixes --- .../System.Numerics.Tensors.sln | 26 +- .../ref/System.Numerics.Tensors.netcore.cs | 10 +- .../src/Properties/InternalVisibleTo.cs | 6 + .../src/System.Numerics.Tensors.csproj | 1 + .../Tensors/netcore/TensorExtensions.cs | 507 ++++++++++++------ .../Numerics/Tensors/netcore/TensorHelpers.cs | 15 + .../System.Numerics.Tensors/tests/Helpers.cs | 4 +- .../tests/TensorSpanTests.cs | 455 +++++++++++++--- 8 files changed, 766 insertions(+), 258 deletions(-) create mode 100644 src/libraries/System.Numerics.Tensors/src/Properties/InternalVisibleTo.cs diff --git a/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln b/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln index 0afa9da9f3c074..d265c704697a0a 100644 --- a/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln +++ b/src/libraries/System.Numerics.Tensors/System.Numerics.Tensors.sln @@ -1,4 +1,8 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.11.35118.90 +MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TestUtilities", "..\Common\tests\TestUtilities\TestUtilities.csproj", "{9F20CEA1-2216-4432-BBBD-F01E05D17F23}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Bcl.Numerics", "..\Microsoft.Bcl.Numerics\src\Microsoft.Bcl.Numerics.csproj", "{1578185F-C4FA-4866-936B-E62AAEDD03B7}" @@ -33,11 +37,11 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "ref", "{7AC4B2C7-A55 EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "gen", "{841A2FA4-A95F-4612-A8B9-AD2EF769BC71}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "tools\gen", "{A21C99E7-E22B-470E-BF48-56B00AFE3D34}" +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "gen", "{A21C99E7-E22B-470E-BF48-56B00AFE3D34}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "tools\src", "{25B37C75-C737-4AE8-9260-74A79870C8B8}" +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{25B37C75-C737-4AE8-9260-74A79870C8B8}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "tools\ref", "{9482D7C5-F37C-40FC-B057-A16C1ED1C121}" +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "ref", "{9482D7C5-F37C-40FC-B057-A16C1ED1C121}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools", "tools", "{F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B}" EndProject @@ -105,23 +109,27 @@ Global EndGlobalSection GlobalSection(NestedProjects) = preSolution {9F20CEA1-2216-4432-BBBD-F01E05D17F23} = {DE94CA7D-BB10-4865-85A6-6B694631247F} - {46AD9423-D8C3-44BB-A201-1CCCAB4C6DAF} = {DE94CA7D-BB10-4865-85A6-6B694631247F} - {4AF6A02D-82C8-4898-9EDF-01F107C25061} = {DE94CA7D-BB10-4865-85A6-6B694631247F} {1578185F-C4FA-4866-936B-E62AAEDD03B7} = {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} - {848DD000-3D22-4A25-A9D9-05AFF857A116} = {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} {21CB448A-3882-4337-B416-D1A3E0BCFFC5} = {7AC4B2C7-A55C-4C4F-9B02-77F5CBFFF4AB} + {848DD000-3D22-4A25-A9D9-05AFF857A116} = {DF0561A1-3AB8-4B51-AFB4-392EE1DD6247} + {46AD9423-D8C3-44BB-A201-1CCCAB4C6DAF} = {DE94CA7D-BB10-4865-85A6-6B694631247F} + {4AF6A02D-82C8-4898-9EDF-01F107C25061} = {DE94CA7D-BB10-4865-85A6-6B694631247F} {4588351F-4233-4957-B84C-7F8E22B8888A} = {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} {DB954E01-898A-4FE2-A3AA-180D041AB08F} = {841A2FA4-A95F-4612-A8B9-AD2EF769BC71} {04FC0651-B9D0-448A-A28B-11B1D4A897F4} = {A21C99E7-E22B-470E-BF48-56B00AFE3D34} {683A7D28-CC55-4375-848D-E659075ECEE4} = {A21C99E7-E22B-470E-BF48-56B00AFE3D34} - {A21C99E7-E22B-470E-BF48-56B00AFE3D34} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} {1CBEAEA8-2CA1-4B07-9930-35A785205852} = {25B37C75-C737-4AE8-9260-74A79870C8B8} {BA7828B1-7953-47A0-AE5A-E22B501C4BD0} = {25B37C75-C737-4AE8-9260-74A79870C8B8} - {25B37C75-C737-4AE8-9260-74A79870C8B8} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} {57E57290-3A6A-43F8-8764-D4DC8151F89C} = {9482D7C5-F37C-40FC-B057-A16C1ED1C121} + {A21C99E7-E22B-470E-BF48-56B00AFE3D34} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} + {25B37C75-C737-4AE8-9260-74A79870C8B8} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} {9482D7C5-F37C-40FC-B057-A16C1ED1C121} = {F9C2AAB1-C7B0-4E43-BB18-4FB16F6E272B} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {10A5F2C3-5230-4916-9D4D-BBDB94851037} EndGlobalSection + GlobalSection(SharedMSBuildProjectFiles) = preSolution + ..\..\tools\illink\src\ILLink.Shared\ILLink.Shared.projitems*{683a7d28-cc55-4375-848d-e659075ecee4}*SharedItemsImports = 5 + ..\..\tools\illink\src\ILLink.Shared\ILLink.Shared.projitems*{ba7828b1-7953-47a0-ae5a-e22b501c4bd0}*SharedItemsImports = 5 + EndGlobalSection EndGlobal diff --git a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs index 9a4b08633db1d7..c3a9f5be9ee36f 100644 --- a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs @@ -299,7 +299,7 @@ public static void BroadcastTo(this System.Numerics.Tensors.Tensor source, public static bool GreaterThanAny(in System.Numerics.Tensors.ReadOnlyTensorSpan x, T y) where T : System.Numerics.IComparisonOperators { throw null; } public static bool GreaterThanAny(T x, in System.Numerics.Tensors.ReadOnlyTensorSpan y) where T : System.Numerics.IComparisonOperators { throw null; } public static bool GreaterThanOrEqualAll(in System.Numerics.Tensors.ReadOnlyTensorSpan x, in System.Numerics.Tensors.ReadOnlyTensorSpan y) where T : System.Numerics.IComparisonOperators { throw null; } - public static bool GreaterThanOrEqualAll(in System.Numerics.Tensors.ReadOnlyTensorSpan s, T y) where T : System.Numerics.IComparisonOperators { throw null; } + public static bool GreaterThanOrEqualAll(in System.Numerics.Tensors.ReadOnlyTensorSpan x, T y) where T : System.Numerics.IComparisonOperators { throw null; } public static bool GreaterThanOrEqualAll(T x, in System.Numerics.Tensors.ReadOnlyTensorSpan y) where T : System.Numerics.IComparisonOperators { throw null; } public static bool GreaterThanOrEqualAny(in System.Numerics.Tensors.ReadOnlyTensorSpan x, in System.Numerics.Tensors.ReadOnlyTensorSpan y) where T : System.Numerics.IComparisonOperators { throw null; } public static bool GreaterThanOrEqualAny(in System.Numerics.Tensors.ReadOnlyTensorSpan x, T y) where T : System.Numerics.IComparisonOperators { throw null; } @@ -333,16 +333,16 @@ public static void BroadcastTo(this System.Numerics.Tensors.Tensor source, public static System.Numerics.Tensors.Tensor LeadingZeroCount(in System.Numerics.Tensors.ReadOnlyTensorSpan x) where T : System.Numerics.IBinaryInteger { throw null; } public static ref readonly System.Numerics.Tensors.TensorSpan LeadingZeroCount(scoped in System.Numerics.Tensors.ReadOnlyTensorSpan x, in System.Numerics.Tensors.TensorSpan destination) where T : System.Numerics.IBinaryInteger { throw null; } public static bool LessThanAll(in System.Numerics.Tensors.ReadOnlyTensorSpan x, in System.Numerics.Tensors.ReadOnlyTensorSpan y) where T : System.Numerics.IComparisonOperators { throw null; } - public static bool LessThanAll(in System.Numerics.Tensors.ReadOnlyTensorSpan f, T x) where T : System.Numerics.IComparisonOperators { throw null; } + public static bool LessThanAll(in System.Numerics.Tensors.ReadOnlyTensorSpan x, T y) where T : System.Numerics.IComparisonOperators { throw null; } public static bool LessThanAll(T x, in System.Numerics.Tensors.ReadOnlyTensorSpan y) where T : System.Numerics.IComparisonOperators { throw null; } public static bool LessThanAny(in System.Numerics.Tensors.ReadOnlyTensorSpan x, in System.Numerics.Tensors.ReadOnlyTensorSpan y) where T : System.Numerics.IComparisonOperators { throw null; } - public static bool LessThanAny(in System.Numerics.Tensors.ReadOnlyTensorSpan f, T x) where T : System.Numerics.IComparisonOperators { throw null; } + public static bool LessThanAny(in System.Numerics.Tensors.ReadOnlyTensorSpan x, T y) where T : System.Numerics.IComparisonOperators { throw null; } public static bool LessThanAny(T x, in System.Numerics.Tensors.ReadOnlyTensorSpan y) where T : System.Numerics.IComparisonOperators { throw null; } public static bool LessThanOrEqualAll(in System.Numerics.Tensors.ReadOnlyTensorSpan x, in System.Numerics.Tensors.ReadOnlyTensorSpan y) where T : System.Numerics.IComparisonOperators { throw null; } - public static bool LessThanOrEqualAll(in System.Numerics.Tensors.ReadOnlyTensorSpan f, T x) where T : System.Numerics.IComparisonOperators { throw null; } + public static bool LessThanOrEqualAll(in System.Numerics.Tensors.ReadOnlyTensorSpan x, T y) where T : System.Numerics.IComparisonOperators { throw null; } public static bool LessThanOrEqualAll(T x, in System.Numerics.Tensors.ReadOnlyTensorSpan y) where T : System.Numerics.IComparisonOperators { throw null; } public static bool LessThanOrEqualAny(in System.Numerics.Tensors.ReadOnlyTensorSpan x, in System.Numerics.Tensors.ReadOnlyTensorSpan y) where T : System.Numerics.IComparisonOperators { throw null; } - public static bool LessThanOrEqualAny(in System.Numerics.Tensors.ReadOnlyTensorSpan f, T x) where T : System.Numerics.IComparisonOperators { throw null; } + public static bool LessThanOrEqualAny(in System.Numerics.Tensors.ReadOnlyTensorSpan x, T y) where T : System.Numerics.IComparisonOperators { throw null; } public static bool LessThanOrEqualAny(T x, in System.Numerics.Tensors.ReadOnlyTensorSpan y) where T : System.Numerics.IComparisonOperators { throw null; } public static System.Numerics.Tensors.Tensor LessThanOrEqual(in System.Numerics.Tensors.ReadOnlyTensorSpan x, in System.Numerics.Tensors.ReadOnlyTensorSpan y) where T : System.Numerics.IComparisonOperators { throw null; } public static ref readonly System.Numerics.Tensors.TensorSpan LessThanOrEqual(scoped in System.Numerics.Tensors.ReadOnlyTensorSpan x, scoped in System.Numerics.Tensors.ReadOnlyTensorSpan y, in System.Numerics.Tensors.TensorSpan destination) where T : System.Numerics.IComparisonOperators { throw null; } diff --git a/src/libraries/System.Numerics.Tensors/src/Properties/InternalVisibleTo.cs b/src/libraries/System.Numerics.Tensors/src/Properties/InternalVisibleTo.cs new file mode 100644 index 00000000000000..8858b4b76a81bf --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/Properties/InternalVisibleTo.cs @@ -0,0 +1,6 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("System.Numerics.Tensors.Tests, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")] diff --git a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj index 29c38e142ba808..454a1ce98fbe36 100644 --- a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj +++ b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj @@ -11,6 +11,7 @@ + diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorExtensions.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorExtensions.cs index c01da7a52160d2..176a2b4ef5ec77 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorExtensions.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorExtensions.cs @@ -56,7 +56,7 @@ public static T Average(scoped in ReadOnlyTensorSpan x) where T : IFloatingPoint { T sum = Sum(x); - return T.CreateChecked(sum / T.CreateChecked(x._shape._memoryLength)); + return T.CreateChecked(sum / T.CreateChecked(x.FlattenedLength)); } #endregion @@ -395,10 +395,11 @@ public static ref readonly TensorSpan ConcatenateOnDimension(int dimension scoped Span curIndex; nint[]? curIndexArray; - if (tensors[0].Rank > 6) + if (tensors[0].Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(tensors[0].Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, tensors[0].Rank); } else { @@ -499,10 +500,11 @@ public static ref readonly TensorSpan Equals(scoped in ReadOnlyTensorSp scoped Span curIndex; nint[]? curIndexArray; - if (right.Rank > 6) + if (right.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(right.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, right.Rank); } else { @@ -554,10 +556,11 @@ public static ref readonly TensorSpan Equals(scoped in ReadOnlyTensorSp scoped Span curIndex; nint[]? curIndexArray; - if (x.Rank > 6) + if (x.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(x.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, x.Rank); } else { @@ -598,10 +601,11 @@ public static bool EqualsAll(in ReadOnlyTensorSpan x, in ReadOnlyTensorSpa scoped Span curIndex; nint[]? curIndexArray; - if (broadcastedLeft.Rank > 6) + if (broadcastedLeft.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(broadcastedRight.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, broadcastedRight.Rank); } else { @@ -636,10 +640,11 @@ public static bool EqualsAll(in ReadOnlyTensorSpan x, T y) scoped Span curIndex; nint[]? curIndexArray; - if (x.Rank > 6) + if (x.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(x.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, x.Rank); } else { @@ -680,10 +685,11 @@ public static bool EqualsAny(in ReadOnlyTensorSpan x, in ReadOnlyTensorSpa scoped Span curIndex; nint[]? curIndexArray; - if (broadcastedRight.Lengths.Length > 6) + if (broadcastedRight.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(broadcastedRight.Lengths.Length); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, broadcastedRight.Rank); } else { @@ -718,10 +724,11 @@ public static bool EqualsAny(in ReadOnlyTensorSpan x, T y) scoped Span curIndex; nint[]? curIndexArray; - if (x.Rank > 6) + if (x.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(x.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, x.Rank); } else { @@ -864,10 +871,11 @@ public static ref readonly TensorSpan GreaterThan(scoped in ReadOnlyTen scoped Span curIndex; nint[]? curIndexArray; - if (right.Rank > 6) + if (right.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(right.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, right.Rank); } else { @@ -923,10 +931,11 @@ public static ref readonly TensorSpan GreaterThan(scoped in ReadOnlyTen scoped Span curIndex; nint[]? curIndexArray; - if (x.Rank > 6) + if (x.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(x.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, x.Rank); } else { @@ -982,10 +991,11 @@ public static ref readonly TensorSpan GreaterThan(T x, scoped in ReadOn scoped Span curIndex; nint[]? curIndexArray; - if (y.Rank > 6) + if (y.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(y.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, y.Rank); } else { @@ -1070,10 +1080,11 @@ public static ref readonly TensorSpan GreaterThanOrEqual(scoped in Read scoped Span curIndex; nint[]? curIndexArray; - if (right.Rank > 6) + if (right.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(right.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, right.Rank); } else { @@ -1129,10 +1140,11 @@ public static ref readonly TensorSpan GreaterThanOrEqual(scoped in Read scoped Span curIndex; nint[]? curIndexArray; - if (x.Rank > 6) + if (x.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(x.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, x.Rank); } else { @@ -1188,10 +1200,11 @@ public static ref readonly TensorSpan GreaterThanOrEqual(T x, scoped in scoped Span curIndex; nint[]? curIndexArray; - if (y.Rank > 6) + if (y.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(y.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, y.Rank); } else { @@ -1231,10 +1244,11 @@ public static bool GreaterThanAny(in ReadOnlyTensorSpan x, in ReadOnlyTens scoped Span curIndex; nint[]? curIndexArray; - if (broadcastedRight.Lengths.Length > 6) + if (broadcastedRight.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(broadcastedRight.Lengths.Length); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, broadcastedRight.Rank); } else { @@ -1269,10 +1283,11 @@ public static bool GreaterThanAny(in ReadOnlyTensorSpan x, T y) scoped Span curIndex; nint[]? curIndexArray; - if (x.Rank > 6) + if (x.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(x.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, x.Rank); } else { @@ -1307,10 +1322,11 @@ public static bool GreaterThanAny(T x, in ReadOnlyTensorSpan y) scoped Span curIndex; nint[]? curIndexArray; - if (y.Rank > 6) + if (y.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(y.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, y.Rank); } else { @@ -1351,10 +1367,11 @@ public static bool GreaterThanOrEqualAny(in ReadOnlyTensorSpan x, in ReadO scoped Span curIndex; nint[]? curIndexArray; - if (broadcastedRight.Lengths.Length > 6) + if (broadcastedRight.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(broadcastedRight.Lengths.Length); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, broadcastedRight.Rank); } else { @@ -1389,10 +1406,11 @@ public static bool GreaterThanOrEqualAny(in ReadOnlyTensorSpan x, T y) scoped Span curIndex; nint[]? curIndexArray; - if (x.Rank > 6) + if (x.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(x.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, x.Rank); } else { @@ -1427,10 +1445,11 @@ public static bool GreaterThanOrEqualAny(T x, in ReadOnlyTensorSpan y) scoped Span curIndex; nint[]? curIndexArray; - if (y.Rank > 6) + if (y.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(y.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, y.Rank); } else { @@ -1472,10 +1491,11 @@ public static bool GreaterThanAll(in ReadOnlyTensorSpan x, in ReadOnlyTens scoped Span curIndex; nint[]? curIndexArray; - if (broadcastedLeft.Rank > 6) + if (broadcastedLeft.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(broadcastedRight.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, broadcastedRight.Rank); } else { @@ -1510,10 +1530,11 @@ public static bool GreaterThanAll(in ReadOnlyTensorSpan x, T y) scoped Span curIndex; nint[]? curIndexArray; - if (x.Rank > 6) + if (x.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(x.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, x.Rank); } else { @@ -1548,10 +1569,11 @@ public static bool GreaterThanAll(T x, in ReadOnlyTensorSpan y) scoped Span curIndex; nint[]? curIndexArray; - if (y.Rank > 6) + if (y.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(y.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, y.Rank); } else { @@ -1593,10 +1615,11 @@ public static bool GreaterThanOrEqualAll(in ReadOnlyTensorSpan x, in ReadO scoped Span curIndex; nint[]? curIndexArray; - if (broadcastedLeft.Rank > 6) + if (broadcastedLeft.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(broadcastedRight.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, broadcastedRight.Rank); } else { @@ -1618,35 +1641,36 @@ public static bool GreaterThanOrEqualAll(in ReadOnlyTensorSpan x, in ReadO } /// - /// Compares the elements of two to see if all elements of are greater than . + /// Compares the elements of two to see if all elements of are greater than . /// If the shapes are not the same, the tensors are broadcasted to the smallest broadcastable size before they are compared. - /// It returns a where the value is true if all elements in are greater than . + /// It returns a where the value is true if all elements in are greater than . /// - /// First to compare. + /// First to compare. /// Second to compare against. - /// where the value is true if all elements in are greater than . - public static bool GreaterThanOrEqualAll(in ReadOnlyTensorSpan s, T y) + /// where the value is true if all elements in are greater than . + public static bool GreaterThanOrEqualAll(in ReadOnlyTensorSpan x, T y) where T : IComparisonOperators { scoped Span curIndex; nint[]? curIndexArray; - if (s.Rank > 6) + if (x.Rank > TensorShape.MaxInlineRank) { - curIndexArray = ArrayPool.Shared.Rent(s.Rank); + curIndexArray = ArrayPool.Shared.Rent(x.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, x.Rank); } else { curIndexArray = null; - curIndex = stackalloc nint[s.Rank]; + curIndex = stackalloc nint[x.Rank]; } - for (int i = 0; i < s.FlattenedLength; i++) + for (int i = 0; i < x.FlattenedLength; i++) { - if (s[curIndex] < y) + if (x[curIndex] < y) return false; - TensorSpanHelpers.AdjustIndexes(s.Rank - 1, 1, curIndex, s.Lengths); + TensorSpanHelpers.AdjustIndexes(x.Rank - 1, 1, curIndex, x.Lengths); } if (curIndexArray != null) @@ -1669,10 +1693,11 @@ public static bool GreaterThanOrEqualAll(T x, in ReadOnlyTensorSpan y) scoped Span curIndex; nint[]? curIndexArray; - if (y.Rank > 6) + if (y.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(y.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, y.Rank); } else { @@ -1758,10 +1783,11 @@ public static ref readonly TensorSpan LessThan(scoped in ReadOnlyTensor scoped Span curIndex; nint[]? curIndexArray; - if (right.Rank > 6) + if (right.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(right.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, right.Rank); } else { @@ -1817,10 +1843,11 @@ public static ref readonly TensorSpan LessThan(scoped in ReadOnlyTensor scoped Span curIndex; nint[]? curIndexArray; - if (x.Rank > 6) + if (x.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(x.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, x.Rank); } else { @@ -1876,10 +1903,11 @@ public static ref readonly TensorSpan LessThan(T x, scoped in ReadOnlyT scoped Span curIndex; nint[]? curIndexArray; - if (y.Rank > 6) + if (y.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(y.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, y.Rank); } else { @@ -1964,10 +1992,11 @@ public static ref readonly TensorSpan LessThanOrEqual(scoped in ReadOnl scoped Span curIndex; nint[]? curIndexArray; - if (right.Rank > 6) + if (right.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(right.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, right.Rank); } else { @@ -2023,10 +2052,11 @@ public static ref readonly TensorSpan LessThanOrEqual(scoped in ReadOnl scoped Span curIndex; nint[]? curIndexArray; - if (x.Rank > 6) + if (x.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(x.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, x.Rank); } else { @@ -2082,10 +2112,11 @@ public static ref readonly TensorSpan LessThanOrEqual(T x, scoped in Re scoped Span curIndex; nint[]? curIndexArray; - if (y.Rank > 6) + if (y.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(y.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, y.Rank); } else { @@ -2126,10 +2157,11 @@ public static bool LessThanAny(in ReadOnlyTensorSpan x, in ReadOnlyTensorS scoped Span curIndex; nint[]? curIndexArray; - if (broadcastedRight.Lengths.Length > 6) + if (broadcastedRight.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(broadcastedRight.Lengths.Length); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, broadcastedRight.Rank); } else { @@ -2151,35 +2183,36 @@ public static bool LessThanAny(in ReadOnlyTensorSpan x, in ReadOnlyTensorS } /// - /// Compares the elements of two to see if any elements of are less than . + /// Compares the elements of two to see if any elements of are less than . /// If the shapes are not the same, the tensors are broadcasted to the smallest broadcastable size before they are compared. - /// It returns a where the value is true if any elements in are less than . + /// It returns a where the value is true if any elements in are less than . /// - /// First to compare. - /// Second value to compare against. - /// where the value is true if any elements in are less than . - public static bool LessThanAny(in ReadOnlyTensorSpan f, T x) + /// First to compare. + /// Second value to compare against. + /// where the value is true if any elements in are less than . + public static bool LessThanAny(in ReadOnlyTensorSpan x, T y) where T : IComparisonOperators { scoped Span curIndex; nint[]? curIndexArray; - if (f.Rank > 6) + if (x.Rank > TensorShape.MaxInlineRank) { - curIndexArray = ArrayPool.Shared.Rent(f.Rank); + curIndexArray = ArrayPool.Shared.Rent(x.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, x.Rank); } else { curIndexArray = null; - curIndex = stackalloc nint[f.Rank]; + curIndex = stackalloc nint[x.Rank]; } - for (int i = 0; i < f.FlattenedLength; i++) + for (int i = 0; i < x.FlattenedLength; i++) { - if (f[curIndex] < x) + if (x[curIndex] < y) return true; - TensorSpanHelpers.AdjustIndexes(f.Rank - 1, 1, curIndex, f.Lengths); + TensorSpanHelpers.AdjustIndexes(x.Rank - 1, 1, curIndex, x.Lengths); } if (curIndexArray != null) @@ -2202,10 +2235,11 @@ public static bool LessThanAny(T x, in ReadOnlyTensorSpan y) scoped Span curIndex; nint[]? curIndexArray; - if (y.Rank > 6) + if (y.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(y.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, y.Rank); } else { @@ -2247,10 +2281,11 @@ public static bool LessThanOrEqualAny(in ReadOnlyTensorSpan x, in ReadOnly scoped Span curIndex; nint[]? curIndexArray; - if (broadcastedRight.Lengths.Length > 6) + if (broadcastedRight.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(broadcastedRight.Lengths.Length); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, broadcastedRight.Rank); } else { @@ -2272,35 +2307,36 @@ public static bool LessThanOrEqualAny(in ReadOnlyTensorSpan x, in ReadOnly } /// - /// Compares the elements of two to see if any elements of are less than . + /// Compares the elements of two to see if any elements of are less than . /// If the shapes are not the same, the tensors are broadcasted to the smallest broadcastable size before they are compared. - /// It returns a where the value is true if any elements in are less than . + /// It returns a where the value is true if any elements in are less than . /// - /// First to compare. - /// Second value to compare against. - /// where the value is true if any elements in are less than . - public static bool LessThanOrEqualAny(in ReadOnlyTensorSpan f, T x) + /// First to compare. + /// Second value to compare against. + /// where the value is true if any elements in are less than . + public static bool LessThanOrEqualAny(in ReadOnlyTensorSpan x, T y) where T : IComparisonOperators { scoped Span curIndex; nint[]? curIndexArray; - if (f.Rank > 6) + if (x.Rank > TensorShape.MaxInlineRank) { - curIndexArray = ArrayPool.Shared.Rent(f.Rank); + curIndexArray = ArrayPool.Shared.Rent(x.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, x.Rank); } else { curIndexArray = null; - curIndex = stackalloc nint[f.Rank]; + curIndex = stackalloc nint[x.Rank]; } - for (int i = 0; i < f.FlattenedLength; i++) + for (int i = 0; i < x.FlattenedLength; i++) { - if (f[curIndex] <= x) + if (x[curIndex] <= y) return true; - TensorSpanHelpers.AdjustIndexes(f.Rank - 1, 1, curIndex, f.Lengths); + TensorSpanHelpers.AdjustIndexes(x.Rank - 1, 1, curIndex, x.Lengths); } if (curIndexArray != null) @@ -2323,10 +2359,11 @@ public static bool LessThanOrEqualAny(T x, in ReadOnlyTensorSpan y) scoped Span curIndex; nint[]? curIndexArray; - if (y.Rank > 6) + if (y.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(y.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, y.Rank); } else { @@ -2367,10 +2404,11 @@ public static bool LessThanAll(in ReadOnlyTensorSpan x, in ReadOnlyTensorS scoped Span curIndex; nint[]? curIndexArray; - if (broadcastedRight.Lengths.Length > 6) + if (broadcastedRight.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(broadcastedRight.Lengths.Length); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, broadcastedRight.Rank); } else { @@ -2392,35 +2430,36 @@ public static bool LessThanAll(in ReadOnlyTensorSpan x, in ReadOnlyTensorS } /// - /// Compares the elements of two to see if all elements of are less than . + /// Compares the elements of two to see if all elements of are less than . /// If the shapes are not the same, the tensors are broadcasted to the smallest broadcastable size before they are compared. - /// It returns a where the value is true if all elements in are less than . + /// It returns a where the value is true if all elements in are less than . /// - /// First to compare. - /// Second value to compare against. - /// where the value is true if all elements in are less than . - public static bool LessThanAll(in ReadOnlyTensorSpan f, T x) + /// First to compare. + /// Second value to compare against. + /// where the value is true if all elements in are less than . + public static bool LessThanAll(in ReadOnlyTensorSpan x, T y) where T : IComparisonOperators { scoped Span curIndex; nint[]? curIndexArray; - if (f.Rank > 6) + if (x.Rank > TensorShape.MaxInlineRank) { - curIndexArray = ArrayPool.Shared.Rent(f.Rank); + curIndexArray = ArrayPool.Shared.Rent(x.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, x.Rank); } else { curIndexArray = null; - curIndex = stackalloc nint[f.Rank]; + curIndex = stackalloc nint[x.Rank]; } - for (int i = 0; i < f.FlattenedLength; i++) + for (int i = 0; i < x.FlattenedLength; i++) { - if (f[curIndex] >= x) + if (x[curIndex] >= y) return false; - TensorSpanHelpers.AdjustIndexes(f.Rank - 1, 1, curIndex, f.Lengths); + TensorSpanHelpers.AdjustIndexes(x.Rank - 1, 1, curIndex, x.Lengths); } if (curIndexArray != null) @@ -2443,10 +2482,11 @@ public static bool LessThanAll(T x, in ReadOnlyTensorSpan y) scoped Span curIndex; nint[]? curIndexArray; - if (y.Rank > 6) + if (y.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(y.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, y.Rank); } else { @@ -2487,10 +2527,11 @@ public static bool LessThanOrEqualAll(in ReadOnlyTensorSpan x, in ReadOnly scoped Span curIndex; nint[]? curIndexArray; - if (broadcastedRight.Lengths.Length > 6) + if (broadcastedRight.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(broadcastedRight.Lengths.Length); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, broadcastedRight.Rank); } else { @@ -2512,35 +2553,36 @@ public static bool LessThanOrEqualAll(in ReadOnlyTensorSpan x, in ReadOnly } /// - /// Compares the elements of two to see if all elements of are less than . + /// Compares the elements of two to see if all elements of are less than . /// If the shapes are not the same, the tensors are broadcasted to the smallest broadcastable size before they are compared. - /// It returns a where the value is true if all elements in are less than . + /// It returns a where the value is true if all elements in are less than . /// - /// First to compare. - /// Second value to compare against. - /// where the value is true if all elements in are less than . - public static bool LessThanOrEqualAll(in ReadOnlyTensorSpan f, T x) + /// First to compare. + /// Second value to compare against. + /// where the value is true if all elements in are less than . + public static bool LessThanOrEqualAll(in ReadOnlyTensorSpan x, T y) where T : IComparisonOperators { scoped Span curIndex; nint[]? curIndexArray; - if (f.Rank > 6) + if (x.Rank > TensorShape.MaxInlineRank) { - curIndexArray = ArrayPool.Shared.Rent(f.Rank); + curIndexArray = ArrayPool.Shared.Rent(x.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, x.Rank); } else { curIndexArray = null; - curIndex = stackalloc nint[f.Rank]; + curIndex = stackalloc nint[x.Rank]; } - for (int i = 0; i < f.FlattenedLength; i++) + for (int i = 0; i < x.FlattenedLength; i++) { - if (f[curIndex] > x) + if (x[curIndex] > y) return false; - TensorSpanHelpers.AdjustIndexes(f.Rank - 1, 1, curIndex, f.Lengths); + TensorSpanHelpers.AdjustIndexes(x.Rank - 1, 1, curIndex, x.Lengths); } if (curIndexArray != null) @@ -2563,10 +2605,11 @@ public static bool LessThanOrEqualAll(T x, in ReadOnlyTensorSpan y) scoped Span curIndex; nint[]? curIndexArray; - if (y.Rank > 6) + if (y.Rank > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(y.Rank); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, y.Rank); } else { @@ -4234,7 +4277,7 @@ public static ref readonly TensorSpan ConvertChecked(scoped in where TFrom : IEquatable, IEqualityOperators, INumberBase where TTo : INumberBase { - return ref TensorPrimitivesHelperTFromSpanInTToSpanOut(source, destination, TensorPrimitives.ConvertChecked); + return ref TensorPrimitivesHelperSpanInSpanOut(source, destination, TensorPrimitives.ConvertChecked); } #endregion @@ -4264,7 +4307,7 @@ public static ref readonly TensorSpan ConvertSaturating(scoped where TFrom : IEquatable, IEqualityOperators, INumberBase where TTo : INumberBase { - return ref TensorPrimitivesHelperTFromSpanInTToSpanOut(source, destination, TensorPrimitives.ConvertSaturating); + return ref TensorPrimitivesHelperSpanInSpanOut(source, destination, TensorPrimitives.ConvertSaturating); } #endregion @@ -4294,7 +4337,7 @@ public static ref readonly TensorSpan ConvertTruncating(scoped where TFrom : IEquatable, IEqualityOperators, INumberBase where TTo : INumberBase { - return ref TensorPrimitivesHelperTFromSpanInTToSpanOut(source, destination, TensorPrimitives.ConvertTruncating); + return ref TensorPrimitivesHelperSpanInSpanOut(source, destination, TensorPrimitives.ConvertTruncating); } #endregion @@ -4964,7 +5007,7 @@ public static Tensor ILogB(in ReadOnlyTensorSpan x) public static ref readonly TensorSpan ILogB(scoped in ReadOnlyTensorSpan x, in TensorSpan destination) where T : IFloatingPointIeee754 { - return ref TensorPrimitivesHelperSpanInIntSpanOut(x, destination, TensorPrimitives.ILogB); + return ref TensorPrimitivesHelperSpanInSpanOut(x, destination, TensorPrimitives.ILogB); } #endregion @@ -5824,8 +5867,7 @@ public static ref readonly TensorSpan Negate(scoped in ReadOnlyTensorSpan< public static T Norm(scoped in ReadOnlyTensorSpan x) where T : IRootFunctions { - ReadOnlySpan span = MemoryMarshal.CreateSpan(ref x._reference, (int)x._shape._memoryLength); - return TensorPrimitives.Norm(span); + return TensorPrimitivesHelperSpanInTOut(x, TensorPrimitives.Norm); } #endregion @@ -6423,8 +6465,7 @@ public static ref readonly TensorSpan Subtract(scoped in ReadOnlyTensorSpa public static T Sum(scoped in ReadOnlyTensorSpan x) where T : IAdditionOperators, IAdditiveIdentity { - ReadOnlySpan span = MemoryMarshal.CreateSpan(ref x._reference, (int)x._shape._memoryLength); - return TensorPrimitives.Sum(span); + return TensorPrimitivesHelperSpanInTOut(x, TensorPrimitives.Sum); } #endregion @@ -6616,7 +6657,7 @@ public static nint[] GetSmallestBroadcastableLengths(ReadOnlySpan shape1, } #region TensorPrimitivesHelpers - private delegate void PerformCalculationSpanInSpanOut(ReadOnlySpan input, Span output); + private delegate void PerformCalculationSpanInSpanOut(ReadOnlySpan input, Span output); private delegate void PerformCalculationSpanInTInSpanOut(ReadOnlySpan input, T value, Span output); @@ -6624,37 +6665,36 @@ public static nint[] GetSmallestBroadcastableLengths(ReadOnlySpan shape1, private delegate void PerformCalculationTwoSpanInSpanOut(ReadOnlySpan input, ReadOnlySpan inputTwo, Span output); - private delegate void PerformCalculationTFromSpanInTToSpanOut(ReadOnlySpan input, Span output) - where TFrom : INumberBase - where TTo : INumberBase; - private delegate T PerformCalculationTwoSpanInTOut(ReadOnlySpan input, ReadOnlySpan inputTwo); - private delegate void PerformCalculationSpanInIntSpanOut(ReadOnlySpan input, Span output); - private delegate T PerformCalculationSpanInTOut(ReadOnlySpan input); private static T TensorPrimitivesHelperSpanInTOut(scoped in ReadOnlyTensorSpan input, PerformCalculationSpanInTOut performCalculation) { - ReadOnlySpan span = MemoryMarshal.CreateSpan(ref input._reference, (int)input._shape._memoryLength); - return performCalculation(span); - } - - private static ref readonly TensorSpan TensorPrimitivesHelperSpanInIntSpanOut(scoped in ReadOnlyTensorSpan input, in TensorSpan destination, PerformCalculationSpanInIntSpanOut performCalculation) - { - ReadOnlySpan span = MemoryMarshal.CreateSpan(ref input._reference, (int)input._shape._memoryLength); - Span data = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape._memoryLength); - performCalculation(span, data); - return ref destination; + if (TensorHelpers.IsContiguousAndDense(input)) + { + ReadOnlySpan span = MemoryMarshal.CreateSpan(ref input._reference, (int)input._shape.FlattenedLength); + return performCalculation(span); + } + // Flattening needs to happen + else + { + // TODO: Can optimize this to not need to realize the broadcasts + // That will need to be done on a per method basis. + nint flattenedLength = input.FlattenedLength; + T[] flattened = new T[flattenedLength]; + input.FlattenTo(flattened); + return performCalculation(flattened); + } } private static T TensorPrimitivesHelperTwoSpanInTOut(scoped in ReadOnlyTensorSpan left, scoped in ReadOnlyTensorSpan right, PerformCalculationTwoSpanInTOut performCalculation) { // If sizes are the same. - if (TensorHelpers.AreLengthsTheSame(left, right) && TensorHelpers.IsUnderlyingStorageSameSize(left, right)) + if (TensorHelpers.IsContiguousAndDense(left) && TensorHelpers.IsContiguousAndDense(right) && TensorHelpers.AreLengthsTheSame(left, right)) { - ReadOnlySpan span = MemoryMarshal.CreateSpan(ref left._reference, (int)left._shape._memoryLength); - ReadOnlySpan rspan = MemoryMarshal.CreateSpan(ref right._reference, (int)right._shape._memoryLength); + ReadOnlySpan span = MemoryMarshal.CreateSpan(ref left._reference, (int)left._shape.FlattenedLength); + ReadOnlySpan rspan = MemoryMarshal.CreateSpan(ref right._reference, (int)right._shape.FlattenedLength); return performCalculation(span, rspan); } // Broadcasting needs to happen. @@ -6665,6 +6705,8 @@ private static T TensorPrimitivesHelperTwoSpanInTOut(scoped in ReadOnlyTensor // 2 - One tensor has row contiguous memory and the right has column contiguous memory (i.e. a 1x5 and a 5x1) // Because we are returning a single T though we need to actual realize the broadcasts at this point to perform the calculations. + // TODO: Can optimize this to not need to realize the broadcasts + // That will need to be done on a per method basis. nint[] newLengths = Tensor.GetSmallestBroadcastableLengths(left.Lengths, right.Lengths); nint newLength = TensorSpanHelpers.CalculateTotalLength(newLengths); TensorSpan broadcastedLeft = new TensorSpan(new T[newLength], newLengths, ReadOnlySpan.Empty); @@ -6678,59 +6720,177 @@ private static T TensorPrimitivesHelperTwoSpanInTOut(scoped in ReadOnlyTensor } } - private static ref readonly TensorSpan TensorPrimitivesHelperSpanInSpanOut(scoped in ReadOnlyTensorSpan input, in TensorSpan destination, PerformCalculationSpanInSpanOut performCalculation) + private static ref readonly TensorSpan TensorPrimitivesHelperSpanInSpanOut(scoped in ReadOnlyTensorSpan input, in TensorSpan destination, PerformCalculationSpanInSpanOut performCalculation) { - if (destination._shape._memoryLength < input._shape._memoryLength) + // Make sure destination has enough memory + if (destination._shape._memoryLength < input._shape.FlattenedLength) ThrowHelper.ThrowArgumentException_DestinationTooShort(); - ReadOnlySpan span = MemoryMarshal.CreateSpan(ref input._reference, (int)input._shape._memoryLength); - Span ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape._memoryLength); - performCalculation(span, ospan); + // Make sure destination shape works with input shape + TensorSpan slicedDestination = destination.Slice(input._shape.Lengths); + + Span destinationSpan; + ReadOnlySpan inputSpan; + + // Memory is contiguous for both input and destination + if (TensorHelpers.IsContiguousAndDense(input) && TensorHelpers.IsContiguousAndDense(slicedDestination)) + { + inputSpan = MemoryMarshal.CreateSpan(ref input._reference, (int)input._shape.FlattenedLength); + destinationSpan = MemoryMarshal.CreateSpan(ref slicedDestination._reference, (int)slicedDestination._shape.FlattenedLength); + performCalculation(inputSpan, destinationSpan); + } + else + { + scoped Span curIndex; + nint[]? curIndexArray; + if (input.Lengths.Length > TensorShape.MaxInlineRank) + { + curIndexArray = ArrayPool.Shared.Rent(input.Lengths.Length); + curIndex = curIndexArray; + curIndex = curIndex.Slice(0, input.Rank); + } + else + { + curIndexArray = null; + curIndex = stackalloc nint[input.Lengths.Length]; + } + + int copiedValues = 0; + nint rowLength = input.Lengths[^1]; + + while (copiedValues < slicedDestination.FlattenedLength) + { + inputSpan = MemoryMarshal.CreateReadOnlySpan(in input[curIndex], (int)rowLength); + destinationSpan = MemoryMarshal.CreateSpan(ref slicedDestination[curIndex], (int)rowLength); + performCalculation(inputSpan, destinationSpan); + copiedValues += (int)rowLength; + TensorSpanHelpers.AdjustIndexes(input.Rank - 2, 1, curIndex, input.Lengths); + } + + if (curIndexArray != null) + ArrayPool.Shared.Return(curIndexArray); + } + return ref destination; } private static ref readonly TensorSpan TensorPrimitivesHelperSpanInTInSpanOut(scoped in ReadOnlyTensorSpan input, T value, in TensorSpan destination, PerformCalculationSpanInTInSpanOut performCalculation) { - if (destination._shape._memoryLength < input._shape._memoryLength) + if (destination._shape._memoryLength < input._shape.FlattenedLength) ThrowHelper.ThrowArgumentException_DestinationTooShort(); - ReadOnlySpan span = MemoryMarshal.CreateSpan(ref input._reference, (int)input._shape._memoryLength); - Span ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape._memoryLength); - performCalculation(span, value, ospan); + // Make sure destination shape works with input shape + TensorSpan slicedDestination = destination.Slice(input._shape.Lengths); + + ReadOnlySpan inputSpan; + Span destinationSpan; + + if (TensorHelpers.IsContiguousAndDense(input) && TensorHelpers.IsContiguousAndDense(slicedDestination)) + { + inputSpan = MemoryMarshal.CreateSpan(ref input._reference, (int)input.FlattenedLength); + destinationSpan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination.FlattenedLength); + performCalculation(inputSpan, value, destinationSpan); + } + else + { + scoped Span curIndex; + nint[]? curIndexArray; + if (input.Lengths.Length > TensorShape.MaxInlineRank) + { + curIndexArray = ArrayPool.Shared.Rent(input.Lengths.Length); + curIndex = curIndexArray; + curIndex = curIndex.Slice(0, input.Rank); + } + else + { + curIndexArray = null; + curIndex = stackalloc nint[input.Lengths.Length]; + } + + int copiedValues = 0; + nint rowLength = input.Lengths[^1]; + + while (copiedValues < slicedDestination.FlattenedLength) + { + inputSpan = MemoryMarshal.CreateReadOnlySpan(in input[curIndex], (int)rowLength); + destinationSpan = MemoryMarshal.CreateSpan(ref slicedDestination[curIndex], (int)rowLength); + performCalculation(inputSpan, value, destinationSpan); + copiedValues += (int)rowLength; + TensorSpanHelpers.AdjustIndexes(input.Rank - 2, 1, curIndex, input.Lengths); + } + + if (curIndexArray != null) + ArrayPool.Shared.Return(curIndexArray); + } + return ref destination; } private static ref readonly TensorSpan TensorPrimitivesHelperTInSpanInSpanOut(T value, scoped in ReadOnlyTensorSpan input, in TensorSpan destination, PerformCalculationTInSpanInSpanOut performCalculation) { - if (destination._shape._memoryLength < input._shape._memoryLength) + if (destination._shape._memoryLength < input._shape.FlattenedLength) ThrowHelper.ThrowArgumentException_DestinationTooShort(); - ReadOnlySpan span = MemoryMarshal.CreateSpan(ref input._reference, (int)input._shape._memoryLength); - Span ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape._memoryLength); - performCalculation(value, span, ospan); - return ref destination; - } + // Make sure destination shape works with input shape + TensorSpan slicedDestination = destination.Slice(input._shape.Lengths); + + ReadOnlySpan inputSpan; + Span destinationSpan; + + if (TensorHelpers.IsContiguousAndDense(input) && TensorHelpers.IsContiguousAndDense(slicedDestination)) + { + inputSpan = MemoryMarshal.CreateSpan(ref input._reference, (int)input._shape._memoryLength); + destinationSpan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape._memoryLength); + performCalculation(value, inputSpan, destinationSpan); + } + else + { + scoped Span curIndex; + nint[]? curIndexArray; + if (input.Lengths.Length > TensorShape.MaxInlineRank) + { + curIndexArray = ArrayPool.Shared.Rent(input.Lengths.Length); + curIndex = curIndexArray; + curIndex = curIndex.Slice(0, input.Rank); + } + else + { + curIndexArray = null; + curIndex = stackalloc nint[input.Lengths.Length]; + } + + int copiedValues = 0; + nint rowLength = input.Lengths[^1]; + + while (copiedValues < slicedDestination.FlattenedLength) + { + inputSpan = MemoryMarshal.CreateReadOnlySpan(in input[curIndex], (int)rowLength); + destinationSpan = MemoryMarshal.CreateSpan(ref slicedDestination[curIndex], (int)rowLength); + performCalculation(value, inputSpan, destinationSpan); + copiedValues += (int)rowLength; + TensorSpanHelpers.AdjustIndexes(input.Rank - 2, 1, curIndex, input.Lengths); + } + + if (curIndexArray != null) + ArrayPool.Shared.Return(curIndexArray); + } - private static ref readonly TensorSpan TensorPrimitivesHelperTFromSpanInTToSpanOut(scoped in ReadOnlyTensorSpan input, in TensorSpan destination, PerformCalculationTFromSpanInTToSpanOut performCalculation) - where TFrom : IEquatable, IEqualityOperators, INumberBase - where TTo : INumberBase - { - ReadOnlySpan span = MemoryMarshal.CreateSpan(ref input._reference, (int)input._shape._memoryLength); - Span ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape._memoryLength); - performCalculation(span, ospan); return ref destination; } private static ref readonly TensorSpan TensorPrimitivesHelperTwoSpanInSpanOut(scoped in ReadOnlyTensorSpan left, scoped in ReadOnlyTensorSpan right, in TensorSpan destination, PerformCalculationTwoSpanInSpanOut performCalculation) { - // If sizes are the same. - if (TensorHelpers.AreLengthsTheSame(left, right) && TensorHelpers.IsUnderlyingStorageSameSize(left, right)) + nint[] newSize = Tensor.GetSmallestBroadcastableLengths(left.Lengths, right.Lengths); + + TensorSpan slicedDestination = destination.Slice(newSize); + + // If sizes are the same and memory is contiguous for all tensors + if (TensorHelpers.AreLengthsTheSame(left, right) && TensorHelpers.IsUnderlyingStorageSameSize(left, right) && TensorHelpers.IsContiguousAndDense(left) + && TensorHelpers.IsContiguousAndDense(right) && TensorHelpers.IsContiguousAndDense(slicedDestination)) { - if (!TensorHelpers.IsUnderlyingStorageSameSize(left, destination)) - ThrowHelper.ThrowArgument_DestinationTooShort(); - ReadOnlySpan span = MemoryMarshal.CreateSpan(ref left._reference, (int)left._shape._memoryLength); - ReadOnlySpan rspan = MemoryMarshal.CreateSpan(ref right._reference, (int)right._shape._memoryLength); - Span ospan = MemoryMarshal.CreateSpan(ref destination._reference, (int)destination._shape._memoryLength); + ReadOnlySpan span = MemoryMarshal.CreateSpan(ref left._reference, left._shape._memoryLength <= left.FlattenedLength ? (int)left._shape._memoryLength : (int)left.FlattenedLength); + ReadOnlySpan rspan = MemoryMarshal.CreateSpan(ref right._reference, right._shape._memoryLength <= right.FlattenedLength ? (int)right._shape._memoryLength : (int)right.FlattenedLength); + Span ospan = MemoryMarshal.CreateSpan(ref slicedDestination._reference, (int)slicedDestination._shape._memoryLength); performCalculation(span, rspan, ospan); return ref destination; } @@ -6741,12 +6901,8 @@ private static ref readonly TensorSpan TensorPrimitivesHelperTwoSpanInSpanOut // 1 - Both tensors have row contiguous memory (i.e. a 1x5 being broadcast to a 5x5) // 2 - One tensor has row contiguous memory and the right has column contiguous memory (i.e. a 1x5 and a 5x1) - nint[] newSize = Tensor.GetSmallestBroadcastableLengths(left.Lengths, right.Lengths); - ReadOnlyTensorSpan broadcastedLeft = Tensor.LazyBroadcast(left, newSize); ReadOnlyTensorSpan broadcastedRight = Tensor.LazyBroadcast(right, newSize); - if (!destination.Lengths.SequenceEqual(newSize) || destination._shape._memoryLength < broadcastedLeft.FlattenedLength) - ThrowHelper.ThrowArgument_LengthsNotBroadcastCompatible(); nint rowLength = newSize[^1]; Span ospan; @@ -6755,10 +6911,11 @@ private static ref readonly TensorSpan TensorPrimitivesHelperTwoSpanInSpanOut scoped Span curIndex; nint[]? curIndexArray; - if (newSize.Length > 6) + if (newSize.Length > TensorShape.MaxInlineRank) { curIndexArray = ArrayPool.Shared.Rent(newSize.Length); curIndex = curIndexArray; + curIndex = curIndex.Slice(0, newSize.Length); } else { @@ -6767,13 +6924,27 @@ private static ref readonly TensorSpan TensorPrimitivesHelperTwoSpanInSpanOut } int outputOffset = 0; - // ADD IN CASE WHERE NEITHER ARE ROW CONTIGUOUS + // neither row contiguous + if (broadcastedLeft.Strides[^1] == 0 && broadcastedRight.Strides[^1] == 0) + { + Span buffer2 = new T[rowLength]; + + while (outputOffset < slicedDestination.FlattenedLength) + { + ospan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref slicedDestination._reference, TensorSpanHelpers.ComputeLinearIndex(curIndex, slicedDestination.Strides, slicedDestination.Lengths)), (int)rowLength); + buffer.Fill(broadcastedLeft[curIndex]); + buffer2.Fill(broadcastedRight[curIndex]); + performCalculation(buffer, buffer2, ospan); + outputOffset += (int)rowLength; + TensorSpanHelpers.AdjustIndexes(broadcastedLeft.Rank - 2, 1, curIndex, broadcastedLeft.Lengths); + } + } // tensor not row contiguous - if (broadcastedLeft.Strides[^1] == 0) + else if (broadcastedLeft.Strides[^1] == 0) { - while (outputOffset < destination.FlattenedLength) + while (outputOffset < slicedDestination.FlattenedLength) { - ospan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref destination._reference, outputOffset), (int)rowLength); + ospan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref slicedDestination._reference, TensorSpanHelpers.ComputeLinearIndex(curIndex, slicedDestination.Strides, slicedDestination.Lengths)), (int)rowLength); buffer.Fill(broadcastedLeft[curIndex]); ispan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref broadcastedRight._reference, TensorSpanHelpers.ComputeLinearIndex(curIndex, broadcastedRight.Strides, broadcastedRight.Lengths)), (int)rowLength); performCalculation(buffer, ispan, ospan); @@ -6784,9 +6955,9 @@ private static ref readonly TensorSpan TensorPrimitivesHelperTwoSpanInSpanOut // right not row contiguous else if (broadcastedRight.Strides[^1] == 0) { - while (outputOffset < destination.FlattenedLength) + while (outputOffset < slicedDestination.FlattenedLength) { - ospan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref destination._reference, outputOffset), (int)rowLength); + ospan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref slicedDestination._reference, TensorSpanHelpers.ComputeLinearIndex(curIndex, slicedDestination.Strides, slicedDestination.Lengths)), (int)rowLength); buffer.Fill(broadcastedRight[curIndex]); ispan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref broadcastedLeft._reference, TensorSpanHelpers.ComputeLinearIndex(curIndex, broadcastedLeft.Strides, broadcastedLeft.Lengths)), (int)rowLength); performCalculation(ispan, buffer, ospan); @@ -6798,9 +6969,9 @@ private static ref readonly TensorSpan TensorPrimitivesHelperTwoSpanInSpanOut else { Span rspan; - while (outputOffset < destination.FlattenedLength) + while (outputOffset < slicedDestination.FlattenedLength) { - ospan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref destination._reference, outputOffset), (int)rowLength); + ospan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref slicedDestination._reference, TensorSpanHelpers.ComputeLinearIndex(curIndex, slicedDestination.Strides, slicedDestination.Lengths)), (int)rowLength); ispan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref broadcastedLeft._reference, TensorSpanHelpers.ComputeLinearIndex(curIndex, broadcastedLeft.Strides, broadcastedLeft.Lengths)), (int)rowLength); rspan = MemoryMarshal.CreateSpan(ref Unsafe.Add(ref broadcastedRight._reference, TensorSpanHelpers.ComputeLinearIndex(curIndex, broadcastedRight.Strides, broadcastedRight.Lengths)), (int)rowLength); performCalculation(ispan, rspan, ospan); diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorHelpers.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorHelpers.cs index 163075ce5c13dc..c169c082af48ae 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorHelpers.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorHelpers.cs @@ -97,6 +97,21 @@ internal static bool AreLengthsTheSame(scoped in ReadOnlyTensorSpan tensor internal static bool AreLengthsTheSame(ReadOnlySpan lengths1, ReadOnlySpan lengths2) => lengths1.SequenceEqual(lengths2); + internal static bool IsContiguousAndDense(scoped in ReadOnlyTensorSpan tensor) + { + // Right most dimension must be 1 for a dense tensor. + if (tensor._shape.Strides[^1] != 1) + return false; + + // For other dimensions, the stride must be equal to the product of the dimensions to the right. + for (int i = tensor._shape._rank - 2; i >= 0; i--) + { + if (tensor._shape.Strides[i] != TensorPrimitives.Product(tensor.Lengths.Slice(i + 1, tensor.Lengths.Length - i - 1))) + return false; + } + return true; + } + internal static void PermuteIndices(Span indices, Span permutedIndices, ReadOnlySpan permutation) { for (int i = 0; i < indices.Length; i++) diff --git a/src/libraries/System.Numerics.Tensors/tests/Helpers.cs b/src/libraries/System.Numerics.Tensors/tests/Helpers.cs index d55745614646fe..4ea2988e6f47af 100644 --- a/src/libraries/System.Numerics.Tensors/tests/Helpers.cs +++ b/src/libraries/System.Numerics.Tensors/tests/Helpers.cs @@ -14,7 +14,9 @@ public static class Helpers public static IEnumerable TensorLengthsIncluding0 => Enumerable.Range(0, 257); public static IEnumerable TensorLengths => Enumerable.Range(1, 256); - public static IEnumerable TensorShapes => [[1], [2], [10], [1,1], [1,2], [2,2], [5, 5], [2, 2, 2], [5, 5, 5], [3, 3, 3, 3], [4, 4, 4, 4, 4], [1, 2, 3, 4, 5, 6, 7]]; + public static IEnumerable TensorShapes => [[1], [2], [10], [1, 1], [1, 2], [2, 2], [5, 5], [2, 2, 2], [5, 5, 5], [3, 3, 3, 3], [4, 4, 4, 4, 4], [1, 2, 3, 4, 5, 6, 7, 1, 2]]; + public static nint[][] TensorSliceShapes => [[1], [1], [5], [1, 1], [1, 1], [1, 2], [3, 3], [2, 2, 1], [5, 3, 5], [3, 2, 1, 3], [4, 3, 2, 1, 2], [1, 2, 2, 2, 2, 1, 1, 1, 1]]; + public static nint[][] TensorSliceShapesForBroadcast => [[1], [1], [1], [1, 1], [1, 1], [1, 2], [1, 1], [2, 2, 1], [1, 5, 5], [3, 1, 1, 3], [4, 1, 4, 1, 4], [1, 2, 1, 4, 1, 1, 7, 1, 1]]; // Tolerances taken from testing in the scalar math routines: // cf. https://github.com/dotnet/runtime/blob/89f7ad3b276fb0b48f20cb4e8408bdce85c2b415/src/libraries/System.Runtime/tests/System.Runtime.Extensions.Tests/System/Math.cs diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorSpanTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorSpanTests.cs index e43a7d55e7dfae..f03f682f94f704 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorSpanTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorSpanTests.cs @@ -4,6 +4,7 @@ using System.Buffers; using System.Collections.Generic; using System.Linq; +using System.Reflection; using System.Runtime.InteropServices; using Xunit; @@ -34,93 +35,179 @@ private static nint CalculateTotalLength(ReadOnlySpan lengths) return totalLength; } - public delegate void TensorPrimitivesSpanInSpanOut(ReadOnlySpan input, Span output); - public delegate ref readonly TensorSpan TensorSpanInSpanOut(scoped in ReadOnlyTensorSpan input, in TensorSpan destination); - public delegate ref readonly TensorSpan TensorSpanInSpanOutInPlace(in TensorSpan input); + public delegate void TensorPrimitivesSpanInSpanOut(ReadOnlySpan input, Span output); + public delegate ref readonly TensorSpan TensorSpanInSpanOut(scoped in ReadOnlyTensorSpan input, in TensorSpan destination); public static IEnumerable SpanInSpanOutData() { - yield return Create(TensorPrimitives.Abs, Tensor.Abs); - yield return Create(TensorPrimitives.Acos, Tensor.Acos); - yield return Create(TensorPrimitives.Acosh, Tensor.Acosh); - yield return Create(TensorPrimitives.AcosPi, Tensor.AcosPi); - yield return Create(TensorPrimitives.Asin, Tensor.Asin); - yield return Create(TensorPrimitives.Asinh, Tensor.Asinh); - yield return Create(TensorPrimitives.AsinPi, Tensor.AsinPi); - yield return Create(TensorPrimitives.Atan, Tensor.Atan); - yield return Create(TensorPrimitives.Atanh, Tensor.Atanh); - yield return Create(TensorPrimitives.AtanPi, Tensor.AtanPi); - yield return Create(TensorPrimitives.Cbrt, Tensor.Cbrt); - yield return Create(TensorPrimitives.Ceiling, Tensor.Ceiling); - yield return Create(TensorPrimitives.Cos, Tensor.Cos); - yield return Create(TensorPrimitives.Cosh, Tensor.Cosh); - yield return Create(TensorPrimitives.CosPi, Tensor.CosPi); - yield return Create(TensorPrimitives.DegreesToRadians, Tensor.DegreesToRadians); - yield return Create(TensorPrimitives.Exp, Tensor.Exp); - yield return Create(TensorPrimitives.Exp10, Tensor.Exp10); - yield return Create(TensorPrimitives.Exp10M1, Tensor.Exp10M1); - yield return Create(TensorPrimitives.Exp2, Tensor.Exp2); - yield return Create(TensorPrimitives.Exp2M1, Tensor.Exp2M1); - yield return Create(TensorPrimitives.ExpM1, Tensor.ExpM1); - yield return Create(TensorPrimitives.Floor, Tensor.Floor); - yield return Create(TensorPrimitives.LeadingZeroCount, Tensor.LeadingZeroCount); - yield return Create(TensorPrimitives.Log, Tensor.Log); - yield return Create(TensorPrimitives.Log10, Tensor.Log10); - yield return Create(TensorPrimitives.Log10P1, Tensor.Log10P1); - yield return Create(TensorPrimitives.Log2, Tensor.Log2); - yield return Create(TensorPrimitives.Log2P1, Tensor.Log2P1); - yield return Create(TensorPrimitives.LogP1, Tensor.LogP1); - yield return Create(TensorPrimitives.Negate, Tensor.Negate); - yield return Create(TensorPrimitives.OnesComplement, Tensor.OnesComplement); - yield return Create(TensorPrimitives.PopCount, Tensor.PopCount); - yield return Create(TensorPrimitives.RadiansToDegrees, Tensor.RadiansToDegrees); - yield return Create(TensorPrimitives.Reciprocal, Tensor.Reciprocal); - yield return Create(TensorPrimitives.Round, Tensor.Round); - yield return Create(TensorPrimitives.Sigmoid, Tensor.Sigmoid); - yield return Create(TensorPrimitives.Sin, Tensor.Sin); - yield return Create(TensorPrimitives.Sinh, Tensor.Sinh); - yield return Create(TensorPrimitives.SinPi, Tensor.SinPi); - yield return Create(TensorPrimitives.SoftMax, Tensor.SoftMax); - yield return Create(TensorPrimitives.Sqrt, Tensor.Sqrt); - yield return Create(TensorPrimitives.Tan, Tensor.Tan); - yield return Create(TensorPrimitives.Tanh, Tensor.Tanh); - yield return Create(TensorPrimitives.TanPi, Tensor.TanPi); - yield return Create(TensorPrimitives.Truncate, Tensor.Truncate); - - static object[] Create(TensorPrimitivesSpanInSpanOut tensorPrimitivesMethod, TensorSpanInSpanOut tensorOperation) + yield return Create(TensorPrimitives.Abs, Tensor.Abs); + yield return Create(TensorPrimitives.Acos, Tensor.Acos); + yield return Create(TensorPrimitives.Acosh, Tensor.Acosh); + yield return Create(TensorPrimitives.AcosPi, Tensor.AcosPi); + yield return Create(TensorPrimitives.Asin, Tensor.Asin); + yield return Create(TensorPrimitives.Asinh, Tensor.Asinh); + yield return Create(TensorPrimitives.AsinPi, Tensor.AsinPi); + yield return Create(TensorPrimitives.Atan, Tensor.Atan); + yield return Create(TensorPrimitives.Atanh, Tensor.Atanh); + yield return Create(TensorPrimitives.AtanPi, Tensor.AtanPi); + yield return Create(TensorPrimitives.Cbrt, Tensor.Cbrt); + yield return Create(TensorPrimitives.Ceiling, Tensor.Ceiling); + yield return Create(TensorPrimitives.Cos, Tensor.Cos); + yield return Create(TensorPrimitives.Cosh, Tensor.Cosh); + yield return Create(TensorPrimitives.CosPi, Tensor.CosPi); + yield return Create(TensorPrimitives.DegreesToRadians, Tensor.DegreesToRadians); + yield return Create(TensorPrimitives.Exp, Tensor.Exp); + yield return Create(TensorPrimitives.Exp10, Tensor.Exp10); + yield return Create(TensorPrimitives.Exp10M1, Tensor.Exp10M1); + yield return Create(TensorPrimitives.Exp2, Tensor.Exp2); + yield return Create(TensorPrimitives.Exp2M1, Tensor.Exp2M1); + yield return Create(TensorPrimitives.ExpM1, Tensor.ExpM1); + yield return Create(TensorPrimitives.Floor, Tensor.Floor); + yield return Create(TensorPrimitives.LeadingZeroCount, Tensor.LeadingZeroCount); + yield return Create(TensorPrimitives.Log, Tensor.Log); + yield return Create(TensorPrimitives.Log10, Tensor.Log10); + yield return Create(TensorPrimitives.Log10P1, Tensor.Log10P1); + yield return Create(TensorPrimitives.Log2, Tensor.Log2); + yield return Create(TensorPrimitives.Log2P1, Tensor.Log2P1); + yield return Create(TensorPrimitives.LogP1, Tensor.LogP1); + yield return Create(TensorPrimitives.Negate, Tensor.Negate); + yield return Create(TensorPrimitives.OnesComplement, Tensor.OnesComplement); + yield return Create(TensorPrimitives.PopCount, Tensor.PopCount); + yield return Create(TensorPrimitives.RadiansToDegrees, Tensor.RadiansToDegrees); + yield return Create(TensorPrimitives.Reciprocal, Tensor.Reciprocal); + yield return Create(TensorPrimitives.Round, Tensor.Round); + yield return Create(TensorPrimitives.Sigmoid, Tensor.Sigmoid); + yield return Create(TensorPrimitives.Sin, Tensor.Sin); + yield return Create(TensorPrimitives.Sinh, Tensor.Sinh); + yield return Create(TensorPrimitives.SinPi, Tensor.SinPi); + yield return Create(TensorPrimitives.SoftMax, Tensor.SoftMax); + yield return Create(TensorPrimitives.Sqrt, Tensor.Sqrt); + yield return Create(TensorPrimitives.Tan, Tensor.Tan); + yield return Create(TensorPrimitives.Tanh, Tensor.Tanh); + yield return Create(TensorPrimitives.TanPi, Tensor.TanPi); + yield return Create(TensorPrimitives.Truncate, Tensor.Truncate); + yield return Create(TensorPrimitives.ILogB, Tensor.ILogB); + yield return Create(TensorPrimitives.ConvertChecked, Tensor.ConvertChecked); + yield return Create(TensorPrimitives.ConvertSaturating, Tensor.ConvertSaturating); + yield return Create(TensorPrimitives.ConvertTruncating, Tensor.ConvertTruncating); + + static object[] Create(TensorPrimitivesSpanInSpanOut tensorPrimitivesMethod, TensorSpanInSpanOut tensorOperation) => new object[] { tensorPrimitivesMethod, tensorOperation }; } [Theory, MemberData(nameof(SpanInSpanOutData))] - public void TensorExtensionsSpanInSpanOut(TensorPrimitivesSpanInSpanOut tensorPrimitivesOperation, TensorSpanInSpanOut tensorOperation) - where T : INumberBase + public void TensorExtensionsSpanInSpanOut(TensorPrimitivesSpanInSpanOut tensorPrimitivesOperation, TensorSpanInSpanOut tensorOperation) + where TIn : INumberBase + where TOut: INumber { - Assert.All(Helpers.TensorShapes, tensorLength => + Assert.All(Helpers.TensorShapes, (tensorLength, index) => { nint length = CalculateTotalLength(tensorLength); - T[] data = new T[length]; - T[] data2 = new T[length]; - T[] expectedOutput = new T[length]; - FillTensor(data); - TensorSpan x = Tensor.Create(data, tensorLength, []); - TensorSpan destination = Tensor.Create(data2, tensorLength, []); - tensorPrimitivesOperation((ReadOnlySpan)data, expectedOutput); - TensorSpan results = tensorOperation(x, destination); + TIn[] data = new TIn[length]; + TOut[] data2 = new TOut[length]; + TOut[] expectedOutput = new TOut[length]; + + FillTensor(data); + TensorSpan x = Tensor.Create(data, tensorLength, []); + TensorSpan destination = Tensor.Create(data2, tensorLength, []); + tensorPrimitivesOperation((ReadOnlySpan)data, expectedOutput); + TensorSpan tensorResults = tensorOperation(x, destination); - Assert.Equal(tensorLength, results.Lengths); + Assert.Equal(tensorLength, tensorResults.Lengths); nint[] startingIndex = new nint[tensorLength.Length]; // the "Return" value - ReadOnlySpan span = MemoryMarshal.CreateSpan(ref results[startingIndex], (int)length); + ReadOnlySpan span = MemoryMarshal.CreateSpan(ref tensorResults[startingIndex], (int)length); // the "destination" value - ReadOnlySpan destSpan = MemoryMarshal.CreateSpan(ref destination[startingIndex], (int)length); + ReadOnlySpan destSpan = MemoryMarshal.CreateSpan(ref destination[startingIndex], (int)length); for (int i = 0; i < data.Length; i++) { Assert.Equal(expectedOutput[i], span[i]); Assert.Equal(expectedOutput[i], destSpan[i]); } + + // Now test if the source is sliced to be smaller then the destination that the destination is also sliced + // to the correct size. + NRange[] sliceLengths = Helpers.TensorSliceShapes[index].Select(i => new NRange(0, i)).ToArray(); + nint sliceFlattenedLength = CalculateTotalLength(Helpers.TensorSliceShapes[index]); + x = x.Slice(sliceLengths); + TIn[] sliceData = new TIn[sliceFlattenedLength]; + x.FlattenTo(sliceData); + expectedOutput = new TOut[sliceFlattenedLength]; + + if (TensorHelpers.IsContiguousAndDense(x)) + { + tensorPrimitivesOperation((ReadOnlySpan)sliceData, expectedOutput); + } + else + { + int rowLength = (int)Helpers.TensorSliceShapes[index][^1]; + for (int i = 0; i < sliceData.Length; i+= rowLength) + { + tensorPrimitivesOperation(((ReadOnlySpan)sliceData).Slice(i, rowLength), ((Span)expectedOutput).Slice(i, rowLength)); + } + + } + + tensorResults = tensorOperation(x, destination); + + // tensorResults lengths will still be the original tensorLength and not equal to the sliced length since that happened internally/automatically + Assert.Equal(tensorLength, tensorResults.Lengths); + + TensorSpan.Enumerator destEnum = destination.Slice(sliceLengths).GetEnumerator(); + TensorSpan.Enumerator tensorResultsEnum = tensorResults.Slice(sliceLengths).GetEnumerator(); + bool destEnumMove; + bool tensorResultsEnumMove; + + for (int i = 0; i < expectedOutput.Length; i++) + { + destEnumMove = destEnum.MoveNext(); + tensorResultsEnumMove = tensorResultsEnum.MoveNext(); + + Assert.True(destEnumMove); + Assert.True(tensorResultsEnumMove); + Assert.Equal(expectedOutput[i], destEnum.Current); + Assert.Equal(expectedOutput[i], tensorResultsEnum.Current); + } + + // Now test if the source and destination are sliced (so neither is continuous) it works correctly. + destination = destination.Slice(sliceLengths); + x.FlattenTo(sliceData); + expectedOutput = new TOut[sliceFlattenedLength]; + + if (TensorHelpers.IsContiguousAndDense(x)) + { + tensorPrimitivesOperation((ReadOnlySpan)sliceData, expectedOutput); + } + else + { + int rowLength = (int)Helpers.TensorSliceShapes[index][^1]; + for (int i = 0; i < sliceData.Length; i += rowLength) + { + tensorPrimitivesOperation(((ReadOnlySpan)sliceData).Slice(i, rowLength), ((Span)expectedOutput).Slice(i, rowLength)); + } + + } + + tensorResults = tensorOperation(x, destination); + + Assert.Equal(Helpers.TensorSliceShapes[index], tensorResults.Lengths); + + destEnum = destination.GetEnumerator(); + tensorResultsEnum = tensorResults.GetEnumerator(); + + for (int i = 0; i < expectedOutput.Length; i++) + { + destEnumMove = destEnum.MoveNext(); + tensorResultsEnumMove = tensorResultsEnum.MoveNext(); + + Assert.True(destEnumMove); + Assert.True(tensorResultsEnumMove); + Assert.Equal(expectedOutput[i], destEnum.Current); + Assert.Equal(expectedOutput[i], tensorResultsEnum.Current); + } }); } @@ -146,7 +233,7 @@ static object[] Create(TensorPrimitivesSpanInTOut tensorPrimitivesMethod, public void TensorExtensionsSpanInTOut(TensorPrimitivesSpanInTOut tensorPrimitivesOperation, TensorSpanInTOut tensorOperation) where T : INumberBase { - Assert.All(Helpers.TensorShapes, tensorLength => + Assert.All(Helpers.TensorShapes, (tensorLength, index) => { nint length = CalculateTotalLength(tensorLength); T[] data = new T[length]; @@ -157,6 +244,40 @@ public void TensorExtensionsSpanInTOut(TensorPrimitivesSpanInTOut tensorPr T results = tensorOperation(x); Assert.Equal(expectedOutput, results); + + float[] testData = [49.788437f, 32.736755f, -0.25761032f, -46.402596f, 4.5581512f, 21.813591f, 44.976646f, 12.691814f, -44.188023f, 40.35988f, -6.999405f, 4.713642f, 5.274975f, 21.312515f, -12.536407f, -34.888573f, -1.90839f, 28.734451f, -38.64155f, -28.840702f, 7.373543f, 18.600182f, 26.007828f, 0.71430206f, -6.8293495f, -13.327972f, -25.149017f, 9.331852f, 40.87751f, 28.321632f, 42.918175f, 25.213333f, -41.392017f, 36.727768f, 26.49012f, 3.8807983f, 24.933182f, -43.050568f, -42.6283f, 18.01947f, -47.62874f, -49.94487f, -1.036602f, -37.086433f, 32.77098f, -12.903477f, -45.100212f, -20.596504f, 33.67714f, 46.864395f, 44.437485f, -44.092155f, 37.122124f, 25.220505f, 41.994873f, -13.3394165f, -28.193134f, -21.329712f, -36.623306f, 3.3981133f, -26.475079f, 16.339478f, -44.07065f, 36.321762f, -24.63433f, 28.652397f, 4.096817f, 33.29615f, -2.3503838f, -7.509815f, 42.943604f, -32.52115f, -0.20326233f, 29.554626f, 18.044052f]; + nint[] testLengths = [5, 3, 5]; + Tensor testTensor = Tensor.Create(testData, testLengths, []); + float[] testSliceData = new float[75]; + testTensor.FlattenTo(testSliceData); + float testExpectedOutput = TensorPrimitives.Sum((ReadOnlySpan)testSliceData); + float testResults = Tensor.Sum(testTensor); + + + // Now test if the source is sliced to be non contiguous that it still gives expected result. + NRange[] sliceLengths = Helpers.TensorSliceShapes[index].Select(i => new NRange(0, i)).ToArray(); + nint sliceFlattenedLength = CalculateTotalLength(Helpers.TensorSliceShapes[index]); + x = x.Slice(sliceLengths); + T[] sliceData = new T[sliceFlattenedLength]; + x.FlattenTo(sliceData); + + IEnumerator enumerator = x.GetEnumerator(); + bool cont = enumerator.MoveNext(); + ReadOnlySpan span = MemoryMarshal.CreateSpan(ref x.AsReadOnlyTensorSpan()._reference, (int)x.FlattenedLength); + int i = 0; + Assert.True(span.SequenceEqual(sliceData)); + while (cont) + { + Assert.Equal(sliceData[i], enumerator.Current); + Assert.Equal(span[i], enumerator.Current); + Assert.Equal(span[i], sliceData[i++]); + cont = enumerator.MoveNext(); + } + + expectedOutput = tensorPrimitivesOperation((ReadOnlySpan)sliceData); + results = tensorOperation(x); + + Assert.Equal(expectedOutput, results); }); } @@ -184,9 +305,9 @@ static object[] Create(TensorPrimitivesTwoSpanInSpanOut tensorPrimitivesMe public void TensorExtensionsTwoSpanInSpanOut(TensorPrimitivesTwoSpanInSpanOut tensorPrimitivesOperation, TensorTwoSpanInSpanOut tensorOperation) where T : INumberBase { - Assert.All(Helpers.TensorShapes, tensorLength => + Assert.All(Helpers.TensorShapes, (tensorLengths, index) => { - nint length = CalculateTotalLength(tensorLength); + nint length = CalculateTotalLength(tensorLengths); T[] data1 = new T[length]; T[] data2 = new T[length]; T[] destData = new T[length]; @@ -194,14 +315,17 @@ public void TensorExtensionsTwoSpanInSpanOut(TensorPrimitivesTwoSpanInSpanOut FillTensor(data1); FillTensor(data2); - TensorSpan x = Tensor.Create(data1, tensorLength, []); - TensorSpan y = Tensor.Create(data2, tensorLength, []); - TensorSpan destination = Tensor.Create(destData, tensorLength, []); + + + // First test when everything is exact sizes + TensorSpan x = Tensor.Create(data1, tensorLengths, []); + TensorSpan y = Tensor.Create(data2, tensorLengths, []); + TensorSpan destination = Tensor.Create(destData, tensorLengths, []); tensorPrimitivesOperation((ReadOnlySpan)data1, data2, expectedOutput); TensorSpan results = tensorOperation(x, y, destination); - Assert.Equal(tensorLength, results.Lengths); - nint[] startingIndex = new nint[tensorLength.Length]; + Assert.Equal(tensorLengths, results.Lengths); + nint[] startingIndex = new nint[tensorLengths.Length]; // the "Return" value ReadOnlySpan span = MemoryMarshal.CreateSpan(ref results[startingIndex], (int)length); // the "destination" value @@ -212,6 +336,148 @@ public void TensorExtensionsTwoSpanInSpanOut(TensorPrimitivesTwoSpanInSpanOut Assert.Equal(expectedOutput[i], span[i]); Assert.Equal(expectedOutput[i], destSpan[i]); } + + // Now test when both sources are exact sizes but destination is too large and gets sliced internally. + nint[] tempLengths = tensorLengths.Select(i => i + 1).ToArray(); + T[] tempDestData = new T[CalculateTotalLength(tempLengths)]; + destination = Tensor.Create(tempDestData, tempLengths, []); + results = tensorOperation(x, y, destination); + + // Since the slice was internal the result lengths will be the extra large size. + Assert.Equal(tempLengths, results.Lengths); + startingIndex = new nint[tensorLengths.Length]; + + TensorSpan.Enumerator destEnum = destination.Slice(tensorLengths).GetEnumerator(); + TensorSpan.Enumerator tensorResultsEnum = results.Slice(tensorLengths).GetEnumerator(); + bool destEnumMove; + bool tensorResultsEnumMove; + + for (int i = 0; i < expectedOutput.Length; i++) + { + destEnumMove = destEnum.MoveNext(); + tensorResultsEnumMove = tensorResultsEnum.MoveNext(); + + Assert.True(destEnumMove); + Assert.True(tensorResultsEnumMove); + Assert.Equal(expectedOutput[i], destEnum.Current); + Assert.Equal(expectedOutput[i], tensorResultsEnum.Current); + } + + // Now test if the first source is sliced to be smaller than the second (but is broadcast compatible) that broadcasting happens). + int rowLength = (int)Helpers.TensorSliceShapesForBroadcast[index][^1]; + + NRange[] sliceLengths = Helpers.TensorSliceShapesForBroadcast[index].Select(i => new NRange(0, i)).ToArray(); + nint sliceFlattenedLength = CalculateTotalLength(Helpers.TensorSliceShapesForBroadcast[index]); + destination = destination.Slice(tensorLengths); + x.Slice(sliceLengths).BroadcastTo(x); + x.FlattenTo(data1); + + if (TensorHelpers.IsContiguousAndDense(x.Slice(sliceLengths)) && TensorHelpers.IsContiguousAndDense(y)) + { + tensorPrimitivesOperation((ReadOnlySpan)data1, data2, expectedOutput); + } + else + { + for (int i = 0; i < data1.Length; i += rowLength) + { + tensorPrimitivesOperation(((ReadOnlySpan)data1).Slice(i, rowLength), ((ReadOnlySpan)data2).Slice(i, rowLength), ((Span)expectedOutput).Slice(i, rowLength)); + } + + } + + results = tensorOperation(x.Slice(sliceLengths), y, destination); + + // results lengths will still be the original tensorLength + Assert.Equal(tensorLengths, results.Lengths); + + destEnum = destination.GetEnumerator(); + tensorResultsEnum = results.GetEnumerator(); + + for (int i = 0; i < expectedOutput.Length; i++) + { + destEnumMove = destEnum.MoveNext(); + tensorResultsEnumMove = tensorResultsEnum.MoveNext(); + + Assert.True(destEnumMove); + Assert.True(tensorResultsEnumMove); + Assert.Equal(expectedOutput[i], destEnum.Current); + Assert.Equal(expectedOutput[i], tensorResultsEnum.Current); + } + + // Now test if the second source is sliced to be smaller than the first (but is broadcast compatible) that broadcasting happens). + y.Slice(sliceLengths).BroadcastTo(y); + y.FlattenTo(data2); + + if (TensorHelpers.IsContiguousAndDense(x) && TensorHelpers.IsContiguousAndDense(y.Slice(sliceLengths))) + { + tensorPrimitivesOperation((ReadOnlySpan)data1, data2, expectedOutput); + } + else + { + for (int i = 0; i < data2.Length; i += rowLength) + { + tensorPrimitivesOperation(((ReadOnlySpan)data1).Slice(i, rowLength), ((ReadOnlySpan)data2).Slice(i, rowLength), ((Span)expectedOutput).Slice(i, rowLength)); + } + + } + + results = tensorOperation(x, y.Slice(sliceLengths), destination); + + // results lengths will still be the original tensorLength + Assert.Equal(tensorLengths, results.Lengths); + + destEnum = destination.GetEnumerator(); + tensorResultsEnum = results.GetEnumerator(); + + for (int i = 0; i < expectedOutput.Length; i++) + { + destEnumMove = destEnum.MoveNext(); + tensorResultsEnumMove = tensorResultsEnum.MoveNext(); + + Assert.True(destEnumMove); + Assert.True(tensorResultsEnumMove); + Assert.Equal(expectedOutput[i], destEnum.Current); + Assert.Equal(expectedOutput[i], tensorResultsEnum.Current); + } + + // Now test if both sources are sliced to be smaller than the destination that the destination will be sliced automatically + T[] sliceData1 = new T[sliceFlattenedLength]; + T[] sliceData2 = new T[sliceFlattenedLength]; + expectedOutput = new T[sliceFlattenedLength]; + + x.Slice(sliceLengths).FlattenTo(sliceData1); + y.Slice(sliceLengths).FlattenTo(sliceData2); + + if (TensorHelpers.IsContiguousAndDense(x.Slice(sliceLengths)) && TensorHelpers.IsContiguousAndDense(y.Slice(sliceLengths))) + { + tensorPrimitivesOperation((ReadOnlySpan)sliceData1, sliceData2, expectedOutput); + } + else + { + for (int i = 0; i < sliceData1.Length; i += rowLength) + { + tensorPrimitivesOperation(((ReadOnlySpan)sliceData1).Slice(i, rowLength), ((ReadOnlySpan)sliceData2).Slice(i, rowLength), ((Span)expectedOutput).Slice(i, rowLength)); + } + + } + + results = tensorOperation(x.Slice(sliceLengths), y.Slice(sliceLengths), destination); + + Assert.Equal(tensorLengths, results.Lengths); + + destEnum = destination.Slice(sliceLengths).GetEnumerator(); + tensorResultsEnum = results.Slice(sliceLengths).GetEnumerator(); + + for (int i = 0; i < expectedOutput.Length; i++) + { + destEnumMove = destEnum.MoveNext(); + tensorResultsEnumMove = tensorResultsEnum.MoveNext(); + + Assert.True(destEnumMove); + Assert.True(tensorResultsEnumMove); + Assert.Equal(expectedOutput[i], destEnum.Current); + Assert.Equal(expectedOutput[i], tensorResultsEnum.Current); + } }); } @@ -220,7 +486,7 @@ public void TensorExtensionsTwoSpanInSpanOut(TensorPrimitivesTwoSpanInSpanOut public static IEnumerable TwoSpanInFloatOutData() { yield return Create(TensorPrimitives.Distance, Tensor.Distance); - yield return Create(TensorPrimitives.Dot, Tensor.Dot); + //yield return Create(TensorPrimitives.Dot, Tensor.Dot); static object[] Create(TensorPrimitivesTwoSpanInTOut tensorPrimitivesMethod, TensorTwoSpanInTOut tensorOperation) => new object[] { tensorPrimitivesMethod, tensorOperation }; @@ -230,11 +496,13 @@ static object[] Create(TensorPrimitivesTwoSpanInTOut tensorPrimitivesMetho public void TensorExtensionsTwoSpanInFloatOut(TensorPrimitivesTwoSpanInTOut tensorPrimitivesOperation, TensorTwoSpanInTOut tensorOperation) where T : INumberBase { - Assert.All(Helpers.TensorShapes, tensorLength => + Assert.All(Helpers.TensorShapes, (tensorLength, index) => { nint length = CalculateTotalLength(tensorLength); T[] data1 = new T[length]; T[] data2 = new T[length]; + T[] broadcastData1 = new T[length]; + T[] broadcastData2 = new T[length]; FillTensor(data1); FillTensor(data2); @@ -244,6 +512,43 @@ public void TensorExtensionsTwoSpanInFloatOut(TensorPrimitivesTwoSpanInTOut new NRange(0, i)).ToArray(); + TensorSpan broadcastX = Tensor.Create(broadcastData1, tensorLength, []); + x.Slice(sliceLengths).BroadcastTo(broadcastX); + TensorSpan.Enumerator enumerator = broadcastX.GetEnumerator(); + bool cont = enumerator.MoveNext(); + int i = 0; + while (cont) + { + Assert.Equal(broadcastData1[i++], enumerator.Current); + cont = enumerator.MoveNext(); + } + + expectedOutput = tensorPrimitivesOperation((ReadOnlySpan)broadcastData1, data2); + results = tensorOperation(x.Slice(sliceLengths), y); + + Assert.Equal(expectedOutput, results); + + // Now test if the second source is sliced to be non contiguous that it still gives expected result. + + TensorSpan broadcastY = Tensor.Create(broadcastData2, tensorLength, []); + y.Slice(sliceLengths).BroadcastTo(broadcastY); + + enumerator = broadcastY.GetEnumerator(); + cont = enumerator.MoveNext(); + i = 0; + while (cont) + { + Assert.Equal(broadcastData2[i++], enumerator.Current); + cont = enumerator.MoveNext(); + } + + expectedOutput = tensorPrimitivesOperation((ReadOnlySpan)data1, broadcastData2); + results = tensorOperation(x, y.Slice(sliceLengths)); + + Assert.Equal(expectedOutput, results); }); }