Skip to content

Commit

Permalink
Fix ValueType.GetHashCode not calling overriden method on nested field (
Browse files Browse the repository at this point in the history
#98754)

* Handle overridden value type method

* Call virtual for NativeAot

* Add test

* Change to box

* Fold GetHashCodeImpl

* NativeAOT cleanup

* Don't create duplicated HashCode

* Update test to not use default equals of nested field

* Return span from method instead
  • Loading branch information
huoyaoyuan authored Feb 26, 2024
1 parent 66aebed commit a71cc52
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 39 deletions.
11 changes: 9 additions & 2 deletions src/coreclr/System.Private.CoreLib/src/System/ValueType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public override unsafe int GetHashCode()
else
{
object thisRef = this;
switch (GetHashCodeStrategy(pMT, ObjectHandleOnStack.Create(ref thisRef), out uint fieldOffset, out uint fieldSize))
switch (GetHashCodeStrategy(pMT, ObjectHandleOnStack.Create(ref thisRef), out uint fieldOffset, out uint fieldSize, out MethodTable* fieldMT))
{
case ValueTypeHashCodeStrategy.ReferenceField:
hashCode.Add(Unsafe.As<byte, object>(ref Unsafe.AddByteOffset(ref rawData, fieldOffset)).GetHashCode());
Expand All @@ -138,6 +138,12 @@ public override unsafe int GetHashCode()
Debug.Assert(fieldSize != 0);
hashCode.AddBytes(MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AddByteOffset(ref rawData, fieldOffset), (int)fieldSize));
break;

case ValueTypeHashCodeStrategy.ValueTypeOverride:
Debug.Assert(fieldMT != null);
// Box the field to handle complicated cases like mutable method and shared generic
hashCode.Add(RuntimeHelpers.Box(fieldMT, ref Unsafe.AddByteOffset(ref rawData, fieldOffset))?.GetHashCode() ?? 0);
break;
}
}

Expand All @@ -152,11 +158,12 @@ private enum ValueTypeHashCodeStrategy
DoubleField,
SingleField,
FastGetHashCode,
ValueTypeOverride,
}

[LibraryImport(RuntimeHelpers.QCall, EntryPoint = "ValueType_GetHashCodeStrategy")]
private static unsafe partial ValueTypeHashCodeStrategy GetHashCodeStrategy(
MethodTable* pMT, ObjectHandleOnStack objHandle, out uint fieldOffset, out uint fieldSize);
MethodTable* pMT, ObjectHandleOnStack objHandle, out uint fieldOffset, out uint fieldSize, out MethodTable* fieldMT);

public override string? ToString()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,43 +95,28 @@ public override unsafe bool Equals([NotNullWhen(true)] object? obj)

public override unsafe int GetHashCode()
{
int hashCode = (int)this.GetMethodTable()->HashCode;
HashCode hashCode = default;
hashCode.Add((IntPtr)this.GetMethodTable());

hashCode ^= GetHashCodeImpl();

return hashCode;
}

private unsafe int GetHashCodeImpl()
{
int numFields = __GetFieldHelper(GetNumFields, out _);

if (numFields == UseFastHelper)
return FastGetValueTypeHashCodeHelper(this.GetMethodTable(), ref this.GetRawData());
hashCode.AddBytes(GetSpanForField(this.GetMethodTable(), ref this.GetRawData()));
else
RegularGetValueTypeHashCode(ref hashCode, ref this.GetRawData(), numFields);

return RegularGetValueTypeHashCode(ref this.GetRawData(), numFields);
return hashCode.ToHashCode();
}

private static unsafe int FastGetValueTypeHashCodeHelper(MethodTable* type, ref byte data)
private static unsafe ReadOnlySpan<byte> GetSpanForField(MethodTable* type, ref byte data)
{
// Sanity check - if there are GC references, we should not be hashing bytes
Debug.Assert(!type->ContainsGCPointers);

int size = (int)type->ValueTypeSize;
int hashCode = 0;

for (int i = 0; i < size / 4; i++)
{
hashCode ^= Unsafe.As<byte, int>(ref Unsafe.Add(ref data, i * 4));
}

return hashCode;
return new ReadOnlySpan<byte>(ref data, (int)type->ValueTypeSize);
}

