Mutating the expression tree of a predicate to target another type

Solution 1:

It seems you're generating the parameter expression twice, in VisitMember() here:

var converted = Expression.MakeMemberAccess(
    base.Visit(node.Expression),
    activeRecordType.GetProperty(node.Member.Name));

...since base.Visit() will end up in VisitParameter I imagine, and in GetMany() itself:

var lambda = Expression.Lambda<Func<ActiveRecord.Widget, bool>>(
    visitor.Visit(predicate.Body),
    predicate.Parameters.Select(p => visitor.Visit(p));

If you're using a ParameterExpression in the body, it has to be the same instance (not just the same type and name) as the one declared for the Lambda. I've had problems before with this kind of scenario, though I think the result was that I just wasn't able to create the expression, it would just throw an exception. In any case you might try reusing the parameter instance see if it helps.

Solution 2:

It turned out that the tricky part is simply that the ParameterExpression instances that exist in the expression tree of the new lambda must be the same instances as are passed in the IEnumerable<ParameterExpression> parameter of Expression.Lambda.

Note that inside TransformPredicateLambda I am giving t => typeof(TNewTarget) as the "type converter" function; that's because in this specific case, we can assume that all parameters and member accesses will be of that one specific type. More advanced scenarios may need additional logic in there.

The code:

internal class DbAccessLayer {
    private static Expression<Func<TNewTarget, bool>> 
    TransformPredicateLambda<TOldTarget, TNewTarget>(
    Expression<Func<TOldTarget, bool>> predicate)
    {
        var lambda = (LambdaExpression) predicate;
        if (lambda == null) {
            throw new NotSupportedException();
        }

        var mutator = new ExpressionTargetTypeMutator(t => typeof(TNewTarget));
        var explorer = new ExpressionTreeExplorer();
        var converted = mutator.Visit(predicate.Body);

        return Expression.Lambda<Func<TNewTarget, bool>>(
            converted,
            lambda.Name,
            lambda.TailCall,
            explorer.Explore(converted).OfType<ParameterExpression>());
    }


    private class ExpressionTargetTypeMutator : ExpressionVisitor
    {
        private readonly Func<Type, Type> typeConverter;

        public ExpressionTargetTypeMutator(Func<Type, Type> typeConverter)
        {
            this.typeConverter = typeConverter;
        }

        protected override Expression VisitMember(MemberExpression node)
        {
            var dataContractType = node.Member.ReflectedType;
            var activeRecordType = this.typeConverter(dataContractType);

            var converted = Expression.MakeMemberAccess(
                base.Visit(node.Expression), 
                activeRecordType.GetProperty(node.Member.Name));

            return converted;
        }

        protected override Expression VisitParameter(ParameterExpression node)
        {
            var dataContractType = node.Type;
            var activeRecordType = this.typeConverter(dataContractType);

            return Expression.Parameter(activeRecordType, node.Name);
        }
    }
}

/// <summary>
/// Utility class for the traversal of expression trees.
/// </summary>
public class ExpressionTreeExplorer
{
    private readonly Visitor visitor = new Visitor();

    /// <summary>
    /// Returns the enumerable collection of expressions that comprise
    /// the expression tree rooted at the specified node.
    /// </summary>
    /// <param name="node">The node.</param>
    /// <returns>
    /// The enumerable collection of expressions that comprise the expression tree.
    /// </returns>
    public IEnumerable<Expression> Explore(Expression node)
    {
        return this.visitor.Explore(node);
    }

    private class Visitor : ExpressionVisitor
    {
        private readonly List<Expression> expressions = new List<Expression>();

        protected override Expression VisitBinary(BinaryExpression node)
        {
            this.expressions.Add(node);
            return base.VisitBinary(node);
        }

        protected override Expression VisitBlock(BlockExpression node)
        {
            this.expressions.Add(node);
            return base.VisitBlock(node);
        }

        protected override Expression VisitConditional(ConditionalExpression node)
        {
            this.expressions.Add(node);
            return base.VisitConditional(node);
        }

        protected override Expression VisitConstant(ConstantExpression node)
        {
            this.expressions.Add(node);
            return base.VisitConstant(node);
        }

        protected override Expression VisitDebugInfo(DebugInfoExpression node)
        {
            this.expressions.Add(node);
            return base.VisitDebugInfo(node);
        }

        protected override Expression VisitDefault(DefaultExpression node)
        {
            this.expressions.Add(node);
            return base.VisitDefault(node);
        }

        protected override Expression VisitDynamic(DynamicExpression node)
        {
            this.expressions.Add(node);
            return base.VisitDynamic(node);
        }

        protected override Expression VisitExtension(Expression node)
        {
            this.expressions.Add(node);
            return base.VisitExtension(node);
        }

        protected override Expression VisitGoto(GotoExpression node)
        {
            this.expressions.Add(node);
            return base.VisitGoto(node);
        }

        protected override Expression VisitIndex(IndexExpression node)
        {
            this.expressions.Add(node);
            return base.VisitIndex(node);
        }

        protected override Expression VisitInvocation(InvocationExpression node)
        {
            this.expressions.Add(node);
            return base.VisitInvocation(node);
        }

        protected override Expression VisitLabel(LabelExpression node)
        {
            this.expressions.Add(node);
            return base.VisitLabel(node);
        }

        protected override Expression VisitLambda<T>(Expression<T> node)
        {
            this.expressions.Add(node);
            return base.VisitLambda(node);
        }

        protected override Expression VisitListInit(ListInitExpression node)
        {
            this.expressions.Add(node);
            return base.VisitListInit(node);
        }

        protected override Expression VisitLoop(LoopExpression node)
        {
            this.expressions.Add(node);
            return base.VisitLoop(node);
        }

        protected override Expression VisitMember(MemberExpression node)
        {
            this.expressions.Add(node);
            return base.VisitMember(node);
        }

        protected override Expression VisitMemberInit(MemberInitExpression node)
        {
            this.expressions.Add(node);
            return base.VisitMemberInit(node);
        }

        protected override Expression VisitMethodCall(MethodCallExpression node)
        {
            this.expressions.Add(node);
            return base.VisitMethodCall(node);
        }

        protected override Expression VisitNew(NewExpression node)
        {
            this.expressions.Add(node);
            return base.VisitNew(node);
        }

        protected override Expression VisitNewArray(NewArrayExpression node)
        {
            this.expressions.Add(node);
            return base.VisitNewArray(node);
        }

        protected override Expression VisitParameter(ParameterExpression node)
        {
            this.expressions.Add(node);
            return base.VisitParameter(node);
        }

        protected override Expression VisitRuntimeVariables(RuntimeVariablesExpression node)
        {
            this.expressions.Add(node);
            return base.VisitRuntimeVariables(node);
        }

        protected override Expression VisitSwitch(SwitchExpression node)
        {
            this.expressions.Add(node);
            return base.VisitSwitch(node);
        }

        protected override Expression VisitTry(TryExpression node)
        {
            this.expressions.Add(node);
            return base.VisitTry(node);
        }

        protected override Expression VisitTypeBinary(TypeBinaryExpression node)
        {
            this.expressions.Add(node);
            return base.VisitTypeBinary(node);
        }

        protected override Expression VisitUnary(UnaryExpression node)
        {
            this.expressions.Add(node);
            return base.VisitUnary(node);
        }

        public IEnumerable<Expression> Explore(Expression node)
        {
            this.expressions.Clear();
            this.Visit(node);
            return expressions.ToArray();
        }
    }
}

Solution 3:

I tried the simple (not complete) implementation for mutating the expression p => p.Id == 15 (the code is below). There are one class named "CrossMapping" which defines the mapping between original and "new" types and type members.

There are several metods named Mutate_XY_Expression for every expression type, which makes new mutated expression. The method inputs need the original express (MemberExpression originalExpression) as model of expression, the list or parameters expression (IList<ParameterExpression> parameterExpressions) which are defined parameters by "parent" expression and should be used by "parent's" body, and the mapping object (CrossMapping mapping) which defines the mapping between types and members.

For full implementation you will maybe need more informations from parent's expression than parameters. But the pattern should be same.

Sample does not implement the Visitor pattern, as you know - it's because simplicity. But there is no barrier to converting to them.

I hope, it will help.

The code (C# 4.0):

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Linq.Expressions;

namespace ConsoleApplication1 {
    public class Product1 {
        public int Id { get; set; }
        public string Name { get; set; }
        public decimal Weight { get; set; }
    }

    public class Product2 {
        public int Id { get; set; }
        public string Name { get; set; }
        public decimal Weight { get; set; }
    }

    class Program {
        static void Main( string[] args ) {
            // list of products typed as Product1
            var lst1 = new List<Product1> {
                new Product1{ Id = 1, Name = "One" },
                new Product1{ Id = 15, Name = "Fifteen" },
                new Product1{ Id = 9, Name = "Nine" }
            };

            // the expression for filtering products
            // typed as Product1
            Expression<Func<Product1, bool>> q1;
            q1 = p => p.Id == 15;

            // list of products typed as Product2
            var lst2 = new List<Product2> {
                new Product2{ Id = 1, Name = "One" },
                new Product2{ Id = 15, Name = "Fifteen" },
                new Product2{ Id = 9, Name = "Nine" }
            };

            // type of Product1
            var tp1 = typeof( Product1 );
            // property info of "Id" property from type Product1
            var tp1Id = tp1.GetProperty( "Id", BindingFlags.Public | BindingFlags.Instance );
            // delegate type for predicating for Product1
            var tp1FuncBool = typeof( Func<,> ).MakeGenericType( tp1, typeof( bool ) );

            // type of Product2
            var tp2 = typeof( Product2 );
            // property info of "Id" property from type Product2
            var tp21Id = tp2.GetProperty( "Id", BindingFlags.Public | BindingFlags.Instance );
            // delegate type for predicating for Product2
            var tp2FuncBool = typeof( Func<,> ).MakeGenericType( tp2, typeof( bool ) );

            // mapping object for types and type members
            var cm1 = new CrossMapping {
                TypeMapping = {
                    // Product1 -> Product2
                    { tp1, tp2 },
                    // Func<Product1, bool> -> Func<Product2, bool>
                    { tp1FuncBool, tp2FuncBool }
                },
                MemberMapping = {
                    // Product1.Id -> Product2.Id
                    { tp1Id, tp21Id }
                }
            };

            // mutate express from Product1's "enviroment" to Product2's "enviroment"
            var cq1_2 = MutateExpression( q1, cm1 );

            // compile lambda to delegate
            var dlg1_2 = ((LambdaExpression)cq1_2).Compile();

            // executing delegate
            var rslt1_2 = lst2.Where( (Func<Product2, bool>)dlg1_2 ).ToList();

            return;
        }

        class CrossMapping {
            public IDictionary<Type, Type> TypeMapping { get; private set; }
            public IDictionary<MemberInfo, MemberInfo> MemberMapping { get; private set; }

            public CrossMapping() {
                this.TypeMapping = new Dictionary<Type, Type>();
                this.MemberMapping = new Dictionary<MemberInfo, MemberInfo>();
            }
        }
        static Expression MutateExpression( Expression originalExpression, CrossMapping mapping ) {
            var ret = MutateExpression(
                originalExpression: originalExpression,
                parameterExpressions: null,
                mapping: mapping
            );

            return ret;
        }
        static Expression MutateExpression( Expression originalExpression, IList<ParameterExpression> parameterExpressions, CrossMapping mapping ) {
            Expression ret;

            if ( null == originalExpression ) {
                ret = null;
            }
            else if ( originalExpression is LambdaExpression ) {
                ret = MutateLambdaExpression( (LambdaExpression)originalExpression, parameterExpressions, mapping );
            }
            else if ( originalExpression is BinaryExpression ) {
                ret = MutateBinaryExpression( (BinaryExpression)originalExpression, parameterExpressions, mapping );
            }
            else if ( originalExpression is ParameterExpression ) {
                ret = MutateParameterExpression( (ParameterExpression)originalExpression, parameterExpressions, mapping );
            }
            else if ( originalExpression is MemberExpression ) {
                ret = MutateMemberExpression( (MemberExpression)originalExpression, parameterExpressions, mapping );
            }
            else if ( originalExpression is ConstantExpression ) {
                ret = MutateConstantExpression( (ConstantExpression)originalExpression, parameterExpressions, mapping );
            }
            else {
                throw new NotImplementedException();
            }

            return ret;
        }

        static Type MutateType( Type originalType, IDictionary<Type, Type> typeMapping ) {
            if ( null == originalType ) { return null; }

            Type ret;
            typeMapping.TryGetValue( originalType, out ret );
            if ( null == ret ) { ret = originalType; }

            return ret;
        }
        static MemberInfo MutateMember( MemberInfo originalMember, IDictionary<MemberInfo, MemberInfo> memberMapping ) {
            if ( null == originalMember ) { return null; }

            MemberInfo ret;
            memberMapping.TryGetValue( originalMember, out ret );
            if ( null == ret ) { ret = originalMember; }

            return ret;
        }
        static LambdaExpression MutateLambdaExpression( LambdaExpression originalExpression, IList<ParameterExpression> parameterExpressions, CrossMapping mapping ) {
            if ( null == originalExpression ) { return null; }

            var newParameters = (from p in originalExpression.Parameters
                                 let np = MutateParameterExpression( p, parameterExpressions, mapping )
                                 select np).ToArray();

            var newBody = MutateExpression( originalExpression.Body, newParameters, mapping );

            var newType = MutateType( originalExpression.Type, mapping.TypeMapping );

            var ret = Expression.Lambda(
                delegateType: newType,
                body: newBody,
                name: originalExpression.Name,
                tailCall: originalExpression.TailCall,
                parameters: newParameters
            );

            return ret;
        }
        static BinaryExpression MutateBinaryExpression( BinaryExpression originalExpression, IList<ParameterExpression> parameterExpressions, CrossMapping mapping ) {
            if ( null == originalExpression ) { return null; }

            var newExprConversion = MutateExpression( originalExpression.Conversion, parameterExpressions, mapping );
            var newExprLambdaConversion = (LambdaExpression)newExprConversion;
            var newExprLeft = MutateExpression( originalExpression.Left, parameterExpressions, mapping );
            var newExprRigth = MutateExpression( originalExpression.Right, parameterExpressions, mapping );
            var newType = MutateType( originalExpression.Type, mapping.TypeMapping );
            var newMember = MutateMember( originalExpression.Method, mapping.MemberMapping);
            var newMethod = (MethodInfo)newMember;

            var ret = Expression.MakeBinary(
                binaryType: originalExpression.NodeType,
                left: newExprLeft,
                right: newExprRigth,
                liftToNull: originalExpression.IsLiftedToNull,
                method: newMethod,
                conversion: newExprLambdaConversion
            );

            return ret;
        }
        static ParameterExpression MutateParameterExpression( ParameterExpression originalExpresion, IList<ParameterExpression> parameterExpressions, CrossMapping mapping ) {
            if ( null == originalExpresion ) { return null; }

            ParameterExpression ret = null;
            if ( null != parameterExpressions ) {
                ret = (from p in parameterExpressions
                       where p.Name == originalExpresion.Name
                       select p).FirstOrDefault();
            }

            if ( null == ret ) {
                var newType = MutateType( originalExpresion.Type, mapping.TypeMapping );

                ret = Expression.Parameter( newType, originalExpresion.Name );
            }

            return ret;
        }
        static MemberExpression MutateMemberExpression( MemberExpression originalExpression, IList<ParameterExpression> parameterExpressions, CrossMapping mapping ) {
            if ( null == originalExpression ) { return null; }

            var newExpression = MutateExpression( originalExpression.Expression, parameterExpressions, mapping );
            var newMember = MutateMember( originalExpression.Member, mapping.MemberMapping );

            var ret = Expression.MakeMemberAccess(
                expression: newExpression,
                member: newMember
            );

            return ret;
        }
        static ConstantExpression MutateConstantExpression( ConstantExpression originalExpression, IList<ParameterExpression> parameterExpressions, CrossMapping mapping ) {
            if ( null == originalExpression ) { return null; }

            var newType = MutateType( originalExpression.Type, mapping.TypeMapping );
            var newValue = originalExpression.Value;

            var ret = Expression.Constant(
                value: newValue,
                type: newType
            );

            return ret;
        }
    }
}

Solution 4:

Jon's own answer above is great, so I expanded it to handle method calls, constant expressions, etc. so that now it works also for expressions such as:

x => x.SubObjects
      .AsQueryable()
      .SelectMany(y => y.GrandChildObjects)
      .Any(z => z.Value == 3)

I also did away with the ExpressionTreeExplorer since the only thing we need are the ParameterExpressions.

Here's the code (Update: Clear the cache when done converting)

public class ExpressionTargetTypeMutator : ExpressionVisitor
{
    private readonly Func<Type, Type> typeConverter;
    private readonly Dictionary<Expression, Expression> _convertedExpressions 
        = new Dictionary<Expression, Expression>();

    public ExpressionTargetTypeMutator(Func<Type, Type> typeConverter)
    {
        this.typeConverter = typeConverter;
    }

    // Clear the ParameterExpression cache between calls to Visit.
    // Not thread safe, but you can probably fix it easily.
    public override Expression Visit(Expression node)
    {
        bool outermostCall = false;
        if (false == _isVisiting)
        {
            this._isVisiting = true;
            outermostCall = true;
        }
        try
        {
            return base.Visit(node);
        }
        finally
        {
            if (outermostCall)
            {
                this._isVisiting = false;
                _convertedExpressions.Clear();
            }
        }
    }


    protected override Expression VisitMember(MemberExpression node)
    {
        var sourceType = node.Member.ReflectedType;
        var targetType = this.typeConverter(sourceType);

        var converted = Expression.MakeMemberAccess(
            base.Visit(node.Expression),
            targetType.GetProperty(node.Member.Name));

        return converted;
    }

    protected override Expression VisitParameter(ParameterExpression node)
    {
        Expression converted;
        if (false == _convertedExpressions.TryGetValue(node, out converted))
        {
            var sourceType = node.Type;
            var targetType = this.typeConverter(sourceType);
            converted = Expression.Parameter(targetType, node.Name);
            _convertedExpressions.Add(node, converted);
        }
        return converted;
    }

    protected override Expression VisitMethodCall(MethodCallExpression node)
    {
        if (node.Method.IsGenericMethod)
        {
            var convertedTypeArguments = node.Method.GetGenericArguments()
                                                    .Select(this.typeConverter)
                                                    .ToArray();
            var genericMethodDefinition = node.Method.GetGenericMethodDefinition();
            var newMethod = genericMethodDefinition.MakeGenericMethod(convertedTypeArguments);
            return Expression.Call(newMethod, node.Arguments.Select(this.Visit));
        }
        return base.VisitMethodCall(node);
    }

    protected override Expression VisitConstant(ConstantExpression node)
    {
        var valueExpression = node.Value as Expression;
        if (null != valueExpression)
        {
            return Expression.Constant(this.Visit(valueExpression));
        }
        return base.VisitConstant(node);
    }

    protected override Expression VisitLambda<T>(Expression<T> node)
    {
        return Expression.Lambda(this.Visit(node.Body), node.Name, node.TailCall, node.Parameters.Select(x => (ParameterExpression)this.VisitParameter(x)));
    }
}