I have a query that looks something like this.
var railcars = DbContext.Railcars
.Select(r => new
{
Number = r.RailcarNumber,
Quantity = railcar.InboundQuantity
- DbContext.Transfers
.Where(t => t.FromType == TransferType.Railcar && t.FromId == railcar.Id)
.Sum(t => t.Quantity)
+ DbContext.Transfers
.Where(t => t.ToType == TransferType.Railcar && t.ToId == railcar.Id)
.Sum(t => t.Quantity)
};
The calculation for the Quantity
is a little convoluted, and I will use this calculation in many different places. So I was looking to simplify this syntax a little. What I would really like to do is create my own function and make it convertible to SQL something like the following.
public static class QueryHelper
{
public static double GetCurrentQuantity(this Railcar railcar, ApplicationDbContext dbContext)
{
throw new InvalidOperationException("This method cannot be executed directly.");
}
}
var railcars = DbContext.Railcars
.Select(r => new
{
Number = r.RailcarNumber,
Quantity = QueryHelper.GetCurrentQuantity(r, DbContext)
};
I've been looking into using IMethodCallTranslator
to translate the needed code to SQL. But the only examples I can find seem old and, for example, the call signature for IMethodCallTranslator.Translate()
was completely different.
Also, I'm not sure how to go about creating an expression as complex as I need into a SqlExpression
(which is what the Translate()
method returns).
Questions:
SqlExpression
this complex?The simplest way to achieve that is by adding a third-party extension that corrects the expression tree before passing it to the LINQ Translator. I would suggest using LINQKit.
If you are using EF Core, there are configuration extensions that enable expression expanding. You can achieve this by adding the following code to your project:
builder
.UseSqlServer(connectionString) // or any other provider
.WithExpressionExpanding(); // enabling LINQKit extension
For EF6, you can use the extension method AsExpandable()
to achieve similar functionality. Simply add this call at the top of your query:
var result = query.AsExpandable()
.ToList();
With LINQKit, you can mark your methods with the ExpandableAttribute
and specify a static function that returns an expression representation for injection into the query. Here's an example:
public static class QueryHelper
{
[Expandable(nameof(GetCurrentQuantityImpl))]
public static double GetCurrentQuantity(this Railcar railcar, ApplicationDbContext dbContext)
{
throw new InvalidOperationException("This method cannot be executed directly.");
}
static Expression<Func<Railcar, ApplicationDbContext, double>> GetCurrentQuantityImpl()
{
return (railcar, dbContext) => railcar.InboundQuantity
- dbContext.Transfers
.Where(t => t.FromType == TransferType.Railcar && t.FromId == railcar.Id)
.Sum(t => t.Quantity)
+ dbContext.Transfers
.Where(t => t.ToType == TransferType.Railcar && t.ToId == railcar.Id)
.Sum(t => t.Quantity);
}
}
After defining this method, you can use GetCurrentQuantity
in your LINQ queries.
For those who cannot use third-party extensions, there is an extract from LINQKit that does the same thing:
ExpandableAttribute
[AttributeUsage(AttributeTargets.Property | AttributeTargets.Method, AllowMultiple = true, Inherited = true)]
public class ExpandableAttribute : Attribute
{
/// <summary>
/// Creates instance of attribute.
/// </summary>
/// <param name="methodName">Name of method in the same class that returns substitution expression. [Optional]</param>
public ExpandableAttribute(string methodName = null)
{
MethodName = methodName;
}
/// <summary>
/// Name of method in the same class that returns substitution expression.
/// </summary>
public string MethodName { get; set; }
}
Extension realization:
public static class EFCoreLinqExtensions
{
public static DbContextOptionsBuilder UseMemberReplacement(this DbContextOptionsBuilder optionsBuilder)
{
var coreExtension = optionsBuilder.Options.GetExtension<CoreOptionsExtension>();
QueryExpressionReplacementInterceptor? currentInterceptor = null;
if (coreExtension.Interceptors != null)
{
currentInterceptor = coreExtension.Interceptors.OfType<QueryExpressionReplacementInterceptor>()
.FirstOrDefault();
}
if (currentInterceptor == null)
{
currentInterceptor = new QueryExpressionReplacementInterceptor();
optionsBuilder.AddInterceptors(currentInterceptor);
}
return optionsBuilder;
}
private sealed class MemberReplacementVisitor : ExpressionVisitor
{
readonly Dictionary<MemberInfo, LambdaExpression?> _expandableCache = new();
protected override Expression VisitMember(MemberExpression node)
{
if (GetExpandLambda(node.Member, out var methodLambda))
{
var newExpr = methodLambda.Parameters.Count > 0
? ReplacingExpressionVisitor.Replace(methodLambda.Parameters[0], node.Expression!,
methodLambda.Body)
: methodLambda.Body;
return Visit(newExpr);
}
return base.VisitMember(node);
}
protected override Expression VisitMethodCall(MethodCallExpression node)
{
if (GetExpandLambda(node.Method, out var methodLambda))
{
Expression newExpr;
if (node.Method.IsStatic)
{
var replaceVisitor = new ReplacingExpressionVisitor(methodLambda.Parameters, node.Arguments);
if (node.Arguments.Count != methodLambda.Parameters.Count)
throw new InvalidOperationException(
$"Required {node.Arguments.Count} arguments, but returned lambda with {methodLambda.Parameters.Count} parameters during call {node.Method}");
newExpr = replaceVisitor.Visit(methodLambda.Body);
}
else
{
List<Expression> replacements = new(methodLambda.Parameters.Count);
replacements.Add(node.Object!);
replacements.AddRange(node.Arguments);
if (replacements.Count != methodLambda.Parameters.Count)
throw new InvalidOperationException(
$"Required {replacements.Count} arguments, but returned lambda with {methodLambda.Parameters.Count} parameters during call {node.Method}");
var replaceVisitor = new ReplacingExpressionVisitor(methodLambda.Parameters, replacements);
newExpr = replaceVisitor.Visit(methodLambda.Body);
}
return Visit(newExpr);
}
return base.VisitMethodCall(node);
}
#region Helper methods
static bool IsNullableType(Type type)
{
return !type.IsValueType || type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>);
}
static object? EvaluateExpression(Expression? expr)
{
if (expr == null)
return null;
switch (expr.NodeType)
{
case ExpressionType.Default:
return !IsNullableType(expr.Type) ? Activator.CreateInstance(expr.Type) : null;
case ExpressionType.Constant:
return ((ConstantExpression)expr).Value;
case ExpressionType.Convert:
case ExpressionType.ConvertChecked:
{
var unary = (UnaryExpression)expr;
var operand = EvaluateExpression(unary.Operand);
if (operand == null)
return null;
break;
}
case ExpressionType.MemberAccess:
{
var member = (MemberExpression) expr;
if (member.Member.IsFieldEx())
return ((FieldInfo)member.Member).GetValue(EvaluateExpression(member.Expression));
if (member.Member is PropertyInfo propertyInfo)
{
var obj = EvaluateExpression(member.Expression);
if (obj == null)
{
if (propertyInfo.IsNullableValueMember())
return null;
if (propertyInfo.IsNullableHasValueMember())
return false;
}
return propertyInfo.GetValue(obj, null);
}
break;
}
case ExpressionType.Call:
{
var mc = (MethodCallExpression)expr;
var arguments = mc.Arguments.Select(EvaluateExpression).ToArray();
var instance = EvaluateExpression(mc.Object);
if (instance == null && mc.Method.IsNullableGetValueOrDefault())
return null;
return mc.Method.Invoke(instance, arguments);
}
}
var value = Expression.Lambda(expr).Compile().DynamicInvoke();
return value;
}
bool GetExpandLambda(MemberInfo memberInfo, out LambdaExpression expandLambda)
{
if (_expandableCache.TryGetValue(memberInfo, out expandLambda))
{
return expandLambda != null;
}
var canExpand = memberInfo.DeclaringType != null;
if (canExpand)
{
// shortcut for standard methods
canExpand = memberInfo.DeclaringType != typeof(Enumerable) &&
memberInfo.DeclaringType != typeof(Queryable);
}
if (canExpand)
{
if (memberInfo.GetCustomAttributes(typeof(ExpandableAttribute), true).FirstOrDefault() is ExpandableAttribute attr && memberInfo.DeclaringType != null)
{
var methodName = string.IsNullOrEmpty(attr.MethodName) ? memberInfo.Name : attr.MethodName;
Expression expr;
if (memberInfo is MethodInfo method && method.IsGenericMethod)
{
var args = method.GetGenericArguments();
expr = Expression.Call(memberInfo.DeclaringType, methodName, args);
}
else
{
expr = Expression.Call(memberInfo.DeclaringType, methodName, Type.EmptyTypes);
}
expandLambda = (EvaluateExpression(expr) as LambdaExpression)!;
if (expandLambda == null)
{
throw new InvalidOperationException(
$"Expandable method from '{memberInfo.DeclaringType}.{methodName}()' have returned not a LambdaExpression.");
}
_expandableCache.Add(memberInfo, expandLambda);
return true;
}
}
_expandableCache.Add(memberInfo, null);
return false;
}
#endregion
}
sealed class QueryExpressionReplacementInterceptor : IQueryExpressionInterceptor
{
public Expression QueryCompilationStarting(Expression queryExpression, QueryExpressionEventData eventData)
{
var visitor = new MemberReplacementVisitor();
var result = visitor.Visit(queryExpression);
return result;
}
}
}
Reflection helpers:
internal static class ReflectionExtensions
{
/// <summary>
/// Returns true, if type is <see cref="Nullable{T}"/> type.
/// </summary>
/// <param name="type">A <see cref="Type"/> instance. </param>
/// <returns><c>true</c>, if <paramref name="type"/> represents <see cref="Nullable{T}"/> type; otherwise, <c>false</c>.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool IsNullable(this Type type)
{
return type.IsConstructedGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>);
}
public static bool IsNullableValueMember(this MemberInfo member)
{
return
member.Name == "Value" &&
member.DeclaringType!.IsNullable();
}
public static bool IsNullableHasValueMember(this MemberInfo member)
{
return
member.Name == "HasValue" &&
member.DeclaringType!.IsNullable();
}
public static bool IsNullableGetValueOrDefault(this MemberInfo member)
{
return
member.Name == "GetValueOrDefault" &&
member.DeclaringType!.IsNullable();
}
}
For initializing extension, it is needed to call UseMemberReplacement()
method in OnConfiguring
method of DbContext
class:
builder
.UseSqlServer(connectionString) // or any other provider
.UseMemberReplacement(); // enabling extension
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With