Skip to content

Commit

Permalink
Allow discriminator properties to value converter and comparer
Browse files Browse the repository at this point in the history
Fixes #19650
  • Loading branch information
ajcvickers committed Sep 6, 2020
1 parent 26c8e1c commit 99d344c
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -630,20 +630,24 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp
}

var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == typeBinaryExpression.TypeOperand);
var discriminatorProperty = entityType.GetDiscriminatorProperty();
if (derivedType != null
&& TryBindMember(
entityReferenceExpression,
MemberIdentity.Create(entityType.GetDiscriminatorProperty().Name)) is SqlExpression discriminatorColumn)
MemberIdentity.Create(discriminatorProperty.Name)) is SqlExpression discriminatorColumn)
{
var concreteEntityTypes = derivedType.GetConcreteDerivedTypesInclusive().ToList();

var typeMapping = discriminatorProperty.GetTypeMapping();
return concreteEntityTypes.Count == 1
? _sqlExpressionFactory.Equal(
discriminatorColumn,
_sqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue()))
_sqlExpressionFactory.Constant(
concreteEntityTypes[0].GetDiscriminatorValue(), typeMapping))
: (SqlExpression)_sqlExpressionFactory.In(
discriminatorColumn,
_sqlExpressionFactory.Constant(concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList()),
_sqlExpressionFactory.Constant(
concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList(), typeMapping),
negated: false);
}
}
Expand Down
9 changes: 6 additions & 3 deletions src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -533,13 +533,16 @@ private void AddDiscriminator(SelectExpression selectExpression, IEntityType ent
if (concreteEntityTypes.Count == 1)
{
var concreteEntityType = concreteEntityTypes[0];
if (concreteEntityType.GetDiscriminatorProperty() != null)
var discriminatorProperty = concreteEntityType.GetDiscriminatorProperty();
if (discriminatorProperty != null)
{
var discriminatorColumn = ((EntityProjectionExpression)selectExpression.GetMappedProjection(new ProjectionMember()))
.BindProperty(concreteEntityType.GetDiscriminatorProperty(), clientEval: false);
.BindProperty(discriminatorProperty, clientEval: false);

selectExpression.ApplyPredicate(
Equal((SqlExpression)discriminatorColumn, Constant(concreteEntityType.GetDiscriminatorValue())));
Equal(
(SqlExpression)discriminatorColumn,
Constant(concreteEntityType.GetDiscriminatorValue(), discriminatorProperty.GetTypeMapping())));
}
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -965,16 +965,17 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp
{
var discriminatorProperty = entityType.GetDiscriminatorProperty();
var boundProperty = BindProperty(entityReferenceExpression, discriminatorProperty, discriminatorProperty.ClrType);
var valueComparer = discriminatorProperty.GetKeyValueComparer();

var equals = Expression.Equal(
var equals = valueComparer.ExtractEqualsBody(
boundProperty,
Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType));

foreach (var derivedDerivedType in derivedType.GetDerivedTypes())
{
equals = Expression.OrElse(
equals,
Expression.Equal(
valueComparer.ExtractEqualsBody(
boundProperty,
Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ public InMemoryQueryExpression([NotNull] IEntityType entityType)
foreach (var derivedEntityType in entityType.GetDerivedTypes())
{
var entityCheck = derivedEntityType.GetConcreteDerivedTypesInclusive()
.Select(e => Equal(readExpressionMap[discriminatorProperty], Constant(e.GetDiscriminatorValue())))
.Select(
e => discriminatorProperty.GetKeyValueComparer().ExtractEqualsBody(
readExpressionMap[discriminatorProperty],
Constant(e.GetDiscriminatorValue(), discriminatorProperty.ClrType)))
.Aggregate((l, r) => OrElse(l, r));

foreach (var property in derivedEntityType.GetDeclaredProperties())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -971,14 +971,16 @@ SqlExpression GeneratePredicateTPT(EntityProjectionExpression entityProjectionEx
var discriminatorColumn = BindProperty(entityReferenceExpression, discriminatorProperty);
if (discriminatorColumn != null)
{
var typeMapping = discriminatorProperty.GetRelationalTypeMapping();
return concreteEntityTypes.Count == 1
? _sqlExpressionFactory.Equal(
discriminatorColumn,
_sqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue()))
_sqlExpressionFactory.Constant(
concreteEntityTypes[0].GetDiscriminatorValue(), typeMapping))
: (Expression)_sqlExpressionFactory.In(
discriminatorColumn,
_sqlExpressionFactory.Constant(
concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList()),
concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList(), typeMapping),
negated: false);
}
}
Expand Down
9 changes: 7 additions & 2 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -918,9 +918,14 @@ private bool AddDiscriminatorCondition(SelectExpression selectExpression, IEntit

var discriminatorColumn = GetMappedEntityProjectionExpression(selectExpression).BindProperty(discriminatorProperty);
var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList();
var typeMapping = discriminatorProperty.GetRelationalTypeMapping();
var predicate = concreteEntityTypes.Count == 1
? (SqlExpression)Equal(discriminatorColumn, Constant(concreteEntityTypes[0].GetDiscriminatorValue()))
: In(discriminatorColumn, Constant(concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList()), negated: false);
? (SqlExpression)Equal(
discriminatorColumn,
Constant(concreteEntityTypes[0].GetDiscriminatorValue(), typeMapping))
: In(
discriminatorColumn,
Constant(concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList(), typeMapping), negated: false);

selectExpression.ApplyPredicate(predicate);

Expand Down
3 changes: 2 additions & 1 deletion src/EFCore/ChangeTracking/ValueComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand All @@ -27,7 +28,7 @@ namespace Microsoft.EntityFrameworkCore.ChangeTracking
/// reference.
/// </para>
/// </summary>
public abstract class ValueComparer : IEqualityComparer
public abstract class ValueComparer : IEqualityComparer, IEqualityComparer<object>
{
private protected static readonly MethodInfo _doubleEqualsMethodInfo
= typeof(double).GetRuntimeMethod(nameof(double.Equals), new[] { typeof(double) });
Expand Down
6 changes: 4 additions & 2 deletions src/EFCore/Infrastructure/ModelValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -638,19 +638,21 @@ protected virtual void ValidateInheritanceMapping(
/// <param name="rootEntityType"> The entity type to validate. </param>
protected virtual void ValidateDiscriminatorValues([NotNull] IEntityType rootEntityType)
{
var discriminatorValues = new Dictionary<object, IEntityType>();
var derivedTypes = rootEntityType.GetDerivedTypesInclusive().ToList();
if (derivedTypes.Count == 1)
{
return;
}

if (rootEntityType.GetDiscriminatorProperty() == null)
var discriminatorProperty = rootEntityType.GetDiscriminatorProperty();
if (discriminatorProperty == null)
{
throw new InvalidOperationException(
CoreStrings.NoDiscriminatorProperty(rootEntityType.DisplayName()));
}

var discriminatorValues = new Dictionary<object, IEntityType>(discriminatorProperty.GetKeyValueComparer());

foreach (var derivedType in derivedTypes)
{
if (derivedType.ClrType?.IsInstantiable() != true)
Expand Down
24 changes: 23 additions & 1 deletion src/EFCore/Query/EntityShaperExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;

Expand Down Expand Up @@ -123,7 +125,27 @@ protected virtual LambdaExpression GenerateMaterializationCondition([NotNull] IE
Convert(discriminatorValueVariable, typeof(object)))),
Constant(null, typeof(IEntityType)));

expressions.Add(Switch(discriminatorValueVariable, exception, switchCases));

var discriminatorComparer = discriminatorProperty.GetKeyValueComparer();
if (discriminatorComparer.IsDefault()
|| discriminatorProperty.ClrType.IsEnum)
{
expressions.Add(Switch(discriminatorValueVariable, exception, switchCases));
}
else
{
var staticComparer = typeof(StaticDiscriminatorComparer<,,>).MakeGenericType(
discriminatorProperty.DeclaringEntityType.ClrType,
discriminatorProperty.ClrType,
discriminatorProperty.GetTypeMapping().Converter.ProviderClrType);

var comparerField = staticComparer.GetField(nameof(StaticDiscriminatorComparer<int, int, int>.Comparer));
comparerField.SetValue(null, discriminatorComparer);

var equalsMethod = staticComparer.GetMethod(nameof(StaticDiscriminatorComparer<int, int, int>.DiscriminatorEquals));
expressions.Add(Switch(discriminatorValueVariable, exception, equalsMethod, switchCases));
}

body = Block(new[] { discriminatorValueVariable }, expressions);
}
else
Expand Down
35 changes: 35 additions & 0 deletions src/EFCore/Query/Internal/StaticDiscriminatorComparer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.ChangeTracking;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
// ReSharper disable twice UnusedTypeParameter
public static class StaticDiscriminatorComparer<TEntity, TModel, TProvider>
{
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static ValueComparer<TModel> Comparer;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static bool DiscriminatorEquals([CanBeNull] TModel x, [CanBeNull] TModel y)
=> Comparer.Equals(x, y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,25 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
.HasForeignKey<OptionalSingle2>(e => e.BackId)
.OnDelete(DeleteBehavior.SetNull);

modelBuilder.Entity<OptionalSingle2Derived>();
modelBuilder.Entity<OptionalSingle2MoreDerived>();
modelBuilder.Entity<OptionalSingle2>(
b =>
{
b.HasDiscriminator(e => e.Disc)
.HasValue<OptionalSingle2>(new MyDiscriminator(1))
.HasValue<OptionalSingle2Derived>(new MyDiscriminator(2))
.HasValue<OptionalSingle2MoreDerived>(new MyDiscriminator(3));
b.Property(e => e.Disc)
.HasConversion(
v => v.Value,
v => new MyDiscriminator(v),
new ValueComparer<MyDiscriminator>(
(l, r) => l.Value == r.Value,
v => v.Value.GetHashCode(),
v => new MyDiscriminator(v.Value)))
.Metadata
.SetAfterSaveBehavior(PropertySaveBehavior.Save);
});

modelBuilder.Entity<RequiredNonPkSingle1>()
.HasOne(e => e.Single)
Expand Down Expand Up @@ -379,10 +396,6 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
modelBuilder.Entity<Produce>()
.HasIndex(e => e.BarCode)
.IsUnique();

modelBuilder.Entity<OptionalSingle2Derived>()
.Property<string>("Discriminator")
.Metadata.SetAfterSaveBehavior(PropertySaveBehavior.Save);
}

protected virtual object CreateFullGraph()
Expand Down Expand Up @@ -1692,6 +1705,7 @@ protected class OptionalSingle2 : NotifyingEntity
{
private int _id;
private int? _backId;
private MyDiscriminator _disc;
private OptionalSingle1 _back;

public int Id
Expand All @@ -1706,6 +1720,12 @@ public int? BackId
set => SetWithNotify(value, ref _backId);
}

public MyDiscriminator Disc
{
get => _disc;
set => SetWithNotify(value, ref _disc);
}

public OptionalSingle1 Back
{
get => _back;
Expand All @@ -1722,6 +1742,20 @@ public override int GetHashCode()
=> _id;
}

protected class MyDiscriminator
{
public MyDiscriminator(int value)
=> Value = value;

public int Value { get; }

public override bool Equals(object obj)
=> throw new InvalidOperationException();

public override int GetHashCode()
=> throw new InvalidOperationException();
}

protected class OptionalSingle2Derived : OptionalSingle2
{
public override bool Equals(object obj)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,23 @@ public virtual void Mutating_discriminator_value_can_be_configured_to_allow_muta
context =>
{
var instance = context.Set<OptionalSingle2Derived>().First();
var propertyEntry = context.Entry(instance).Property("Discriminator");
var propertyEntry = context.Entry(instance).Property(e => e.Disc);
id = instance.Id;
Assert.IsType<OptionalSingle2Derived>(instance);
Assert.Equal(nameof(OptionalSingle2Derived), propertyEntry.CurrentValue);
Assert.Equal(2, propertyEntry.CurrentValue.Value);
propertyEntry.CurrentValue = nameof(OptionalSingle2);
propertyEntry.CurrentValue = new MyDiscriminator(1);
context.SaveChanges();
},
context =>
{
var instance = context.Set<OptionalSingle2>().First(e => e.Id == id);
var propertyEntry = context.Entry(instance).Property("Discriminator");
var propertyEntry = context.Entry(instance).Property(e => e.Disc);
Assert.IsType<OptionalSingle2>(instance);
Assert.Equal(nameof(OptionalSingle2), propertyEntry.CurrentValue);
Assert.Equal(1, propertyEntry.CurrentValue.Value);
});
}

Expand Down

0 comments on commit 99d344c

Please sign in to comment.