Pull to refresh

Учебный фреймворк на Java по глубокому обучению

Reading time 2 min
Views 6.8K

Недавно мы выпустили первую версию нового фреймворка по глубокому обучению DeepJava (DJ) 0.01.


Основная цель фреймворка, по крайней мере, на текущий момент, чисто учебная. Мы строим шаг за шагом фреймворк, у которого:


  • будет понятная кодовая база
  • будет набор бранчей, по которым можно шаг за шагом проследить процесс создания и понять, почему были сделаны те или иные изменения

Вместе с нашим первым релизом мы так же выпустили первую главу открытой книги по глубокому обучению. Книга пишется для Java инженеров, которые ранее не занимались нейронными сетями. При этом процесс обучения строится вокруг создания своего фреймворка с нуля:



При создании нашего учебного фреймворка мы будим вводить новые понятия и сущности только там и тогда, где это действительно необходимо. Например, в первом релизе представление сети сделано так, как инстинктивно большинство инженеров захочет его сделать, в виде графа (а не набора тензоров). Это позволяет создавать более гибкие сети. Поскольку у нас уже есть код, который тренирует модель MNIst, мы можем увидеть, насколько медленно работает подобное представление сети. Теперь, уткнувшись в эту проблему, в дальнейших главах мы познакомим читателя с основами линейной алгебры в том объеме который необходим, чтобы решить ровно эту проблему. И т.д. мы планируем вводить сущности там, где это необходимо, по мере появления проблем до тех пор, пока мы не посмотрим фреймворк.


Несколько небольших плюшек:



PS


Если кто видел наше видео "нейронные сети за 30 минут", то вот небольшой пример кода как воссоздать сеть из видео на DJ:


var context = new Context(
         /* learningRate */ 0.2, 
         /* debug mode */ false);

var inputFriend = new InputNeuron("friend");
var inputVodka = new InputNeuron("vodka");
var inputSunny = new InputNeuron("sunny");

var outputNeuron
        = new ConnectedNeuron.Builder()
            .bias(0.1)
            .activationFunction(new Sigmoid())
            .context(context)
            .build();

inputFriend.connect(outputNeuron, wFriend);
inputVodka.connect(outputNeuron, wVodka);
inputSunny.connect(outputNeuron, wSunny);

// Посылаем входные сигналы:
inputFriend.forwardSignalReceived(null, 1.);
inputVodka.forwardSignalReceived(null, 1.);
inputSunny.forwardSignalReceived(null, 1.);

// Получаем итоговый результат и считаем ошибку:
double result = outputNeuron.getForwardResult();
double expectedResult = 1.;
double errorDy = 2. * (expectedResult - result);

// Посылаем обратно ошибку:
outputNeuron.backwardSignalReceived(errorDy);
Tags:
Hubs:
+8
Comments 3
Comments Comments 3

Articles