Skip to content

Commit

Permalink
Adds more async extension methods
Browse files Browse the repository at this point in the history
Fix #256
  • Loading branch information
anpete committed Jun 10, 2014
1 parent 70b29b8 commit 706e862
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 63 deletions.
210 changes: 172 additions & 38 deletions src/Microsoft.Data.Entity/Extensions/QueryableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.Data.Entity.Utilities;

// ReSharper disable once CheckNamespace

namespace System.Linq
{
public static class QueryableExtensions
Expand Down Expand Up @@ -42,6 +43,40 @@ public static Task<bool> AnyAsync<TSource>(
throw new InvalidOperationException(Strings.FormatIQueryableProviderNotAsync());
}

private static readonly MethodInfo _anyPredicate
= GetMethod("Any",
t => new[]
{
typeof(IQueryable<>).MakeGenericType(t),
typeof(Expression<>).MakeGenericType(typeof(Func<,>).MakeGenericType(t, typeof(bool)))
});

public static Task<bool> AnyAsync<TSource>(
[NotNull] this IQueryable<TSource> source,
[NotNull] Expression<Func<TSource, bool>> predicate,
CancellationToken cancellationToken = default(CancellationToken))
{
Check.NotNull(source, "source");
Check.NotNull(predicate, "predicate");

cancellationToken.ThrowIfCancellationRequested();

var provider = source.Provider as IAsyncQueryProvider;

if (provider != null)
{
return provider.ExecuteAsync<bool>(
Expression.Call(
null,
_anyPredicate.MakeGenericMethod(typeof(TSource)),
new[] { source.Expression, Expression.Quote(predicate) }
),
cancellationToken);
}

throw new InvalidOperationException(Strings.FormatIQueryableProviderNotAsync());
}

private static readonly MethodInfo _count
= GetMethod("Count", t => new[] { typeof(IQueryable<>).MakeGenericType(t) });

Expand Down Expand Up @@ -94,90 +129,189 @@ public static Task<TSource> SingleAsync<TSource>(
throw new InvalidOperationException(Strings.FormatIQueryableProviderNotAsync());
}

public static Task<List<TSource>> ToListAsync<TSource>(
private static readonly MethodInfo _singlePredicate
= GetMethod("Single",
t => new[]
{
typeof(IQueryable<>).MakeGenericType(t),
typeof(Expression<>).MakeGenericType(typeof(Func<,>).MakeGenericType(t, typeof(bool)))
});

public static Task<TSource> SingleAsync<TSource>(
[NotNull] this IQueryable<TSource> source,
[NotNull] Expression<Func<TSource, bool>> predicate,
CancellationToken cancellationToken = default(CancellationToken))
{
Check.NotNull(source, "source");
Check.NotNull(predicate, "predicate");

cancellationToken.ThrowIfCancellationRequested();

return source.ToAsyncEnumerable().ToList(cancellationToken);
var provider = source.Provider as IAsyncQueryProvider;

if (provider != null)
{
return provider.ExecuteAsync<TSource>(
Expression.Call(
null,
_singlePredicate.MakeGenericMethod(typeof(TSource)),
new[] { source.Expression, Expression.Quote(predicate) }
),
cancellationToken);
}

throw new InvalidOperationException(Strings.FormatIQueryableProviderNotAsync());
}

public static Task<TSource[]> ToArrayAsync<TSource>(
private static readonly MethodInfo _singleOrDefault
= GetMethod("SingleOrDefault", t => new[] { typeof(IQueryable<>).MakeGenericType(t) });

public static Task<TSource> SingleOrDefaultAsync<TSource>(
[NotNull] this IQueryable<TSource> source,
CancellationToken cancellationToken = default(CancellationToken))
{
Check.NotNull(source, "source");

cancellationToken.ThrowIfCancellationRequested();

return source.ToAsyncEnumerable().ToArray(cancellationToken);
}
var provider = source.Provider as IAsyncQueryProvider;

#region TODO
if (provider != null)
{
return provider.ExecuteAsync<TSource>(
Expression.Call(
null,
_singleOrDefault.MakeGenericMethod(typeof(TSource)),
new[] { source.Expression }
),
cancellationToken);
}

public static IQueryable<T> Include<T, TProperty>(
[NotNull] this IQueryable<T> source,
[NotNull] Expression<Func<T, TProperty>> path)
{
// TODO
return source;
throw new InvalidOperationException(Strings.FormatIQueryableProviderNotAsync());
}

public static Task<bool> AnyAsync<TSource>(
private static readonly MethodInfo _singleOrDefaultPredicate
= GetMethod(
"SingleOrDefault",
t => new[]
{
typeof(IQueryable<>).MakeGenericType(t),
typeof(Expression<>).MakeGenericType(typeof(Func<,>).MakeGenericType(t, typeof(bool)))
});

public static Task<TSource> SingleOrDefaultAsync<TSource>(
[NotNull] this IQueryable<TSource> source,
[NotNull] Expression<Func<TSource, bool>> predicate,
CancellationToken cancellationToken = default(CancellationToken))
{
// TODO
return Task.FromResult(false);
Check.NotNull(source, "source");
Check.NotNull(predicate, "predicate");

cancellationToken.ThrowIfCancellationRequested();

var provider = source.Provider as IAsyncQueryProvider;

if (provider != null)
{
return provider.ExecuteAsync<TSource>(
Expression.Call(
null,
_singleOrDefaultPredicate.MakeGenericMethod(typeof(TSource)),
new[] { source.Expression, Expression.Quote(predicate) }
),
cancellationToken);
}

throw new InvalidOperationException(Strings.FormatIQueryableProviderNotAsync());
}

public static Task<TSource> SingleAsync<TSource>(
[NotNull] this IQueryable<TSource> source,
[NotNull] Expression<Func<TSource, bool>> predicate,
private static readonly MethodInfo _sumDecimal
= GetMethod("Sum", () => new[] { typeof(IQueryable<decimal>) });

public static Task<decimal> SumAsync(
[NotNull] this IQueryable<decimal> source,
CancellationToken cancellationToken = default(CancellationToken))
{
// TODO
return Task.FromResult(default(TSource));
Check.NotNull(source, "source");

cancellationToken.ThrowIfCancellationRequested();

var provider = source.Provider as IAsyncQueryProvider;

if (provider != null)
{
return provider.ExecuteAsync<decimal>(
Expression.Call(
null,
_sumDecimal,
new[] { source.Expression }
),
cancellationToken);
}

throw new InvalidOperationException(Strings.FormatIQueryableProviderNotAsync());
}

public static Task<TSource> SingleOrDefaultAsync<TSource>(
[NotNull] this IQueryable<TSource> source,
private static readonly MethodInfo _sumInt
= GetMethod("Sum", () => new[] { typeof(IQueryable<int>) });

public static Task<int> SumAsync(
[NotNull] this IQueryable<int> source,
CancellationToken cancellationToken = default(CancellationToken))
{
// TODO
return Task.FromResult(default(TSource));
Check.NotNull(source, "source");

var provider = source.Provider as IAsyncQueryProvider;

if (provider != null)
{
return provider.ExecuteAsync<int>(
Expression.Call(
null,
_sumInt,
new[] { source.Expression }
),
cancellationToken);
}

throw new InvalidOperationException(Strings.FormatIQueryableProviderNotAsync());
}

public static Task<TSource> SingleOrDefaultAsync<TSource>(
public static Task<List<TSource>> ToListAsync<TSource>(
[NotNull] this IQueryable<TSource> source,
[NotNull] Expression<Func<TSource, bool>> predicate,
CancellationToken cancellationToken = default(CancellationToken))
{
// TODO
return Task.FromResult(default(TSource));
Check.NotNull(source, "source");

cancellationToken.ThrowIfCancellationRequested();

return source.ToAsyncEnumerable().ToList(cancellationToken);
}

public static Task<decimal> SumAsync(
[NotNull] this IQueryable<decimal> source,
public static Task<TSource[]> ToArrayAsync<TSource>(
[NotNull] this IQueryable<TSource> source,
CancellationToken cancellationToken = default(CancellationToken))
{
// TODO
return Task.FromResult(default(decimal));
Check.NotNull(source, "source");

cancellationToken.ThrowIfCancellationRequested();

return source.ToAsyncEnumerable().ToArray(cancellationToken);
}

public static Task<int> SumAsync(
[NotNull] this IQueryable<int> source,
CancellationToken cancellationToken = default(CancellationToken))
public static IQueryable<T> Include<T, TProperty>(
[NotNull] this IQueryable<T> source,
[NotNull] Expression<Func<T, TProperty>> path)
{
// TODO
return Task.FromResult(0);
Check.NotNull(source, "source");

throw new NotImplementedException();
}

#endregion
private static MethodInfo GetMethod(string methodName, Func<Type[]> getParameterTypes)
{
return GetMethod(methodName, getParameterTypes.GetMethodInfo(), 0);
}

private static MethodInfo GetMethod(string methodName, Func<Type, Type[]> getParameterTypes)
{
Expand Down
Loading

0 comments on commit 706e862

Please sign in to comment.