Skip to content

Commit

Permalink
Preserve synchronization context in SaveChangesAsync
Browse files Browse the repository at this point in the history
Closes #23971
  • Loading branch information
roji committed Jan 26, 2021
1 parent 7c8ff90 commit 761a95c
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/EFCore/ChangeTracking/Internal/StateManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1198,7 +1198,7 @@ private async Task<int> SaveChangesAsync(
{
SavingChanges = true;
var result = await SaveChangesAsync(entriesToSave, cancellationToken)
.ConfigureAwait(false);
.ConfigureAwait(acceptAllChangesOnSuccess);

if (acceptAllChangesOnSuccess)
{
Expand Down
2 changes: 1 addition & 1 deletion src/EFCore/DbContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ public virtual async Task<int> SaveChangesAsync(
SavingChanges?.Invoke(this, new SavingChangesEventArgs(acceptAllChangesOnSuccess));

var interceptionResult = await DbContextDependencies.UpdateLogger
.SaveChangesStartingAsync(this, cancellationToken).ConfigureAwait(false);
.SaveChangesStartingAsync(this, cancellationToken).ConfigureAwait(acceptAllChangesOnSuccess);

TryDetectChanges();

Expand Down
105 changes: 104 additions & 1 deletion test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.ComponentModel.DataAnnotations;
Expand All @@ -11,6 +12,7 @@
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Data.SqlClient;
using Microsoft.EntityFrameworkCore.Diagnostics;
Expand Down Expand Up @@ -6546,7 +6548,7 @@ public virtual async Task Can_ignore_invalid_include_path_error()
var contextFactory = await InitializeAsync<IssueContext20609>(
onConfiguring: o => o.ConfigureWarnings(x => x.Ignore(CoreEventId.InvalidIncludePathError)));

using var context = contextFactory.CreateContext();
using var context = contextFactory.CreateContext();
var result = context.Set<IssueContext20609.ClassA>().Include("SubB").ToList();
}

Expand Down Expand Up @@ -9125,6 +9127,107 @@ public SqlExpression Translate(

#endregion

#region Issue22841

[ConditionalFact]
public async Task SaveChangesAsync_accepts_changes_with_ConfigureAwait_true_22841()
{
var contextFactory = await InitializeAsync<MyContext22841>();

using var context = contextFactory.CreateContext();
var observableThing = new ObservableThing22841();

var origSynchronizationContext = SynchronizationContext.Current;
var trackingSynchronizationContext = new SingleThreadSynchronizationContext22841();
SynchronizationContext.SetSynchronizationContext(trackingSynchronizationContext);

bool? isMySyncContext = null;
Action callback = () => isMySyncContext = Thread.CurrentThread == trackingSynchronizationContext.Thread;
observableThing.Event += callback;

try
{
context.Add(observableThing);
await context.SaveChangesAsync();
}
finally
{
observableThing.Event -= callback;
SynchronizationContext.SetSynchronizationContext(origSynchronizationContext);
trackingSynchronizationContext.Dispose();
}

Assert.True(isMySyncContext);
}

protected class MyContext22841 : DbContext
{
public MyContext22841(DbContextOptions options)
: base(options)
{
}

protected override void OnModelCreating(ModelBuilder modelBuilder)
=> modelBuilder
.Entity<ObservableThing22841>()
.Property(o => o.Id)
.UsePropertyAccessMode(PropertyAccessMode.Property);

public DbSet<ObservableThing22841> ObservableThings { get; set; }
}

public class ObservableThing22841
{
public int Id
{
get => _id;
set
{
_id = value;
Event?.Invoke();
}
}

private int _id;

public event Action Event;
}

class SingleThreadSynchronizationContext22841 : SynchronizationContext, IDisposable
{
private CancellationTokenSource _cancellationTokenSource;
readonly BlockingCollection<(SendOrPostCallback callback, object state)> _tasks = new();
internal Thread Thread { get; }

internal SingleThreadSynchronizationContext22841()
{
_cancellationTokenSource = new CancellationTokenSource();
Thread = new Thread(WorkLoop);
Thread.Start();
}

public override void Post(SendOrPostCallback callback, object state) => _tasks.Add((callback, state));
public void Dispose() => _tasks.CompleteAdding();

void WorkLoop()
{
try
{
while (true)
{
var (callback, state) = _tasks.Take();
callback(state);
}
}
catch (InvalidOperationException)
{
_tasks.Dispose();
}
}
}

#endregion Issue22841

protected override string StoreName => "QueryBugsTest";
protected TestSqlLoggerFactory TestSqlLoggerFactory
=> (TestSqlLoggerFactory)ListLoggerFactory;
Expand Down

0 comments on commit 761a95c

Please sign in to comment.