Как стать автором
Обновить

Compilation of math functions into Linq.Expression

Время на прочтение12 мин
Количество просмотров5.6K

Hello. In this article, I want to demonstrate how I implemented compilation of mathematical (both numerical and logical) expressions into a delegate using Linq Expression.

Navigation: Problem · Compilation rules · Compiler · Default rules · Fancy API · Performance · Examples of compilation · Conclusion · References

What do we want?

We want to compile an expression into a function of an arbitrary number of arguments of an arbitrary type, not only numeric, but also boolean. For instance,

var func = "x + sin(y) + 2ch(0)".Compile<Complex, double, Complex>("x", "y");
Console.WriteLine(func(new(3, 4), 1.2d));
>>> (5.932039085967226, 4)
  
var func = "x > 3 and (a implies b)".Compile<int, bool, bool, bool>("x", "a", "b");
Console.WriteLine(func(4, false, true));
>>> True

What do we have already?

Since I am doing this within an existing symbolic algebra library, we will immediately proceed to compilation, already having a parser and an expression tree.

We have base class Entityclass and its type hierarchy, looking like this:

Entity
|
+--Operators
  |
  +--Sumf
  |
  +--Minusf
  |
  ...
+--Trigonometry
  |
  +--Sinf
  |
  +--Cosf
  |
  ...
+--Discrete
  |
  +--Andf
  |
  +--Lessf
  |

An expression tree is just a graph where the children of a node are the operands of an operator / function.

Each type is either abstract (only used to generalize other types) or sealed. The latter can be a real operator / function / constant / other entity that occurs in an expression (be it plus, sine, conjunction, number, set, etc.).

For example, this is how the plus operator is defined.

Compilation protocol

This is how I named the interface / data class that will define how Entity's subtypes are mapped to Linq.Expression's structures, operators, or functions. Since we do not know what types the user will request as input and output, the user will have to provide this information.

This is what it will look like:

public sealed record CompilationProtocol
{
    public Func<Entity, Expression> ConstantConverter { get; init; }

    public Func<Expression, Expression, Entity, Expression> BinaryNodeConverter { get; init; }

    public Func<Expression, Entity, Expression> UnaryNodeConverter { get; init; }

    public Func<IEnumerable<Expression>, Entity, Expression> AnyArgumentConverter { get; init; }
}

Note: We will implement ConstantConverter , BinaryNodeConverterand UnaryNodeConverter at the end of the article.

Our compiler will call all these lambdas internally when converting unary nodes, binary nodes, constant nodes, and multiple nodes into Linq.Expression.

That is, now we know by what rules we will transform each node. Now it's time to write the "compiler" itself, or rather the algorithm that will build a tree.

Compiler

The prototype for the method we want to write looks like this:

internal static TDelegate Compile<TDelegate>(
            Entity expr, 
            Type? returnType,
            CompilationProtocol protocol,
            IEnumerable<(Type type, Variable variable)> typesAndNames
            ) where TDelegate : Delegate
  1. Entity expr is an expression that we compile.

  2. Type? returnTypeis the return type. We can't extract it from the delegate type, so we have to pass it as a separate argument.

  3. CompilationProtocol protocol - this is the protocol of rules by which we will transform each node of the expression.

  4. IEnumerable<(Type type, Variable variable)> typesAndNamesis a set of type-variable tuples that the user will to pass to the resulting delegate. For example, if instead of x we ​​want to substitute an integer, and instead of y to pass a complex one, we will writenew[] { (typeof(int), "x"), (typeof(Complex), "y") }

And here is the implementation of the method:

internal static TDelegate Compile<TDelegate>(Entity expr, Type? returnType, CompilationProtocol protocol, IEnumerable<(Type type, Variable variable)> typesAndNames) where TDelegate : Delegate
{
  // We keep local variables for every subtree here
    var subexpressionsCache = typesAndNames.ToDictionary(c => (Entity)c.variable, c => Expression.Parameter(c.type));
  // Unlike local variables, these parameters are function's arguments
    var functionArguments = subexpressionsCache.Select(c => c.Value).ToArray(); // copying
  // We will save local variables to this list
    var localVars = new List<ParameterExpression>();
  // That is a list of assignments of subtrees to local variables
    var variableAssignments = new List<Expression>();

  // Building a tree with the provided data
    var tree = BuildTree(expr, subexpressionsCache, variableAssignments, localVars, protocol);
  // Then, we create Expression.Block passing the local variables, referenced in the tree
    var treeWithLocals = Expression.Block(localVars, variableAssignments.Append(tree));
  // If the user passed returnType, then we try to cast the expression into it
    Expression entireExpresion = returnType is not null ? Expression.Convert(treeWithLocals, returnType) : treeWithLocals;
  // Create a lambda with the given expression and the passed function's arguments
    var finalLambda = Expression.Lambda<TDelegate>(entireExpresion, functionArguments);

  // finally, compile into a delegate
    return finalLambda.Compile();
}

