Compilation of math functions into Linq.Expression
Hello. In this article, I want to demonstrate how I implemented compilation of mathematical (both numerical and logical) expressions into a delegate using Linq Expression.
Navigation: Problem · Compilation rules · Compiler · Default rules · Fancy API · Performance · Examples of compilation · Conclusion · References
What do we want?
We want to compile an expression into a function of an arbitrary number of arguments of an arbitrary type, not only numeric, but also boolean. For instance,
var func = "x + sin(y) + 2ch(0)".Compile<Complex, double, Complex>("x", "y");
Console.WriteLine(func(new(3, 4), 1.2d));
>>> (5.932039085967226, 4)
var func = "x > 3 and (a implies b)".Compile<int, bool, bool, bool>("x", "a", "b");
Console.WriteLine(func(4, false, true));
>>> True
What do we have already?
Since I am doing this within an existing symbolic algebra library, we will immediately proceed to compilation, already having a parser and an expression tree.
We have base class Entity
class and its type hierarchy, looking like this:
Entity
|
+--Operators
|
+--Sumf
|
+--Minusf
|
...
+--Trigonometry
|
+--Sinf
|
+--Cosf
|
...
+--Discrete
|
+--Andf
|
+--Lessf
|
An expression tree is just a graph where the children of a node are the operands of an operator / function.
Each type is either abstract (only used to generalize other types) or sealed. The latter can be a real operator / function / constant / other entity that occurs in an expression (be it plus, sine, conjunction, number, set, etc.).
For example, this is how the plus operator is defined.
Compilation protocol
This is how I named the interface / data class that will define how Entity
's subtypes are mapped to Linq.Expression
's structures, operators, or functions. Since we do not know what types the user will request as input and output, the user will have to provide this information.
This is what it will look like:
public sealed record CompilationProtocol
{
public Func<Entity, Expression> ConstantConverter { get; init; }
public Func<Expression, Expression, Entity, Expression> BinaryNodeConverter { get; init; }
public Func<Expression, Entity, Expression> UnaryNodeConverter { get; init; }
public Func<IEnumerable<Expression>, Entity, Expression> AnyArgumentConverter { get; init; }
}
Note: We will implement ConstantConverter
, BinaryNodeConverter
and UnaryNodeConverter
at the end of the article.
Our compiler will call all these lambdas internally when converting unary nodes, binary nodes, constant nodes, and multiple nodes into Linq.Expression
.
That is, now we know by what rules we will transform each node. Now it's time to write the "compiler" itself, or rather the algorithm that will build a tree.
Compiler
The prototype for the method we want to write looks like this:
internal static TDelegate Compile<TDelegate>(
Entity expr,
Type? returnType,
CompilationProtocol protocol,
IEnumerable<(Type type, Variable variable)> typesAndNames
) where TDelegate : Delegate
Entity expr
is an expression that we compile.Type? returnType
is the return type. We can't extract it from the delegate type, so we have to pass it as a separate argument.CompilationProtocol protocol
- this is the protocol of rules by which we will transform each node of the expression.IEnumerable<(Type type, Variable variable)> typesAndNames
is a set of type-variable tuples that the user will to pass to the resulting delegate. For example, if instead of x we want to substitute an integer, and instead of y to pass a complex one, we will writenew[] { (typeof(int), "x"), (typeof(Complex), "y") }
And here is the implementation of the method:
internal static TDelegate Compile<TDelegate>(Entity expr, Type? returnType, CompilationProtocol protocol, IEnumerable<(Type type, Variable variable)> typesAndNames) where TDelegate : Delegate
{
// We keep local variables for every subtree here
var subexpressionsCache = typesAndNames.ToDictionary(c => (Entity)c.variable, c => Expression.Parameter(c.type));
// Unlike local variables, these parameters are function's arguments
var functionArguments = subexpressionsCache.Select(c => c.Value).ToArray(); // copying
// We will save local variables to this list
var localVars = new List<ParameterExpression>();
// That is a list of assignments of subtrees to local variables
var variableAssignments = new List<Expression>();
// Building a tree with the provided data
var tree = BuildTree(expr, subexpressionsCache, variableAssignments, localVars, protocol);
// Then, we create Expression.Block passing the local variables, referenced in the tree
var treeWithLocals = Expression.Block(localVars, variableAssignments.Append(tree));
// If the user passed returnType, then we try to cast the expression into it
Expression entireExpresion = returnType is not null ? Expression.Convert(treeWithLocals, returnType) : treeWithLocals;
// Create a lambda with the given expression and the passed function's arguments
var finalLambda = Expression.Lambda<TDelegate>(entireExpresion, functionArguments);
// finally, compile into a delegate
return finalLambda.Compile();
}
Its main purpose is to create the necessary containers for the cache of subtrees, local variables, and some other things. The most interesting function here is BuildTree
. It will build a linq expression tree from Entity
. This is what its prototype looks like:
internal static Expression BuildTree(
Entity expr,
Dictionary<Entity, ParameterExpression> cachedSubexpressions,
List<Expression> variableAssignments,
List<ParameterExpression> newLocalVars,
CompilationProtocol protocol)
More about the BuildTree's arguments
Entity expr
- an expression or subexpression to build a tree from.Dictionary<Entity, ParameterExpression> cachedSubexpressions
- a dictionary of cached subtrees (that is, those that are already written to existing local variables).List<Expression> variableAssignments
- a list of assignments of unique subtrees to local variables.List<ParameterExpression> newLocalVars
- local variables, created byBuildTree
. (for storing results of subtrees).CompilationProtocol protocol
- the rules by which we transform a nodeEntity
into a nodeLinq.Expression
. It remains unchanged and is simply passed on to all calls ofBuildTree
.
And here is the most important function - BuildTree
:
internal static Expression BuildTree(Entity expr, ...)
{
// if this subtree was already processed, we return a local variable to which
// we assigned the subtree's value
if (cachedSubexpressions.TryGetValue(expr, out var readyVar))
return readyVar;
Expression subTree = expr switch
{
...
// A constant goes processed through ConstantConverter
Entity.Boolean or Number => protocol.ConstantConverter(expr),
// Same mechanism with unary, binary and n-ary nodes
IUnaryNode oneArg
=> protocol.UnaryNodeConverter(BuildTree(oneArg.NodeChild, ...), expr),
IBinaryNode twoArg
=> protocol.BinaryNodeConverter(
BuildTree(twoArg.NodeFirstChild, ...),
BuildTree(twoArg.NodeSecondChild, ...),
expr),
var other => protocol.AnyArgumentConverter(
other.DirectChildren.Select(c => BuildTree(c, ...)),
expr)
};
// we create a local variable for this subtree
var newVar = Expression.Variable(subTree.Type);
// add an instruction like var5 = subTree
variableAssignments.Add(Expression.Assign(newVar, subTree));
// match the subtree to the variable
cachedSubexpressions[expr] = newVar;
// trach a newly created variable
newLocalVars.Add(newVar);
return newVar;
}
I have omitted large chunks of code for the sake of readability. For each subtree, we either immediately return the local variable corresponding to this expression, or we build a new Linq.Expression tree, store it in a new local variable, and return it.
Actually, that's about it, the compiler is implemented. But we have not implemented any rules for converting our expressions to Linq.Expression
, because we expect these rules to be provided by the user. But why not provide some default rules for built-in types?
The rest of the article will be about creating a default protocol.
Assumptions
The method itself Compile<TDelegate>(Entity, Type?, CompilationProtocol, IEnumerable<(Type, Variable)>)
will be provided to the user, but it is obvious that this is a very long and clumsy construction, you will have to write a huge amount of code describing the transformation of each node and constant, and the method declaration itself is quite long and unclear.
So we can provide a default compilation protocol, which will work with some built-in primitives ( bool
, int
, long
, float
, double
, Complex
, BigInteger
).
ConstantConverter:
This rule converts a constant from Entity
to Linq.Constant
and looks like this:
public static Expression ConverterConstant(Entity e)
=> e switch
{
Number n => Expression.Constant(DownCast(n)),
Entity.Boolean b => Expression.Constant((bool)b),
_ => throw new AngouriBugException("Undefined constant type")
};
The Entity.Number
is casted to a number depending on its type, a boolean constant is unconditionally converted into bool
.
More about DownCast
This function converts Entity.Number to some of the built-in types and is implemented as follows:
private static object DownCast(Number num)
{
if (num is Integer)
return (long)num;
if (num is Real)
return (double)num;
if (num is Number.Complex)
return (System.Numerics.Complex)num;
throw new InvalidProtocolProvided("Undefined type, provide valid compilation protocol");
}
Returns object
because this is exactly what Expression.Constant
expects as an argument. This is what we would like to see: we can cast the number to any class, and it's still a constant.
UnaryNodeConverter:
This protocol's rule is a delegate that converts a node with one argument to a Linq.Expression
.
public static Expression OneArgumentEntity(Expression e, Entity typeHolder)
=> typeHolder switch
{
Sinf => Expression.Call(GetDef("Sin", 1, e.Type), e),
...
Cosecantf => Expression.Call(GetDef("Csc", 1, e.Type), e),
Arcsinf => Expression.Call(GetDef("Asin", 1, e.Type), e),
...
Arccosecantf => Expression.Call(GetDef("Acsc", 1, e.Type), e),
Absf => Expression.Call(GetDef("Abs", 1, e.Type), e),
Signumf => Expression.Call(GetDef("Sgn", 1, e.Type), e),
Notf => Expression.Not(e),
_ => throw new AngouriBugException("A node seems to be not added")
};
I've omitted some big blocks (all code here ). So, here we consider the possible types of our node, and for each we select the desired overload of a function. GetDef
finds the function we want by name.
About GetDef
At first I thought of calling all the necessary functions from modules Math
and Complex
. I had to write a lot of conditional statments everywhere, consider cases when I should use Math
, Complex
, and BigInteger
. Another issue is that Math
does not have some overloads, for instance, int Pow(int, int)
.
Therefore, I created the MathAllMethods class (in T4), where I created all the necessary overloads for all the necessary functions.
GetDef
searches for the required method with the given number of arguments and type in this class. This allowed us to get rid of the spaghetti code and write down all calls to the necessary functions by these types in a beautiful and concise manner.
BinaryNodeConverter:
This rule converts a two-argument node to a Linq.Expression
.
public static Expression TwoArgumentEntity(Expression left, Expression right, Entity typeHolder)
{
var typeToCastTo = MaxType(left.Type, right.Type);
if (left.Type != typeToCastTo)
left = Expression.Convert(left, typeToCastTo);
if (right.Type != typeToCastTo)
right = Expression.Convert(right, typeToCastTo);
return typeHolder switch
{
Sumf => Expression.Add(left, right),
...
Andf => Expression.And(left, right),
...
Lessf => Expression.LessThan(left, right),
...
_ => throw new AngouriBugException("A node seems to be not added")
};
}
There is an upcast
. Since we may have two expressions of different types, we want to find the most primitive type to which both operands are cast. To do this, I assigned a level to every type:
Complex: 10
double: 9
float: 8
long: 8
BigInteger: 8
int: 7
If the types are the same, MaxType
will return one of them. For example,MaxType(int, int) -> int
.
If the level of the operand A's type is higher than that of operand B's type, then B is casted to A. For example, MaxType(long, double) -> double
.
If the levels are equal, but the types are not, then the closest common vertex is found, that is, any such type whose level is higher by 1. For example, MaxType(long, float) -> double
.
Operands, if necessary, are cast to the selected type, and then we simply find the required overload or operator. For example, for Sumf
we choose Expression.Add
, and for conjunction, Andf
will be turned into Expression.And
.
What happened?
Great, we have defined all the necessary rules for our protocol. Now, during creation, we can pass these rules to the required protocol properties.
Fancy API
This is the low-level version of what we want to see in the final API:
public TDelegate Compile<TDelegate>(CompilationProtocol protocol, Type returnType, IEnumerable<(Type type, Variable variable)> typesAndNames) where TDelegate : Delegate
It is inconvenient and requires a lot of work to be performed before we can call one. But we can pass our default protocol AND overload this method for delegates from one argument, two, three, and so on. Since I already assign our rules to the default protocol properties, when passing the protocol, we simply create an instance of it. The second is a little more complicated - I solved it by generating the code using the T4 Text Template. Here's an example of the generated code:
// specifying the input and output types Here we pass the variables corresponding to the types
public Func<TIn1, TIn2, TIn3, TOut> Compile<TIn1, TIn2, TIn3, TOut>(Variable var1, Variable var2, Variable var3)
// The delegate we want to get We know the output type new() as we create default rules
=> IntoLinqCompiler.Compile<Func<TIn1, TIn2, TIn3, TOut>>(this, typeof(TOut), new(),
new[] { (typeof(TIn1), var1), (typeof(TIn2), var2) , (typeof(TIn3), var3) });
It's in the source .
T4-template text to generate
<# for (var i = 1; i <= 8; i++) { #>
public Func<<# for(var t=1;t<=i;t++){ #>TIn<#= t #>, <# } #>TOut> Compile<<# for(var t=1;t<=i;t++){ #>TIn<#= t #>, <# } #>TOut>(Variable var1<# for(var t=2; t<=i; t++){ #>, Variable var<#= t #><# } #>)
=> IntoLinqCompiler.Compile<Func<<# for(var t=1;t<=i;t++){ #>TIn<#= t #>, <# } #>TOut>>(this, typeof(TOut), new(),
new[] { (typeof(TIn1), var1)<# for(var t=2;t<=i;t++){ #>, (typeof(TIn<#= t #>), var<#= t #>) <# } #> });
<# } #>
We create extension methods in the same way. Here is an example of generated code:
public static Func<TIn1, TIn2, TOut> Compile<TIn1, TIn2, TOut>(this string @this, Variable var1, Variable var2)
=> IntoLinqCompiler.Compile<Func<TIn1, TIn2, TOut>>(@this, typeof(TOut), new(),
new[] { (typeof(TIn1), var1), (typeof(TIn2), var2) });
Now we need to measure the performance.
Performance
BenchNormalSimple
is a simple lambda, declared right in the code.
BenchMySimple
is the same lambda, but compiled by me.
BenchNormalComplicated
is a big fat lambda with a bunch of identical subtrees, declared right in the code.
BenchmyComplicated
- the same lambda, but compiled by me.
| Method | Mean | Error | StdDev |
|----------------------- |-----------:|---------:|---------:|
| BenchNormalSimple | 189.1 ns | 3.75 ns | 5.83 ns |
| BenchMySimple | 195.7 ns | 3.92 ns | 5.50 ns |
| BenchNormalComplicated | 1,383.0 ns | 26.82 ns | 35.80 ns |
| BenchMyComplicated | 293.6 ns | 5.74 ns | 8.77 ns |
Simple things work equally fast, and where there are identical subtrees, my compilation beats the normal one. In general, the result is predictable, and there is nothing extraordinary here.
The benchmark is here .
Examples of work
var func = "sin(x)".Compile<double, double>("x");
Console.WriteLine(func(Math.PI / 2));
>>> 1
var func1 = "a > b".Compile<float, int, bool>("a", "b");
Console.WriteLine(func1(5.4f, 4));
Console.WriteLine(func1(4f, 4));
>>> True
>>> False
var cr = new CompilationProtocol()
{
ConstantConverter = ent => Expression.Constant(ent.ToString()),
BinaryNodeConverter = (a, b, t) => t switch
{
Sumf => Expression.Call(typeof(string)
.GetMethod("Concat", new[] { typeof(string), typeof(string) }) ?? throw new Exception(), a, b),
_ => throw new Exception()
}
};
var func2 = "a + b + c + 1234"
.Compile<Func<string, string, string, string>>(
cr, typeof(string),
new[] {
(typeof(string), Var("a")),
(typeof(string), Var("b")),
(typeof(string), Var("c")) }
);
Console.WriteLine(func2("White", "Black", "Goose"));
>>> WhiteBlackGoose1234
(The last example is an example of how the user themselves declares the protocol instead of using the existing one. The result of this compilation is a lambda, concatenating strings).
Conclusion
Linq.Expression
is truly a brilliant thing.In this short article, we well implemented a compilation process for mathematical expressions.
To ensure that the types are arbitrary, we came up with a protocol for translating an expression.
To avoid making the user write the same code over and over again, we offer a number of overloads with a default protocol that works great with a number of built-in types. This protocol will automatically upcast types in binary operators and functions to the nearest generic type, if necessary.
Such functionality can be used where you would like to get a fast-working mathematical function from a string at runtime. Maybe the community will come up with some other useful application. By the way, I have already used runtime compilation in another project , in which dynamic compilation of nested loops allowed us to avoid recursion (and save precious nanoseconds).
Thank you for your attention! The next article will probably be about symbolic limits or parsing from a string.
References
GitHub of the AngouriMath project , within which I developed the compilation
Compilation's code here
Compilation's tests can be found here