Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Writing custom Entity Framework function with IMethodCallTranslator

I have a query that looks something like this.

var railcars = DbContext.Railcars
    .Select(r => new
    {
        Number = r.RailcarNumber,
        Quantity = railcar.InboundQuantity
            - DbContext.Transfers
                .Where(t => t.FromType == TransferType.Railcar && t.FromId == railcar.Id)
                .Sum(t => t.Quantity)
            + DbContext.Transfers
                .Where(t => t.ToType == TransferType.Railcar && t.ToId == railcar.Id)
                .Sum(t => t.Quantity)
    };

The calculation for the Quantity is a little convoluted, and I will use this calculation in many different places. So I was looking to simplify this syntax a little. What I would really like to do is create my own function and make it convertible to SQL something like the following.

public static class QueryHelper
{
    public static double GetCurrentQuantity(this Railcar railcar, ApplicationDbContext dbContext)
    {
        throw new InvalidOperationException("This method cannot be executed directly.");
    }
}

var railcars = DbContext.Railcars
    .Select(r => new
    {
        Number = r.RailcarNumber,
        Quantity = QueryHelper.GetCurrentQuantity(r, DbContext)
    };

I've been looking into using IMethodCallTranslator to translate the needed code to SQL. But the only examples I can find seem old and, for example, the call signature for IMethodCallTranslator.Translate() was completely different.

Also, I'm not sure how to go about creating an expression as complex as I need into a SqlExpression (which is what the Translate() method returns).

Questions:

  1. Does anyone know if it's possible? Am I on the right track?
  2. Are there any good examples of this?
  3. Any examples or documentation on building a SqlExpression this complex?
  4. Any examples for my specific logic to help get me started?
like image 848
Jonathan Wood Avatar asked Oct 16 '25 16:10

Jonathan Wood


1 Answers

The simplest way to achieve that is by adding a third-party extension that corrects the expression tree before passing it to the LINQ Translator. I would suggest using LINQKit.

If you are using EF Core, there are configuration extensions that enable expression expanding. You can achieve this by adding the following code to your project:

builder
    .UseSqlServer(connectionString) // or any other provider
    .WithExpressionExpanding();     // enabling LINQKit extension

For EF6, you can use the extension method AsExpandable() to achieve similar functionality. Simply add this call at the top of your query:

var result = query.AsExpandable()
    .ToList();

With LINQKit, you can mark your methods with the ExpandableAttribute and specify a static function that returns an expression representation for injection into the query. Here's an example:

public static class QueryHelper
{
    [Expandable(nameof(GetCurrentQuantityImpl))]
    public static double GetCurrentQuantity(this Railcar railcar, ApplicationDbContext dbContext)
    {
        throw new InvalidOperationException("This method cannot be executed directly.");
    }

    static Expression<Func<Railcar, ApplicationDbContext, double>> GetCurrentQuantityImpl()
    {
        return (railcar, dbContext) => railcar.InboundQuantity
            - dbContext.Transfers
                .Where(t => t.FromType == TransferType.Railcar && t.FromId == railcar.Id)
                .Sum(t => t.Quantity)
            + dbContext.Transfers
                .Where(t => t.ToType == TransferType.Railcar && t.ToId == railcar.Id)
                .Sum(t => t.Quantity);
    }
}

After defining this method, you can use GetCurrentQuantity in your LINQ queries.


Without LINQKit

For those who cannot use third-party extensions, there is an extract from LINQKit that does the same thing:

ExpandableAttribute

[AttributeUsage(AttributeTargets.Property | AttributeTargets.Method, AllowMultiple = true, Inherited = true)]
public class ExpandableAttribute : Attribute
{
    /// <summary>
    /// Creates instance of attribute.
    /// </summary>
    /// <param name="methodName">Name of method in the same class that returns substitution expression. [Optional]</param>
    public ExpandableAttribute(string methodName = null)
    {
        MethodName = methodName;
    }

    /// <summary>
    /// Name of method in the same class that returns substitution expression.
    /// </summary>
    public string MethodName { get; set; }
}

Extension realization:

public static class EFCoreLinqExtensions
{
    public static DbContextOptionsBuilder UseMemberReplacement(this DbContextOptionsBuilder optionsBuilder)
    {
        var coreExtension = optionsBuilder.Options.GetExtension<CoreOptionsExtension>();

        QueryExpressionReplacementInterceptor? currentInterceptor = null;
        if (coreExtension.Interceptors != null)
        {
            currentInterceptor = coreExtension.Interceptors.OfType<QueryExpressionReplacementInterceptor>()
                .FirstOrDefault();
        }

        if (currentInterceptor == null)
        {
            currentInterceptor = new QueryExpressionReplacementInterceptor();
            optionsBuilder.AddInterceptors(currentInterceptor);
        }

        return optionsBuilder;
    }

    private sealed class MemberReplacementVisitor : ExpressionVisitor
    {
        readonly Dictionary<MemberInfo, LambdaExpression?> _expandableCache = new();

        protected override Expression VisitMember(MemberExpression node)
        {
            if (GetExpandLambda(node.Member, out var methodLambda))
            {
                var newExpr = methodLambda.Parameters.Count > 0
                    ? ReplacingExpressionVisitor.Replace(methodLambda.Parameters[0], node.Expression!,
                        methodLambda.Body)
                    : methodLambda.Body;

                return Visit(newExpr);
            }

            return base.VisitMember(node);
        }

        protected override Expression VisitMethodCall(MethodCallExpression node)
        {
            if (GetExpandLambda(node.Method, out var methodLambda))
            {
                Expression newExpr;
                if (node.Method.IsStatic)
                {
                    var replaceVisitor = new ReplacingExpressionVisitor(methodLambda.Parameters, node.Arguments);

                    if (node.Arguments.Count != methodLambda.Parameters.Count)
                        throw new InvalidOperationException(
                            $"Required {node.Arguments.Count} arguments, but returned lambda with {methodLambda.Parameters.Count} parameters during call {node.Method}");

                    newExpr = replaceVisitor.Visit(methodLambda.Body);
                }
                else
                {
                    List<Expression> replacements = new(methodLambda.Parameters.Count);

                    replacements.Add(node.Object!);
                    replacements.AddRange(node.Arguments);

                    if (replacements.Count != methodLambda.Parameters.Count)
                        throw new InvalidOperationException(
                            $"Required {replacements.Count} arguments, but returned lambda with {methodLambda.Parameters.Count} parameters during call {node.Method}");

                    var replaceVisitor = new ReplacingExpressionVisitor(methodLambda.Parameters, replacements);
                    newExpr = replaceVisitor.Visit(methodLambda.Body);
                }

                return Visit(newExpr);
            }

            return base.VisitMethodCall(node);
        }

        #region Helper methods

        static bool IsNullableType(Type type)
        {
            return !type.IsValueType || type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>);
        }

        static object? EvaluateExpression(Expression? expr)
        {
            if (expr == null)
                return null;

            switch (expr.NodeType)
            {
                case ExpressionType.Default:
                    return !IsNullableType(expr.Type) ? Activator.CreateInstance(expr.Type) : null;

                case ExpressionType.Constant:
                    return ((ConstantExpression)expr).Value;

                case ExpressionType.Convert:
                case ExpressionType.ConvertChecked:
                {
                    var unary = (UnaryExpression)expr;
                    var operand = EvaluateExpression(unary.Operand);
                    if (operand == null)
                        return null;
                    break;
                }

                case ExpressionType.MemberAccess:
                {
                    var member = (MemberExpression) expr;

                    if (member.Member.IsFieldEx())
                        return ((FieldInfo)member.Member).GetValue(EvaluateExpression(member.Expression));

                    if (member.Member is PropertyInfo propertyInfo)
                    {
                        var obj = EvaluateExpression(member.Expression);
                        if (obj == null)
                        {
                            if (propertyInfo.IsNullableValueMember())
                                return null;
                            if (propertyInfo.IsNullableHasValueMember())
                                return false;
                        }
                        return propertyInfo.GetValue(obj, null);
                    }

                    break;
                }

                case ExpressionType.Call:
                {
                    var mc = (MethodCallExpression)expr;
                    var arguments = mc.Arguments.Select(EvaluateExpression).ToArray();
                    var instance  = EvaluateExpression(mc.Object);

                    if (instance == null && mc.Method.IsNullableGetValueOrDefault())
                        return null;

                    return mc.Method.Invoke(instance, arguments);
                }
            }

            var value = Expression.Lambda(expr).Compile().DynamicInvoke();
            return value;
        }

        bool GetExpandLambda(MemberInfo memberInfo, out LambdaExpression expandLambda)
        {
            if (_expandableCache.TryGetValue(memberInfo, out expandLambda))
            {
                return expandLambda != null;
            }

            var canExpand = memberInfo.DeclaringType != null;
            if (canExpand)
            {
                // shortcut for standard methods
                canExpand = memberInfo.DeclaringType != typeof(Enumerable) &&
                            memberInfo.DeclaringType != typeof(Queryable);
            }

            if (canExpand)
            {
                if (memberInfo.GetCustomAttributes(typeof(ExpandableAttribute), true).FirstOrDefault() is ExpandableAttribute attr && memberInfo.DeclaringType != null)
                {
                    var methodName = string.IsNullOrEmpty(attr.MethodName) ? memberInfo.Name : attr.MethodName;

                    Expression expr;

                    if (memberInfo is MethodInfo method && method.IsGenericMethod)
                    {
                        var args = method.GetGenericArguments();

                        expr = Expression.Call(memberInfo.DeclaringType, methodName, args);
                    }
                    else
                    {
                        expr = Expression.Call(memberInfo.DeclaringType, methodName, Type.EmptyTypes);
                    }

                    expandLambda = (EvaluateExpression(expr) as LambdaExpression)!;
                    if (expandLambda == null)
                    {
                        throw new InvalidOperationException(
                            $"Expandable method from '{memberInfo.DeclaringType}.{methodName}()' have returned not a LambdaExpression.");
                    }

                    _expandableCache.Add(memberInfo, expandLambda);
                    return true;
                }
            }

            _expandableCache.Add(memberInfo, null);

            return false;
        }

        #endregion
    }

    sealed class QueryExpressionReplacementInterceptor : IQueryExpressionInterceptor
    {
        public Expression QueryCompilationStarting(Expression queryExpression, QueryExpressionEventData eventData)
        {
            var visitor = new MemberReplacementVisitor();

            var result = visitor.Visit(queryExpression);

            return result;
        }
    }
}

Reflection helpers:

internal static class ReflectionExtensions
{
    /// <summary>
    /// Returns true, if type is <see cref="Nullable{T}"/> type.
    /// </summary>
    /// <param name="type">A <see cref="Type"/> instance. </param>
    /// <returns><c>true</c>, if <paramref name="type"/> represents <see cref="Nullable{T}"/> type; otherwise, <c>false</c>.</returns>
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static bool IsNullable(this Type type)
    {
        return type.IsConstructedGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>);
    }

    public static bool IsNullableValueMember(this MemberInfo member)
    {
        return
            member.Name == "Value" &&
            member.DeclaringType!.IsNullable();
    }

    public static bool IsNullableHasValueMember(this MemberInfo member)
    {
        return
            member.Name == "HasValue" &&
            member.DeclaringType!.IsNullable();
    }

    public static bool IsNullableGetValueOrDefault(this MemberInfo member)
    {
        return
            member.Name == "GetValueOrDefault" &&
            member.DeclaringType!.IsNullable();
    }

}

For initializing extension, it is needed to call UseMemberReplacement() method in OnConfiguring method of DbContext class:

builder
    .UseSqlServer(connectionString) // or any other provider
    .UseMemberReplacement();        // enabling extension
like image 198
Svyatoslav Danyliv Avatar answered Oct 18 '25 09:10

Svyatoslav Danyliv



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!