Chasing DevOps

A blog about software development, DevOps, and delivering value.

Extending SQL Generation in Entity Framework Core

I’ve been working on a project where we’re migrating a data access layer from an old ORM to Entity Framework Core. The old ORM has some features that EF Core doesn’t support, so I was tasked with seeing if the SQL generation in Entity Framework Core’s SQL Server provider could be extended to support what we needed. Turns out it can, but there’s nobody who has really posted details on how to do it.

We needed a way to build an extension method that could affect the SQL that EF Core generated when queries. Queries against the DbContext would look something like this:

context.Foo.Where(f => f.Bar == "Test").WithSqlTweaks().ToList()

Where WithSqlTweaks() modifies the SQL that EF Core generates in some way. Apart from this vaguely worded github issue there really wasn’t much guidance on how to do this. I ended up cloning the source code, building a simple console app to experiment with, and stepping through the EF Core internals to figure out how to extend it. The good news is that it can be done without needing to build a custom version of EF Core. The bad news? Well… you’ll see.

Step 0: Don’t follow this guide!

Seriously. Don’t. Run away.

This is a bad idea. To do what I’m trying to do you have to extend a ton of classes that were never meant to be extended. On practically every public method the EF Core team has left nice warning for prying eyes:

This API supports the Entity Framework Core infrastructure and is not intended to be used directly from your code. This API may change or be removed in future releases.

That’s the developer equivalent of Here be dragons. This will be a maintenance nightmare. Your code will most certainly break every time a new version of EF Core is released. In fact the specific tweaks I made work fine with the latest release (2.0.1) but they’re already broken in the dev branch. And the release is only three months old.

There is almost certainly a better way to accomplish what you need. But if you are absolutely convinced this is necessary (it’s not) then proceed on.

Step 1: Building the Extension Method

The first thing you need is an IQueryable extension method (WithSqlTweaks() in my example above). This method doesn’t really contain any logic that manipulates the query in it. Calling the extension method will basically place an annotation on the query which we can then check for when we’re generating the SQL and make changes as necessary. There are four things necessary to get the extension method to work:

  • The IQueryable extension method.
  • A re-linq result operator.
  • A re-linq result operator expression node.
  • A node type registry factory.

It sounds like a lot, but the good news is these classes really have no logic in them. I copy/pasted most of the implementations that enabled another extension method (Include() in this case) and just changed the names. There are a lot of re-linq/expression things in them but that’s beyond the scope of this post. Here is the static class that contains the extension method and a static property containing the extension method’s MethodInfo that will be used later on:

public static class IQueryableExtensions
{
internal static readonly MethodInfo WithSqlTweaksMethodInfo
= typeof(IQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(WithSqlTeaks));
public static IQueryable<TEntity> WithSqlTweaks<TEntity>(this IQueryable<TEntity> source) where TEntity : class
{
return
source.Provider is EntityQueryProvider
? source.Provider.CreateQuery<TEntity>(
Expression.Call(
instance: null,
method: WithSqlTweaksMethodInfo.MakeGenericMethod(typeof(TEntity)),
arguments: source.Expression))
: source;
}
}

Here’s the result operator. This is once again a copy/paste with no custom logic. Notice it implements the IQueryAnnotation interface. This is the annotation we’ll look for in future steps when we actually generate the SQL.

internal class WithSqlTweaksResultOperator : SequenceTypePreservingResultOperatorBase, IQueryAnnotation
{
public IQuerySource QuerySource { get; set; }
public QueryModel QueryModel { get; set; }
public override ResultOperatorBase Clone(CloneContext cloneContext)
=> new WithSqlTweaksResultOperator();
public override StreamedSequence ExecuteInMemory<T>(StreamedSequence input) => input;
public override void TransformExpressions(Func<Expression, Expression> transformation)
{
}
}

The expression node is what is used to register the result operator with EF Core’s linq provider. It exposes a SupportedMethods property that has the MethodInfo for our extension method above. This will be used to register the extension method with the linq provider. The Create() method returns an instance of our result operator.

internal class WithSqlTweaksExpressionNode : ResultOperatorExpressionNodeBase
{
public static readonly IReadOnlyCollection<MethodInfo> SupportedMethods = new[]
{
IQueryableExtensions.WithSqlTweaksMethodInfo
};
public WithSqlTweaksExpressionNode(MethodCallExpressionParseInfo parseInfo)
: base(parseInfo, null, null)
{
}
protected override ResultOperatorBase CreateResultOperator(ClauseGenerationContext clauseGenerationContext)
=> new WithSqlTweaksResultOperator();
public override Expression Resolve(
ParameterExpression inputParameter,
Expression expressionToBeResolved,
ClauseGenerationContext clauseGenerationContext)
=> Source.Resolve(inputParameter, expressionToBeResolved, clauseGenerationContext);
}

The node type registry factory is what registers other extension methods, such as Distinct() and OrderBy(). The SQL Server provider uses one called DefaultMethodInfoBasedNodeTypeRegistryFactory, so that’s the class that you need to extend. I overrode the Create() method and called a base class method to register my extension.

internal class CustomMethodInfoBasedNodeTypeRegistryFactory : DefaultMethodInfoBasedNodeTypeRegistryFactory
{
public override INodeTypeProvider Create()
{
RegisterMethods(WithSqlTweaksExpressionNode.SupportedMethods, typeof(WithSqlTweaksExpressionNode));
base.Create();
}
}

After that we just need to replace the node type registry factory service with our custom implementation. This will wire up the linq provider to know how to handle our extension method. This is done using the ReplaceService() method on the DbContextOptionsBuilder.

