diff --git a/src/EFCore.Cosmos/Diagnostics/Internal/CosmosLoggerExtensions.cs b/src/EFCore.Cosmos/Diagnostics/Internal/CosmosLoggerExtensions.cs index 786fc6d263c..00f6a84877f 100644 --- a/src/EFCore.Cosmos/Diagnostics/Internal/CosmosLoggerExtensions.cs +++ b/src/EFCore.Cosmos/Diagnostics/Internal/CosmosLoggerExtensions.cs @@ -51,6 +51,28 @@ public static void ExecutingSqlQuery( cosmosSqlQuery.Query); } + public static void ExecutingReadItem( + [NotNull] this IDiagnosticsLogger diagnosticsLogger, + [NotNull] string partitionKey, + [NotNull] string resourceId) + { + var definition = new EventDefinition( + diagnosticsLogger.Options, + CoreEventId.ProviderBaseId, + LogLevel.Debug, + "CoreEventId.ProviderBaseId", + level => LoggerMessage.Define( + level, + CoreEventId.ProviderBaseId, + "Executing Read Item [Partition Key, Resource Id=[{parameters}]]{newLine}{commandText}")); + + definition.Log( + diagnosticsLogger, + $"{partitionKey}, {resourceId}", + Environment.NewLine, + "Read Item"); + } + private static string FormatParameters(IReadOnlyList parameters) { return parameters.Count == 0 diff --git a/src/EFCore.Cosmos/Extensions/CosmosServiceCollectionExtensions.cs b/src/EFCore.Cosmos/Extensions/CosmosServiceCollectionExtensions.cs index 1bcfd2bfa8f..a31f3b5291a 100644 --- a/src/EFCore.Cosmos/Extensions/CosmosServiceCollectionExtensions.cs +++ b/src/EFCore.Cosmos/Extensions/CosmosServiceCollectionExtensions.cs @@ -14,6 +14,7 @@ using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure; using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; using Microsoft.EntityFrameworkCore.ValueGeneration; diff --git a/src/EFCore.Cosmos/Metadata/Conventions/StoreKeyConvention.cs b/src/EFCore.Cosmos/Metadata/Conventions/StoreKeyConvention.cs index 68a982035ab..1fb0c8d00cb 100644 --- a/src/EFCore.Cosmos/Metadata/Conventions/StoreKeyConvention.cs +++ b/src/EFCore.Cosmos/Metadata/Conventions/StoreKeyConvention.cs @@ -4,6 +4,7 @@ using System.Linq; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal; +using Microsoft.EntityFrameworkCore.Cosmos.Metadata; using Microsoft.EntityFrameworkCore.Cosmos.ValueGeneration.Internal; using Microsoft.EntityFrameworkCore.Metadata.Builders; using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure; diff --git a/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs b/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs index 68e32e66bc7..0d8b9cb88a8 100644 --- a/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs +++ b/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs @@ -121,7 +121,7 @@ public static string UpdateConflict([CanBeNull] object itemId) itemId); /// - /// Non-embedded IncludeExpression is not supported: {expression} + /// Non-embedded IncludeExpression is not supported: {expression} /// public static string NonEmbeddedIncludeNotSupported([CanBeNull] object expression) => string.Format( @@ -129,7 +129,7 @@ public static string NonEmbeddedIncludeNotSupported([CanBeNull] object expressio expression); /// - /// Navigation '{entityType}.{navigationName}' doesn't point to an embedded entity. + /// Navigation '{entityType}.{navigationName}' doesn't point to an embedded entity. /// public static string NavigationPropertyIsNotAnEmbeddedEntity([CanBeNull] object entityType, [CanBeNull] object navigationName) => string.Format( @@ -137,17 +137,35 @@ public static string NavigationPropertyIsNotAnEmbeddedEntity([CanBeNull] object entityType, navigationName); /// - /// Offset is not supported without Limit. + /// Offset is not supported without Limit. /// public static string OffsetRequiresLimit => GetString("OffsetRequiresLimit"); /// - /// Reverse is not supported without Limit or Offset. + /// Reverse is not supported without Limit or Offset. /// public static string ReverseRequiresOffsetOrLimit => GetString("ReverseRequiresOffsetOrLimit"); + /// + /// Invalid Resource id. Resource id cannot be null or empty and must be a string value. + /// + public static string InvalidResourceId + => GetString("InvalidResourceId"); + + /// + /// Partition key missing. + /// + public static string ParitionKeyMissing + => GetString("ParitionKeyMissing"); + + /// + /// Resource id missing or cannot be generated. + /// + public static string ResourceIdMissing + => GetString("ResourceIdMissing"); + private static string GetString(string name, params string[] formatterNames) { var value = _resourceManager.GetString(name); diff --git a/src/EFCore.Cosmos/Properties/CosmosStrings.resx b/src/EFCore.Cosmos/Properties/CosmosStrings.resx index 08a70ba211c..ac6c492769f 100644 --- a/src/EFCore.Cosmos/Properties/CosmosStrings.resx +++ b/src/EFCore.Cosmos/Properties/CosmosStrings.resx @@ -168,4 +168,13 @@ Reverse is not supported without Limit or Offset. + + Invalid Resource id. Resource id cannot be null or empty and must be a string value. + + + Partition key missing. + + + Resource id missing or cannot be generated. + \ No newline at end of file diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryCompilationContext.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryCompilationContext.cs index e8f13debbe5..645a51c9a3e 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryCompilationContext.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryCompilationContext.cs @@ -8,7 +8,7 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal { public class CosmosQueryCompilationContext : QueryCompilationContext { - public virtual string PartitionKey { get; internal set; } + public virtual string PartitionKeyFromExtension { get; internal set; } public CosmosQueryCompilationContext( [NotNull] QueryCompilationContextDependencies dependencies, bool async) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryContext.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryContext.cs index b3ff160cec2..70276eadaad 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryContext.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryContext.cs @@ -5,6 +5,7 @@ using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Utilities; +using Microsoft.EntityFrameworkCore.ValueGeneration; namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal { @@ -28,6 +29,7 @@ public CosmosQueryContext( : base(dependencies) { Check.NotNull(cosmosClient, nameof(cosmosClient)); + CosmosClient = cosmosClient; } @@ -38,5 +40,13 @@ public CosmosQueryContext( /// doing so can result in application failures when updating to a new Entity Framework Core release. /// public virtual CosmosClientWrapper CosmosClient { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual IValueGeneratorSelector ValueGeneratorSelector { get; } } } diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryMetadataExtractingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryMetadataExtractingExpressionVisitor.cs index 84868860944..4536976b8cc 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryMetadataExtractingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryMetadataExtractingExpressionVisitor.cs @@ -24,7 +24,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp { var innerQueryable = Visit(methodCallExpression.Arguments[0]); - _cosmosQueryCompilationContext.PartitionKey = (string)((ConstantExpression)methodCallExpression.Arguments[1]).Value; + _cosmosQueryCompilationContext.PartitionKeyFromExtension = (string)((ConstantExpression)methodCallExpression.Arguments[1]).Value; return innerQueryable; } diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs index e9d59f3cbea..ac3cf630e2d 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs @@ -9,6 +9,7 @@ using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; @@ -68,6 +69,115 @@ protected CosmosQueryableMethodTranslatingExpressionVisitor( _projectionBindingExpressionVisitor = new CosmosProjectionBindingExpressionVisitor(_model, _sqlTranslator); } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression Visit(Expression expression) + { + if (expression is MethodCallExpression methodCallExpression + && methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.FirstOrDefaultWithoutPredicate) + { + if (methodCallExpression.Arguments[0] is MethodCallExpression queryRootMethodCallExpression + && methodCallExpression.Method.IsGenericMethod + && queryRootMethodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.Where) + { + if (queryRootMethodCallExpression.Arguments[0] is QueryRootExpression queryRootExpression) + { + var entityType = queryRootExpression.EntityType; + + if (queryRootMethodCallExpression.Arguments[1] is UnaryExpression unaryExpression + && unaryExpression.Operand is LambdaExpression lambdaExpression + && lambdaExpression.Body is BinaryExpression lambdaBodyBinaryExpression) + { + var queryProperties = new List(); + var parameterNames = new List(); + + if (ProcessJoinCondition(lambdaBodyBinaryExpression, queryProperties, parameterNames)) + { + var entityTypePrimaryKeyProperties = entityType.FindPrimaryKey().Properties; + + if (TryGetPartitionKeyProperty(out var partitionKeyProperty) + && entityTypePrimaryKeyProperties.Contains(partitionKeyProperty) + && entityTypePrimaryKeyProperties.SequenceEqual(queryProperties)) + { + var propertyParameterList = queryProperties.Zip(parameterNames, + (property, parameter) => (property, parameter)) + .ToDictionary(tuple => tuple.property, tuple => tuple.parameter); + + var readItemExpression = new ReadItemExpression(entityType, propertyParameterList); + + var shapedQueryExpression = new ShapedQueryExpression( + readItemExpression, + new EntityShaperExpression( + entityType, + new ProjectionBindingExpression( + readItemExpression, + new ProjectionMember(), + typeof(ValueBuffer)), + false)); + + shapedQueryExpression = shapedQueryExpression.UpdateResultCardinality(ResultCardinality.Single); + + return shapedQueryExpression; + } + } + + bool ProcessJoinCondition( + Expression joinCondition, ICollection properties, ICollection paramNames) + { + if (joinCondition is BinaryExpression binaryExpression) + { + switch (binaryExpression.NodeType) + { + case ExpressionType.AndAlso: + return ProcessJoinCondition(binaryExpression.Left, properties, paramNames) + && ProcessJoinCondition(binaryExpression.Right, properties, paramNames); + + case ExpressionType.Equal: + if (binaryExpression.Left is MethodCallExpression methodCallExpr + && binaryExpression.Right is ParameterExpression parameterExpr) + { + if (methodCallExpr.TryGetEFPropertyArguments(out _, out var propertyName)) + { +#pragma warning disable EF1001 + properties.Add(entityType.GetProperty(propertyName)); + paramNames.Add(parameterExpr.Name); + return true; + } + } + return false; + + default: + return false; + } + } + return false; + } + + bool TryGetPartitionKeyProperty(out IProperty partitionKeyProperty) + { + var partitionKeyPropertyName = entityType.GetPartitionKeyPropertyName(); + + if (partitionKeyPropertyName is null) + { + partitionKeyProperty = null; + return false; + } + + partitionKeyProperty = entityType.FindProperty(partitionKeyPropertyName); + return true; + } + } + } + } + } + return base.Visit(expression); + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -124,7 +234,7 @@ protected override ShapedQueryExpression CreateShapedQueryExpression(Type elemen protected override ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType) { Check.NotNull(entityType, nameof(entityType)); - + var selectExpression = _sqlExpressionFactory.Select(entityType); return new ShapedQueryExpression( @@ -194,7 +304,7 @@ protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression projection = _sqlExpressionFactory.Function( "AVG", new[] { projection }, projection.Type, projection.TypeMapping); - return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType); + return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); } /// @@ -551,7 +661,7 @@ protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression sour projection = _sqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping); - return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType); + return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); } /// @@ -581,7 +691,7 @@ protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression sour projection = _sqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping); - return AggregateResultShaper(source, projection, throwWhenEmpty: true, resultType); + return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); } /// @@ -796,7 +906,7 @@ protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression sour projection = _sqlExpressionFactory.Function( "SUM", new[] { projection }, serverOutputType, projection.TypeMapping); - return AggregateResultShaper(source, projection, throwWhenEmpty: false, resultType); + return AggregateResultShaper(source, projection, throwOnNullResult: false, resultType); } /// @@ -912,35 +1022,29 @@ private static Expression RemapLambdaBody(Expression shaperBody, LambdaExpressio } private ShapedQueryExpression AggregateResultShaper( - ShapedQueryExpression source, Expression projection, bool throwWhenEmpty, Type resultType) + ShapedQueryExpression source, Expression projection, bool throwOnNullResult, Type resultType) { var selectExpression = (SelectExpression)source.QueryExpression; selectExpression.ReplaceProjectionMapping( new Dictionary { { new ProjectionMember(), projection } }); selectExpression.ClearOrdering(); - Expression shaper; - if (throwWhenEmpty) + var nullableResultType = resultType.MakeNullable(); + Expression shaper = new ProjectionBindingExpression( + source.QueryExpression, new ProjectionMember(), throwOnNullResult ? nullableResultType : projection.Type); + + if (throwOnNullResult) { - // Avg/Max/Min case. - // We always read nullable value - // If resultType is nullable then we always return null. Only non-null result shows throwing behavior. - // otherwise, if projection.Type is nullable then server result is passed through DefaultIfEmpty, hence we return default - // otherwise, server would return null only if it is empty, and we throw - var nullableResultType = resultType.MakeNullable(); - shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType); var resultVariable = Expression.Variable(nullableResultType, "result"); var returnValueForNull = resultType.IsNullableType() - ? Expression.Constant(null, resultType) - : projection.Type.IsNullableType() - ? (Expression)Expression.Default(resultType) - : Expression.Throw( - Expression.New( - typeof(InvalidOperationException).GetConstructors() - .Single(ci => ci.GetParameters().Length == 1), - Expression.Constant(CoreStrings.NoElements)), - resultType); + ? (Expression)Expression.Constant(null, resultType) + : Expression.Throw( + Expression.New( + typeof(InvalidOperationException).GetConstructors() + .Single(ci => ci.GetParameters().Length == 1), + Expression.Constant(CoreStrings.NoElements)), + resultType); shaper = Expression.Block( new[] { resultVariable }, @@ -952,15 +1056,9 @@ private ShapedQueryExpression AggregateResultShaper( ? Expression.Convert(resultVariable, resultType) : (Expression)resultVariable)); } - else + else if (resultType != shaper.Type) { - // Sum case. Projection is always non-null. We read non-nullable value (0 if empty) - shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), projection.Type); - // Cast to nullable type if required - if (resultType != shaper.Type) - { - shaper = Expression.Convert(shaper, resultType); - } + shaper = Expression.Convert(shaper, resultType); } return source.UpdateShaperExpression(shaper); diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitor.cs index 517ca013c26..6d2ca88555e 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitor.cs @@ -2,754 +2,36 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; -using System.Linq; using System.Linq.Expressions; -using System.Reflection; using JetBrains.Annotations; -using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; -using Microsoft.EntityFrameworkCore.Cosmos.Internal; -using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; -using Microsoft.EntityFrameworkCore.Metadata; -using Microsoft.EntityFrameworkCore.Metadata.Conventions; using Microsoft.EntityFrameworkCore.Query; -using Microsoft.EntityFrameworkCore.Storage; -using Microsoft.EntityFrameworkCore.Utilities; -using Newtonsoft.Json.Linq; namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal { public partial class CosmosShapedQueryCompilingExpressionVisitor { - private sealed class CosmosProjectionBindingRemovingExpressionVisitor : ExpressionVisitor + private sealed class CosmosProjectionBindingRemovingExpressionVisitor : CosmosProjectionBindingRemovingExpressionVisitorBase { - private static readonly MethodInfo _getItemMethodInfo - = typeof(JObject).GetRuntimeProperties() - .Single(pi => pi.Name == "Item" && pi.GetIndexParameters()[0].ParameterType == typeof(string)) - .GetMethod; - - private static readonly PropertyInfo _jTokenTypePropertyInfo - = typeof(JToken).GetRuntimeProperties() - .Single(mi => mi.Name == nameof(JToken.Type)); - - private static readonly MethodInfo _jTokenToObjectMethodInfo - = typeof(JToken).GetRuntimeMethods() - .Single(mi => mi.Name == nameof(JToken.ToObject) && mi.GetParameters().Length == 0); - - private static readonly MethodInfo _toObjectMethodInfo - = typeof(CosmosProjectionBindingRemovingExpressionVisitor).GetRuntimeMethods() - .Single(mi => mi.Name == nameof(SafeToObject)); - - private static readonly MethodInfo _collectionAccessorAddMethodInfo - = typeof(IClrCollectionAccessor).GetTypeInfo() - .GetDeclaredMethod(nameof(IClrCollectionAccessor.Add)); - - private static readonly MethodInfo _collectionAccessorGetOrCreateMethodInfo - = typeof(IClrCollectionAccessor).GetTypeInfo() - .GetDeclaredMethod(nameof(IClrCollectionAccessor.GetOrCreate)); - private readonly SelectExpression _selectExpression; - private readonly ParameterExpression _jObjectParameter; - private readonly bool _trackQueryResults; - - private readonly IDictionary _materializationContextBindings - = new Dictionary(); - - private readonly IDictionary _projectionBindings - = new Dictionary(); - - private readonly IDictionary _ownerMappings - = new Dictionary(); - - private readonly IDictionary _ordinalParameterBindings - = new Dictionary(); - - private List _pendingIncludes - = new List(); public CosmosProjectionBindingRemovingExpressionVisitor( [NotNull] SelectExpression selectExpression, [NotNull] ParameterExpression jObjectParameter, - bool trackQueryResults) + bool trackQueryResults) : base(jObjectParameter, trackQueryResults) { _selectExpression = selectExpression; - _jObjectParameter = jObjectParameter; - _trackQueryResults = trackQueryResults; - } - - protected override Expression VisitBinary(BinaryExpression binaryExpression) - { - Check.NotNull(binaryExpression, nameof(binaryExpression)); - - if (binaryExpression.NodeType == ExpressionType.Assign) - { - if (binaryExpression.Left is ParameterExpression parameterExpression) - { - if (parameterExpression.Type == typeof(JObject) - || parameterExpression.Type == typeof(JArray)) - { - string storeName = null; - - // Values injected by JObjectInjectingExpressionVisitor - var projectionExpression = ((UnaryExpression)binaryExpression.Right).Operand; - if (projectionExpression is ProjectionBindingExpression projectionBindingExpression) - { - var projection = GetProjection(projectionBindingExpression); - projectionExpression = projection.Expression; - storeName = projection.Alias; - } - else if (projectionExpression is UnaryExpression convertExpression - && convertExpression.NodeType == ExpressionType.Convert) - { - // Unwrap EntityProjectionExpression when the root entity is not projected - projectionExpression = ((UnaryExpression)convertExpression.Operand).Operand; - } - - Expression innerAccessExpression; - if (projectionExpression is ObjectArrayProjectionExpression objectArrayProjectionExpression) - { - innerAccessExpression = objectArrayProjectionExpression.AccessExpression; - _projectionBindings[objectArrayProjectionExpression] = parameterExpression; - storeName ??= objectArrayProjectionExpression.Name; - } - else - { - var entityProjectionExpression = (EntityProjectionExpression)projectionExpression; - var accessExpression = entityProjectionExpression.AccessExpression; - _projectionBindings[accessExpression] = parameterExpression; - storeName ??= entityProjectionExpression.Name; - - switch (accessExpression) - { - case ObjectAccessExpression innerObjectAccessExpression: - innerAccessExpression = innerObjectAccessExpression.AccessExpression; - _ownerMappings[accessExpression] = - (innerObjectAccessExpression.Navigation.DeclaringEntityType, innerAccessExpression); - break; - case RootReferenceExpression _: - innerAccessExpression = _jObjectParameter; - break; - default: - throw new InvalidOperationException( - CoreStrings.QueryFailed(binaryExpression.Print(), GetType().Name)); - } - } - - var valueExpression = CreateGetValueExpression(innerAccessExpression, storeName, parameterExpression.Type); - - return Expression.MakeBinary(ExpressionType.Assign, binaryExpression.Left, valueExpression); - } - - if (parameterExpression.Type == typeof(MaterializationContext)) - { - var newExpression = (NewExpression)binaryExpression.Right; - - EntityProjectionExpression entityProjectionExpression; - if (newExpression.Arguments[0] is ProjectionBindingExpression projectionBindingExpression) - { - var projection = GetProjection(projectionBindingExpression); - entityProjectionExpression = (EntityProjectionExpression)projection.Expression; - } - else - { - var projection = ((UnaryExpression)((UnaryExpression)newExpression.Arguments[0]).Operand).Operand; - entityProjectionExpression = (EntityProjectionExpression)projection; - } - - _materializationContextBindings[parameterExpression] = entityProjectionExpression.AccessExpression; - - var updatedExpression = Expression.New( - newExpression.Constructor, - Expression.Constant(ValueBuffer.Empty), - newExpression.Arguments[1]); - - return Expression.MakeBinary(ExpressionType.Assign, binaryExpression.Left, updatedExpression); - } - } - - if (binaryExpression.Left is MemberExpression memberExpression - && memberExpression.Member is FieldInfo fieldInfo - && fieldInfo.IsInitOnly) - { - return memberExpression.Assign(Visit(binaryExpression.Right)); - } - } - - return base.VisitBinary(binaryExpression); - } - - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) - { - Check.NotNull(methodCallExpression, nameof(methodCallExpression)); - - var method = methodCallExpression.Method; - var genericMethod = method.IsGenericMethod ? method.GetGenericMethodDefinition() : null; - if (genericMethod == EntityFrameworkCore.Infrastructure.ExpressionExtensions.ValueBufferTryReadValueMethod) - { - var property = (IProperty)((ConstantExpression)methodCallExpression.Arguments[2]).Value; - Expression innerExpression; - if (methodCallExpression.Arguments[0] is ProjectionBindingExpression projectionBindingExpression) - { - var projection = GetProjection(projectionBindingExpression); - - innerExpression = Expression.Convert( - CreateReadJTokenExpression(_jObjectParameter, projection.Alias), - typeof(JObject)); - } - else - { - innerExpression = _materializationContextBindings[ - (ParameterExpression)((MethodCallExpression)methodCallExpression.Arguments[0]).Object]; - } - - return CreateGetValueExpression(innerExpression, property, methodCallExpression.Type); - } - - if (method.DeclaringType == typeof(Enumerable) - && method.Name == nameof(Enumerable.Select) - && genericMethod == EnumerableMethods.Select) - { - var lambda = (LambdaExpression)methodCallExpression.Arguments[1]; - if (lambda.Body is IncludeExpression includeExpression) - { - if (includeExpression.Navigation.IsOnDependent - || includeExpression.Navigation.ForeignKey.DeclaringEntityType.IsDocumentRoot()) - { - throw new InvalidOperationException(CosmosStrings.NonEmbeddedIncludeNotSupported(includeExpression.Print())); - } - - _pendingIncludes.Add(includeExpression); - - Visit(includeExpression.EntityExpression); - - // Includes on collections are processed when visiting CollectionShaperExpression - return Visit(methodCallExpression.Arguments[0]); - } - } - - return base.VisitMethodCall(methodCallExpression); - } - - protected override Expression VisitExtension(Expression extensionExpression) - { - Check.NotNull(extensionExpression, nameof(extensionExpression)); - - switch (extensionExpression) - { - case ProjectionBindingExpression projectionBindingExpression: - { - var projection = GetProjection(projectionBindingExpression); - - return CreateGetValueExpression( - _jObjectParameter, - projection.Alias, - projectionBindingExpression.Type, (projection.Expression as SqlExpression)?.TypeMapping); - } - - case CollectionShaperExpression collectionShaperExpression: - { - ObjectArrayProjectionExpression objectArrayProjection; - switch (collectionShaperExpression.Projection) - { - case ProjectionBindingExpression projectionBindingExpression: - var projection = GetProjection(projectionBindingExpression); - objectArrayProjection = (ObjectArrayProjectionExpression)projection.Expression; - break; - case ObjectArrayProjectionExpression objectArrayProjectionExpression: - objectArrayProjection = objectArrayProjectionExpression; - break; - default: - throw new InvalidOperationException(CoreStrings.QueryFailed(extensionExpression.Print(), GetType().Name)); - } - - var jArray = _projectionBindings[objectArrayProjection]; - var jObjectParameter = Expression.Parameter(typeof(JObject), jArray.Name + "Object"); - var ordinalParameter = Expression.Parameter(typeof(int), jArray.Name + "Ordinal"); - - var accessExpression = objectArrayProjection.InnerProjection.AccessExpression; - _projectionBindings[accessExpression] = jObjectParameter; - _ownerMappings[accessExpression] = ( - objectArrayProjection.Navigation.DeclaringEntityType, objectArrayProjection.AccessExpression); - _ordinalParameterBindings[accessExpression] = ordinalParameter; - - var innerShaper = (BlockExpression)Visit(collectionShaperExpression.InnerShaper); - - innerShaper = AddIncludes(innerShaper); - - var entities = Expression.Call( - EnumerableMethods.SelectWithOrdinal.MakeGenericMethod(typeof(JObject), innerShaper.Type), - Expression.Call( - EnumerableMethods.Cast.MakeGenericMethod(typeof(JObject)), - jArray), - Expression.Lambda(innerShaper, jObjectParameter, ordinalParameter)); - - var navigation = collectionShaperExpression.Navigation; - return Expression.Call( - _populateCollectionMethodInfo.MakeGenericMethod(navigation.TargetEntityType.ClrType, navigation.ClrType), - Expression.Constant(navigation.GetCollectionAccessor()), - entities); - } - - case IncludeExpression includeExpression: - { - if (includeExpression.Navigation.IsOnDependent - || includeExpression.Navigation.ForeignKey.DeclaringEntityType.IsDocumentRoot()) - { - throw new InvalidOperationException(CosmosStrings.NonEmbeddedIncludeNotSupported(includeExpression.Print())); - } - - var isFirstInclude = _pendingIncludes.Count == 0; - _pendingIncludes.Add(includeExpression); - - var jObjectBlock = Visit(includeExpression.EntityExpression) as BlockExpression; - - if (!isFirstInclude) - { - return jObjectBlock; - } - - Check.DebugAssert(jObjectBlock != null, "The first include must end up on a valid shaper block"); - - // These are the expressions added by JObjectInjectingExpressionVisitor - var jObjectCondition = (ConditionalExpression)jObjectBlock.Expressions[jObjectBlock.Expressions.Count - 1]; - - var shaperBlock = (BlockExpression)jObjectCondition.IfFalse; - shaperBlock = AddIncludes(shaperBlock); - - var jObjectExpressions = new List(jObjectBlock.Expressions); - jObjectExpressions.RemoveAt(jObjectExpressions.Count - 1); - - jObjectExpressions.Add( - jObjectCondition.Update(jObjectCondition.Test, jObjectCondition.IfTrue, shaperBlock)); - - return jObjectBlock.Update(jObjectBlock.Variables, jObjectExpressions); - } - } - - return base.VisitExtension(extensionExpression); - } - - private BlockExpression AddIncludes(BlockExpression shaperBlock) - { - if (_pendingIncludes.Count == 0) - { - return shaperBlock; - } - - var shaperExpressions = new List(shaperBlock.Expressions); - var instanceVariable = shaperExpressions[shaperExpressions.Count - 1]; - shaperExpressions.RemoveAt(shaperExpressions.Count - 1); - - var includesToProcess = _pendingIncludes; - _pendingIncludes = new List(); - - foreach (var include in includesToProcess) - { - AddInclude(shaperExpressions, include, shaperBlock, instanceVariable); - } - - shaperExpressions.Add(instanceVariable); - shaperBlock = shaperBlock.Update(shaperBlock.Variables, shaperExpressions); - return shaperBlock; - } - - private void AddInclude( - List shaperExpressions, - IncludeExpression includeExpression, - BlockExpression shaperBlock, - Expression instanceVariable) - { - var navigation = includeExpression.Navigation; - var includeMethod = navigation.IsCollection ? _includeCollectionMethodInfo : _includeReferenceMethodInfo; - var includingClrType = navigation.DeclaringEntityType.ClrType; - var relatedEntityClrType = navigation.TargetEntityType.ClrType; -#pragma warning disable EF1001 // Internal EF Core API usage. - // #16707 - var entityEntryVariable = _trackQueryResults - ? shaperBlock.Variables.Single(v => v.Type == typeof(InternalEntityEntry)) - : (Expression)Expression.Constant(null, typeof(InternalEntityEntry)); -#pragma warning restore EF1001 // Internal EF Core API usage. - var concreteEntityTypeVariable = shaperBlock.Variables.Single(v => v.Type == typeof(IEntityType)); - var inverseNavigation = navigation.Inverse; - var fixup = GenerateFixup( - includingClrType, relatedEntityClrType, navigation, inverseNavigation); - var initialize = GenerateInitialize(includingClrType, navigation); - - var navigationExpression = Visit(includeExpression.NavigationExpression); - - shaperExpressions.Add( - Expression.Call( - includeMethod.MakeGenericMethod(includingClrType, relatedEntityClrType), - entityEntryVariable, - instanceVariable, - concreteEntityTypeVariable, - navigationExpression, - Expression.Constant(navigation), - Expression.Constant(inverseNavigation, typeof(INavigation)), - Expression.Constant(fixup), - Expression.Constant(initialize, typeof(Action<>).MakeGenericType(includingClrType)))); - } - - private static readonly MethodInfo _includeReferenceMethodInfo - = typeof(CosmosProjectionBindingRemovingExpressionVisitor).GetTypeInfo() - .GetDeclaredMethod(nameof(IncludeReference)); - - private static void IncludeReference( -#pragma warning disable EF1001 // Internal EF Core API usage. - // #16707 - InternalEntityEntry entry, -#pragma warning restore EF1001 // Internal EF Core API usage. - object entity, - IEntityType entityType, - TIncludedEntity relatedEntity, - INavigation navigation, - INavigation inverseNavigation, - Action fixup, - Action _) - { - if (entity == null - || !navigation.DeclaringEntityType.IsAssignableFrom(entityType)) - { - return; - } - - if (entry == null) - { - var includingEntity = (TIncludingEntity)entity; - SetIsLoadedNoTracking(includingEntity, navigation); - if (relatedEntity != null) - { - fixup(includingEntity, relatedEntity); - if (inverseNavigation != null - && !inverseNavigation.IsCollection) - { - SetIsLoadedNoTracking(relatedEntity, inverseNavigation); - } - } - } - // For non-null relatedEntity StateManager will set the flag - else if (relatedEntity == null) - { -#pragma warning disable EF1001 // Internal EF Core API usage. - // #16707 - entry.SetIsLoaded(navigation); -#pragma warning restore EF1001 // Internal EF Core API usage. - } - } - - private static readonly MethodInfo _includeCollectionMethodInfo - = typeof(CosmosProjectionBindingRemovingExpressionVisitor).GetTypeInfo() - .GetDeclaredMethod(nameof(IncludeCollection)); - - private static void IncludeCollection( -#pragma warning disable EF1001 // Internal EF Core API usage. - // #16707 - InternalEntityEntry entry, -#pragma warning restore EF1001 // Internal EF Core API usage. - object entity, - IEntityType entityType, - IEnumerable relatedEntities, - INavigation navigation, - INavigation inverseNavigation, - Action fixup, - Action initialize) - { - if (entity == null - || !navigation.DeclaringEntityType.IsAssignableFrom(entityType)) - { - return; - } - - if (entry == null) - { - var includingEntity = (TIncludingEntity)entity; - SetIsLoadedNoTracking(includingEntity, navigation); - - if (relatedEntities != null) - { - foreach (var relatedEntity in relatedEntities) - { - fixup(includingEntity, relatedEntity); - if (inverseNavigation != null) - { - SetIsLoadedNoTracking(relatedEntity, inverseNavigation); - } - } - } - else - { - initialize(includingEntity); - } - } - else - { -#pragma warning disable EF1001 // Internal EF Core API usage. - // #16707 - entry.SetIsLoaded(navigation); -#pragma warning restore EF1001 // Internal EF Core API usage. - if (relatedEntities != null) - { - using var enumerator = relatedEntities.GetEnumerator(); - while (enumerator.MoveNext()) - { - } - } - else - { - initialize((TIncludingEntity)entity); - } - } - } - - private static void SetIsLoadedNoTracking(object entity, INavigation navigation) - => ((ILazyLoader)(navigation - .DeclaringEntityType - .GetServiceProperties() - .FirstOrDefault(p => p.ClrType == typeof(ILazyLoader))) - ?.GetGetter().GetClrValue(entity)) - ?.SetLoaded(entity, navigation.Name); - - private static Delegate GenerateFixup( - Type entityType, - Type relatedEntityType, - INavigation navigation, - INavigation inverseNavigation) - { - var entityParameter = Expression.Parameter(entityType); - var relatedEntityParameter = Expression.Parameter(relatedEntityType); - var expressions = new List - { - navigation.IsCollection - ? AddToCollectionNavigation(entityParameter, relatedEntityParameter, navigation) - : AssignReferenceNavigation(entityParameter, relatedEntityParameter, navigation) - }; - - if (inverseNavigation != null) - { - expressions.Add( - inverseNavigation.IsCollection - ? AddToCollectionNavigation(relatedEntityParameter, entityParameter, inverseNavigation) - : AssignReferenceNavigation(relatedEntityParameter, entityParameter, inverseNavigation)); - } - - return Expression.Lambda(Expression.Block(typeof(void), expressions), entityParameter, relatedEntityParameter) - .Compile(); - } - - private static Delegate GenerateInitialize( - Type entityType, - INavigation navigation) - { - if (!navigation.IsCollection) - { - return null; - } - - var entityParameter = Expression.Parameter(entityType); - - var getOrCreateExpression = Expression.Call( - Expression.Constant(navigation.GetCollectionAccessor()), - _collectionAccessorGetOrCreateMethodInfo, - entityParameter, - Expression.Constant(true)); - - return Expression.Lambda(Expression.Block(typeof(void), getOrCreateExpression), entityParameter) - .Compile(); - } - - private static Expression AssignReferenceNavigation( - ParameterExpression entity, - ParameterExpression relatedEntity, - INavigation navigation) - => entity.MakeMemberAccess(navigation.GetMemberInfo(forMaterialization: true, forSet: true)).Assign(relatedEntity); - - private static Expression AddToCollectionNavigation( - ParameterExpression entity, - ParameterExpression relatedEntity, - INavigation navigation) - => Expression.Call( - Expression.Constant(navigation.GetCollectionAccessor()), - _collectionAccessorAddMethodInfo, - entity, - relatedEntity, - Expression.Constant(true)); - - private static readonly MethodInfo _populateCollectionMethodInfo - = typeof(CosmosProjectionBindingRemovingExpressionVisitor).GetTypeInfo() - .GetDeclaredMethod(nameof(PopulateCollection)); - - private static TCollection PopulateCollection( - IClrCollectionAccessor accessor, - IEnumerable entities) - { - // TODO: throw a better exception for non ICollection navigations - var collection = (ICollection)accessor.Create(); - foreach (var entity in entities) - { - collection.Add(entity); - } - - return (TCollection)collection; } + + protected override ProjectionExpression GetProjection(ProjectionBindingExpression projectionBindingExpression) + => _selectExpression.Projection[GetProjectionIndex(projectionBindingExpression)]; private int GetProjectionIndex(ProjectionBindingExpression projectionBindingExpression) => projectionBindingExpression.ProjectionMember != null ? (int)((ConstantExpression)_selectExpression.GetMappedProjection(projectionBindingExpression.ProjectionMember)).Value : projectionBindingExpression.Index ?? throw new InvalidOperationException(CoreStrings.QueryFailed(projectionBindingExpression.Print(), GetType().Name)); - - private ProjectionExpression GetProjection(ProjectionBindingExpression projectionBindingExpression) - { - var index = GetProjectionIndex(projectionBindingExpression); - return _selectExpression.Projection[index]; - } - - private static Expression CreateReadJTokenExpression(Expression jObjectExpression, string propertyName) - => Expression.Call(jObjectExpression, _getItemMethodInfo, Expression.Constant(propertyName)); - - private Expression CreateGetValueExpression( - Expression jObjectExpression, - IProperty property, - Type clrType) - { - if (property.Name == StoreKeyConvention.JObjectPropertyName) - { - return _projectionBindings[jObjectExpression]; - } - - var storeName = property.GetJsonPropertyName(); - if (storeName.Length == 0) - { - var entityType = property.DeclaringEntityType; - if (!entityType.IsDocumentRoot()) - { - var ownership = entityType.FindOwnership(); - if (!ownership.IsUnique - && property.IsOrdinalKeyProperty()) - { - Expression readExpression = _ordinalParameterBindings[jObjectExpression]; - if (readExpression.Type != clrType) - { - readExpression = Expression.Convert(readExpression, clrType); - } - - return readExpression; - } - - var principalProperty = property.FindFirstPrincipal(); - if (principalProperty != null) - { - Expression ownerJObjectExpression = null; - if (_ownerMappings.TryGetValue(jObjectExpression, out var ownerInfo)) - { - Check.DebugAssert( - principalProperty.DeclaringEntityType.IsAssignableFrom(ownerInfo.EntityType), - $"{principalProperty.DeclaringEntityType} is not assignable from {ownerInfo.EntityType}"); - - ownerJObjectExpression = ownerInfo.JObjectExpression; - } - else if (jObjectExpression is RootReferenceExpression rootReferenceExpression) - { - ownerJObjectExpression = rootReferenceExpression; - } - else if (jObjectExpression is ObjectAccessExpression objectAccessExpression) - { - ownerJObjectExpression = objectAccessExpression.AccessExpression; - } - - if (ownerJObjectExpression != null) - { - return CreateGetValueExpression(ownerJObjectExpression, principalProperty, clrType); - } - } - } - - return Expression.Default(clrType); - } - - return CreateGetValueExpression(jObjectExpression, storeName, clrType, property.GetTypeMapping()); - } - - private Expression CreateGetValueExpression( - Expression jObjectExpression, - string storeName, - Type clrType, - CoreTypeMapping typeMapping = null) - { - var innerExpression = jObjectExpression; - if (_projectionBindings.TryGetValue(jObjectExpression, out var innerVariable)) - { - innerExpression = innerVariable; - } - else if (jObjectExpression is RootReferenceExpression rootReferenceExpression) - { - innerExpression = CreateGetValueExpression( - _jObjectParameter, rootReferenceExpression.Alias, typeof(JObject)); - } - else if (jObjectExpression is ObjectAccessExpression objectAccessExpression) - { - var innerAccessExpression = objectAccessExpression.AccessExpression; - - innerExpression = CreateGetValueExpression( - innerAccessExpression, ((IAccessExpression)innerAccessExpression).Name, typeof(JObject)); - } - - var jTokenExpression = CreateReadJTokenExpression(innerExpression, storeName); - - Expression valueExpression; - var converter = typeMapping?.Converter; - if (converter != null) - { - var jTokenParameter = Expression.Parameter(typeof(JToken)); - - var body - = ReplacingExpressionVisitor.Replace( - converter.ConvertFromProviderExpression.Parameters.Single(), - Expression.Call( - jTokenParameter, - _jTokenToObjectMethodInfo.MakeGenericMethod(converter.ProviderClrType)), - converter.ConvertFromProviderExpression.Body); - - if (body.Type != clrType) - { - body = Expression.Convert(body, clrType); - } - - body = Expression.Condition( - Expression.OrElse( - Expression.Equal(jTokenParameter, Expression.Default(typeof(JToken))), - Expression.Equal( - Expression.MakeMemberAccess(jTokenParameter, _jTokenTypePropertyInfo), - Expression.Constant(JTokenType.Null))), - Expression.Default(clrType), - body); - - valueExpression = Expression.Invoke(Expression.Lambda(body, jTokenParameter), jTokenExpression); - } - else - { - valueExpression = ConvertJTokenToType(jTokenExpression, typeMapping?.ClrType.MakeNullable() ?? clrType); - - if (valueExpression.Type != clrType) - { - valueExpression = Expression.Convert(valueExpression, clrType); - } - } - - return valueExpression; - } - - private static Expression ConvertJTokenToType(Expression jTokenExpression, Type type) - => type == typeof(JToken) - ? jTokenExpression - : Expression.Call( - _toObjectMethodInfo.MakeGenericMethod(type), - jTokenExpression); - - private static T SafeToObject(JToken token) - => token == null || token.Type == JTokenType.Null ? default : token.ToObject(); } } } diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs new file mode 100644 index 00000000000..242264f28fe --- /dev/null +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs @@ -0,0 +1,728 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; +using Microsoft.EntityFrameworkCore.Cosmos.Internal; +using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal; +using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Metadata.Conventions; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Storage; +using Microsoft.EntityFrameworkCore.Utilities; +using Newtonsoft.Json.Linq; + +namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal +{ + public partial class CosmosShapedQueryCompilingExpressionVisitor + { + private abstract class CosmosProjectionBindingRemovingExpressionVisitorBase : ExpressionVisitor + { + private static readonly MethodInfo _getItemMethodInfo + = typeof(JObject).GetRuntimeProperties() + .Single(pi => pi.Name == "Item" && pi.GetIndexParameters()[0].ParameterType == typeof(string)) + .GetMethod; + + private static readonly PropertyInfo _jTokenTypePropertyInfo + = typeof(JToken).GetRuntimeProperties() + .Single(mi => mi.Name == nameof(JToken.Type)); + + private static readonly MethodInfo _jTokenToObjectMethodInfo + = typeof(JToken).GetRuntimeMethods() + .Single(mi => mi.Name == nameof(JToken.ToObject) && mi.GetParameters().Length == 0); + + private static readonly MethodInfo _collectionAccessorAddMethodInfo + = typeof(IClrCollectionAccessor).GetTypeInfo() + .GetDeclaredMethod(nameof(IClrCollectionAccessor.Add)); + + private static readonly MethodInfo _collectionAccessorGetOrCreateMethodInfo + = typeof(IClrCollectionAccessor).GetTypeInfo() + .GetDeclaredMethod(nameof(IClrCollectionAccessor.GetOrCreate)); + + private readonly ParameterExpression _jObjectParameter; + private readonly bool _trackQueryResults; + + private readonly IDictionary _materializationContextBindings + = new Dictionary(); + + private readonly IDictionary _projectionBindings + = new Dictionary(); + + private readonly IDictionary _ownerMappings + = new Dictionary(); + + private readonly IDictionary _ordinalParameterBindings + = new Dictionary(); + + private List _pendingIncludes + = new List(); + + private static readonly MethodInfo _toObjectMethodInfo + = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase) + .GetRuntimeMethods().Single(mi => mi.Name == nameof(SafeToObject)); + + public CosmosProjectionBindingRemovingExpressionVisitorBase( + [NotNull] ParameterExpression jObjectParameter, + bool trackQueryResults) + { + _jObjectParameter = jObjectParameter; + _trackQueryResults = trackQueryResults; + } + + protected override Expression VisitBinary(BinaryExpression binaryExpression) + { + Check.NotNull(binaryExpression, nameof(binaryExpression)); + + if (binaryExpression.NodeType == ExpressionType.Assign) + { + if (binaryExpression.Left is ParameterExpression parameterExpression) + { + if (parameterExpression.Type == typeof(JObject) + || parameterExpression.Type == typeof(JArray)) + { + string storeName = null; + + // Values injected by JObjectInjectingExpressionVisitor + var projectionExpression = ((UnaryExpression)binaryExpression.Right).Operand; + if (projectionExpression is ProjectionBindingExpression projectionBindingExpression) + { + var projection = GetProjection(projectionBindingExpression); + projectionExpression = projection.Expression; + storeName = projection.Alias; + } + else if (projectionExpression is UnaryExpression convertExpression + && convertExpression.NodeType == ExpressionType.Convert) + { + // Unwrap EntityProjectionExpression when the root entity is not projected + projectionExpression = ((UnaryExpression)convertExpression.Operand).Operand; + } + + Expression innerAccessExpression; + if (projectionExpression is ObjectArrayProjectionExpression objectArrayProjectionExpression) + { + innerAccessExpression = objectArrayProjectionExpression.AccessExpression; + _projectionBindings[objectArrayProjectionExpression] = parameterExpression; + storeName ??= objectArrayProjectionExpression.Name; + } + else + { + var entityProjectionExpression = (EntityProjectionExpression)projectionExpression; + var accessExpression = entityProjectionExpression.AccessExpression; + _projectionBindings[accessExpression] = parameterExpression; + storeName ??= entityProjectionExpression.Name; + + switch (accessExpression) + { + case ObjectAccessExpression innerObjectAccessExpression: + innerAccessExpression = innerObjectAccessExpression.AccessExpression; + _ownerMappings[accessExpression] = + (innerObjectAccessExpression.Navigation.DeclaringEntityType, innerAccessExpression); + break; + case RootReferenceExpression _: + innerAccessExpression = _jObjectParameter; + break; + default: + throw new InvalidOperationException( + CoreStrings.QueryFailed(binaryExpression.Print(), GetType().Name)); + } + } + + var valueExpression = CreateGetValueExpression(innerAccessExpression, storeName, parameterExpression.Type); + + return Expression.MakeBinary(ExpressionType.Assign, binaryExpression.Left, valueExpression); + } + + if (parameterExpression.Type == typeof(MaterializationContext)) + { + var newExpression = (NewExpression)binaryExpression.Right; + + EntityProjectionExpression entityProjectionExpression; + if (newExpression.Arguments[0] is ProjectionBindingExpression projectionBindingExpression) + { + var projection = GetProjection(projectionBindingExpression); + entityProjectionExpression = (EntityProjectionExpression)projection.Expression; + } + else + { + var projection = ((UnaryExpression)((UnaryExpression)newExpression.Arguments[0]).Operand).Operand; + entityProjectionExpression = (EntityProjectionExpression)projection; + } + + _materializationContextBindings[parameterExpression] = entityProjectionExpression.AccessExpression; + + var updatedExpression = Expression.New( + newExpression.Constructor, + Expression.Constant(ValueBuffer.Empty), + newExpression.Arguments[1]); + + return Expression.MakeBinary(ExpressionType.Assign, binaryExpression.Left, updatedExpression); + } + } + + if (binaryExpression.Left is MemberExpression memberExpression + && memberExpression.Member is FieldInfo fieldInfo + && fieldInfo.IsInitOnly) + { + return memberExpression.Assign(Visit(binaryExpression.Right)); + } + } + + return base.VisitBinary(binaryExpression); + } + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + Check.NotNull(methodCallExpression, nameof(methodCallExpression)); + + var method = methodCallExpression.Method; + var genericMethod = method.IsGenericMethod ? method.GetGenericMethodDefinition() : null; + if (genericMethod == EntityFrameworkCore.Infrastructure.ExpressionExtensions.ValueBufferTryReadValueMethod) + { + var property = (IProperty)((ConstantExpression)methodCallExpression.Arguments[2]).Value; + Expression innerExpression; + if (methodCallExpression.Arguments[0] is ProjectionBindingExpression projectionBindingExpression) + { + var projection = GetProjection(projectionBindingExpression); + + innerExpression = Expression.Convert( + CreateReadJTokenExpression(_jObjectParameter, projection.Alias), + typeof(JObject)); + } + else + { + innerExpression = _materializationContextBindings[ + (ParameterExpression)((MethodCallExpression)methodCallExpression.Arguments[0]).Object]; + } + + return CreateGetValueExpression(innerExpression, property, methodCallExpression.Type); + } + + if (method.DeclaringType == typeof(Enumerable) + && method.Name == nameof(Enumerable.Select) + && genericMethod == EnumerableMethods.Select) + { + var lambda = (LambdaExpression)methodCallExpression.Arguments[1]; + if (lambda.Body is IncludeExpression includeExpression) + { + if (includeExpression.Navigation.IsOnDependent + || includeExpression.Navigation.ForeignKey.DeclaringEntityType.IsDocumentRoot()) + { + throw new InvalidOperationException(CosmosStrings.NonEmbeddedIncludeNotSupported(includeExpression.Print())); + } + + _pendingIncludes.Add(includeExpression); + + Visit(includeExpression.EntityExpression); + + // Includes on collections are processed when visiting CollectionShaperExpression + return Visit(methodCallExpression.Arguments[0]); + } + } + + return base.VisitMethodCall(methodCallExpression); + } + + protected override Expression VisitExtension(Expression extensionExpression) + { + Check.NotNull(extensionExpression, nameof(extensionExpression)); + + switch (extensionExpression) + { + case ProjectionBindingExpression projectionBindingExpression: + { + var projection = GetProjection(projectionBindingExpression); + + return CreateGetValueExpression( + _jObjectParameter, + projection.Alias, + projectionBindingExpression.Type, (projection.Expression as SqlExpression)?.TypeMapping); + } + + case CollectionShaperExpression collectionShaperExpression: + { + ObjectArrayProjectionExpression objectArrayProjection; + switch (collectionShaperExpression.Projection) + { + case ProjectionBindingExpression projectionBindingExpression: + var projection = GetProjection(projectionBindingExpression); + objectArrayProjection = (ObjectArrayProjectionExpression)projection.Expression; + break; + case ObjectArrayProjectionExpression objectArrayProjectionExpression: + objectArrayProjection = objectArrayProjectionExpression; + break; + default: + throw new InvalidOperationException(CoreStrings.QueryFailed(extensionExpression.Print(), GetType().Name)); + } + + var jArray = _projectionBindings[objectArrayProjection]; + var jObjectParameter = Expression.Parameter(typeof(JObject), jArray.Name + "Object"); + var ordinalParameter = Expression.Parameter(typeof(int), jArray.Name + "Ordinal"); + + var accessExpression = objectArrayProjection.InnerProjection.AccessExpression; + _projectionBindings[accessExpression] = jObjectParameter; + _ownerMappings[accessExpression] = ( + objectArrayProjection.Navigation.DeclaringEntityType, objectArrayProjection.AccessExpression); + _ordinalParameterBindings[accessExpression] = ordinalParameter; + + var innerShaper = (BlockExpression)Visit(collectionShaperExpression.InnerShaper); + + innerShaper = AddIncludes(innerShaper); + + var entities = Expression.Call( + EnumerableMethods.SelectWithOrdinal.MakeGenericMethod(typeof(JObject), innerShaper.Type), + Expression.Call( + EnumerableMethods.Cast.MakeGenericMethod(typeof(JObject)), + jArray), + Expression.Lambda(innerShaper, jObjectParameter, ordinalParameter)); + + var navigation = collectionShaperExpression.Navigation; + return Expression.Call( + _populateCollectionMethodInfo.MakeGenericMethod(navigation.TargetEntityType.ClrType, navigation.ClrType), + Expression.Constant(navigation.GetCollectionAccessor()), + entities); + } + + case IncludeExpression includeExpression: + { + if (includeExpression.Navigation.IsOnDependent + || includeExpression.Navigation.ForeignKey.DeclaringEntityType.IsDocumentRoot()) + { + throw new InvalidOperationException(CosmosStrings.NonEmbeddedIncludeNotSupported(includeExpression.Print())); + } + + var isFirstInclude = _pendingIncludes.Count == 0; + _pendingIncludes.Add(includeExpression); + + var jObjectBlock = Visit(includeExpression.EntityExpression) as BlockExpression; + + if (!isFirstInclude) + { + return jObjectBlock; + } + + Check.DebugAssert(jObjectBlock != null, "The first include must end up on a valid shaper block"); + + // These are the expressions added by JObjectInjectingExpressionVisitor + var jObjectCondition = (ConditionalExpression)jObjectBlock.Expressions[jObjectBlock.Expressions.Count - 1]; + + var shaperBlock = (BlockExpression)jObjectCondition.IfFalse; + shaperBlock = AddIncludes(shaperBlock); + + var jObjectExpressions = new List(jObjectBlock.Expressions); + jObjectExpressions.RemoveAt(jObjectExpressions.Count - 1); + + jObjectExpressions.Add( + jObjectCondition.Update(jObjectCondition.Test, jObjectCondition.IfTrue, shaperBlock)); + + return jObjectBlock.Update(jObjectBlock.Variables, jObjectExpressions); + } + } + + return base.VisitExtension(extensionExpression); + } + + private BlockExpression AddIncludes(BlockExpression shaperBlock) + { + if (_pendingIncludes.Count == 0) + { + return shaperBlock; + } + + var shaperExpressions = new List(shaperBlock.Expressions); + var instanceVariable = shaperExpressions[shaperExpressions.Count - 1]; + shaperExpressions.RemoveAt(shaperExpressions.Count - 1); + + var includesToProcess = _pendingIncludes; + _pendingIncludes = new List(); + + foreach (var include in includesToProcess) + { + AddInclude(shaperExpressions, include, shaperBlock, instanceVariable); + } + + shaperExpressions.Add(instanceVariable); + shaperBlock = shaperBlock.Update(shaperBlock.Variables, shaperExpressions); + return shaperBlock; + } + + private void AddInclude( + List shaperExpressions, + IncludeExpression includeExpression, + BlockExpression shaperBlock, + Expression instanceVariable) + { + var navigation = includeExpression.Navigation; + var includeMethod = navigation.IsCollection ? _includeCollectionMethodInfo : _includeReferenceMethodInfo; + var includingClrType = navigation.DeclaringEntityType.ClrType; + var relatedEntityClrType = navigation.TargetEntityType.ClrType; + var entityEntryVariable = _trackQueryResults + ? shaperBlock.Variables.Single(v => v.Type == typeof(InternalEntityEntry)) + : (Expression)Expression.Constant(null, typeof(InternalEntityEntry)); + var concreteEntityTypeVariable = shaperBlock.Variables.Single(v => v.Type == typeof(IEntityType)); + var inverseNavigation = navigation.Inverse; + var fixup = GenerateFixup( + includingClrType, relatedEntityClrType, navigation, inverseNavigation); + var initialize = GenerateInitialize(includingClrType, navigation); + + var navigationExpression = Visit(includeExpression.NavigationExpression); + + shaperExpressions.Add( + Expression.Call( + includeMethod.MakeGenericMethod(includingClrType, relatedEntityClrType), + entityEntryVariable, + instanceVariable, + concreteEntityTypeVariable, + navigationExpression, + Expression.Constant(navigation), + Expression.Constant(inverseNavigation, typeof(INavigation)), + Expression.Constant(fixup), + Expression.Constant(initialize, typeof(Action<>).MakeGenericType(includingClrType)))); + } + + private static readonly MethodInfo _includeReferenceMethodInfo + = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase).GetTypeInfo() + .GetDeclaredMethod(nameof(IncludeReference)); + + private static void IncludeReference( + InternalEntityEntry entry, + object entity, + IEntityType entityType, + TIncludedEntity relatedEntity, + INavigation navigation, + INavigation inverseNavigation, + Action fixup, + Action _) + { + if (entity == null + || !navigation.DeclaringEntityType.IsAssignableFrom(entityType)) + { + return; + } + + if (entry == null) + { + var includingEntity = (TIncludingEntity)entity; + SetIsLoadedNoTracking(includingEntity, navigation); + if (relatedEntity != null) + { + fixup(includingEntity, relatedEntity); + if (inverseNavigation != null + && !inverseNavigation.IsCollection) + { + SetIsLoadedNoTracking(relatedEntity, inverseNavigation); + } + } + } + // For non-null relatedEntity StateManager will set the flag + else if (relatedEntity == null) + { +#pragma warning disable EF1001 + entry.SetIsLoaded(navigation); + } + } + + private static readonly MethodInfo _includeCollectionMethodInfo + = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase).GetTypeInfo() + .GetDeclaredMethod(nameof(IncludeCollection)); + + private static void IncludeCollection( + InternalEntityEntry entry, + object entity, + IEntityType entityType, + IEnumerable relatedEntities, + INavigation navigation, + INavigation inverseNavigation, + Action fixup, + Action initialize) + { + if (entity == null + || !navigation.DeclaringEntityType.IsAssignableFrom(entityType)) + { + return; + } + + if (entry == null) + { + var includingEntity = (TIncludingEntity)entity; + SetIsLoadedNoTracking(includingEntity, navigation); + + if (relatedEntities != null) + { + foreach (var relatedEntity in relatedEntities) + { + fixup(includingEntity, relatedEntity); + if (inverseNavigation != null) + { + SetIsLoadedNoTracking(relatedEntity, inverseNavigation); + } + } + } + else + { + initialize(includingEntity); + } + } + else + { + entry.SetIsLoaded(navigation); + if (relatedEntities != null) + { + using var enumerator = relatedEntities.GetEnumerator(); + while (enumerator.MoveNext()) + { + } + } + else + { + initialize((TIncludingEntity)entity); + } + } + } + + private static void SetIsLoadedNoTracking(object entity, INavigation navigation) + => ((ILazyLoader)(navigation + .DeclaringEntityType + .GetServiceProperties() + .FirstOrDefault(p => p.ClrType == typeof(ILazyLoader))) + ?.GetGetter().GetClrValue(entity)) + ?.SetLoaded(entity, navigation.Name); + + private static Delegate GenerateFixup( + Type entityType, + Type relatedEntityType, + INavigation navigation, + INavigation inverseNavigation) + { + var entityParameter = Expression.Parameter(entityType); + var relatedEntityParameter = Expression.Parameter(relatedEntityType); + var expressions = new List + { + navigation.IsCollection + ? AddToCollectionNavigation(entityParameter, relatedEntityParameter, navigation) + : AssignReferenceNavigation(entityParameter, relatedEntityParameter, navigation) + }; + + if (inverseNavigation != null) + { + expressions.Add( + inverseNavigation.IsCollection + ? AddToCollectionNavigation(relatedEntityParameter, entityParameter, inverseNavigation) + : AssignReferenceNavigation(relatedEntityParameter, entityParameter, inverseNavigation)); + } + + return Expression.Lambda(Expression.Block(typeof(void), expressions), entityParameter, relatedEntityParameter) + .Compile(); + } + + private static Delegate GenerateInitialize( + Type entityType, + INavigation navigation) + { + if (!navigation.IsCollection) + { + return null; + } + + var entityParameter = Expression.Parameter(entityType); + + var getOrCreateExpression = Expression.Call( + Expression.Constant(navigation.GetCollectionAccessor()), + _collectionAccessorGetOrCreateMethodInfo, + entityParameter, + Expression.Constant(true)); + + return Expression.Lambda(Expression.Block(typeof(void), getOrCreateExpression), entityParameter) + .Compile(); + } + + private static Expression AssignReferenceNavigation( + ParameterExpression entity, + ParameterExpression relatedEntity, + INavigation navigation) + => entity.MakeMemberAccess(navigation.GetMemberInfo(forMaterialization: true, forSet: true)).Assign(relatedEntity); + + private static Expression AddToCollectionNavigation( + ParameterExpression entity, + ParameterExpression relatedEntity, + INavigation navigation) + => Expression.Call( + Expression.Constant(navigation.GetCollectionAccessor()), + _collectionAccessorAddMethodInfo, + entity, + relatedEntity, + Expression.Constant(true)); + + private static readonly MethodInfo _populateCollectionMethodInfo + = typeof(CosmosProjectionBindingRemovingExpressionVisitorBase).GetTypeInfo() + .GetDeclaredMethod(nameof(PopulateCollection)); + + private static TCollection PopulateCollection( + IClrCollectionAccessor accessor, + IEnumerable entities) + { + // TODO: throw a better exception for non ICollection navigations + var collection = (ICollection)accessor.Create(); + foreach (var entity in entities) + { + collection.Add(entity); + } + + return (TCollection)collection; + } + + protected abstract ProjectionExpression GetProjection(ProjectionBindingExpression projectionBindingExpression); + + private static Expression CreateReadJTokenExpression(Expression jObjectExpression, string propertyName) + => Expression.Call(jObjectExpression, _getItemMethodInfo, Expression.Constant(propertyName)); + + private Expression CreateGetValueExpression( + Expression jObjectExpression, + IProperty property, + Type clrType) + { + if (property.Name == StoreKeyConvention.JObjectPropertyName) + { + return _projectionBindings[jObjectExpression]; + } + + var storeName = property.GetJsonPropertyName(); + if (storeName.Length == 0) + { + var entityType = property.DeclaringEntityType; + if (!entityType.IsDocumentRoot()) + { + var ownership = entityType.FindOwnership(); + if (!ownership.IsUnique + && property.IsOrdinalKeyProperty()) + { + Expression readExpression = _ordinalParameterBindings[jObjectExpression]; + if (readExpression.Type != clrType) + { + readExpression = Expression.Convert(readExpression, clrType); + } + + return readExpression; + } + + var principalProperty = property.FindFirstPrincipal(); + if (principalProperty != null) + { + Expression ownerJObjectExpression = null; + if (_ownerMappings.TryGetValue(jObjectExpression, out var ownerInfo)) + { + Check.DebugAssert( + principalProperty.DeclaringEntityType.IsAssignableFrom(ownerInfo.EntityType), + $"{principalProperty.DeclaringEntityType} is not assignable from {ownerInfo.EntityType}"); + + ownerJObjectExpression = ownerInfo.JObjectExpression; + } + else if (jObjectExpression is RootReferenceExpression rootReferenceExpression) + { + ownerJObjectExpression = rootReferenceExpression; + } + else if (jObjectExpression is ObjectAccessExpression objectAccessExpression) + { + ownerJObjectExpression = objectAccessExpression.AccessExpression; + } + + if (ownerJObjectExpression != null) + { + return CreateGetValueExpression(ownerJObjectExpression, principalProperty, clrType); + } + } + } + + return Expression.Default(clrType); + } + + return CreateGetValueExpression(jObjectExpression, storeName, clrType, property.GetTypeMapping()); + } + + private Expression CreateGetValueExpression( + Expression jObjectExpression, + string storeName, + Type clrType, + CoreTypeMapping typeMapping = null) + { + var innerExpression = jObjectExpression; + if (_projectionBindings.TryGetValue(jObjectExpression, out var innerVariable)) + { + innerExpression = innerVariable; + } + else if (jObjectExpression is RootReferenceExpression rootReferenceExpression) + { + innerExpression = CreateGetValueExpression( + _jObjectParameter, rootReferenceExpression.Alias, typeof(JObject)); + } + else if (jObjectExpression is ObjectAccessExpression objectAccessExpression) + { + var innerAccessExpression = objectAccessExpression.AccessExpression; + + innerExpression = CreateGetValueExpression( + innerAccessExpression, ((IAccessExpression)innerAccessExpression).Name, typeof(JObject)); + } + + var jTokenExpression = CreateReadJTokenExpression(innerExpression, storeName); + + Expression valueExpression; + var converter = typeMapping?.Converter; + if (converter != null) + { + var jTokenParameter = Expression.Parameter(typeof(JToken)); + + var body + = ReplacingExpressionVisitor.Replace( + converter.ConvertFromProviderExpression.Parameters.Single(), + Expression.Call( + jTokenParameter, + _jTokenToObjectMethodInfo.MakeGenericMethod(converter.ProviderClrType)), + converter.ConvertFromProviderExpression.Body); + + if (body.Type != clrType) + { + body = Expression.Convert(body, clrType); + } + + body = Expression.Condition( + Expression.OrElse( + Expression.Equal(jTokenParameter, Expression.Default(typeof(JToken))), + Expression.Equal( + Expression.MakeMemberAccess(jTokenParameter, _jTokenTypePropertyInfo), + Expression.Constant(JTokenType.Null))), + Expression.Default(clrType), + body); + + valueExpression = Expression.Invoke(Expression.Lambda(body, jTokenParameter), jTokenExpression); + } + else + { + valueExpression = ConvertJTokenToType(jTokenExpression, typeMapping?.ClrType.MakeNullable() ?? clrType); + + if (valueExpression.Type != clrType) + { + valueExpression = Expression.Convert(valueExpression, clrType); + } + } + + return valueExpression; + } + + private Expression ConvertJTokenToType(Expression jTokenExpression, Type type) + => type == typeof(JToken) + ? jTokenExpression + : Expression.Call( + _toObjectMethodInfo.MakeGenericMethod(type), + jTokenExpression); + + private static T SafeToObject(JToken token) + => token == null || token.Type == JTokenType.Null ? default : token.ToObject(); + } + } +} diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingReadItemExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingReadItemExpressionVisitor.cs new file mode 100644 index 00000000000..3c95dd9364a --- /dev/null +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingReadItemExpressionVisitor.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Diagnostics.CodeAnalysis; +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Query; + +namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal +{ + public partial class CosmosShapedQueryCompilingExpressionVisitor + { + private sealed class CosmosProjectionBindingRemovingReadItemExpressionVisitor : CosmosProjectionBindingRemovingExpressionVisitorBase + { + private readonly ReadItemExpression _readItemExpression; + + public CosmosProjectionBindingRemovingReadItemExpressionVisitor( + [NotNull] ReadItemExpression readItemExpression, + [NotNull] ParameterExpression jObjectParameter, + bool trackQueryResults) + : base(jObjectParameter, trackQueryResults) + { + _readItemExpression = readItemExpression; + } + + protected override ProjectionExpression GetProjection(ProjectionBindingExpression _) + => _readItemExpression.ProjectionExpression; + } + } +} diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs index 81db2208fc9..9cb8dee976e 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs @@ -27,7 +27,7 @@ private sealed class QueryingEnumerable : IEnumerable, IAsyncEnumerable private readonly CosmosQueryContext _cosmosQueryContext; private readonly ISqlExpressionFactory _sqlExpressionFactory; private readonly SelectExpression _selectExpression; - private readonly Func _shaper; + private readonly Func _shaper; private readonly IQuerySqlGeneratorFactory _querySqlGeneratorFactory; private readonly Type _contextType; private readonly string _partitionKey; @@ -38,9 +38,9 @@ public QueryingEnumerable( ISqlExpressionFactory sqlExpressionFactory, IQuerySqlGeneratorFactory querySqlGeneratorFactory, SelectExpression selectExpression, - Func shaper, + Func shaper, Type contextType, - string partitionKey, + string partitionKeyFromExtension, IDiagnosticsLogger logger) { _cosmosQueryContext = cosmosQueryContext; @@ -49,8 +49,8 @@ public QueryingEnumerable( _selectExpression = selectExpression; _shaper = shaper; _contextType = contextType; - _partitionKey = partitionKey; _logger = logger; + _partitionKey = partitionKeyFromExtension; } public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) @@ -153,7 +153,7 @@ private sealed class AsyncEnumerator : IAsyncEnumerator private IAsyncEnumerator _enumerator; private readonly CosmosQueryContext _cosmosQueryContext; private readonly SelectExpression _selectExpression; - private readonly Func _shaper; + private readonly Func _shaper; private readonly ISqlExpressionFactory _sqlExpressionFactory; private readonly IQuerySqlGeneratorFactory _querySqlGeneratorFactory; private readonly Type _contextType; @@ -192,7 +192,8 @@ public async ValueTask MoveNextAsync() _selectExpression.Container, _partitionKey, _querySqlGeneratorFactory.Create().GetSqlQuery( - selectExpression, _cosmosQueryContext.ParameterValues)) + selectExpression, + _cosmosQueryContext.ParameterValues)) .GetAsyncEnumerator(_cancellationToken); } diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.ReadItemQueryingEnumerable.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.ReadItemQueryingEnumerable.cs new file mode 100644 index 00000000000..0cf5922aa10 --- /dev/null +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.ReadItemQueryingEnumerable.cs @@ -0,0 +1,323 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore.ChangeTracking; +using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; +using Microsoft.EntityFrameworkCore.Cosmos.Internal; +using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; +using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Metadata.Conventions; +using Microsoft.EntityFrameworkCore.Query; +using Newtonsoft.Json.Linq; + +namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public partial class CosmosShapedQueryCompilingExpressionVisitor + { + private sealed class ReadItemQueryingEnumerable : IEnumerable, IAsyncEnumerable, IQueryingEnumerable + { + private readonly CosmosQueryContext _cosmosQueryContext; + private readonly ReadItemExpression _readItemExpression; + private readonly Func _shaper; + private readonly Type _contextType; + private readonly IDiagnosticsLogger _logger; + + public ReadItemQueryingEnumerable( + CosmosQueryContext cosmosQueryContext, + ReadItemExpression readItemExpression, + Func shaper, + Type contextType, + IDiagnosticsLogger logger) + { + _cosmosQueryContext = cosmosQueryContext; + _readItemExpression = readItemExpression; + _shaper = shaper; + _contextType = contextType; + _logger = logger; + } + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + => new AsyncEnumerator(this, cancellationToken); + + public IEnumerator GetEnumerator() => new Enumerator(this); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + public string ToQueryString() + { + throw new NotImplementedException("Cosmos: ToQueryString for ReadItemQueryingEnumerable #20653"); + } + + private sealed class Enumerator : ReadItemBase, IEnumerator + { + private JObject _item; + private bool _hasExecuted; + + public Enumerator(ReadItemQueryingEnumerable readItemEnumerable) : base(readItemEnumerable) + { + } + + object IEnumerator.Current => Current; + + public bool MoveNext() + { + try + { + using (CosmosQueryContext.ConcurrencyDetector.EnterCriticalSection()) + { + if (!_hasExecuted) + { + if (!TryGetResourceId(out var resourceId)) + { + throw new InvalidOperationException(CosmosStrings.ResourceIdMissing); + } + + if (!TryGetPartitionId(out var partitionKey)) + { + throw new InvalidOperationException(CosmosStrings.ParitionKeyMissing); + } + + _item = CosmosClient.ExecuteReadItem( + ContainerId, + partitionKey, + resourceId); + + var hasNext = !(_item is null); + + Current + = hasNext + ? Shaper(CosmosQueryContext, _item) + : default; + + _hasExecuted = true; + + return hasNext; + } + + return false; + } + } + catch (Exception exception) + { + Logger.QueryIterationFailed(ContextType, exception); + + throw; + } + } + + public void Dispose() + { + _item = null; + _hasExecuted = false; + } + + public void Reset() => throw new NotImplementedException(); + } + + private sealed class AsyncEnumerator : ReadItemBase, IAsyncEnumerator + { + private JObject _item; + private readonly CancellationToken _cancellationToken; + private bool _hasExecuted; + + public AsyncEnumerator( + ReadItemQueryingEnumerable readItemEnumerable, + CancellationToken cancellationToken) : base(readItemEnumerable) + { + _cancellationToken = cancellationToken; + } + + public async ValueTask MoveNextAsync() + { + try + { + using (CosmosQueryContext.ConcurrencyDetector.EnterCriticalSection()) + { + if (!_hasExecuted) + { + + if (!TryGetResourceId(out var resourceId)) + { + throw new InvalidOperationException(CosmosStrings.ResourceIdMissing); + } + + if (!TryGetPartitionId(out var partitionKey)) + { + throw new InvalidOperationException(CosmosStrings.ParitionKeyMissing); + } + + _item = await CosmosClient.ExecuteReadItemAsync( + ContainerId, + partitionKey, + resourceId, + _cancellationToken); + + var hasNext = !(_item is null); + + Current + = hasNext + ? Shaper(CosmosQueryContext, _item) + : default; + + _hasExecuted = true; + + return hasNext; + } + + return false; + } + } + catch (Exception exception) + { + Logger.QueryIterationFailed(ContextType, exception); + + throw; + } + } + + public ValueTask DisposeAsync() + { + _item = null; + _hasExecuted = false; + return default; + } + } + + private abstract class ReadItemBase + { + private readonly IStateManager _stateManager; + private readonly ReadItemExpression _readItemExpression; + private readonly IEntityType _entityType; + + protected readonly CosmosQueryContext CosmosQueryContext; + protected readonly CosmosClientWrapper CosmosClient; + protected readonly string ContainerId; + protected readonly Func Shaper; + protected readonly Type ContextType; + protected readonly IDiagnosticsLogger Logger; + + public T Current { get; protected set; } + + protected ReadItemBase( + ReadItemQueryingEnumerable readItemEnumerable) + { +#pragma warning disable EF1001 + _stateManager = readItemEnumerable._cosmosQueryContext.StateManager; + CosmosQueryContext = readItemEnumerable._cosmosQueryContext; + _readItemExpression = readItemEnumerable._readItemExpression; + _entityType = readItemEnumerable._readItemExpression.EntityType; + CosmosClient = readItemEnumerable._cosmosQueryContext.CosmosClient; + ContainerId = _readItemExpression.Container; + Shaper = readItemEnumerable._shaper; + ContextType = readItemEnumerable._contextType; + Logger = readItemEnumerable._logger; + } + + protected bool TryGetPartitionId(out string partitionKey) + { + partitionKey = null; + + var partitionKeyProperty = _entityType.FindProperty(_entityType.GetPartitionKeyPropertyName()); + + if (TryGetParameterValue(partitionKeyProperty, out var value)) + { + partitionKey = GetString(partitionKeyProperty, value); + + return !string.IsNullOrEmpty(partitionKey); + } + + return false; + } + + protected bool TryGetResourceId(out string resourceId) + { + var resourceIdProperty = _entityType.GetProperties() + .FirstOrDefault(p => p.GetJsonPropertyName() == StoreKeyConvention.IdPropertyName); + + if (TryGetParameterValue(resourceIdProperty, out var value)) + { + resourceId = GetString(resourceIdProperty, value); + + if (string.IsNullOrEmpty(resourceId)) + { + throw new InvalidOperationException(CosmosStrings.InvalidResourceId); + } + + return true; + } + + if (TryGenerateResourceIdFromKeys(out var generatedValue)) + { + resourceId = GetString(resourceIdProperty, generatedValue); + + return true; + } + + resourceId = null; + return false; + } + + private bool TryGenerateResourceIdFromKeys(out object value) + { + var entityEntry = Activator.CreateInstance(_entityType.ClrType); + + var entityProperties = entityEntry.GetType().GetProperties(); + +#pragma warning disable EF1001 + var internalEntityEntry = new InternalEntityEntryFactory().Create(_stateManager, _entityType, entityEntry); + + foreach (var entityProperty in entityProperties) + { + var property = _entityType.FindProperty(entityProperty.Name); + + if (TryGetParameterValue(property, out var parameterValue)) + { + internalEntityEntry[property] = parameterValue; + } + } + +#pragma warning disable EF1001 + var entry = new EntityEntry(internalEntityEntry) { State = EntityState.Added }; + + value = entry.Properties + .FirstOrDefault( + propertyEntry => propertyEntry.Metadata.GetJsonPropertyName() == StoreKeyConvention.IdPropertyName) + .CurrentValue; + + entry.State = EntityState.Detached; + + return !(value is null); + } + + private bool TryGetParameterValue(IProperty property, out object value) + { + value = null; + return _readItemExpression.PropertyParameters.TryGetValue(property, out var parameterName) + && CosmosQueryContext.ParameterValues.TryGetValue(parameterName, out value); + } + + private static string GetString(IProperty property, object value) + { + var converter = property.GetTypeMapping().Converter; + + return converter is null + ? (string)value + : (string)converter.ConvertToProvider(value); + } + } + } + } +} diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs index 510fd2414dc..e3c1c43d1a4 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs @@ -23,8 +23,8 @@ public partial class CosmosShapedQueryCompilingExpressionVisitor : ShapedQueryCo private readonly IQuerySqlGeneratorFactory _querySqlGeneratorFactory; private readonly Type _contextType; private readonly IDiagnosticsLogger _logger; - private readonly string _partitionKey; - + private readonly string _partitionKeyFromExtension; + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -42,7 +42,7 @@ public CosmosShapedQueryCompilingExpressionVisitor( _querySqlGeneratorFactory = querySqlGeneratorFactory; _contextType = cosmosQueryCompilationContext.ContextType; _logger = cosmosQueryCompilationContext.Logger; - _partitionKey = cosmosQueryCompilationContext.PartitionKey; + _partitionKeyFromExtension = cosmosQueryCompilationContext.PartitionKeyFromExtension; } /// @@ -55,32 +55,63 @@ protected override Expression VisitShapedQuery(ShapedQueryExpression shapedQuery { Check.NotNull(shapedQueryExpression, nameof(shapedQueryExpression)); - var selectExpression = (SelectExpression)shapedQueryExpression.QueryExpression; - selectExpression.ApplyProjection(); var jObjectParameter = Expression.Parameter(typeof(JObject), "jObject"); var shaperBody = shapedQueryExpression.ShaperExpression; - shaperBody = new JObjectInjectingExpressionVisitor() - .Visit(shaperBody); + shaperBody = new JObjectInjectingExpressionVisitor().Visit(shaperBody); shaperBody = InjectEntityMaterializers(shaperBody); - shaperBody = new CosmosProjectionBindingRemovingExpressionVisitor(selectExpression, jObjectParameter, IsTracking) - .Visit(shaperBody); - var shaperLambda = Expression.Lambda( - shaperBody, - QueryCompilationContext.QueryContextParameter, - jObjectParameter); + switch (shapedQueryExpression.QueryExpression) + { + case SelectExpression selectExpression: + + selectExpression.ApplyProjection(); + + shaperBody = new CosmosProjectionBindingRemovingExpressionVisitor(selectExpression, jObjectParameter, IsTracking) + .Visit(shaperBody); + + var shaperLambda = Expression.Lambda( + shaperBody, + QueryCompilationContext.QueryContextParameter, + jObjectParameter); + + return Expression.New( + typeof(QueryingEnumerable<>).MakeGenericType(shaperLambda.ReturnType).GetConstructors()[0], + Expression.Convert( + QueryCompilationContext.QueryContextParameter, + typeof(CosmosQueryContext)), + Expression.Constant(_sqlExpressionFactory), + Expression.Constant(_querySqlGeneratorFactory), + Expression.Constant(selectExpression), + Expression.Constant(shaperLambda.Compile()), + Expression.Constant(_contextType), + Expression.Constant(_partitionKeyFromExtension, typeof(string)), + Expression.Constant(_logger)); + + case ReadItemExpression readItemExpression: + + shaperBody = + new CosmosProjectionBindingRemovingReadItemExpressionVisitor(readItemExpression, jObjectParameter, IsTracking) + .Visit(shaperBody); + + var shaperReadItemLambda = Expression.Lambda( + shaperBody, + QueryCompilationContext.QueryContextParameter, + jObjectParameter); - return Expression.New( - typeof(QueryingEnumerable<>).MakeGenericType(shaperLambda.ReturnType).GetConstructors()[0], - Expression.Convert(QueryCompilationContext.QueryContextParameter, typeof(CosmosQueryContext)), - Expression.Constant(_sqlExpressionFactory), - Expression.Constant(_querySqlGeneratorFactory), - Expression.Constant(selectExpression), - Expression.Constant(shaperLambda.Compile()), - Expression.Constant(_contextType), - Expression.Constant(_partitionKey, typeof(string)), - Expression.Constant(_logger)); + return Expression.New( + typeof(ReadItemQueryingEnumerable<>).MakeGenericType(shaperReadItemLambda.ReturnType).GetConstructors()[0], + Expression.Convert( + QueryCompilationContext.QueryContextParameter, + typeof(CosmosQueryContext)), + Expression.Constant(readItemExpression), + Expression.Constant(shaperReadItemLambda.Compile()), + Expression.Constant(_contextType), + Expression.Constant(_logger)); + + default: + throw new NotImplementedException(); + } } } } diff --git a/src/EFCore.Cosmos/Query/Internal/ReadItemExpression.cs b/src/EFCore.Cosmos/Query/Internal/ReadItemExpression.cs new file mode 100644 index 00000000000..96639ae3823 --- /dev/null +++ b/src/EFCore.Cosmos/Query/Internal/ReadItemExpression.cs @@ -0,0 +1,98 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Utilities; +using Newtonsoft.Json.Linq; + +namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public class ReadItemExpression : Expression + { + private const string RootAlias = "c"; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Type Type => typeof(JObject); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override ExpressionType NodeType => ExpressionType.Extension; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual string Container { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual ProjectionExpression ProjectionExpression { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual IEntityType EntityType { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual IDictionary PropertyParameters { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public ReadItemExpression( + [NotNull] IEntityType entityType, + [NotNull] IDictionary propertyParameters) + { + Check.NotNull(entityType, nameof(entityType)); + Check.NotNull(propertyParameters, nameof(propertyParameters)); + + Container = entityType.GetContainer(); + + ProjectionExpression = new ProjectionExpression( + new EntityProjectionExpression( + entityType, + new RootReferenceExpression(entityType, RootAlias)), + RootAlias); + + EntityType = entityType; + + PropertyParameters = propertyParameters; + } + } +} diff --git a/src/EFCore.Cosmos/Storage/Internal/CosmosClientWrapper.cs b/src/EFCore.Cosmos/Storage/Internal/CosmosClientWrapper.cs index 3d2e8a4d490..e4c66aa5a79 100644 --- a/src/EFCore.Cosmos/Storage/Internal/CosmosClientWrapper.cs +++ b/src/EFCore.Cosmos/Storage/Internal/CosmosClientWrapper.cs @@ -5,6 +5,7 @@ using System.Collections; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Net; using System.Runtime.CompilerServices; using System.Text; @@ -16,8 +17,6 @@ using Microsoft.EntityFrameworkCore.Cosmos.Infrastructure.Internal; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; -using Microsoft.EntityFrameworkCore.Metadata; -using Microsoft.EntityFrameworkCore.Metadata.Conventions; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Update; using Microsoft.EntityFrameworkCore.Utilities; @@ -467,20 +466,6 @@ private static void ProcessResponse(ResponseMessage response, IUpdateEntry entry { entry.SetStoreGeneratedValue(etagProperty, response.Headers.ETag); } - - var jObjectProperty = entry.EntityType.FindProperty(StoreKeyConvention.JObjectPropertyName); - if (jObjectProperty != null - && jObjectProperty.ValueGenerated == ValueGenerated.OnAddOrUpdate - && response.Content != null) - { - using var responseStream = response.Content; - using var reader = new StreamReader(responseStream); - using var jsonReader = new JsonTextReader(reader); - - var createdDocument = new JsonSerializer().Deserialize(jsonReader); - - entry.SetStoreGeneratedValue(jObjectProperty, createdDocument); - } } /// @@ -515,6 +500,58 @@ public virtual IAsyncEnumerable ExecuteSqlQueryAsync( return new DocumentAsyncEnumerable(this, containerId, partitionKey, query); } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual JObject ExecuteReadItem( + [NotNull] string containerId, + [NotNull] string partitionKey, + [NotNull] string resourceId) + { + _commandLogger.ExecutingReadItem(partitionKey, resourceId); + + var responseMessage = CreateSingleItemQuery( + containerId, partitionKey, resourceId).GetAwaiter().GetResult(); + + return JObjectFromReadItemResponseMessage(responseMessage); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + internal virtual async Task ExecuteReadItemAsync( + [NotNull] string containerId, + [NotNull] string partitionKey, + [NotNull] string resourceId, + CancellationToken cancellationToken = default) + { + _commandLogger.ExecutingReadItem(partitionKey, resourceId); + + var responseMessage = await CreateSingleItemQuery( + containerId, partitionKey, resourceId, cancellationToken); + + return JObjectFromReadItemResponseMessage(responseMessage); + } + + private static JObject JObjectFromReadItemResponseMessage(ResponseMessage responseMessage) + { + responseMessage.EnsureSuccessStatusCode(); + + var responseStream = responseMessage.Content; + var reader = new StreamReader(responseStream); + var jsonReader = new JsonTextReader(reader); + + var jObject = new JsonSerializer().Deserialize(jsonReader); + + return new JObject(new JProperty("c", jObject)); + } + private FeedIterator CreateQuery( string containerId, string partitionKey, @@ -522,10 +559,10 @@ private FeedIterator CreateQuery( { var container = Client.GetDatabase(_databaseId).GetContainer(containerId); var queryDefinition = new QueryDefinition(query.Query); - foreach (var parameter in query.Parameters) - { - queryDefinition = queryDefinition.WithParameter(parameter.Name, parameter.Value); - } + + queryDefinition = query.Parameters + .Aggregate(queryDefinition, + (current, parameter) => current.WithParameter(parameter.Name, parameter.Value)); if (string.IsNullOrEmpty(partitionKey)) { @@ -535,7 +572,59 @@ private FeedIterator CreateQuery( var queryRequestOptions = new QueryRequestOptions { PartitionKey = new PartitionKey(partitionKey) }; return container.GetItemQueryStreamIterator(queryDefinition, requestOptions: queryRequestOptions); + } + private async Task CreateSingleItemQuery( + string containerId, + string partitionKey, + string resourceId, + CancellationToken cancellationToken = default) + { + var container = Client.GetDatabase(_databaseId).GetContainer(containerId); + + return await container.ReadItemStreamAsync( + resourceId, + string.IsNullOrEmpty(partitionKey) ? PartitionKey.None : new PartitionKey(partitionKey), + cancellationToken: cancellationToken); + } + + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static JsonTextReader CreateJsonReader(TextReader reader) + { + var jsonReader = new JsonTextReader(reader); + + while (jsonReader.Read()) + { + if (jsonReader.TokenType == JsonToken.StartObject) + { + while (jsonReader.Read()) + { + if (jsonReader.TokenType == JsonToken.StartArray) + { + return jsonReader; + } + } + } + } + + return jsonReader; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool TryReadJObject(JsonTextReader jsonReader, out JObject jObject) + { + jObject = null; + + while (jsonReader.Read()) + { + if (jsonReader.TokenType == JsonToken.StartObject) + { + jObject = new JsonSerializer().Deserialize(jsonReader); + return true; + } + } + return false; } private sealed class DocumentEnumerable : IEnumerable @@ -563,24 +652,25 @@ public DocumentEnumerable( private sealed class Enumerator : IEnumerator { - private FeedIterator _query; + private readonly CosmosClientWrapper _cosmosClientWrapper; + private readonly string _containerId; + private readonly string _partitionKey; + private readonly CosmosSqlQuery _cosmosSqlQuery; + private ResponseMessage _responseMessage; private Stream _responseStream; private StreamReader _reader; private JsonTextReader _jsonReader; - private readonly CosmosClientWrapper _cosmosClient; - private readonly string _containerId; - private readonly string _partitionKey; - private readonly CosmosSqlQuery _cosmosSqlQuery; - + + private FeedIterator _query; + public Enumerator(DocumentEnumerable documentEnumerable) { - _cosmosClient = documentEnumerable._cosmosClient; + _cosmosClientWrapper = documentEnumerable._cosmosClient; _containerId = documentEnumerable._containerId; _partitionKey = documentEnumerable._partitionKey; _cosmosSqlQuery = documentEnumerable._cosmosSqlQuery; } - public JObject Current { get; private set; } object IEnumerator.Current => Current; @@ -588,12 +678,9 @@ public Enumerator(DocumentEnumerable documentEnumerable) [MethodImpl(MethodImplOptions.AggressiveInlining)] public bool MoveNext() { - if (_jsonReader == null) + if (_jsonReader is null) { - if (_query == null) - { - _query = _cosmosClient.CreateQuery(_containerId, _partitionKey, _cosmosSqlQuery); - } + _query ??= _cosmosClientWrapper.CreateQuery(_containerId, _partitionKey, _cosmosSqlQuery); if (!_query.HasMoreResults) { @@ -606,46 +693,21 @@ public bool MoveNext() _responseStream = _responseMessage.Content; _reader = new StreamReader(_responseStream); - _jsonReader = new JsonTextReader(_reader); - - while (_jsonReader.Read()) - { - if (_jsonReader.TokenType == JsonToken.StartObject) - { - while (_jsonReader.Read()) - { - if (_jsonReader.TokenType == JsonToken.StartArray) - { - goto ObjectFound; - } - } - } - } - - ObjectFound: ; + _jsonReader = CreateJsonReader(_reader); } - while (_jsonReader.Read()) + if (TryReadJObject(_jsonReader, out var jObject)) { - if (_jsonReader.TokenType == JsonToken.StartObject) - { - Current = new JsonSerializer().Deserialize(_jsonReader); - - return true; - } + Current = jObject; + return true; } - _jsonReader.Close(); - _jsonReader = null; - _reader.Dispose(); - _reader = null; - _responseStream.Dispose(); - _responseStream = null; + ResetRead(); return MoveNext(); } - public void Dispose() + private void ResetRead() { _jsonReader?.Close(); _jsonReader = null; @@ -653,6 +715,12 @@ public void Dispose() _reader = null; _responseStream?.Dispose(); _responseStream = null; + } + + public void Dispose() + { + ResetRead(); + _responseMessage?.Dispose(); _responseMessage = null; } @@ -665,7 +733,7 @@ private sealed class DocumentAsyncEnumerable : IAsyncEnumerable { private readonly CosmosClientWrapper _cosmosClient; private readonly string _containerId; - private readonly string _patitionKey; + private readonly string _partitionKey; private readonly CosmosSqlQuery _cosmosSqlQuery; public DocumentAsyncEnumerable( @@ -676,7 +744,7 @@ public DocumentAsyncEnumerable( { _cosmosClient = cosmosClient; _containerId = containerId; - _patitionKey = partitionKey; + _partitionKey = partitionKey; _cosmosSqlQuery = cosmosSqlQuery; } @@ -685,39 +753,38 @@ public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellati private sealed class AsyncEnumerator : IAsyncEnumerator { - private FeedIterator _query; - private ResponseMessage _responseMessage; - private Stream _responseStream; - private StreamReader _reader; - private JsonTextReader _jsonReader; - private readonly CosmosClientWrapper _cosmosClient; + private readonly CosmosClientWrapper _cosmosClientWrapper; private readonly string _containerId; private readonly string _partitionKey; private readonly CosmosSqlQuery _cosmosSqlQuery; private readonly CancellationToken _cancellationToken; + private ResponseMessage _responseMessage; + private Stream _responseStream; + private StreamReader _reader; + private JsonTextReader _jsonReader; + + private FeedIterator _query; + + public JObject Current { get; private set; } + public AsyncEnumerator(DocumentAsyncEnumerable documentEnumerable, CancellationToken cancellationToken) { - _cosmosClient = documentEnumerable._cosmosClient; + _cosmosClientWrapper = documentEnumerable._cosmosClient; _containerId = documentEnumerable._containerId; - _partitionKey = documentEnumerable._patitionKey; + _partitionKey = documentEnumerable._partitionKey; _cosmosSqlQuery = documentEnumerable._cosmosSqlQuery; _cancellationToken = cancellationToken; } - public JObject Current { get; private set; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] public async ValueTask MoveNextAsync() { _cancellationToken.ThrowIfCancellationRequested(); - if (_jsonReader == null) + if (_jsonReader is null) { - if (_query == null) - { - _query = _cosmosClient.CreateQuery(_containerId, _partitionKey, _cosmosSqlQuery); - } + _query ??= _cosmosClientWrapper.CreateQuery(_containerId, _partitionKey, _cosmosSqlQuery); if (!_query.HasMoreResults) { @@ -730,47 +797,34 @@ public async ValueTask MoveNextAsync() _responseStream = _responseMessage.Content; _reader = new StreamReader(_responseStream); - _jsonReader = new JsonTextReader(_reader); - - while (_jsonReader.Read()) - { - if (_jsonReader.TokenType == JsonToken.StartObject) - { - while (_jsonReader.Read()) - { - if (_jsonReader.TokenType == JsonToken.StartArray) - { - goto ObjectFound; - } - } - } - } - - ObjectFound: ; + _jsonReader = CreateJsonReader(_reader); } - while (_jsonReader.Read()) + if (TryReadJObject(_jsonReader, out var jObject)) { - if (_jsonReader.TokenType == JsonToken.StartObject) - { - Current = new JsonSerializer().Deserialize(_jsonReader); - return true; - } + Current = jObject; + return true; } - await DisposeAsync(); + await ResetReadAsync(); return await MoveNextAsync(); } - public async ValueTask DisposeAsync() + private async Task ResetReadAsync() { _jsonReader?.Close(); _jsonReader = null; await _reader.DisposeAsyncIfAvailable(); _reader = null; - await _responseStream.DisposeAsync(); + await _responseStream.DisposeAsyncIfAvailable(); _responseStream = null; + } + + public async ValueTask DisposeAsync() + { + await ResetReadAsync(); + await _responseMessage.DisposeAsyncIfAvailable(); _responseMessage = null; } diff --git a/test/EFCore.Cosmos.FunctionalTests/EndToEndCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/EndToEndCosmosTest.cs index 0ea79252113..30b2f2cb4b2 100644 --- a/test/EFCore.Cosmos.FunctionalTests/EndToEndCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/EndToEndCosmosTest.cs @@ -182,274 +182,449 @@ public async Task Can_add_update_delete_detached_entity_end_to_end_async() } } - [ConditionalFact] - public void Can_add_update_untracked_properties() + private class Customer { - var options = Fixture.CreateOptions(); - - var customer = new Customer { Id = 42, Name = "Theon" }; - - using (var context = new CustomerContext(options)) - { - context.Database.EnsureCreated(); - - var entry = context.Add(customer); - - context.SaveChanges(); + public int Id { get; set; } + public string Name { get; set; } + public int PartitionKey { get; set; } + } - var document = entry.Property("__jObject").CurrentValue; - Assert.NotNull(document); - Assert.Equal("Theon", document["Name"]); + private class Customer_WithResourceId + { + public int Id { get; set; } + public string id { get; set; } + public string Name { get; set; } + public int PartitionKey { get; set; } + } - context.Remove(customer); + private class Customer_NoPartitionKey + { + public int Id { get; set; } + public string Name { get; set; } + } - context.SaveChanges(); + private class CustomerContext : DbContext + { + public CustomerContext(DbContextOptions dbContextOptions) + : base(dbContextOptions) + { } - using (var context = new CustomerContext(options)) + protected override void OnModelCreating(ModelBuilder modelBuilder) { - Assert.Empty(context.Set().ToList()); + modelBuilder.Entity(); + } + } - var entry = context.Add(customer); + [ConditionalFact] + public async Task Can_add_update_delete_end_to_end_with_partition_key() + { + var options = Fixture.CreateOptions(); - entry.Property("__jObject").CurrentValue = new JObject - { - ["key1"] = "value1" - }; + var customer = new Customer + { + Id = 42, + Name = "Theon", + PartitionKey = 1 + }; - context.SaveChanges(); + using (var context = new PartitionKeyContext(options)) + { + await context.Database.EnsureCreatedAsync(); - var document = entry.Property("__jObject").CurrentValue; - Assert.NotNull(document); - Assert.Equal("Theon", document["Name"]); - Assert.Equal("value1", document["key1"]); + context.Add(customer); + context.Add( + new Customer + { + Id = 42, + Name = "Theon Twin", + PartitionKey = 2 + }); - document["key2"] = "value2"; - entry.State = EntityState.Modified; - context.SaveChanges(); + await context.SaveChangesAsync(); } - using (var context = new CustomerContext(options)) + using (var context = new PartitionKeyContext(options)) { - var customerFromStore = context.Set().Single(); + var customerFromStore = await context.Set().OrderBy(c => c.PartitionKey).FirstAsync(); Assert.Equal(42, customerFromStore.Id); Assert.Equal("Theon", customerFromStore.Name); + Assert.Equal(1, customerFromStore.PartitionKey); - var entry = context.Entry(customerFromStore); - var document = entry.Property("__jObject").CurrentValue; - Assert.Equal("value1", document["key1"]); - Assert.Equal("value2", document["key2"]); - - document["key1"] = "value1.1"; customerFromStore.Name = "Theon Greyjoy"; - context.SaveChanges(); + await context.SaveChangesAsync(); } - using (var context = new CustomerContext(options)) + using (var context = new PartitionKeyContext(options)) { - var customerFromStore = context.Set().Single(); + var customerFromStore = await context.Set().OrderBy(c => c.PartitionKey).FirstAsync(); + customerFromStore.PartitionKey = 2; - Assert.Equal("Theon Greyjoy", customerFromStore.Name); + Assert.Equal( + CoreStrings.KeyReadOnly(nameof(Customer.PartitionKey), nameof(Customer)), + Assert.Throws(() => context.SaveChanges()).Message); + } - var entry = context.Entry(customerFromStore); - var document = entry.Property("__jObject").CurrentValue; - Assert.Equal("value1.1", document["key1"]); - Assert.Equal("value2", document["key2"]); + using (var context = new PartitionKeyContext(options)) + { + var customerFromStore = await context.Set().OrderBy(c => c.PartitionKey).FirstAsync(); + + Assert.Equal(42, customerFromStore.Id); + Assert.Equal("Theon Greyjoy", customerFromStore.Name); + Assert.Equal(1, customerFromStore.PartitionKey); context.Remove(customerFromStore); - context.SaveChanges(); + context.Remove(await context.Set().OrderBy(c => c.PartitionKey).LastAsync()); + + await context.SaveChangesAsync(); } - using (var context = new CustomerContext(options)) + using (var context = new PartitionKeyContext(options)) { - Assert.Empty(context.Set().ToList()); + Assert.Empty(await context.Set().ToListAsync()); } } [ConditionalFact] - public async Task Can_add_update_untracked_properties_async() + public async Task Can_add_update_delete_end_to_end_with_withpartitionkey_extension() { var options = Fixture.CreateOptions(); + const int pk1 = 1; + const int pk2 = 2; - var customer = new Customer { Id = 42, Name = "Theon" }; + var customer = new Customer + { + Id = 42, + Name = "Theon", + PartitionKey = pk1 + }; - using (var context = new CustomerContext(options)) + using (var context = new PartitionKeyContext(options)) { await context.Database.EnsureCreatedAsync(); - var entry = context.Add(customer); + context.Add(customer); + context.Add( + new Customer + { + Id = 42, + Name = "Theon Twin", + PartitionKey = pk2 + }); await context.SaveChangesAsync(); + } - var document = entry.Property("__jObject").CurrentValue; - Assert.NotNull(document); - Assert.Equal("Theon", document["Name"]); + using (var context = new PartitionKeyContext(options)) + { + var customerFromStore = await context.Set() + .WithPartitionKey(partitionKey: pk1.ToString()) + .FirstAsync(); - context.Remove(customer); + Assert.Equal(42, customerFromStore.Id); + Assert.Equal("Theon", customerFromStore.Name); + Assert.Equal(pk1, customerFromStore.PartitionKey); + + customerFromStore.Name = "Theon Greyjoy"; await context.SaveChangesAsync(); } - using (var context = new CustomerContext(options)) + using (var context = new PartitionKeyContext(options)) { - Assert.Empty(await context.Set().ToListAsync()); + var customerFromStore = await context.Set().WithPartitionKey(partitionKey:pk1.ToString()).FirstAsync(); - var entry = context.Add(customer); + customerFromStore.PartitionKey = pk2; - entry.Property("__jObject").CurrentValue = new JObject - { - ["key1"] = "value1" - }; + Assert.Equal( + CoreStrings.KeyReadOnly(nameof(Customer.PartitionKey), nameof(Customer)), + Assert.Throws(() => context.SaveChanges()).Message); + } + + using (var context = new PartitionKeyContext(options)) + { + var customerFromStore = await context.Set() + .WithPartitionKey(partitionKey: pk1.ToString()) + .FirstAsync(); + + Assert.Equal(42, customerFromStore.Id); + Assert.Equal("Theon Greyjoy", customerFromStore.Name); + Assert.Equal(pk1, customerFromStore.PartitionKey); + + context.Remove(customerFromStore); + + context.Remove(await context.Set() + .WithPartitionKey(pk2.ToString()) + .LastAsync()); await context.SaveChangesAsync(); + } + + using (var context = new PartitionKeyContext(options)) + { + Assert.Empty(await context.Set() + .WithPartitionKey(partitionKey: pk2.ToString()) + .ToListAsync()); + } + } - var document = entry.Property("__jObject").CurrentValue; - Assert.NotNull(document); - Assert.Equal("Theon", document["Name"]); - Assert.Equal("value1", document["key1"]); + [ConditionalFact] + public async Task Can_read_with_find_with_partitionkey_async() + { + var options = Fixture.CreateOptions(); + const int pk1 = 1; + const int pk2 = 2; + + var customer = new Customer_WithResourceId + { + Id = 42, + Name = "Theon", + PartitionKey = pk1 + }; + + await using (var context = new PartitionKeyContext_WithResourceId(options)) + { + await context.Database.EnsureCreatedAsync(); + + context.Add(customer); + context.Add( + new Customer_WithResourceId + { + Id = 42, + Name = "Theon Twin", + PartitionKey = pk2 + }); - document["key2"] = "value2"; - entry.State = EntityState.Modified; await context.SaveChangesAsync(); } - using (var context = new CustomerContext(options)) + await using (var context = new PartitionKeyContext_WithResourceId(options)) { - var customerFromStore = await context.Set().SingleAsync(); + var customerFromStore = await context.Set() + .FindAsync( + pk1, 42); Assert.Equal(42, customerFromStore.Id); Assert.Equal("Theon", customerFromStore.Name); + Assert.Equal(pk1, customerFromStore.PartitionKey); - var entry = context.Entry(customerFromStore); - var document = entry.Property("__jObject").CurrentValue; - Assert.Equal("value1", document["key1"]); - Assert.Equal("value2", document["key2"]); - - document["key1"] = "value1.1"; customerFromStore.Name = "Theon Greyjoy"; await context.SaveChangesAsync(); } - using (var context = new CustomerContext(options)) + await using (var context = new PartitionKeyContext_WithResourceId(options)) { - var customerFromStore = await context.Set().SingleAsync(); + var customerFromStore = await context.Set() + .FindAsync( + pk1, 42); - Assert.Equal("Theon Greyjoy", customerFromStore.Name); + customerFromStore.PartitionKey = pk2; - var entry = context.Entry(customerFromStore); - var document = entry.Property("__jObject").CurrentValue; - Assert.Equal("value1.1", document["key1"]); - Assert.Equal("value2", document["key2"]); + Assert.Equal( + CoreStrings.KeyReadOnly(nameof(Customer_WithResourceId.PartitionKey), nameof(Customer_WithResourceId)), + Assert.Throws(() => context.SaveChanges()).Message); + } + + await using (var context = new PartitionKeyContext_WithResourceId(options)) + { + var customerFromStore = await context.Set() + .WithPartitionKey(partitionKey: pk1.ToString()) + .FirstAsync(); + + Assert.Equal(42, customerFromStore.Id); + Assert.Equal("Theon Greyjoy", customerFromStore.Name); + Assert.Equal(pk1, customerFromStore.PartitionKey); context.Remove(customerFromStore); + context.Remove(await context.Set() + .WithPartitionKey(pk2.ToString()) + .LastAsync()); + await context.SaveChangesAsync(); } - using (var context = new CustomerContext(options)) + await using (var context = new PartitionKeyContext_WithResourceId(options)) { - Assert.Empty(await context.Set().ToListAsync()); + Assert.Empty(await context.Set() + .WithPartitionKey(partitionKey: pk2.ToString()) + .ToListAsync()); } } - private class Customer + [ConditionalFact] + public async Task Can_read_with_find_with_partitionkey_and_value_generator_async() { - public int Id { get; set; } - public string Name { get; set; } - public int PartitionKey { get; set; } - } + var options = Fixture.CreateOptions(); + const int pk1 = 1; + const int pk2 = 2; - private class CustomerContext : DbContext - { - public CustomerContext(DbContextOptions dbContextOptions) - : base(dbContextOptions) + var customer = new Customer + { + Id = 42, + Name = "Theon", + PartitionKey = pk1 + }; + + await using (var context = new PartitionKeyContext_WithCustomValueGenerator(options)) + { + await context.Database.EnsureCreatedAsync(); + + context.Add(customer); + context.Add( + new Customer + { + Id = 42, + Name = "Theon Twin", + PartitionKey = pk2 + }); + + await context.SaveChangesAsync(); + } + + await using (var context = new PartitionKeyContext_WithCustomValueGenerator(options)) { + var customerFromStore = await context.Set() + .FindAsync(pk1, 42); + + Assert.Equal(42, customerFromStore.Id); + Assert.Equal("Theon", customerFromStore.Name); + Assert.Equal(pk1, customerFromStore.PartitionKey); + + customerFromStore.Name = "Theon Greyjoy"; + + await context.SaveChangesAsync(); } - protected override void OnModelCreating(ModelBuilder modelBuilder) + + await using (var context = new PartitionKeyContext_WithCustomValueGenerator(options)) { - modelBuilder.Entity(); + var customerFromStore = await context.Set() + .FindAsync(pk1, 42); + + customerFromStore.PartitionKey = pk2; + + Assert.Equal( + CoreStrings.KeyReadOnly(nameof(Customer.PartitionKey), nameof(Customer)), + Assert.Throws(() => context.SaveChanges()).Message); + } + + await using (var context = new PartitionKeyContext_WithCustomValueGenerator(options)) + { + var customerFromStore = await context.Set() + .WithPartitionKey(partitionKey: pk1.ToString()) + .FirstAsync(); + + Assert.Equal(42, customerFromStore.Id); + Assert.Equal("Theon Greyjoy", customerFromStore.Name); + Assert.Equal(pk1, customerFromStore.PartitionKey); + + context.Remove(customerFromStore); + + context.Remove(await context.Set() + .WithPartitionKey(pk2.ToString()) + .LastAsync()); + + await context.SaveChangesAsync(); + } + + await using (var context = new PartitionKeyContext_WithCustomValueGenerator(options)) + { + Assert.Empty(await context.Set() + .WithPartitionKey(partitionKey: pk2.ToString()) + .ToListAsync()); } } [ConditionalFact] - public async Task Can_add_update_delete_end_to_end_with_partition_key() + public void Can_read_with_find_with_partitionkey() { var options = Fixture.CreateOptions(); + const int pk1 = 1; + const int pk2 = 2; - var customer = new Customer + var customer = new Customer_WithResourceId { Id = 42, Name = "Theon", - PartitionKey = 1 + PartitionKey = pk1 }; - using (var context = new PartitionKeyContext(options)) + using (var context = new PartitionKeyContext_WithResourceId(options)) { - await context.Database.EnsureCreatedAsync(); + context.Database.EnsureCreated(); context.Add(customer); context.Add( - new Customer + new Customer_WithResourceId { Id = 42, Name = "Theon Twin", - PartitionKey = 2 + PartitionKey = pk2 }); - await context.SaveChangesAsync(); + context.SaveChanges(); } - using (var context = new PartitionKeyContext(options)) + using (var context = new PartitionKeyContext_WithResourceId(options)) { - var customerFromStore = await context.Set().OrderBy(c => c.PartitionKey).FirstAsync(); + var customerFromStore = context.Set() + .Find(pk1, 42); Assert.Equal(42, customerFromStore.Id); Assert.Equal("Theon", customerFromStore.Name); - Assert.Equal(1, customerFromStore.PartitionKey); + Assert.Equal(pk1, customerFromStore.PartitionKey); customerFromStore.Name = "Theon Greyjoy"; - await context.SaveChangesAsync(); + context.SaveChanges(); } - using (var context = new PartitionKeyContext(options)) + using (var context = new PartitionKeyContext_WithResourceId(options)) { - var customerFromStore = await context.Set().OrderBy(c => c.PartitionKey).FirstAsync(); - customerFromStore.PartitionKey = 2; + var customerFromStore = context.Set() + .Find(pk1, 42); + + customerFromStore.PartitionKey = pk2; Assert.Equal( - CoreStrings.KeyReadOnly(nameof(Customer.PartitionKey), nameof(Customer)), + CoreStrings.KeyReadOnly(nameof(Customer_WithResourceId.PartitionKey), nameof(Customer_WithResourceId)), Assert.Throws(() => context.SaveChanges()).Message); } - using (var context = new PartitionKeyContext(options)) + using (var context = new PartitionKeyContext_WithResourceId(options)) { - var customerFromStore = await context.Set().OrderBy(c => c.PartitionKey).FirstAsync(); + var customerFromStore = context.Set() + .WithPartitionKey(partitionKey: pk1.ToString()) + .First(); Assert.Equal(42, customerFromStore.Id); Assert.Equal("Theon Greyjoy", customerFromStore.Name); - Assert.Equal(1, customerFromStore.PartitionKey); + Assert.Equal(pk1, customerFromStore.PartitionKey); context.Remove(customerFromStore); - context.Remove(await context.Set().OrderBy(c => c.PartitionKey).LastAsync()); + context.Remove(context.Set() + .WithPartitionKey(pk2.ToString()) + .Last()); - await context.SaveChangesAsync(); + context.SaveChanges(); } - using (var context = new PartitionKeyContext(options)) + using (var context = new PartitionKeyContext_WithResourceId(options)) { - Assert.Empty(await context.Set().ToListAsync()); + Assert.Empty(context.Set() + .WithPartitionKey(partitionKey: pk2.ToString()) + .ToList()); } } [ConditionalFact] - public async Task Can_add_update_delete_end_to_end_with_withpartitionkey_extension() + public void Can_read_with_find_with_partitionkey_and_value_generator() { var options = Fixture.CreateOptions(); const int pk1 = 1; @@ -462,9 +637,9 @@ public async Task Can_add_update_delete_end_to_end_with_withpartitionkey_extensi PartitionKey = pk1 }; - using (var context = new PartitionKeyContext(options)) + using (var context = new PartitionKeyContext_WithCustomValueGenerator(options)) { - await context.Database.EnsureCreatedAsync(); + context.Database.EnsureCreated(); context.Add(customer); context.Add( @@ -475,14 +650,13 @@ public async Task Can_add_update_delete_end_to_end_with_withpartitionkey_extensi PartitionKey = pk2 }); - await context.SaveChangesAsync(); + context.SaveChanges(); } - using (var context = new PartitionKeyContext(options)) + using (var context = new PartitionKeyContext_WithCustomValueGenerator(options)) { - var customerFromStore = await context.Set() - .WithPartitionKey(partitionKey:pk1.ToString()) - .FirstAsync(); + var customerFromStore = context.Set() + .Find(pk1, 42); Assert.Equal(42, customerFromStore.Id); Assert.Equal("Theon", customerFromStore.Name); @@ -490,14 +664,14 @@ public async Task Can_add_update_delete_end_to_end_with_withpartitionkey_extensi customerFromStore.Name = "Theon Greyjoy"; - await context.SaveChangesAsync(); + context.SaveChanges(); } - using (var context = new PartitionKeyContext(options)) + + using (var context = new PartitionKeyContext_WithCustomValueGenerator(options)) { - var customerFromStore = await context.Set() - .WithPartitionKey(partitionKey:pk1.ToString()) - .FirstAsync(); + var customerFromStore = context.Set() + .Find(pk1, 42); customerFromStore.PartitionKey = pk2; @@ -506,11 +680,11 @@ public async Task Can_add_update_delete_end_to_end_with_withpartitionkey_extensi Assert.Throws(() => context.SaveChanges()).Message); } - using (var context = new PartitionKeyContext(options)) + using (var context = new PartitionKeyContext_WithCustomValueGenerator(options)) { - var customerFromStore = await context.Set() + var customerFromStore = context.Set() .WithPartitionKey(partitionKey: pk1.ToString()) - .FirstAsync(); + .First(); Assert.Equal(42, customerFromStore.Id); Assert.Equal("Theon Greyjoy", customerFromStore.Name); @@ -518,21 +692,89 @@ public async Task Can_add_update_delete_end_to_end_with_withpartitionkey_extensi context.Remove(customerFromStore); - context.Remove(await context.Set() + context.Remove(context.Set() .WithPartitionKey(pk2.ToString()) - .LastAsync()); + .Last()); - await context.SaveChangesAsync(); + context.SaveChanges(); } - using (var context = new PartitionKeyContext(options)) + using (var context = new PartitionKeyContext_WithCustomValueGenerator(options)) { - Assert.Empty(await context.Set() + Assert.Empty(context.Set() .WithPartitionKey(partitionKey: pk2.ToString()) - .ToListAsync()); + .ToList()); } } + [ConditionalFact] + public async Task Entity_type_with_partitionkey_not_part_of_primary_keys() + { + var options = Fixture.CreateOptions(); + + var customer = new Customer + { + Id = 42, + Name = "Theon", + PartitionKey = 1 + }; + + await using (var context = new Context_No_PartitionKey_In_PK(options)) + { + await context.Database.EnsureCreatedAsync(); + + context.Add(customer); + + await context.SaveChangesAsync(); + } + + await using (var context = new Context_No_PartitionKey_In_PK(options)) + { + var customerFromStore = context.Set().Find(42); + + Assert.Equal(42, customerFromStore.Id); + Assert.Equal("Theon", customerFromStore.Name); + + context.Remove(customerFromStore); + + await context.SaveChangesAsync(); + } + } + + [ConditionalFact] + public async Task Entity_type_has_no_primiary_keys() + { + var options = Fixture.CreateOptions(); + + var customer = new Customer_NoPartitionKey + { + Id = 42, + Name = "Theon" + }; + + await using (var context = new PartitionKeyContext_EntityWithNoPartitionKey(options)) + { + await context.Database.EnsureCreatedAsync(); + + context.Add(customer); + + await context.SaveChangesAsync(); + } + + await using (var context = new PartitionKeyContext_EntityWithNoPartitionKey(options)) + { + var customerFromStore = context.Set().Find(42); + + Assert.Equal(42, customerFromStore.Id); + Assert.Equal("Theon", customerFromStore.Name); + + context.Remove(customerFromStore); + + await context.SaveChangesAsync(); + } + } + + private class PartitionKeyContext : DbContext { public PartitionKeyContext(DbContextOptions dbContextOptions) @@ -552,6 +794,92 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) } } + private class PartitionKeyContext_EntityWithNoPartitionKey : DbContext + { + public PartitionKeyContext_EntityWithNoPartitionKey(DbContextOptions dbContextOptions) + : base(dbContextOptions) + { + } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity( + cb => + { + + }); + } + } + + private class PartitionKeyContext_WithCustomValueGenerator : DbContext + { + public PartitionKeyContext_WithCustomValueGenerator(DbContextOptions dbContextOptions) + : base(dbContextOptions) + { + } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity( + cb => + { + var valueGeneratorFactory = new CustomPartitionKeyIdValueGeneratorFactory(); + + // Create Shadow property for 'id' (Cosmos Resource Id). + // and attach a value generator to it. + cb.Property("id") + .HasValueGenerator((p, e) => valueGeneratorFactory.Create(p)); + + cb.Property(c => c.Id) + .HasConversion(); + + cb.HasPartitionKey(c => c.PartitionKey); + cb.Property(c => c.PartitionKey).HasConversion(); + cb.Property(c => c.PartitionKey); + cb.HasKey(c => new { c.PartitionKey, c.Id}); + }); + } + } + + private class Context_No_PartitionKey_In_PK : DbContext + { + public Context_No_PartitionKey_In_PK(DbContextOptions dbContextOptions) + : base(dbContextOptions) + { + } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity( + cb => + { + var valueGeneratorFactory = new CustomPartitionKeyIdValueGeneratorFactory(); + + cb.Property(c => c.Id).HasConversion(); + cb.HasKey(c => new { c.Id }); + }); + } + } + + private class PartitionKeyContext_WithResourceId : DbContext + { + public PartitionKeyContext_WithResourceId(DbContextOptions dbContextOptions) + : base(dbContextOptions) + { + } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity( + cb => + { + cb.HasPartitionKey(c => c.PartitionKey); + cb.Property(c => c.PartitionKey).HasConversion(); + cb.HasKey(c => new { c.PartitionKey, c.Id }); + }); + } + } + [ConditionalFact] public async Task Can_use_detached_entities_without_discriminators() { diff --git a/test/EFCore.Cosmos.FunctionalTests/TestUtilities/CustomPartitionKeyIdGenerator.cs b/test/EFCore.Cosmos.FunctionalTests/TestUtilities/CustomPartitionKeyIdGenerator.cs new file mode 100644 index 00000000000..745b3b7d90b --- /dev/null +++ b/test/EFCore.Cosmos.FunctionalTests/TestUtilities/CustomPartitionKeyIdGenerator.cs @@ -0,0 +1,90 @@ +using System.Collections; +using System.Linq; +using System.Text; +using Microsoft.EntityFrameworkCore.ChangeTracking; +using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; +using Microsoft.EntityFrameworkCore.Metadata.Conventions; +using Microsoft.EntityFrameworkCore.ValueGeneration; + +namespace Microsoft.EntityFrameworkCore.Cosmos.TestUtilities +{ + public class CustomPartitionKeyIdGenerator : ValueGenerator + { + public override bool GeneratesTemporaryValues => false; + + public override T Next(EntityEntry entry) + { + return (T)NextValue(entry); + } + protected override object NextValue(EntityEntry entry) + { + var builder = new StringBuilder(); + var entityType = entry.Metadata; + + var pk = entityType.FindPrimaryKey(); + var discriminator = entityType.GetDiscriminatorValue(); + if (discriminator != null + && !pk.Properties.Contains(entityType.GetDiscriminatorProperty())) + { + AppendString(builder, discriminator); + builder.Append("-"); + } + + var partitionKey = entityType.GetPartitionKeyPropertyName() ?? CosmosClientWrapper.DefaultPartitionKey; + foreach (var property in pk.Properties.Where(p => p.Name != StoreKeyConvention.IdPropertyName)) + { + if (property.Name == partitionKey) + { + continue; + } + + var converter = property.GetValueConverter() + ?? property.GetTypeMapping().Converter; + + var value = entry.Property(property.Name).CurrentValue; + if (converter != null) + { + value = converter.ConvertToProvider(value); + } + + if (value is int x) + { + // We don't allow the Id to be zero for our custom generator. + if (x == 0) + { + return default; + } + } + + AppendString(builder, value); + + builder.Append("-"); + } + + builder.Remove(builder.Length - 1, 1); + + return builder.ToString(); + } + + private static void AppendString(StringBuilder builder, object propertyValue) + { + switch (propertyValue) + { + case string stringValue: + builder.Append(stringValue.Replace("-", "/-")); + return; + case IEnumerable enumerable: + foreach (var item in enumerable) + { + builder.Append(item.ToString().Replace("-", "/-")); + builder.Append("|"); + } + + return; + default: + builder.Append(propertyValue == null ? "null" : propertyValue.ToString().Replace("-", "/-")); + return; + } + } + } +} diff --git a/test/EFCore.Cosmos.FunctionalTests/TestUtilities/CustomPartitionKeyIdValueGeneratorFactory.cs b/test/EFCore.Cosmos.FunctionalTests/TestUtilities/CustomPartitionKeyIdValueGeneratorFactory.cs new file mode 100644 index 00000000000..1b312bccca2 --- /dev/null +++ b/test/EFCore.Cosmos.FunctionalTests/TestUtilities/CustomPartitionKeyIdValueGeneratorFactory.cs @@ -0,0 +1,16 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.ValueGeneration; + +namespace Microsoft.EntityFrameworkCore.Cosmos.TestUtilities +{ + class CustomPartitionKeyIdValueGeneratorFactory : ValueGeneratorFactory + { + public override EntityFrameworkCore.ValueGeneration.ValueGenerator Create(IProperty property) + { + return new CustomPartitionKeyIdGenerator(); + } + } +}