Query filter

This scenario demonstrate how to get rows from database with Entity Framework Core's query filter feature. This is for Multi-Tenancy scenarios.

Scenario Prototype

public interface IQueryFilterScenario<TStudent>
    where TStudent : class, IStudent, new()
{
    IList<TStudent> GetStudents(int schoolId);
}

Entity Framework Core

A query filter is setup in the OnModelCreating event of the DB Context.

protected override void OnModelCreating(ModelBuilder modelBuilder)
{
    if (modelBuilder == null)
        throw new ArgumentNullException(nameof(modelBuilder), $"{nameof(modelBuilder)} is null.");

    base.OnModelCreating(modelBuilder);

    //The ISchool interface is applied to all entities that need to be filtered by tenant.
    modelBuilder.SetQueryFilterOnAllEntities<ISchool>(s => s.SchoolId == SchoolId);
}

Normally this would be done for each table. But by placing a marker interface (ISchool in this case) and some helper methods, it can be automatically applied to all tables.

// Copyright (c) 2020 Phil Haack, GitHub: haacked
// https://gist.github.com/haacked/febe9e88354fb2f4a4eb11ba88d64c24

using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata.Builders;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query;
using System;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;

namespace Recipes.EntityFrameworkCore.QueryFilter.Helpers
{
    public static class ModelBuilderExtensions
    {
        private static readonly MethodInfo SetQueryFilterMethod = typeof(ModelBuilderExtensions)
            .GetMethods(BindingFlags.NonPublic | BindingFlags.Static)
            .Single(t => t.IsGenericMethod && t.Name == nameof(SetQueryFilter));

        public static void SetQueryFilterOnAllEntities<TEntityInterface>(this ModelBuilder modelBuilder,
            Expression<Func<TEntityInterface, bool>> filterExpression)
        {
            if (modelBuilder == null)
                throw new ArgumentNullException(nameof(modelBuilder), $"{nameof(modelBuilder)} is null.");

            foreach (var type in modelBuilder.Model.GetEntityTypes()
                .Where(t => t.BaseType == null)
                .Select(t => t.ClrType)
                .Where(t => typeof(TEntityInterface).IsAssignableFrom(t)))
            {
                modelBuilder.SetEntityQueryFilter<TEntityInterface>(type, filterExpression);
            }
        }

        private static void SetEntityQueryFilter<TEntityInterface>(this ModelBuilder builder,
            Type entityType, Expression<Func<TEntityInterface, bool>> filterExpression)
        {
            SetQueryFilterMethod
                .MakeGenericMethod(entityType, typeof(TEntityInterface))
                .Invoke(null, new object[] { builder, filterExpression });
        }

        private static void SetQueryFilter<TEntity, TEntityInterface>(this ModelBuilder builder,
            Expression<Func<TEntityInterface, bool>> filterExpression)
            where TEntityInterface : class
            where TEntity : class, TEntityInterface
        {
            var concreteExpression = filterExpression.Convert<TEntityInterface, TEntity>();
            builder.Entity<TEntity>().AddQueryFilter(concreteExpression);
        }

        private static void AddQueryFilter<T>(this EntityTypeBuilder entityTypeBuilder, Expression<Func<T, bool>> expression)
        {
            var parameterType = Expression.Parameter(entityTypeBuilder.Metadata.ClrType);
            var expressionFilter = ReplacingExpressionVisitor.Replace(
                expression.Parameters.Single(), parameterType, expression.Body);

            var internalEntityTypeBuilder = entityTypeBuilder.GetInternalEntityTypeBuilder();
            var queryFilter = internalEntityTypeBuilder?.Metadata.GetQueryFilter();
            if (queryFilter != null)
            {
                var currentExpressionFilter = ReplacingExpressionVisitor.Replace(
                    queryFilter.Parameters.Single(), parameterType, queryFilter.Body);
                expressionFilter = Expression.AndAlso(currentExpressionFilter, expressionFilter);
            }

            var lambdaExpression = Expression.Lambda(expressionFilter, parameterType);
            entityTypeBuilder.HasQueryFilter(lambdaExpression);
        }

        private static InternalEntityTypeBuilder? GetInternalEntityTypeBuilder(this EntityTypeBuilder entityTypeBuilder)
        {
            var internalEntityTypeBuilder = typeof(EntityTypeBuilder)
                .GetProperty("Builder", BindingFlags.NonPublic | BindingFlags.Instance)?
                .GetValue(entityTypeBuilder) as InternalEntityTypeBuilder;

            return internalEntityTypeBuilder;
        }
    }
}

// Copyright (c) 2020 Phil Haack, GitHub: haacked
// https://gist.github.com/haacked/febe9e88354fb2f4a4eb11ba88d64c24

using System;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;

namespace Recipes.EntityFrameworkCore.QueryFilter.Helpers
{
    public static class ExpressionExtensions
    {
        public static Expression<Func<TTarget, bool>> Convert<TSource, TTarget>(this Expression<Func<TSource, bool>> root)
        {
            var visitor = new ParameterTypeVisitor<TSource, TTarget>();
            return (Expression<Func<TTarget, bool>>)visitor.Visit(root);
        }

        private class ParameterTypeVisitor<TSource, TTarget> : ExpressionVisitor
        {
            private ReadOnlyCollection<ParameterExpression>? _parameters;

            protected override Expression VisitParameter(ParameterExpression node)
            {
                return _parameters?.FirstOrDefault(p => p.Name == node.Name)
                    ?? (node.Type == typeof(TSource) ? Expression.Parameter(typeof(TTarget), node.Name) : node);
            }

            protected override Expression VisitLambda<T>(Expression<T> node)
            {
                _parameters = VisitAndConvert(node.Parameters, "VisitLambda");
                return Expression.Lambda(Visit(node.Body), _parameters);
            }
        }
    }
}

This example shows it being used.

public class QueryFilterScenario : IQueryFilterScenario<Student>
{
    private readonly Func<int, OrmCookbookContextWithQueryFilter> CreateFilteredDbContext;

    public QueryFilterScenario(Func<int, OrmCookbookContextWithQueryFilter> dBContextFactory)
    {
        CreateFilteredDbContext = dBContextFactory;
    }

    public IList<Student> GetStudents(int schoolId)
    {
        using (var context = CreateFilteredDbContext(schoolId))
        {
            //SchoolId filter is automatically applied
            return context.Students.OrderBy(s => s.Name).ToList();
        }
    }
}