【问题标题】:How to wrap Entity Framework to intercept the LINQ expression just before execution?如何包装实体框架以在执行前拦截 LINQ 表达式?
【发布时间】:2009-12-03 13:33:58
【问题描述】:

我想在执行前重写 LINQ 表达式的某些部分。而且我在将我的重写器注入正确的位置时遇到了问题(实际上)。

查看实体框架源(在反射器中),它最终归结为 IQueryProvider.Execute,在 EF 中,ObjectContext 提供 internal IQueryProvider Provider { get; } 属性与表达式耦合。

所以我创建了一个包装类(实现IQueryProvider)来在调用 Execute 时执行表达式重写,然后将其传递给原始 Provider。

问题是,Provider 后面的字段是private ObjectQueryProvider _queryProvider;。这个ObjectQueryProvider 是一个内部密封类,这意味着不可能创建一个提供额外重写的子类。

因此,由于 ObjectContext 耦合非常紧密,这种方法让我陷入了死胡同。

如何解决这个问题?我看错方向了吗?有没有办法让自己围绕这个ObjectQueryProvider 注入?

更新:虽然提供的解决方案在您使用存储库模式“包装”ObjectContext 时都可以工作,但允许直接使用从 ObjectContext 生成的子类的解决方案会更好。特此保持与动态数据脚手架的兼容。

