From 7acab7ac040622b1733ebda4c2faae75166fb36f Mon Sep 17 00:00:00 2001 From: Manodasan Wignarajah Date: Thu, 9 Sep 2021 11:35:15 -0700 Subject: [PATCH] Fix array marshalers to handle null (#983) --- .../UnitTest/TestComponentCSharp_Tests.cs | 21 ++++++++++++++++-- src/Tests/UnitTest/TestComponent_Tests.cs | 22 +++++++++++++++++++ src/WinRT.Runtime/Marshalers.cs | 20 +++++++++++++++-- 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/src/Tests/UnitTest/TestComponentCSharp_Tests.cs b/src/Tests/UnitTest/TestComponentCSharp_Tests.cs index 65dd6010b..3319b3a9f 100644 --- a/src/Tests/UnitTest/TestComponentCSharp_Tests.cs +++ b/src/Tests/UnitTest/TestComponentCSharp_Tests.cs @@ -118,13 +118,19 @@ public void TestEnums() TestObject.EnumsProperty = expectedEnums; Assert.Equal(expectedEnums, TestObject.EnumsProperty); TestObject.CallForEnums(() => expectedEnums); - Assert.Equal(expectedEnums, TestObject.EnumsProperty); + Assert.Equal(expectedEnums, TestObject.EnumsProperty); + + TestObject.EnumsProperty = null; + Assert.Equal(null, TestObject.EnumsProperty); var expectedEnumStructs = new EnumStruct[] { new EnumStruct(EnumValue.One), new EnumStruct(EnumValue.Two) }; TestObject.EnumStructsProperty = expectedEnumStructs; Assert.Equal(expectedEnumStructs, TestObject.EnumStructsProperty); TestObject.CallForEnumStructs(() => expectedEnumStructs); - Assert.Equal(expectedEnumStructs, TestObject.EnumStructsProperty); + Assert.Equal(expectedEnumStructs, TestObject.EnumStructsProperty); + + TestObject.EnumStructsProperty = null; + Assert.Equal(null, TestObject.EnumStructsProperty); // Flags var expectedFlag = FlagValue.All; @@ -1167,6 +1173,17 @@ public void TestComposedNonBlittableStruct() Assert.True(val.bools.x); Assert.False(val.bools.y); Assert.True(val.bools.z); + } + + [Fact] + public void TestBlittableArrays() + { + int[] arr = new[] { 2, 4, 6, 8 }; + TestObject.SetInts(arr); + Assert.True(TestObject.GetInts().SequenceEqual(arr)); + + TestObject.SetInts(null); + Assert.Null(TestObject.GetInts()); } #if NETCOREAPP2_0 diff --git a/src/Tests/UnitTest/TestComponent_Tests.cs b/src/Tests/UnitTest/TestComponent_Tests.cs index fc29de985..aec109c45 100644 --- a/src/Tests/UnitTest/TestComponent_Tests.cs +++ b/src/Tests/UnitTest/TestComponent_Tests.cs @@ -376,6 +376,17 @@ public void Array_String() Assert.True(AllEqual(a, b, c, d)); } + [Fact] + public void Array_NullStringArray() + { + string[] a = null; + string[] b = null; + string[] c; + string[] d = Tests.Array12(a, b, out c); + Assert.Null(c); + Assert.Null(d); + } + [Fact] public void Array_Blittable() { @@ -437,6 +448,17 @@ public void Array_Stringable() Assert.True(AllEqual(a, b, c, d)); } + [Fact] + public void Array_NullInterfaces() + { + IStringable[] a = null; + IStringable[] b = null; + IStringable[] c; + IStringable[] d = Tests.Array16(a, b, out c); + Assert.Null(c); + Assert.Null(d); + } + private T[] Array_Call(T[] a, T[] b, out T[] c) { Assert.True(a.Length == b.Length); diff --git a/src/WinRT.Runtime/Marshalers.cs b/src/WinRT.Runtime/Marshalers.cs index 09f292393..7278ea81d 100644 --- a/src/WinRT.Runtime/Marshalers.cs +++ b/src/WinRT.Runtime/Marshalers.cs @@ -174,6 +174,10 @@ public static unsafe string[] FromAbiArray(object box) return null; } var abi = ((int length, IntPtr data))box; + if (abi.data == IntPtr.Zero) + { + return null; + } string[] array = new string[abi.length]; var data = (IntPtr*)abi.data.ToPointer(); for (int i = 0; i < abi.length; i++) @@ -277,7 +281,7 @@ public struct MarshalBlittable { public struct MarshalerArray { - public MarshalerArray(Array array) => _gchandle = GCHandle.Alloc(array, GCHandleType.Pinned); + public MarshalerArray(Array array) => _gchandle = array is null ? default : GCHandle.Alloc(array, GCHandleType.Pinned); public void Dispose() => _gchandle.Dispose(); public GCHandle _gchandle; @@ -288,7 +292,7 @@ public struct MarshalerArray public static (int length, IntPtr data) GetAbiArray(object box) { var m = (MarshalerArray)box; - return (((Array)m._gchandle.Target).Length, m._gchandle.AddrOfPinnedObject()); + return m._gchandle.IsAllocated ? (((Array)m._gchandle.Target).Length, m._gchandle.AddrOfPinnedObject()) : (0, IntPtr.Zero); } public static unsafe T[] FromAbiArray(object box) @@ -298,6 +302,10 @@ public static unsafe T[] FromAbiArray(object box) return null; } var abi = ((int length, IntPtr data))box; + if (abi.data == IntPtr.Zero) + { + return null; + } var abiSpan = new ReadOnlySpan(abi.data.ToPointer(), abi.length); return abiSpan.ToArray(); } @@ -583,6 +591,10 @@ public void Dispose() return null; } var abi = ((int length, IntPtr data))box; + if (abi.data == IntPtr.Zero) + { + return null; + } var array = new T[abi.length]; var data = (byte*)abi.data.ToPointer(); var abi_element_size = Marshal.SizeOf(AbiType); @@ -763,6 +775,10 @@ public static unsafe T[] FromAbiArray(object box, Func fromAbi) return null; } var abi = ((int length, IntPtr data))box; + if (abi.data == IntPtr.Zero) + { + return null; + } var array = new T[abi.length]; var data = (IntPtr*)abi.data.ToPointer(); for (int i = 0; i < abi.length; i++)