From a71cc52c405ed860431aaf8ff363a7c8622731ff Mon Sep 17 00:00:00 2001 From: Huo Yaoyuan Date: Tue, 27 Feb 2024 00:57:19 +0800 Subject: [PATCH] Fix ValueType.GetHashCode not calling overriden method on nested field (#98754) * Handle overridden value type method * Call virtual for NativeAot * Add test * Change to box * Fold GetHashCodeImpl * NativeAOT cleanup * Don't create duplicated HashCode * Update test to not use default equals of nested field * Return span from method instead --- .../src/System/ValueType.cs | 11 ++++- .../src/System/ValueType.cs | 45 ++++++------------- src/coreclr/vm/comutilnative.cpp | 19 +++++--- src/coreclr/vm/comutilnative.h | 2 +- .../System/ValueTypeTests.cs | 30 +++++++++++++ 5 files changed, 68 insertions(+), 39 deletions(-) diff --git a/src/coreclr/System.Private.CoreLib/src/System/ValueType.cs b/src/coreclr/System.Private.CoreLib/src/System/ValueType.cs index 78301866c36dce..f4c3acb31adf88 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/ValueType.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/ValueType.cs @@ -120,7 +120,7 @@ public override unsafe int GetHashCode() else { object thisRef = this; - switch (GetHashCodeStrategy(pMT, ObjectHandleOnStack.Create(ref thisRef), out uint fieldOffset, out uint fieldSize)) + switch (GetHashCodeStrategy(pMT, ObjectHandleOnStack.Create(ref thisRef), out uint fieldOffset, out uint fieldSize, out MethodTable* fieldMT)) { case ValueTypeHashCodeStrategy.ReferenceField: hashCode.Add(Unsafe.As(ref Unsafe.AddByteOffset(ref rawData, fieldOffset)).GetHashCode()); @@ -138,6 +138,12 @@ public override unsafe int GetHashCode() Debug.Assert(fieldSize != 0); hashCode.AddBytes(MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AddByteOffset(ref rawData, fieldOffset), (int)fieldSize)); break; + + case ValueTypeHashCodeStrategy.ValueTypeOverride: + Debug.Assert(fieldMT != null); + // Box the field to handle complicated cases like mutable method and shared generic + hashCode.Add(RuntimeHelpers.Box(fieldMT, ref Unsafe.AddByteOffset(ref rawData, fieldOffset))?.GetHashCode() ?? 0); + break; } } @@ -152,11 +158,12 @@ private enum ValueTypeHashCodeStrategy DoubleField, SingleField, FastGetHashCode, + ValueTypeOverride, } [LibraryImport(RuntimeHelpers.QCall, EntryPoint = "ValueType_GetHashCodeStrategy")] private static unsafe partial ValueTypeHashCodeStrategy GetHashCodeStrategy( - MethodTable* pMT, ObjectHandleOnStack objHandle, out uint fieldOffset, out uint fieldSize); + MethodTable* pMT, ObjectHandleOnStack objHandle, out uint fieldOffset, out uint fieldSize, out MethodTable* fieldMT); public override string? ToString() { diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/ValueType.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/ValueType.cs index e8340e41191513..968e97c425cf81 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/ValueType.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/ValueType.cs @@ -95,43 +95,28 @@ public override unsafe bool Equals([NotNullWhen(true)] object? obj) public override unsafe int GetHashCode() { - int hashCode = (int)this.GetMethodTable()->HashCode; + HashCode hashCode = default; + hashCode.Add((IntPtr)this.GetMethodTable()); - hashCode ^= GetHashCodeImpl(); - - return hashCode; - } - - private unsafe int GetHashCodeImpl() - { int numFields = __GetFieldHelper(GetNumFields, out _); if (numFields == UseFastHelper) - return FastGetValueTypeHashCodeHelper(this.GetMethodTable(), ref this.GetRawData()); + hashCode.AddBytes(GetSpanForField(this.GetMethodTable(), ref this.GetRawData())); + else + RegularGetValueTypeHashCode(ref hashCode, ref this.GetRawData(), numFields); - return RegularGetValueTypeHashCode(ref this.GetRawData(), numFields); + return hashCode.ToHashCode(); } - private static unsafe int FastGetValueTypeHashCodeHelper(MethodTable* type, ref byte data) + private static unsafe ReadOnlySpan GetSpanForField(MethodTable* type, ref byte data) { // Sanity check - if there are GC references, we should not be hashing bytes Debug.Assert(!type->ContainsGCPointers); - - int size = (int)type->ValueTypeSize; - int hashCode = 0; - - for (int i = 0; i < size / 4; i++) - { - hashCode ^= Unsafe.As(ref Unsafe.Add(ref data, i * 4)); - } - - return hashCode; + return new ReadOnlySpan(ref data, (int)type->ValueTypeSize); } - private unsafe int RegularGetValueTypeHashCode(ref byte data, int numFields) + private unsafe void RegularGetValueTypeHashCode(ref HashCode hashCode, ref byte data, int numFields) { - int hashCode = 0; - // We only take the hashcode for the first non-null field. That's what the CLR does. for (int i = 0; i < numFields; i++) { @@ -142,15 +127,15 @@ private unsafe int RegularGetValueTypeHashCode(ref byte data, int numFields) if (fieldType->ElementType == EETypeElementType.Single) { - hashCode = Unsafe.As(ref fieldData).GetHashCode(); + hashCode.Add(Unsafe.As(ref fieldData)); } else if (fieldType->ElementType == EETypeElementType.Double) { - hashCode = Unsafe.As(ref fieldData).GetHashCode(); + hashCode.Add(Unsafe.As(ref fieldData)); } else if (fieldType->IsPrimitive) { - hashCode = FastGetValueTypeHashCodeHelper(fieldType, ref fieldData); + hashCode.AddBytes(GetSpanForField(fieldType, ref fieldData)); } else if (fieldType->IsValueType) { @@ -164,7 +149,7 @@ private unsafe int RegularGetValueTypeHashCode(ref byte data, int numFields) var fieldValue = (ValueType)RuntimeImports.RhBox(fieldType, ref fieldData); if (fieldValue != null) { - hashCode = fieldValue.GetHashCodeImpl(); + hashCode.Add(fieldValue); } else { @@ -177,7 +162,7 @@ private unsafe int RegularGetValueTypeHashCode(ref byte data, int numFields) object fieldValue = Unsafe.As(ref fieldData); if (fieldValue != null) { - hashCode = fieldValue.GetHashCode(); + hashCode.Add(fieldValue); } else { @@ -187,8 +172,6 @@ private unsafe int RegularGetValueTypeHashCode(ref byte data, int numFields) } break; } - - return hashCode; } } } diff --git a/src/coreclr/vm/comutilnative.cpp b/src/coreclr/vm/comutilnative.cpp index 612cb9d72dc0dd..6c7e2468d2744b 100644 --- a/src/coreclr/vm/comutilnative.cpp +++ b/src/coreclr/vm/comutilnative.cpp @@ -1703,9 +1703,10 @@ enum ValueTypeHashCodeStrategy DoubleField, SingleField, FastGetHashCode, + ValueTypeOverride, }; -static ValueTypeHashCodeStrategy GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize) +static ValueTypeHashCodeStrategy GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize, MethodTable** fieldMTOut) { CONTRACTL { @@ -1772,10 +1773,18 @@ static ValueTypeHashCodeStrategy GetHashCodeStrategy(MethodTable* mt, QCall::Obj *fieldSize = field->LoadSize(); ret = ValueTypeHashCodeStrategy::FastGetHashCode; } + else if (HasOverriddenMethod(fieldMT, + CoreLibBinder::GetClass(CLASS__VALUE_TYPE), + CoreLibBinder::GetMethod(METHOD__VALUE_TYPE__GET_HASH_CODE)->GetSlot())) + { + *fieldOffset += field->GetOffsetUnsafe(); + *fieldMTOut = fieldMT; + ret = ValueTypeHashCodeStrategy::ValueTypeOverride; + } else { *fieldOffset += field->GetOffsetUnsafe(); - ret = GetHashCodeStrategy(fieldMT, objHandle, fieldOffset, fieldSize); + ret = GetHashCodeStrategy(fieldMT, objHandle, fieldOffset, fieldSize, fieldMTOut); } } } @@ -1785,18 +1794,18 @@ static ValueTypeHashCodeStrategy GetHashCodeStrategy(MethodTable* mt, QCall::Obj return ret; } -extern "C" INT32 QCALLTYPE ValueType_GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize) +extern "C" INT32 QCALLTYPE ValueType_GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize, MethodTable** fieldMT) { QCALL_CONTRACT; ValueTypeHashCodeStrategy ret = ValueTypeHashCodeStrategy::None; *fieldOffset = 0; *fieldSize = 0; + *fieldMT = NULL; BEGIN_QCALL; - - ret = GetHashCodeStrategy(mt, objHandle, fieldOffset, fieldSize); + ret = GetHashCodeStrategy(mt, objHandle, fieldOffset, fieldSize, fieldMT); END_QCALL; diff --git a/src/coreclr/vm/comutilnative.h b/src/coreclr/vm/comutilnative.h index a3c5ea65c3ca7c..0f305e0af90072 100644 --- a/src/coreclr/vm/comutilnative.h +++ b/src/coreclr/vm/comutilnative.h @@ -252,7 +252,7 @@ class MethodTableNative { extern "C" BOOL QCALLTYPE MethodTable_AreTypesEquivalent(MethodTable* mta, MethodTable* mtb); extern "C" BOOL QCALLTYPE MethodTable_CanCompareBitsOrUseFastGetHashCode(MethodTable* mt); -extern "C" INT32 QCALLTYPE ValueType_GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize); +extern "C" INT32 QCALLTYPE ValueType_GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize, MethodTable** fieldMT); class StreamNative { public: diff --git a/src/libraries/System.Runtime/tests/System.Runtime.Tests/System/ValueTypeTests.cs b/src/libraries/System.Runtime/tests/System.Runtime.Tests/System/ValueTypeTests.cs index 92a2c006ce2042..92c7000ed414d5 100644 --- a/src/libraries/System.Runtime/tests/System.Runtime.Tests/System/ValueTypeTests.cs +++ b/src/libraries/System.Runtime/tests/System.Runtime.Tests/System/ValueTypeTests.cs @@ -315,6 +315,21 @@ public static void StructContainsPointerNestedCompareTest() Assert.Equal(obj1.GetHashCode(), obj2.GetHashCode()); } + [Fact] + public static void StructWithNestedOverriddenNotBitwiseComparableTest() + { + StructWithNestedOverriddenNotBitwiseComparable obj1 = new StructWithNestedOverriddenNotBitwiseComparable(); + obj1.value1.value = 1; + obj1.value2.value = 0; + + StructWithNestedOverriddenNotBitwiseComparable obj2 = new StructWithNestedOverriddenNotBitwiseComparable(); + obj2.value1.value = -1; + obj2.value2.value = 0; + + Assert.True(obj1.Equals(obj2)); + Assert.Equal(obj1.GetHashCode(), obj2.GetHashCode()); + } + public struct S { public int x; @@ -413,5 +428,20 @@ public struct StructContainsPointerNested public object o; public StructNonOverriddenEqualsOrGetHasCode value; } + + public struct StructOverriddenNotBitwiseComparable + { + public int value; + + public override bool Equals(object obj) => obj is StructOverriddenNotBitwiseComparable other && (value == other.value || value == -other.value); + + public override int GetHashCode() => value < 0 ? -value : value; + } + + public struct StructWithNestedOverriddenNotBitwiseComparable + { + public StructOverriddenNotBitwiseComparable value1; + public StructOverriddenNotBitwiseComparable value2; + } } }