Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ValueType.GetHashCode not calling overriden method on nested field #98754

Merged
merged 11 commits into from
Feb 26, 2024
11 changes: 9 additions & 2 deletions src/coreclr/System.Private.CoreLib/src/System/ValueType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public override unsafe int GetHashCode()
else
{
object thisRef = this;
switch (GetHashCodeStrategy(pMT, ObjectHandleOnStack.Create(ref thisRef), out uint fieldOffset, out uint fieldSize))
switch (GetHashCodeStrategy(pMT, ObjectHandleOnStack.Create(ref thisRef), out uint fieldOffset, out uint fieldSize, out MethodTable* fieldMT))
{
case ValueTypeHashCodeStrategy.ReferenceField:
hashCode.Add(Unsafe.As<byte, object>(ref Unsafe.AddByteOffset(ref rawData, fieldOffset)).GetHashCode());
Expand All @@ -138,6 +138,12 @@ public override unsafe int GetHashCode()
Debug.Assert(fieldSize != 0);
hashCode.AddBytes(MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AddByteOffset(ref rawData, fieldOffset), (int)fieldSize));
break;

case ValueTypeHashCodeStrategy.ValueTypeOverride:
Debug.Assert(fieldMT != null);
// Box the field to handle complicated cases like mutable method and shared generic
hashCode.Add(RuntimeHelpers.Box(fieldMT, ref Unsafe.AddByteOffset(ref rawData, fieldOffset))?.GetHashCode() ?? 0);
break;
}
}

Expand All @@ -152,11 +158,12 @@ private enum ValueTypeHashCodeStrategy
DoubleField,
SingleField,
FastGetHashCode,
ValueTypeOverride,
}

[LibraryImport(RuntimeHelpers.QCall, EntryPoint = "ValueType_GetHashCodeStrategy")]
private static unsafe partial ValueTypeHashCodeStrategy GetHashCodeStrategy(
MethodTable* pMT, ObjectHandleOnStack objHandle, out uint fieldOffset, out uint fieldSize);
MethodTable* pMT, ObjectHandleOnStack objHandle, out uint fieldOffset, out uint fieldSize, out MethodTable* fieldMT);

public override string? ToString()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,43 +95,28 @@ public override unsafe bool Equals([NotNullWhen(true)] object? obj)

public override unsafe int GetHashCode()
{
int hashCode = (int)this.GetMethodTable()->HashCode;
HashCode hashCode = default;
hashCode.Add((IntPtr)this.GetMethodTable());

hashCode ^= GetHashCodeImpl();

return hashCode;
}

private unsafe int GetHashCodeImpl()
{
int numFields = __GetFieldHelper(GetNumFields, out _);

if (numFields == UseFastHelper)
return FastGetValueTypeHashCodeHelper(this.GetMethodTable(), ref this.GetRawData());
hashCode.AddBytes(GetSpanForField(this.GetMethodTable(), ref this.GetRawData()));
else
RegularGetValueTypeHashCode(ref hashCode, ref this.GetRawData(), numFields);

return RegularGetValueTypeHashCode(ref this.GetRawData(), numFields);
return hashCode.ToHashCode();
}

private static unsafe int FastGetValueTypeHashCodeHelper(MethodTable* type, ref byte data)
private static unsafe ReadOnlySpan<byte> GetSpanForField(MethodTable* type, ref byte data)
{
// Sanity check - if there are GC references, we should not be hashing bytes
Debug.Assert(!type->ContainsGCPointers);

int size = (int)type->ValueTypeSize;
int hashCode = 0;

for (int i = 0; i < size / 4; i++)
{
hashCode ^= Unsafe.As<byte, int>(ref Unsafe.Add(ref data, i * 4));
}

return hashCode;
return new ReadOnlySpan<byte>(ref data, (int)type->ValueTypeSize);
}

private unsafe int RegularGetValueTypeHashCode(ref byte data, int numFields)
private unsafe void RegularGetValueTypeHashCode(ref HashCode hashCode, ref byte data, int numFields)
{
int hashCode = 0;

// We only take the hashcode for the first non-null field. That's what the CLR does.
for (int i = 0; i < numFields; i++)
{
Expand All @@ -142,15 +127,15 @@ private unsafe int RegularGetValueTypeHashCode(ref byte data, int numFields)

if (fieldType->ElementType == EETypeElementType.Single)
{
hashCode = Unsafe.As<byte, float>(ref fieldData).GetHashCode();
hashCode.Add(Unsafe.As<byte, float>(ref fieldData));
}
else if (fieldType->ElementType == EETypeElementType.Double)
{
hashCode = Unsafe.As<byte, double>(ref fieldData).GetHashCode();
hashCode.Add(Unsafe.As<byte, double>(ref fieldData));
}
else if (fieldType->IsPrimitive)
{
hashCode = FastGetValueTypeHashCodeHelper(fieldType, ref fieldData);
hashCode.AddBytes(GetSpanForField(fieldType, ref fieldData));
}
else if (fieldType->IsValueType)
{
Expand All @@ -164,7 +149,7 @@ private unsafe int RegularGetValueTypeHashCode(ref byte data, int numFields)
var fieldValue = (ValueType)RuntimeImports.RhBox(fieldType, ref fieldData);
if (fieldValue != null)
{
hashCode = fieldValue.GetHashCodeImpl();
hashCode.Add(fieldValue);
}
else
{
Expand All @@ -177,7 +162,7 @@ private unsafe int RegularGetValueTypeHashCode(ref byte data, int numFields)
object fieldValue = Unsafe.As<byte, object>(ref fieldData);
if (fieldValue != null)
{
hashCode = fieldValue.GetHashCode();
hashCode.Add(fieldValue);
}
else
{
Expand All @@ -187,8 +172,6 @@ private unsafe int RegularGetValueTypeHashCode(ref byte data, int numFields)
}
break;
}

return hashCode;
}
}
}
19 changes: 14 additions & 5 deletions src/coreclr/vm/comutilnative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1703,9 +1703,10 @@ enum ValueTypeHashCodeStrategy
DoubleField,
SingleField,
FastGetHashCode,
ValueTypeOverride,
};