【问题讨论】:

    标签: c# linq entity-framework expression-trees


    【解决方案1】:

    根据 Arthur 的回答,我创建了一个工作包装器。

    提供的 sn-ps 提供了一种使用您自己的 QueryProvider 和 IQueryable 根包装每个 LINQ 查询的方法。这意味着您必须控制初始查询的开始(因为您大部分时间都使用任何类型的模式)。

    这种方法的问题是它不透明,更理想的情况是在构造函数级别的实体容器中注入一些东西。

    我创建了一个可编译的实现,让它与实体框架一起工作,并添加了对 ObjectQuery.Include 方法的支持。表达式访问者类可以从MSDN复制。

    public class QueryTranslator<T> : IOrderedQueryable<T>
    {
        private Expression expression = null;
        private QueryTranslatorProvider<T> provider = null;
    
        public QueryTranslator(IQueryable source)
        {
            expression = Expression.Constant(this);
            provider = new QueryTranslatorProvider<T>(source);
        }
    
        public QueryTranslator(IQueryable source, Expression e)
        {
            if (e == null) throw new ArgumentNullException("e");
            expression = e;
            provider = new QueryTranslatorProvider<T>(source);
        }
    
        public IEnumerator<T> GetEnumerator()
        {
            return ((IEnumerable<T>)provider.ExecuteEnumerable(this.expression)).GetEnumerator();
        }
    
        System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
        {
            return provider.ExecuteEnumerable(this.expression).GetEnumerator();
        }
    
        public QueryTranslator<T> Include(String path)
        {
            ObjectQuery<T> possibleObjectQuery = provider.source as ObjectQuery<T>;
            if (possibleObjectQuery != null)
            {
                return new QueryTranslator<T>(possibleObjectQuery.Include(path));
            }
            else
            {
                throw new InvalidOperationException("The Include should only happen at the beginning of a LINQ expression");
            }
        }
    
        public Type ElementType
        {
            get { return typeof(T); }
        }
    
        public Expression Expression
        {
            get { return expression; }
        }
    
        public IQueryProvider Provider
        {
            get { return provider; }
        }
    }
    
    public class QueryTranslatorProvider<T> : ExpressionVisitor, IQueryProvider
    {
        internal IQueryable source;
    
        public QueryTranslatorProvider(IQueryable source)
        {
            if (source == null) throw new ArgumentNullException("source");
            this.source = source;
        }
    
        public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
        {
            if (expression == null) throw new ArgumentNullException("expression");
    
            return new QueryTranslator<TElement>(source, expression) as IQueryable<TElement>;
        }
    
        public IQueryable CreateQuery(Expression expression)
        {
            if (expression == null) throw new ArgumentNullException("expression");
            Type elementType = expression.Type.GetGenericArguments().First();
            IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType),
                new object[] { source, expression });
            return result;
        }
    
        public TResult Execute<TResult>(Expression expression)
        {
            if (expression == null) throw new ArgumentNullException("expression");
            object result = (this as IQueryProvider).Execute(expression);
            return (TResult)result;
        }
    
        public object Execute(Expression expression)
        {
            if (expression == null) throw new ArgumentNullException("expression");
    
            Expression translated = this.Visit(expression);
            return source.Provider.Execute(translated);
        }
    
        internal IEnumerable ExecuteEnumerable(Expression expression)
        {
            if (expression == null) throw new ArgumentNullException("expression");
    
            Expression translated = this.Visit(expression);
            return source.Provider.CreateQuery(translated);
        }
    
        #region Visitors
        protected override Expression VisitConstant(ConstantExpression c)
        {
            // fix up the Expression tree to work with EF again
            if (c.Type == typeof(QueryTranslator<T>))
            {
                return source.Expression;
            }
            else
            {
                return base.VisitConstant(c);
            }
        }
        #endregion
    }
    

    您的存储库中的示例用法:

    public IQueryable<User> List()
    {
        return new QueryTranslator<User>(entities.Users).Include("Department");
    }
    

    【讨论】:

    • 您现在拥有所需的一切吗?我应该提供更多帮助方法还是在我的代码中查找一些东西?
    • 不,我让它工作了,但奇怪的是你遗漏了将其修复回 EF 查询的部分。
    【解决方案2】:

    我有你需要的源代码 - 但不知道如何附加文件。

    这里有一些sn-ps(sn-ps!我不得不修改这段代码,所以它可能无法编译):

    可查询:

    public class QueryTranslator<T> : IOrderedQueryable<T>
    {
        private Expression _expression = null;
        private QueryTranslatorProvider<T> _provider = null;
    
        public QueryTranslator(IQueryable source)
        {
            _expression = Expression.Constant(this);
            _provider = new QueryTranslatorProvider<T>(source);
        }
    
        public QueryTranslator(IQueryable source, Expression e)
        {
            if (e == null) throw new ArgumentNullException("e");
            _expression = e;
            _provider = new QueryTranslatorProvider<T>(source);
        }
    
        public IEnumerator<T> GetEnumerator()
        {
            return ((IEnumerable<T>)_provider.ExecuteEnumerable(this._expression)).GetEnumerator();
        }
    
        IEnumerator System.Collections.IEnumerable.GetEnumerator()
        {
            return _provider.ExecuteEnumerable(this._expression).GetEnumerator();
        }
    
        public Type ElementType
        {
            get { return typeof(T); }
        }
    
        public Expression Expression
        {
            get { return _expression; }
        }
    
        public IQueryProvider Provider
        {
            get { return _provider; }
        }
    }
    

    IQueryProvider:

    public class QueryTranslatorProvider<T> : ExpressionTreeTranslator, IQueryProvider
    {
        IQueryable _source;
    
        public QueryTranslatorProvider(IQueryable source)
        {
            if (source == null) throw new ArgumentNullException("source");
            _source = source;
        }
    
        #region IQueryProvider Members
    
        public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
        {
            if (expression == null) throw new ArgumentNullException("expression");
            return new QueryTranslator<TElement>(_source, expression) as IQueryable<TElement>;
        }
    
        public IQueryable CreateQuery(Expression expression)
        {
            if (expression == null) throw new ArgumentNullException("expression");
    
            Type elementType = expression.Type.FindElementTypes().First();
            IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType),
                new object[] { _source, expression });
            return result;
        }
    
        public TResult Execute<TResult>(Expression expression)
        {
            if (expression == null) throw new ArgumentNullException("expression");
            object result = (this as IQueryProvider).Execute(expression);
            return (TResult)result;
        }
    
        public object Execute(Expression expression)
        {
            if (expression == null) throw new ArgumentNullException("expression");
    
            Expression translated = this.Visit(expression);
    
            return _source.Provider.Execute(translated);            
        }
    
        internal IEnumerable ExecuteEnumerable(Expression expression)
        {
            if (expression == null) throw new ArgumentNullException("expression");
    
            Expression translated = this.Visit(expression);
    
            return _source.Provider.CreateQuery(translated);
        }
    
        #endregion        
    
        #region Visits
        protected override MethodCallExpression VisitMethodCall(MethodCallExpression m)
        {
            return m;
        }
    
        protected override Expression VisitUnary(UnaryExpression u)
        {
             return Expression.MakeUnary(u.NodeType, base.Visit(u.Operand), u.Type.ToImplementationType(), u.Method);
        }
        #endregion
    }
    

    用法(警告:改编代码!可能无法编译):

    private Dictionary<Type, object> _table = new Dictionary<Type, object>();
    public override IQueryable<T> GetObjectQuery<T>()
    {
        if (!_table.ContainsKey(type))
        {
            _table[type] = new QueryTranslator<T>(
                _ctx.CreateQuery<T>("[" + typeof(T).Name + "]"));
        }
    
        return (IQueryable<T>)_table[type];
    }
    

    表达访问者/翻译者:

    http://blogs.msdn.com/mattwar/archive/2007/07/31/linq-building-an-iqueryable-provider-part-ii.aspx

    http://msdn.microsoft.com/en-us/library/bb882521.aspx

    编辑:添加 FindElementTypes()。希望现在所有方法都存在。

        /// <summary>
        /// Finds all implemented IEnumerables of the given Type
        /// </summary>
        public static IQueryable<Type> FindIEnumerables(this Type seqType)
        {
            if (seqType == null || seqType == typeof(object) || seqType == typeof(string))
                return new Type[] { }.AsQueryable();
    
            if (seqType.IsArray || seqType == typeof(IEnumerable))
                return new Type[] { typeof(IEnumerable) }.AsQueryable();
    
            if (seqType.IsGenericType && seqType.GetGenericArguments().Length == 1 && seqType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
            {
                return new Type[] { seqType, typeof(IEnumerable) }.AsQueryable();
            }
    
            var result = new List<Type>();
    
            foreach (var iface in (seqType.GetInterfaces() ?? new Type[] { }))
            {
                result.AddRange(FindIEnumerables(iface));
            }
    
            return FindIEnumerables(seqType.BaseType).Union(result);
        }
    
        /// <summary>
        /// Finds all element types provided by a specified sequence type.
        /// "Element types" are T for IEnumerable&lt;T&gt; and object for IEnumerable.
        /// </summary>
        public static IQueryable<Type> FindElementTypes(this Type seqType)
        {
            return seqType.FindIEnumerables().Select(t => t.IsGenericType ? t.GetGenericArguments().Single() : typeof(object));
        }
    

    【讨论】:

    • 如果我没记错的话,我必须从生成的 ObjectContext 子类中更改 Set 以使用此包装器而不是 base.CreateQuery 调用?这不是一个很好的解决方案,因为重新生成会破坏我的更改?还是我误解了您的用法示例?
    • 嗨,你能提供“ExpressionTreeTranslator”吗?我猜这是表达式树访问者模式的实现?
    • @First 评论:对,你包装了 CreateQuery 调用。我有自己的发电机,所以我没有麻烦。我也有一个自己的通用 GetQuery 方法,它创建正确的 EF 查询并包装它。我会发布那个方法。 @Second:您可以在此处找到查询翻译器:msdn.microsoft.com/en-us/library/bb882521.aspx 或此处blogs.msdn.com/mattwar/archive/2007/07/31/…
    • 好的,所以这只是这些类的重命名,这就是我的想法。但是你能提供名为 FindElementTypes() 的扩展方法吗?使用谷歌也找不到那个。
    • 很抱歉,但我无法让它真正起作用,EF 提供者不喜欢带有 QueryTranslater 的查询。 -- System.NotSupportedException:无法创建类型为“QueryTranslator`1”的常量值。在此上下文中仅支持原始类型(“例如 Int32、String 和 Guid”)..
    【解决方案3】:

    只是想添加到 Arthur 的示例中。

    正如 Arthur 警告的那样,他的 GetObjectQuery() 方法中的一个错误。

    它使用 typeof(T).Name 作为 EntitySet 的名称创建基本查询。

    EntitySet 名称与类型名称完全不同。

    如果您使用的是 EF 4,您应该这样做:

    public override IQueryable<T> GetObjectQuery<T>()
    {
        if (!_table.ContainsKey(type))
        {
            _table[type] = new QueryTranslator<T>(
                _ctx.CreateObjectSet<T>();
        }
    
        return (IQueryable<T>)_table[type];
    }
    

    只要您没有每个类型的多个实体集 (MEST),这非常有效。

    如果您使用的是 3.5,您可以使用我的 Tip 13 中的代码来获取 EntitySet 名称并将其输入如下:

    public override IQueryable<T> GetObjectQuery<T>()
    {
        if (!_table.ContainsKey(type))
        {
            _table[type] = new QueryTranslator<T>(
                _ctx.CreateQuery<T>("[" + GetEntitySetName<T>() + "]"));
    
        } 
        return (IQueryable<T>)_table[type];
    }
    

    希望对你有帮助

    亚历克斯

    Entity Framework Tips

    【讨论】:

    • tnx 修复错误 ;) 但也许更改非 url 短服务的链接以保持 StackOverflow 数据库清洁?
    • 当然...如果我的博客有良好的推荐统计功能。
    • 哦,我能理解。我不知道 blogs.msdn.com 上是否允许使用 Google Analytics,但我认为它们可以很好地概述您的访问者。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2017-12-04
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多