private unsafe int RegularGetValueTypeHashCode(ref byte data, int numFields)
private unsafe void RegularGetValueTypeHashCode(ref HashCode hashCode, ref byte data, int numFields)
{
int hashCode = 0;

// We only take the hashcode for the first non-null field. That's what the CLR does.
for (int i = 0; i < numFields; i++)
{
Expand All @@ -142,15 +127,15 @@ private unsafe int RegularGetValueTypeHashCode(ref byte data, int numFields)

if (fieldType->ElementType == EETypeElementType.Single)
{
hashCode = Unsafe.As<byte, float>(ref fieldData).GetHashCode();
hashCode.Add(Unsafe.As<byte, float>(ref fieldData));
}
else if (fieldType->ElementType == EETypeElementType.Double)
{
hashCode = Unsafe.As<byte, double>(ref fieldData).GetHashCode();
hashCode.Add(Unsafe.As<byte, double>(ref fieldData));
}
else if (fieldType->IsPrimitive)
{
hashCode = FastGetValueTypeHashCodeHelper(fieldType, ref fieldData);
hashCode.AddBytes(GetSpanForField(fieldType, ref fieldData));
}
else if (fieldType->IsValueType)
{
Expand All @@ -164,7 +149,7 @@ private unsafe int RegularGetValueTypeHashCode(ref byte data, int numFields)
var fieldValue = (ValueType)RuntimeImports.RhBox(fieldType, ref fieldData);
if (fieldValue != null)
{
hashCode = fieldValue.GetHashCodeImpl();
hashCode.Add(fieldValue);
}
else
{
Expand All @@ -177,7 +162,7 @@ private unsafe int RegularGetValueTypeHashCode(ref byte data, int numFields)
object fieldValue = Unsafe.As<byte, object>(ref fieldData);
if (fieldValue != null)
{
hashCode = fieldValue.GetHashCode();
hashCode.Add(fieldValue);
}
else
{
Expand All @@ -187,8 +172,6 @@ private unsafe int RegularGetValueTypeHashCode(ref byte data, int numFields)
}
break;
}

return hashCode;
}
}
}
19 changes: 14 additions & 5 deletions src/coreclr/vm/comutilnative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1703,9 +1703,10 @@ enum ValueTypeHashCodeStrategy
DoubleField,
SingleField,
FastGetHashCode,
ValueTypeOverride,
};

static ValueTypeHashCodeStrategy GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize)
static ValueTypeHashCodeStrategy GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize, MethodTable** fieldMTOut)
{
CONTRACTL
{
Expand Down Expand Up @@ -1772,10 +1773,18 @@ static ValueTypeHashCodeStrategy GetHashCodeStrategy(MethodTable* mt, QCall::Obj
*fieldSize = field->LoadSize();
ret = ValueTypeHashCodeStrategy::FastGetHashCode;
}
else if (HasOverriddenMethod(fieldMT,
CoreLibBinder::GetClass(CLASS__VALUE_TYPE),
CoreLibBinder::GetMethod(METHOD__VALUE_TYPE__GET_HASH_CODE)->GetSlot()))
{
*fieldOffset += field->GetOffsetUnsafe();
*fieldMTOut = fieldMT;
ret = ValueTypeHashCodeStrategy::ValueTypeOverride;
}
else
{
*fieldOffset += field->GetOffsetUnsafe();
ret = GetHashCodeStrategy(fieldMT, objHandle, fieldOffset, fieldSize);
ret = GetHashCodeStrategy(fieldMT, objHandle, fieldOffset, fieldSize, fieldMTOut);
}
}
}
Expand All @@ -1785,18 +1794,18 @@ static ValueTypeHashCodeStrategy GetHashCodeStrategy(MethodTable* mt, QCall::Obj
return ret;
}

extern "C" INT32 QCALLTYPE ValueType_GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize)
extern "C" INT32 QCALLTYPE ValueType_GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize, MethodTable** fieldMT)
{
QCALL_CONTRACT;

ValueTypeHashCodeStrategy ret = ValueTypeHashCodeStrategy::None;
*fieldOffset = 0;
*fieldSize = 0;
*fieldMT = NULL;

BEGIN_QCALL;


ret = GetHashCodeStrategy(mt, objHandle, fieldOffset, fieldSize);
ret = GetHashCodeStrategy(mt, objHandle, fieldOffset, fieldSize, fieldMT);

END_QCALL;

Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/vm/comutilnative.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class MethodTableNative {

extern "C" BOOL QCALLTYPE MethodTable_AreTypesEquivalent(MethodTable* mta, MethodTable* mtb);
extern "C" BOOL QCALLTYPE MethodTable_CanCompareBitsOrUseFastGetHashCode(MethodTable* mt);
extern "C" INT32 QCALLTYPE ValueType_GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize);
extern "C" INT32 QCALLTYPE ValueType_GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize, MethodTable** fieldMT);

class StreamNative {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,21 @@ public static void StructContainsPointerNestedCompareTest()
Assert.Equal(obj1.GetHashCode(), obj2.GetHashCode());
}

[Fact]
public static void StructWithNestedOverriddenNotBitwiseComparableTest()
{
StructWithNestedOverriddenNotBitwiseComparable obj1 = new StructWithNestedOverriddenNotBitwiseComparable();
obj1.value1.value = 1;
obj1.value2.value = 0;

StructWithNestedOverriddenNotBitwiseComparable obj2 = new StructWithNestedOverriddenNotBitwiseComparable();
obj2.value1.value = -1;
obj2.value2.value = 0;

Assert.True(obj1.Equals(obj2));
Assert.Equal(obj1.GetHashCode(), obj2.GetHashCode());
}

public struct S
{
public int x;
Expand Down Expand Up @@ -413,5 +428,20 @@ public struct StructContainsPointerNested
public object o;
public StructNonOverriddenEqualsOrGetHasCode value;
}

public struct StructOverriddenNotBitwiseComparable
{
public int value;

public override bool Equals(object obj) => obj is StructOverriddenNotBitwiseComparable other && (value == other.value || value == -other.value);

public override int GetHashCode() => value < 0 ? -value : value;
}

public struct StructWithNestedOverriddenNotBitwiseComparable
{
public StructOverriddenNotBitwiseComparable value1;
public StructOverriddenNotBitwiseComparable value2;
}
}
}

0 comments on commit a71cc52

Please sign in to comment.