static ValueTypeHashCodeStrategy GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize)
static ValueTypeHashCodeStrategy GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize, MethodTable** fieldMTOut)
{
CONTRACTL
{
Expand Down Expand Up @@ -1772,10 +1773,18 @@ static ValueTypeHashCodeStrategy GetHashCodeStrategy(MethodTable* mt, QCall::Obj
*fieldSize = field->LoadSize();
ret = ValueTypeHashCodeStrategy::FastGetHashCode;
}
else if (HasOverriddenMethod(fieldMT,
CoreLibBinder::GetClass(CLASS__VALUE_TYPE),
CoreLibBinder::GetMethod(METHOD__VALUE_TYPE__GET_HASH_CODE)->GetSlot()))
{
*fieldOffset += field->GetOffsetUnsafe();
*fieldMTOut = fieldMT;
ret = ValueTypeHashCodeStrategy::ValueTypeOverride;
}
else
{
*fieldOffset += field->GetOffsetUnsafe();
ret = GetHashCodeStrategy(fieldMT, objHandle, fieldOffset, fieldSize);
ret = GetHashCodeStrategy(fieldMT, objHandle, fieldOffset, fieldSize, fieldMTOut);
}
}
}
Expand All @@ -1785,18 +1794,18 @@ static ValueTypeHashCodeStrategy GetHashCodeStrategy(MethodTable* mt, QCall::Obj
return ret;
}

extern "C" INT32 QCALLTYPE ValueType_GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize)
extern "C" INT32 QCALLTYPE ValueType_GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize, MethodTable** fieldMT)
{
QCALL_CONTRACT;

ValueTypeHashCodeStrategy ret = ValueTypeHashCodeStrategy::None;
*fieldOffset = 0;
*fieldSize = 0;
*fieldMT = NULL;

BEGIN_QCALL;


ret = GetHashCodeStrategy(mt, objHandle, fieldOffset, fieldSize);
ret = GetHashCodeStrategy(mt, objHandle, fieldOffset, fieldSize, fieldMT);

END_QCALL;

Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/vm/comutilnative.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class MethodTableNative {

extern "C" BOOL QCALLTYPE MethodTable_AreTypesEquivalent(MethodTable* mta, MethodTable* mtb);
extern "C" BOOL QCALLTYPE MethodTable_CanCompareBitsOrUseFastGetHashCode(MethodTable* mt);
extern "C" INT32 QCALLTYPE ValueType_GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize);
extern "C" INT32 QCALLTYPE ValueType_GetHashCodeStrategy(MethodTable* mt, QCall::ObjectHandleOnStack objHandle, UINT32* fieldOffset, UINT32* fieldSize, MethodTable** fieldMT);

class StreamNative {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,21 @@ public static void StructContainsPointerNestedCompareTest()
Assert.Equal(obj1.GetHashCode(), obj2.GetHashCode());
}

[Fact]
public static void StructWithNestedOverriddenNotBitwiseComparableTest()
{
StructWithNestedOverriddenNotBitwiseComparable obj1 = new StructWithNestedOverriddenNotBitwiseComparable();
obj1.value1.value = 1;
obj1.value2.value = 0;

StructWithNestedOverriddenNotBitwiseComparable obj2 = new StructWithNestedOverriddenNotBitwiseComparable();
obj2.value1.value = -1;
obj2.value2.value = 0;

Assert.True(obj1.Equals(obj2));
Assert.Equal(obj1.GetHashCode(), obj2.GetHashCode());
}

public struct S
{
public int x;
Expand Down Expand Up @@ -413,5 +428,20 @@ public struct StructContainsPointerNested
public object o;
public StructNonOverriddenEqualsOrGetHasCode value;
}

public struct StructOverriddenNotBitwiseComparable
{
public int value;

public override bool Equals(object obj) => obj is StructOverriddenNotBitwiseComparable other && (value == other.value || value == -other.value);

public override int GetHashCode() => value < 0 ? -value : value;
}

public struct StructWithNestedOverriddenNotBitwiseComparable
{
public StructOverriddenNotBitwiseComparable value1;
public StructOverriddenNotBitwiseComparable value2;
}
}
}
Loading