Its main purpose is to create the necessary containers for the cache of subtrees, local variables, and some other things. The most interesting function here is BuildTree. It will build a linq expression tree from Entity. This is what its prototype looks like:

internal static Expression BuildTree(
    Entity expr, 
    Dictionary<Entity, ParameterExpression> cachedSubexpressions, 
    List<Expression> variableAssignments, 
    List<ParameterExpression> newLocalVars,
    CompilationProtocol protocol)
More about the BuildTree's arguments
  1. Entity expr - an expression or subexpression to build a tree from.

  2. Dictionary<Entity, ParameterExpression> cachedSubexpressions - a dictionary of cached subtrees (that is, those that are already written to existing local variables).

  3. List<Expression> variableAssignments - a list of assignments of unique subtrees to local variables.

  4. List<ParameterExpression> newLocalVars- local variables, created by BuildTree. (for storing results of subtrees).

  5. CompilationProtocol protocol- the rules by which we transform a node Entityinto a node Linq.Expression. It remains unchanged and is simply passed on to all calls of BuildTree.

And here is the most important function - BuildTree:

internal static Expression BuildTree(Entity expr, ...)
{
    // if this subtree was already processed, we return a local variable to which
    // we assigned the subtree's value
    if (cachedSubexpressions.TryGetValue(expr, out var readyVar))
        return readyVar;

    Expression subTree = expr switch
    {
      ...
        
        // A constant goes processed through ConstantConverter
        Entity.Boolean or Number => protocol.ConstantConverter(expr),

        // Same mechanism with unary, binary and n-ary nodes
        IUnaryNode oneArg
            => protocol.UnaryNodeConverter(BuildTree(oneArg.NodeChild, ...), expr),

        IBinaryNode twoArg
            => protocol.BinaryNodeConverter(
                BuildTree(twoArg.NodeFirstChild, ...), 
                BuildTree(twoArg.NodeSecondChild, ...), 
                expr),

        var other => protocol.AnyArgumentConverter(
                other.DirectChildren.Select(c => BuildTree(c, ...)), 
            expr)
    };

    // we create a local variable for this subtree
    var newVar = Expression.Variable(subTree.Type);
    
    // add an instruction like var5 = subTree
    variableAssignments.Add(Expression.Assign(newVar, subTree));
    
    // match the subtree to the variable
    cachedSubexpressions[expr] = newVar;
    
    // trach a newly created variable
    newLocalVars.Add(newVar);
    
    return newVar;
}

I have omitted large chunks of code for the sake of readability. For each subtree, we either immediately return the local variable corresponding to this expression, or we build a new Linq.Expression tree, store it in a new local variable, and return it.

Actually, that's about it, the compiler is implemented. But we have not implemented any rules for converting our expressions to Linq.Expression, because we expect these rules to be provided by the user. But why not provide some default rules for built-in types?

The rest of the article will be about creating a default protocol.

Assumptions

The method itself Compile<TDelegate>(Entity, Type?, CompilationProtocol, IEnumerable<(Type, Variable)>)will be provided to the user, but it is obvious that this is a very long and clumsy construction, you will have to write a huge amount of code describing the transformation of each node and constant, and the method declaration itself is quite long and unclear.

So we can provide a default compilation protocol, which will work with some built-in primitives ( boolintlongfloatdoubleComplexBigInteger).

ConstantConverter:

This rule converts a constant from Entity to Linq.Constant and looks like this:

public static Expression ConverterConstant(Entity e)
    => e switch
    {
        Number n => Expression.Constant(DownCast(n)),
        Entity.Boolean b => Expression.Constant((bool)b),
        _ => throw new AngouriBugException("Undefined constant type")
    };

The Entity.Number is casted to a number depending on its type, a boolean constant is unconditionally converted into bool.

More about DownCast

This function converts Entity.Number to some of the built-in types and is implemented as follows:

private static object DownCast(Number num)
{
    if (num is Integer)
        return (long)num;
    if (num is Real)
        return (double)num;
    if (num is Number.Complex)
        return (System.Numerics.Complex)num;
    throw new InvalidProtocolProvided("Undefined type, provide valid compilation protocol");
}

