diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs index 9cb8dee976e..e1220de5c6a 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs @@ -92,11 +92,24 @@ public string ToQueryString() private sealed class Enumerator : IEnumerator { private readonly QueryingEnumerable _queryingEnumerable; + private readonly CosmosQueryContext _cosmosQueryContext; + private readonly SelectExpression _selectExpression; + private readonly Func _shaper; + private readonly Type _contextType; + private readonly string _partitionKey; + private readonly IDiagnosticsLogger _logger; + private IEnumerator _enumerator; public Enumerator(QueryingEnumerable queryingEnumerable) { _queryingEnumerable = queryingEnumerable; + _cosmosQueryContext = queryingEnumerable._cosmosQueryContext; + _shaper = queryingEnumerable._shaper; + _selectExpression = queryingEnumerable._selectExpression; + _contextType = queryingEnumerable._contextType; + _partitionKey = queryingEnumerable._partitionKey; + _logger = queryingEnumerable._logger; } public T Current { get; private set; } @@ -107,16 +120,16 @@ public bool MoveNext() { try { - using (_queryingEnumerable._cosmosQueryContext.ConcurrencyDetector.EnterCriticalSection()) + using (_cosmosQueryContext.ConcurrencyDetector.EnterCriticalSection()) { if (_enumerator == null) { var sqlQuery = _queryingEnumerable.GenerateQuery(); - _enumerator = _queryingEnumerable._cosmosQueryContext.CosmosClient + _enumerator = _cosmosQueryContext.CosmosClient .ExecuteSqlQuery( - _queryingEnumerable._selectExpression.Container, - _queryingEnumerable._partitionKey, + _selectExpression.Container, + _partitionKey, sqlQuery) .GetEnumerator(); } @@ -125,7 +138,7 @@ public bool MoveNext() Current = hasNext - ? _queryingEnumerable._shaper(_queryingEnumerable._cosmosQueryContext, _enumerator.Current) + ? _shaper(_cosmosQueryContext, _enumerator.Current) : default; return hasNext; @@ -133,7 +146,7 @@ public bool MoveNext() } catch (Exception exception) { - _queryingEnumerable._logger.QueryIterationFailed(_queryingEnumerable._contextType, exception); + _logger.QueryIterationFailed(_contextType, exception); throw; } @@ -150,24 +163,23 @@ public void Dispose() private sealed class AsyncEnumerator : IAsyncEnumerator { - private IAsyncEnumerator _enumerator; + private readonly QueryingEnumerable _queryingEnumerable; private readonly CosmosQueryContext _cosmosQueryContext; private readonly SelectExpression _selectExpression; private readonly Func _shaper; - private readonly ISqlExpressionFactory _sqlExpressionFactory; - private readonly IQuerySqlGeneratorFactory _querySqlGeneratorFactory; private readonly Type _contextType; private readonly string _partitionKey; private readonly IDiagnosticsLogger _logger; private readonly CancellationToken _cancellationToken; + private IAsyncEnumerator _enumerator; + public AsyncEnumerator(QueryingEnumerable queryingEnumerable, CancellationToken cancellationToken) { + _queryingEnumerable = queryingEnumerable; _cosmosQueryContext = queryingEnumerable._cosmosQueryContext; _shaper = queryingEnumerable._shaper; _selectExpression = queryingEnumerable._selectExpression; - _sqlExpressionFactory = queryingEnumerable._sqlExpressionFactory; - _querySqlGeneratorFactory = queryingEnumerable._querySqlGeneratorFactory; _contextType = queryingEnumerable._contextType; _partitionKey = queryingEnumerable._partitionKey; _logger = queryingEnumerable._logger; @@ -184,16 +196,13 @@ public async ValueTask MoveNextAsync() { if (_enumerator == null) { - var selectExpression = (SelectExpression)new InExpressionValuesExpandingExpressionVisitor( - _sqlExpressionFactory, _cosmosQueryContext.ParameterValues).Visit(_selectExpression); + var sqlQuery = _queryingEnumerable.GenerateQuery(); _enumerator = _cosmosQueryContext.CosmosClient .ExecuteSqlQueryAsync( _selectExpression.Container, _partitionKey, - _querySqlGeneratorFactory.Create().GetSqlQuery( - selectExpression, - _cosmosQueryContext.ParameterValues)) + sqlQuery) .GetAsyncEnumerator(_cancellationToken); } diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.ReadItemQueryingEnumerable.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.ReadItemQueryingEnumerable.cs index ddffe47c6c5..0c915035e0a 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.ReadItemQueryingEnumerable.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.ReadItemQueryingEnumerable.cs @@ -34,7 +34,7 @@ private sealed class ReadItemQueryingEnumerable : IEnumerable, IAsyncEnume private readonly Func _shaper; private readonly Type _contextType; private readonly IDiagnosticsLogger _logger; - + public ReadItemQueryingEnumerable( CosmosQueryContext cosmosQueryContext, ReadItemExpression readItemExpression, @@ -50,7 +50,7 @@ public ReadItemQueryingEnumerable( } public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) - => new AsyncEnumerator(this, cancellationToken); + => new Enumerator(this, cancellationToken); public IEnumerator GetEnumerator() => new Enumerator(this); @@ -61,22 +61,37 @@ public string ToQueryString() throw new NotImplementedException("Cosmos: ToQueryString for ReadItemQueryingEnumerable #20653"); } - private sealed class Enumerator : ReadItemBase, IEnumerator + private sealed class Enumerator : IEnumerator, IAsyncEnumerator { + private readonly CosmosQueryContext _cosmosQueryContext; + private readonly ReadItemExpression _readItemExpression; + private readonly Func _shaper; + private readonly Type _contextType; + private readonly IDiagnosticsLogger _logger; + private readonly CancellationToken _cancellationToken; + private JObject _item; private bool _hasExecuted; - public Enumerator(ReadItemQueryingEnumerable readItemEnumerable) : base(readItemEnumerable) + public Enumerator(ReadItemQueryingEnumerable readItemEnumerable, CancellationToken cancellationToken = default) { + _cosmosQueryContext = readItemEnumerable._cosmosQueryContext; + _readItemExpression = readItemEnumerable._readItemExpression; + _shaper = readItemEnumerable._shaper; + _contextType = readItemEnumerable._contextType; + _logger = readItemEnumerable._logger; + _cancellationToken = cancellationToken; } object IEnumerator.Current => Current; + public T Current { get; private set; } + public bool MoveNext() { try { - using (CosmosQueryContext.ConcurrencyDetector.EnterCriticalSection()) + using (_cosmosQueryContext.ConcurrencyDetector.EnterCriticalSection()) { if (!_hasExecuted) { @@ -84,67 +99,36 @@ public bool MoveNext() { throw new InvalidOperationException(CosmosStrings.ResourceIdMissing); } - + if (!TryGetPartitionId(out var partitionKey)) { throw new InvalidOperationException(CosmosStrings.ParitionKeyMissing); } - _item = CosmosClient.ExecuteReadItem( - ContainerId, + _item = _cosmosQueryContext.CosmosClient.ExecuteReadItem( + _readItemExpression.Container, partitionKey, resourceId); - var hasNext = !(_item is null); - - Current - = hasNext - ? Shaper(CosmosQueryContext, _item) - : default; - - _hasExecuted = true; - - return hasNext; + return ShapeResult(); } - + return false; } } catch (Exception exception) { - Logger.QueryIterationFailed(ContextType, exception); + _logger.QueryIterationFailed(_contextType, exception); throw; } } - public void Dispose() - { - _item = null; - _hasExecuted = false; - } - - public void Reset() => throw new NotImplementedException(); - } - - private sealed class AsyncEnumerator : ReadItemBase, IAsyncEnumerator - { - private JObject _item; - private readonly CancellationToken _cancellationToken; - private bool _hasExecuted; - - public AsyncEnumerator( - ReadItemQueryingEnumerable readItemEnumerable, - CancellationToken cancellationToken) : base(readItemEnumerable) - { - _cancellationToken = cancellationToken; - } - public async ValueTask MoveNextAsync() { try { - using (CosmosQueryContext.ConcurrencyDetector.EnterCriticalSection()) + using (_cosmosQueryContext.ConcurrencyDetector.EnterCriticalSection()) { if (!_hasExecuted) { @@ -159,22 +143,13 @@ public async ValueTask MoveNextAsync() throw new InvalidOperationException(CosmosStrings.ParitionKeyMissing); } - _item = await CosmosClient.ExecuteReadItemAsync( - ContainerId, + _item = await _cosmosQueryContext.CosmosClient.ExecuteReadItemAsync( + _readItemExpression.Container, partitionKey, resourceId, _cancellationToken); - var hasNext = !(_item is null); - - Current - = hasNext - ? Shaper(CosmosQueryContext, _item) - : default; - - _hasExecuted = true; - - return hasNext; + return ShapeResult(); } return false; @@ -182,62 +157,52 @@ public async ValueTask MoveNextAsync() } catch (Exception exception) { - Logger.QueryIterationFailed(ContextType, exception); + _logger.QueryIterationFailed(_contextType, exception); throw; } } - public ValueTask DisposeAsync() + public void Dispose() { _item = null; _hasExecuted = false; - return default; } - } - private abstract class ReadItemBase - { - private readonly IStateManager _stateManager; - private readonly ReadItemExpression _readItemExpression; - private readonly IEntityType _entityType; + public ValueTask DisposeAsync() + { + Dispose(); - protected readonly CosmosQueryContext CosmosQueryContext; - protected readonly CosmosClientWrapper CosmosClient; - protected readonly string ContainerId; - protected readonly Func Shaper; - protected readonly Type ContextType; - protected readonly IDiagnosticsLogger Logger; + return default; + } - public T Current { get; protected set; } + public void Reset() => throw new NotImplementedException(); - protected ReadItemBase( - ReadItemQueryingEnumerable readItemEnumerable) + private bool ShapeResult() { -#pragma warning disable EF1001 - _stateManager = readItemEnumerable._cosmosQueryContext.StateManager; -#pragma warning restore EF1001 - CosmosQueryContext = readItemEnumerable._cosmosQueryContext; - _readItemExpression = readItemEnumerable._readItemExpression; - _entityType = readItemEnumerable._readItemExpression.EntityType; - CosmosClient = readItemEnumerable._cosmosQueryContext.CosmosClient; - ContainerId = _readItemExpression.Container; - Shaper = readItemEnumerable._shaper; - ContextType = readItemEnumerable._contextType; - Logger = readItemEnumerable._logger; + var hasNext = !(_item is null); + + Current + = hasNext + ? _shaper(_cosmosQueryContext, _item) + : default; + + _hasExecuted = true; + + return hasNext; } - protected bool TryGetPartitionId(out string partitionKey) + private bool TryGetPartitionId(out string partitionKey) { partitionKey = null; - var partitionKeyPropertyName = _entityType.GetPartitionKeyPropertyName(); + var partitionKeyPropertyName = _readItemExpression.EntityType.GetPartitionKeyPropertyName(); if (partitionKeyPropertyName == null) { return true; } - var partitionKeyProperty = _entityType.FindProperty(partitionKeyPropertyName); + var partitionKeyProperty = _readItemExpression.EntityType.FindProperty(partitionKeyPropertyName); if (TryGetParameterValue(partitionKeyProperty, out var value)) { @@ -249,9 +214,9 @@ protected bool TryGetPartitionId(out string partitionKey) return false; } - protected bool TryGetResourceId(out string resourceId) + private bool TryGetResourceId(out string resourceId) { - var idProperty = _entityType.GetProperties() + var idProperty = _readItemExpression.EntityType.GetProperties() .FirstOrDefault(p => p.GetJsonPropertyName() == StoreKeyConvention.IdPropertyName); if (TryGetParameterValue(idProperty, out var value)) @@ -279,14 +244,16 @@ protected bool TryGetResourceId(out string resourceId) private bool TryGenerateIdFromKeys(IProperty idProperty, out object value) { - var entityEntry = Activator.CreateInstance(_entityType.ClrType); + var entityEntry = Activator.CreateInstance(_readItemExpression.EntityType.ClrType); -#pragma warning disable EF1001 - var internalEntityEntry = new InternalEntityEntryFactory().Create(_stateManager, _entityType, entityEntry); +#pragma warning disable EF1001 // Internal EF Core API usage. + var internalEntityEntry = new InternalEntityEntryFactory().Create( + _cosmosQueryContext.StateManager, _readItemExpression.EntityType, entityEntry); +#pragma warning restore EF1001 // Internal EF Core API usage. - foreach (var keyProperty in _entityType.FindPrimaryKey().Properties) + foreach (var keyProperty in _readItemExpression.EntityType.FindPrimaryKey().Properties) { - var property = _entityType.FindProperty(keyProperty.Name); + var property = _readItemExpression.EntityType.FindProperty(keyProperty.Name); if (TryGetParameterValue(property, out var parameterValue)) { @@ -294,12 +261,15 @@ private bool TryGenerateIdFromKeys(IProperty idProperty, out object value) } } +#pragma warning disable EF1001 // Internal EF Core API usage. internalEntityEntry.SetEntityState(EntityState.Added); +#pragma warning restore EF1001 // Internal EF Core API usage. value = internalEntityEntry[idProperty]; +#pragma warning disable EF1001 // Internal EF Core API usage. internalEntityEntry.SetEntityState(EntityState.Detached); -#pragma warning restore EF1001 +#pragma warning restore EF1001 // Internal EF Core API usage. return value != null; } @@ -308,7 +278,7 @@ private bool TryGetParameterValue(IProperty property, out object value) { value = null; return _readItemExpression.PropertyParameters.TryGetValue(property, out var parameterName) - && CosmosQueryContext.ParameterValues.TryGetValue(parameterName, out value); + && _cosmosQueryContext.ParameterValues.TryGetValue(parameterName, out value); } private static string GetString(IProperty property, object value) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs index 1eb85c97120..98618080d57 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs @@ -45,7 +45,7 @@ public QueryingEnumerable( } public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) - => new AsyncEnumerator(this, cancellationToken); + => new Enumerator(this, cancellationToken); public IEnumerator GetEnumerator() => new Enumerator(this); @@ -53,7 +53,7 @@ public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToke public string ToQueryString() => InMemoryStrings.NoQueryStrings; - private sealed class Enumerator : IEnumerator + private sealed class Enumerator : IEnumerator, IAsyncEnumerator { private IEnumerator _enumerator; private readonly QueryContext _queryContext; @@ -61,44 +61,29 @@ private sealed class Enumerator : IEnumerator private readonly Func _shaper; private readonly Type _contextType; private readonly IDiagnosticsLogger _logger; + private readonly CancellationToken _cancellationToken; - public Enumerator(QueryingEnumerable queryingEnumerable) + public Enumerator(QueryingEnumerable queryingEnumerable, CancellationToken cancellationToken = default) { _queryContext = queryingEnumerable._queryContext; _innerEnumerable = queryingEnumerable._innerEnumerable; _shaper = queryingEnumerable._shaper; _contextType = queryingEnumerable._contextType; _logger = queryingEnumerable._logger; + _cancellationToken = cancellationToken; } public T Current { get; private set; } object IEnumerator.Current => Current; - public void Dispose() - { - _enumerator?.Dispose(); - _enumerator = null; - } - public bool MoveNext() { try { using (_queryContext.ConcurrencyDetector.EnterCriticalSection()) { - if (_enumerator == null) - { - _enumerator = _innerEnumerable.GetEnumerator(); - } - - var hasNext = _enumerator.MoveNext(); - - Current = hasNext - ? _shaper(_queryContext, _enumerator.Current) - : default; - - return hasNext; + return MoveNextHelper(); } } catch (Exception exception) @@ -109,33 +94,6 @@ public bool MoveNext() } } - public void Reset() => throw new NotImplementedException(); - } - - private sealed class AsyncEnumerator : IAsyncEnumerator - { - private IEnumerator _enumerator; - private readonly QueryContext _queryContext; - private readonly IEnumerable _innerEnumerable; - private readonly Func _shaper; - private readonly Type _contextType; - private readonly IDiagnosticsLogger _logger; - private readonly CancellationToken _cancellationToken; - - public AsyncEnumerator( - QueryingEnumerable asyncQueryingEnumerable, - CancellationToken cancellationToken) - { - _queryContext = asyncQueryingEnumerable._queryContext; - _innerEnumerable = asyncQueryingEnumerable._innerEnumerable; - _shaper = asyncQueryingEnumerable._shaper; - _contextType = asyncQueryingEnumerable._contextType; - _logger = asyncQueryingEnumerable._logger; - _cancellationToken = cancellationToken; - } - - public T Current { get; private set; } - public ValueTask MoveNextAsync() { try @@ -144,18 +102,7 @@ public ValueTask MoveNextAsync() { _cancellationToken.ThrowIfCancellationRequested(); - if (_enumerator == null) - { - _enumerator = _innerEnumerable.GetEnumerator(); - } - - var hasNext = _enumerator.MoveNext(); - - Current = hasNext - ? _shaper(_queryContext, _enumerator.Current) - : default; - - return new ValueTask(hasNext); + return new ValueTask(MoveNextHelper()); } } catch (Exception exception) @@ -166,6 +113,28 @@ public ValueTask MoveNextAsync() } } + private bool MoveNextHelper() + { + if (_enumerator == null) + { + _enumerator = _innerEnumerable.GetEnumerator(); + } + + var hasNext = _enumerator.MoveNext(); + + Current = hasNext + ? _shaper(_queryContext, _enumerator.Current) + : default; + + return hasNext; + } + + public void Dispose() + { + _enumerator?.Dispose(); + _enumerator = null; + } + public ValueTask DisposeAsync() { var enumerator = _enumerator; @@ -173,6 +142,8 @@ public ValueTask DisposeAsync() return enumerator.DisposeAsyncIfAvailable(); } + + public void Reset() => throw new NotImplementedException(); } } } diff --git a/src/EFCore.Relational/Query/Internal/QueryingEnumerable.cs b/src/EFCore.Relational/Query/Internal/QueryingEnumerable.cs index 37655170a72..a26ce4761e4 100644 --- a/src/EFCore.Relational/Query/Internal/QueryingEnumerable.cs +++ b/src/EFCore.Relational/Query/Internal/QueryingEnumerable.cs @@ -352,8 +352,7 @@ public async ValueTask MoveNextAsync() private async Task InitializeReaderAsync(DbContext _, bool result, CancellationToken cancellationToken) { - var relationalCommand = _relationalCommandCache.GetRelationalCommand( - _relationalQueryContext.ParameterValues); + var relationalCommand = _relationalCommandCache.GetRelationalCommand(_relationalQueryContext.ParameterValues); _dataReader = await relationalCommand.ExecuteReaderAsync(