Pull to refresh

Вычисление производных с помощью шаблонов на С++

Reading time3 min
Views50K
Навеяно постом. Попутно получилось что-то похожее на собственную реализацию лямбда-выражений :) С возможностью вычисления производной еще на этапе компиляции. Для задания функции можно использовать операторы +, -, *, /, а также ряд стандартных математических функций.
Sqr — возведение в квадрат
Sqrt — квадратный корень
Pow — возведение в действительную степень
Exp — показательная функция
Log — логарифм
Sin, Cos, Tg, Ctg, Asin, Acos, Atg, Actg — тригонометрия

Производная вычисляется с помощью функции derivative. На входе у нее — функтор, на выходе — тоже. Для того, чтобы производная была вычислена точно, на вход должен подаваться функтор, заданный с помощью особого синтаксиса. Синтаксис — интуитивно понятный (по крайней мере, я на это надеюсь). Если на вход derivative подать любой другой функтор или лямбду с подходящей сигнатурой (double -> double), то производная будет вычислена приближенно.
Пример:
#include <iostream>
#include "CrazyMath.h"

using namespace std;
using namespace CrazyMath;

auto global = Tg(X) + Ctg(X) + Asin(X) * Acos(X) - Atg(X) / Actg(X);
auto d_global = derivative(global);

int main()
{
	auto f1 = (Pow(X, 3) + 2 * Sqr(X) - 4 * X + 1 / Sqrt(1 - Sqr(X))) * (Sin(X) + Cos(X) * (Log(5, X) - Exp(2, X)));
	auto f2 = derivative(f1) * Sqrt(X - Tg(X / 4));
	auto f3 = [](double x) -> double { return sin(x); };
	auto df1 = derivative(f1);
	auto df2 = derivative(f2);
	auto df3 = derivative(f3);
	
	cout << "f(x)\t\tf'(x)" << endl;
	cout << f1(0.5) << " \t" << df1(0.5) << endl;
	cout << f2(0.5) << " \t" << df2(0.5) << endl;
	cout << f3(0) << " \t" << df3(0) << endl;
	cout << global(0.5) << " \t" << d_global(0.5) << endl;
	
	char temp[4];
	cout << "\nPress ENTER to exit..." << endl;
	cin.getline(temp, 3);
	return 0;
}

Работает это так:
//  CrazyMath.h, отрывок

//---------------------------------------------------
// base functions

class Const {
public:
	typedef Const Type;
	Const(double x)	: m_const(x) {}
	double operator()(double) {}
private:
	double m_const;
};

class Simple {
public:
	typedef Simple Type;
	double operator()(double x)
	{
		return x;
	}
};

template <class F1, class F2>
class Add {
public:
	typedef typename Add<F1, F2> Type;
	Add(const F1& f1, const F2& f2)	: m_f1(f1), m_f2(f2) {}
	double operator()(double x)
	{
		return m_f1(x) + m_f2(x);
	}
	F1 m_f1;
	F2 m_f2;
};

//---------------------------------------------------
// helpers

template <class F1, class F2>
Add<F1, F2> operator+(const F1& f1, const F2& f2)
{
	return Add<F1, F2>(f1, f2);
}

template <class F>
Add<F, Const> operator+(double value, const F& f)
{
	return Add<F, Const>(f, Const(value));
}

template <class F>
Add<F, Const> operator+(const F& f, double value)
{
	return Add<F, Const>(f, Const(value));
}

// other helpers ...

//---------------------------------------------------
// derivatives

template <class F>
class Derivative {
public:
	Derivative(const F& f, double dx = 1e-3) : m_f(f), m_dx(dx) {}
	double operator()(double x)
	{
		return (m_f(x + m_dx) - m_f(x)) / m_dx;
	}
	F m_f;
	double m_dx;
	typedef std::function<double (double)> Type;
	Type expression()
	{
		return [this](double x) -> double
		{
			return (m_f(x + m_dx) - m_f(x)) / m_dx;
		};
	}
};

template<>
class Derivative<Const> {
public:
	typedef Const Type;
	Derivative<Const> (Const) {}
	double operator()(double)
	{
		return 0;
	}
	Type expression()
	{
		return Const(0);
	}
};

template<>
class Derivative<Simple> {
public:
	typedef Const Type;
	Derivative<Simple> (Simple) {}
	double operator()(double)
	{
		return 1;
	}
	Type expression()
	{
		return Const(1);
	}
};

template <class F1, class F2>
class Derivative< Add<F1, F2> > {
public:
	Derivative< Add<F1, F2> > (const Add<F1, F2>& f)
		: m_df1(f.m_f1), m_df2(f.m_f2)
	{
	}
	double operator()(double x)
	{
		return m_df1(x) + m_df2(x);
	}
	Derivative<F1> m_df1;
	Derivative<F2> m_df2;
	typedef typename Add<typename Derivative<F1>::Type, typename Derivative<F2>::Type> Type;
	Type expression()
	{
		return m_df1.expression() + m_df2.expression();
	}
};

// other derivatives ...

template <class F>
typename Derivative<F>::Type derivative(F f)
{
	return Derivative<F>(f).expression();
}

extern Simple X;



Файл CrazyMath.h получился достаточно большим, поэтому полностью включать его в статью смысла нет. Те, кому интересно, могут скачать исходники с Github'а

UPD Добавил в класс Derivative метод expression и typedef соответствующий типу, который возвращает метод expression. Метод expression возвращает функтор, пригодный для дальнейшего дифференцирования. Однако при вычислении производных 2-го, 3-го и более высоких порядков размер выражения быстро растет, поэтому компиляция может затянуться.
Tags:
Hubs:
Total votes 35: ↑34 and ↓1+33
Comments22

Articles