Returns objectbecause this is exactly what Expression.Constant expects as an argument. This is what we would like to see: we can cast the number to any class, and it's still a constant.

UnaryNodeConverter:

This protocol's rule is a delegate that converts a node with one argument to a Linq.Expression.

public static Expression OneArgumentEntity(Expression e, Entity typeHolder)
  => typeHolder switch
    {
        Sinf =>         Expression.Call(GetDef("Sin", 1, e.Type), e),
        ...
        Cosecantf =>    Expression.Call(GetDef("Csc", 1, e.Type), e),

        Arcsinf =>      Expression.Call(GetDef("Asin", 1, e.Type), e),
        ...
        Arccosecantf => Expression.Call(GetDef("Acsc", 1, e.Type), e),

        Absf =>         Expression.Call(GetDef("Abs", 1, e.Type), e),
        Signumf =>      Expression.Call(GetDef("Sgn", 1, e.Type), e),

        Notf =>         Expression.Not(e),

        _ => throw new AngouriBugException("A node seems to be not added")
    };

I've omitted some big blocks (all code here ). So, here we consider the possible types of our node, and for each we select the desired overload of a function. GetDeffinds the function we want by name.

About GetDef

At first I thought of calling all the necessary functions from modules Mathand Complex. I had to write a lot of conditional statments everywhere, consider cases when I should use MathComplex, and BigInteger. Another issue is that Math does not have some overloads, for instance, int Pow(int, int).

Therefore, I created the MathAllMethods class (in T4), where I created all the necessary overloads for all the necessary functions.

GetDefsearches for the required method with the given number of arguments and type in this class. This allowed us to get rid of the spaghetti code and write down all calls to the necessary functions by these types in a beautiful and concise manner.

BinaryNodeConverter:

This rule converts a two-argument node to a Linq.Expression.

public static Expression TwoArgumentEntity(Expression left, Expression right, Entity typeHolder)
{
    var typeToCastTo = MaxType(left.Type, right.Type);
    if (left.Type != typeToCastTo)
        left = Expression.Convert(left, typeToCastTo);
    if (right.Type != typeToCastTo)
        right = Expression.Convert(right, typeToCastTo);
    return typeHolder switch
    {
        Sumf => Expression.Add(left, right),
        ...
        Andf => Expression.And(left, right),
        ...
        Lessf => Expression.LessThan(left, right),
        ...
        _ => throw new AngouriBugException("A node seems to be not added")
    };
}

There is an upcast. Since we may have two expressions of different types, we want to find the most primitive type to which both operands are cast. To do this, I assigned a level to every type:

Complex:   10
double:     9
float:      8
long:       8
BigInteger: 8
int:        7

If the types are the same, MaxTypewill return one of them. For example,MaxType(int, int) -> int.

If the level of the operand A's type is higher than that of operand B's type, then B is casted to A. For example, MaxType(long, double) -> double.

If the levels are equal, but the types are not, then the closest common vertex is found, that is, any such type whose level is higher by 1. For example, MaxType(long, float) -> double.

Operands, if necessary, are cast to the selected type, and then we simply find the required overload or operator. For example, for Sumfwe choose Expression.Add, and for conjunction, Andfwill be turned into Expression.And.

What happened?

Great, we have defined all the necessary rules for our protocol. Now, during creation, we can pass these rules to the required protocol properties.

Fancy API

This is the low-level version of what we want to see in the final API:

public TDelegate Compile<TDelegate>(CompilationProtocol protocol, Type returnType, IEnumerable<(Type type, Variable variable)> typesAndNames) where TDelegate : Delegate

It is inconvenient and requires a lot of work to be performed before we can call one. But we can pass our default protocol AND overload this method for delegates from one argument, two, three, and so on. Since I already assign our rules to the default protocol properties, when passing the protocol, we simply create an instance of it. The second is a little more complicated - I solved it by generating the code using the T4 Text Template. Here's an example of the generated code:

// specifying the input and output types                            Here we pass the variables corresponding to the types
public Func<TIn1, TIn2, TIn3, TOut> Compile<TIn1, TIn2, TIn3, TOut>(Variable var1, Variable var2, Variable var3)
                                     // The delegate we want to get     We know the output type   new() as we create default rules  
            => IntoLinqCompiler.Compile<Func<TIn1, TIn2, TIn3, TOut>>(this, typeof(TOut),         new(), 
                new[] { (typeof(TIn1), var1), (typeof(TIn2), var2) , (typeof(TIn3), var3)  });