public class MyDbContext : DbContext
{
public DbSet<Person> People { get; set; }
protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
{
optionsBuilder
.UseSqlServer(connectionString)
.ReplaceService<INodeTypeProviderFactory, CustomMethodInfoBasedNodeTypeRegistryFactory>();
base.OnConfiguring(optionsBuilder);
}
}
view raw MyDbContext.cs hosted with ❤ by GitHub

And that’s it. The extension method will now work and apply that query annotation to any query. The next step is to use that annotation to customize the SQL generation.

Step 2: Consume the Query Annotation

The next step is to extend the SelectExpression class. A SelectExpression is passed into the method that creates the SQL query. It also has easy access to the list of IQueryAnnotations that apply to the query. I extended SelectExpression and added a boolean flag that will tell whether to apply the tweaks in the SQL generation. A private method called from both constructors is where the search for the query annotation actually happens. If it’s there it sets the UseSqlTweaks property to true.

internal class CustomSelectExpression : SelectExpression
{
public bool UseSqlTweaks { get; set; }
public CustomSelectExpression(
SelectExpressionDependencies dependencies,
RelationalQueryCompilationContext queryCompilationContext): base(dependencies, queryCompilationContext)
{
SetCustomSelectExpressionProperties(queryCompilationContext);
}
public CustomSelectExpression(
SelectExpressionDependencies dependencies,
RelationalQueryCompilationContext queryCompilationContext,
string alias): base(dependencies, queryCompilationContext, alias)
{
SetCustomSelectExpressionProperties(queryCompilationContext);
}
private void SetCustomSelectExpressionProperties(RelationalQueryCompilationContext queryCompilationContext)
{
// If the WithSqlTweaksResultOperator query annotation exists then set the property to true.
if(queryCompilationContext.QueryAnnotations.Any(a => a.GetType() == typeof(WithSqlTweaksResultOperator)))
{
UseSqlTweaks = true;
}
}
}

In order to get EF Core to use the CustomSelectExpression instead of the default SelectExpression I extended SelectExpressionFactory and replaced the ISelectExpressionFactory service on the DbContext.

internal class CustomSelectExpressionFactory : SelectExpressionFactory
{
public CustomSelectExpressionFactory(SelectExpressionDependencies dependencies)
: base(dependencies)
{
}
public override SelectExpression Create(RelationalQueryCompilationContext queryCompilationContext)
=> new CustomSelectExpression(Dependencies, queryCompilationContext);
public override SelectExpression Create(RelationalQueryCompilationContext queryCompilationContext, string alias)
=> new CustomSelectExpression(Dependencies, queryCompilationContext, alias);
}
public class MyDbContext : DbContext
{
public DbSet<Person> People { get; set; }
protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
{
optionsBuilder
.UseSqlServer(connectionString)
.ReplaceService<INodeTypeProviderFactory, CustomMethodInfoBasedNodeTypeRegistryFactory>()
.ReplaceService<ISelectExpressionFactory, CustomSelectExpressionFactory>();
base.OnConfiguring(optionsBuilder);
}

Now the annotation is there and it’s being passed down to the SQL generator through a CustomSelectExpression. The next and final step is to override the SQL generator.

Step 3: Customize SQL Generation using the CustomSelectExpression

The last step is to extend the SqlServerQuerySqlGenerator and override the VisitSelect() method. Just check for the property on the CustomSelectExpression and customize the SQL as needed. I believe this is the most fragile part. It’s the part that broke my changes between the latest release and the latest development code.

internal class CustomSqlServerQuerySqlGenerator : SqlServerQuerySqlGenerator
{
public CustomSqlServerQuerySqlGenerator(
QuerySqlGeneratorDependencies dependencies,
SelectExpression selectExpression,
bool rowNumberPagingEnabled)
: base(dependencies, selectExpression, rowNumberPagingEnabled)
{
}
public override Expresssion VisitSelect(SelectExpression selectExpression)
{
// other code left out for simplicity
if(selectExpression is CustomSelectExpression)
{
if(((CustomSelectExpression)selectExpression).UseSqlTweaks)
{
// Do SQL tweaks here!
}
}
// other code left out for simplicity
}
}

The factory also needs to be extended and replaced on the DbContext as well.

internal class CustomSqlServerQuerySqlGeneratorFactory : QuerySqlGeneratorFactoryBase
{
private readonly ISqlServerOptions _sqlServerOptions;
public CustomSqlServerQuerySqlGeneratorFactory(
QuerySqlGeneratorDependencies dependencies,
ISqlServerOptions sqlServerOptions) : base(dependencies)
{
_sqlServerOptions = sqlServerOptions;
}
public override IQuerySqlGenerator CreateDefault(SelectExpression selectExpression)
=> new CustomSqlServerQuerySqlGenerator(
Dependencies,
selectExpression,
_sqlServerOptions.RowNumberPagingEnabled);
}
public class MyDbContext : DbContext
{
public DbSet<Person> People { get; set; }
protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
{
optionsBuilder
.UseSqlServer(connectionString)
.ReplaceService<INodeTypeProviderFactory, CustomMethodInfoBasedNodeTypeRegistryFactory>()
.ReplaceService<ISelectExpressionFactory, CustomSelectExpressionFactory>();
.ReplaceService<IQuerySqlGeneratorFactory, CustomSqlServerQuerySqlGeneratorFactory>();
base.OnConfiguring(optionsBuilder);
}
}

And that’s it! Hopefully you stopped at step 0. But if not, be prepared to maintain this long term and don’t get mad at me when it breaks. If you did follow this guide for some reason I’d love to hear why in the comments!

Leave a Reply

Your email address will not be published.

Jesse Barocio

Software developer, DevOps engineer, and productivity tool nut. Continuously improving. Have a question or problem you need solved? Email me!