It's in the source .

T4-template text to generate
<# for (var i = 1; i <= 8; i++) { #>
        public Func<<# for(var t=1;t<=i;t++){ #>TIn<#= t #>, <# } #>TOut> Compile<<# for(var t=1;t<=i;t++){ #>TIn<#= t #>, <# } #>TOut>(Variable var1<# for(var t=2; t<=i; t++){ #>, Variable var<#= t #><# } #>)
            => IntoLinqCompiler.Compile<Func<<# for(var t=1;t<=i;t++){ #>TIn<#= t #>, <# } #>TOut>>(this, typeof(TOut), new(), 
                new[] { (typeof(TIn1), var1)<# for(var t=2;t<=i;t++){ #>, (typeof(TIn<#= t #>), var<#= t #>) <# } #> });
<# } #>

We create extension methods in the same way. Here is an example of generated code:

public static Func<TIn1, TIn2, TOut> Compile<TIn1, TIn2, TOut>(this string @this, Variable var1, Variable var2)
    => IntoLinqCompiler.Compile<Func<TIn1, TIn2, TOut>>(@this, typeof(TOut), new(), 
        new[] { (typeof(TIn1), var1), (typeof(TIn2), var2)  });

Now we need to measure the performance.

Performance

BenchNormalSimple is a simple lambda, declared right in the code.

BenchMySimple is the same lambda, but compiled by me.

BenchNormalComplicated is a big fat lambda with a bunch of identical subtrees, declared right in the code.

BenchmyComplicated - the same lambda, but compiled by me.

|                 Method |       Mean |    Error |   StdDev |
|----------------------- |-----------:|---------:|---------:|
|      BenchNormalSimple |   189.1 ns |  3.75 ns |  5.83 ns |
|          BenchMySimple |   195.7 ns |  3.92 ns |  5.50 ns |
| BenchNormalComplicated | 1,383.0 ns | 26.82 ns | 35.80 ns |
|     BenchMyComplicated |   293.6 ns |  5.74 ns |  8.77 ns |

Simple things work equally fast, and where there are identical subtrees, my compilation beats the normal one. In general, the result is predictable, and there is nothing extraordinary here.

The benchmark is here .

Examples of work

var func = "sin(x)".Compile<double, double>("x");
Console.WriteLine(func(Math.PI / 2));
>>> 1

var func1 = "a > b".Compile<float, int, bool>("a", "b");
Console.WriteLine(func1(5.4f, 4));
Console.WriteLine(func1(4f, 4));
>>> True
>>> False

var cr = new CompilationProtocol()
{ 
    ConstantConverter = ent => Expression.Constant(ent.ToString()),
    BinaryNodeConverter = (a, b, t) => t switch
    {
        Sumf => Expression.Call(typeof(string)
            .GetMethod("Concat", new[] { typeof(string), typeof(string) }) ?? throw new Exception(), a, b),
        _ => throw new Exception()
    }
};
var func2 = "a + b + c + 1234"
    .Compile<Func<string, string, string, string>>(
        cr, typeof(string), 
        
        new[] { 
            (typeof(string), Var("a")), 
            (typeof(string), Var("b")), 
            (typeof(string), Var("c")) }

        );
Console.WriteLine(func2("White", "Black", "Goose"));
>>> WhiteBlackGoose1234

(The last example is an example of how the user themselves declares the protocol instead of using the existing one. The result of this compilation is a lambda, concatenating strings).

Conclusion

  1. Linq.Expression is truly a brilliant thing.

  2. In this short article, we well implemented a compilation process for mathematical expressions.

  3. To ensure that the types are arbitrary, we came up with a protocol for translating an expression.

  4. To avoid making the user write the same code over and over again, we offer a number of overloads with a default protocol that works great with a number of built-in types. This protocol will automatically upcast types in binary operators and functions to the nearest generic type, if necessary.

Such functionality can be used where you would like to get a fast-working mathematical function from a string at runtime. Maybe the community will come up with some other useful application. By the way, I have already used runtime compilation in another project , in which dynamic compilation of nested loops allowed us to avoid recursion (and save precious nanoseconds).

Thank you for your attention! The next article will probably be about symbolic limits or parsing from a string.

References

  1. GitHub of the AngouriMath project , within which I developed the compilation

  2. Compilation's code here

  3. Compilation's tests can be found here

Теги:
Хабы:
Всего голосов 4: ↑4 и ↓0+4
Комментарии1

Публикации

Истории

Работа

Ближайшие события