Представим следующую достаточно банальную ситуацию: вы студент регионального университета в аграрном регионе, который учится на DS/ML и перед вами поставили достаточно банальную задачу — отличать больное растение от здорового.

Казалось бы, все достаточно просто — с помощью CV обрабатываем фотку + нейросеть для классификации. Если задача обработки изображения достаточно тривиальная (есть много готовых решений и тд), то есть она типичная, то всего сложнее для нас будет подобрать гиперпараметры чтобы нейросеть выдавала нужные нам значения.

Существует достаточно большое количество методов для автоматического подбора гиперпараметров, но у меня своя история с эволюционными методами, так что я сегодня выбрал PBT или Population Based Training или подбор на основе популяции

Что это вообще такое и для чего это надо?

Как я уже описал выше перед нами стоит задача подбора гиперпараметров, так что в этом блоке разберем что такое гиперпараметры и чем они отличаются от обычных параметров.
В качестве достаточно наглядного примера возьмем уравнение линейной регрессии

y = ax + b

Параметры это то, что модель меняет сама, в процессе обучения, в случае регрессии а и b это параметры.

import numpy as np
from sklearn.linear_model import LinearRegression

X = np.array([[1], [2], [3], [4], [5]])  
y = np.array([2, 4, 5, 4, 5])            

model = LinearRegression()
model.fit(X, y)

print(f"Коэффициент наклона: {model.coef_[0]}")
print(f"Свободный член: {model.intercept_}")
print(f"Предсказание для X=6: {model.predict([[6]])[0]}")

Выше представлен код обычной линейной регрессии, модель сама подберет параметры a и b.
Ну с параметрами разобрались, теперь перейдем к гиперпараметрам.

import numpy as np
import tensorflow as tf

X = np.array([[1], [2], [3], [4], [5]], dtype=np.float32)
y = np.array([[2], [4], [5], [4], [5]], dtype=np.float32)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, activation='linear', input_shape=(1,))
])

model.compile(optimizer='sgd', loss='mse')

model.fit(X, y, epochs=500, verbose=0)  # verbose=0 скрывает вывод

weights = model.layers[0].get_weights()
print(f"Коэффициент наклона: {weights[0][0][0]:.4f}")
print(f"Свободный член: {weights[1][0]:.4f}")
print(f"Предсказание для X=6: {model.predict(np.array([[6]]), verbose=0)[0][0]:.4f}")

Гиперпараметры это то что мы настраиваем сами для получения желаемого результата, в случае нейронной сети мы как минимум сами подбираем количество слоев, эпохи, функцию активации и тд. Как вы понимаете для подбирать для каждый раз все гиперпараметры достаточно сложно, муторно и долго по времени, так что умные люди придумали методы чтобы делать это автоматически.

что есть PBT ?

Population Based Training - это асинхронный метод подбора гиперпараметров, работающий по расписанию.
Для понимания метод достаточно не сложный, чуть отойдем от нашей реальной задачи скажем так в абстракцию.

Представим что у нас есть задача научить 10 обезьянок собирать кокосы. Каждой из них мы дали свои параметры по типу "бегай быстро", "смотри вверх" и тд, после чего на 5 минут отправили их в джунгли. По истечению этого времени мы посмотрели кто из обезьян собрал больше всего и их параметры скопировали на тех кто собрал меньше всего и добавили им свои еще какие-то параметры. Таким образом мы каждый раз будем производить селекцию пока не дойдем до нужного нам результата в сборке кокосов.
Ну на обезьянках разобрали, а теперь формализуем метод с помощью математики.

Математика PBT

Небольшие обозначения
P={1,2,…,N}- множество индексов популяции;
Θ⊆R^d- пространство параметров модели (веса нейросети);
Λ⊆R^m- пространство гиперпараметров;
T∈N- количество поколений;
X⊆R^d - пространство признаков;
Y={0,1,…,C - 1}- множество классов (для многоклассовой классификации)

Состояние популяции в момент времени t S_{t} = {(\theta_{t}^{i}, \lambda_{t}^{i})}{i=1}^{N}, \quad \theta{t}^{i} \in \Theta, \quad \lambda_{t}^{i} \in \Lambda

Инициация

(t = 0) \theta_{0}^{i} \sim p_{\Theta}(\theta), \quad \lambda_{0}^{i} \sim p_{\Lambda}(\lambda), \quad \forall i \in \mathcal{P}, где где pΘ и pΛ​ - распределения для инициализации.

Оператор обучения

Для задачи классификации функция потерь - кросс-энтропия:

Для бинарной классификации:


L_{\text{CE}}(\theta) = -\frac{1}{n} \sum_{j=1}^{n} \left[ y_j \log(\sigma(\theta^T x_j)) + (1 - y_j) \log(1 - \sigma(\theta^T x_j)) \right]


Для многоклассовой классификации:

L_{\text{CE}}(\theta) = -\frac{1}{n} \sum_{j=1}^{n} \sum_{c=1}^{C} y_{jc} \log\left( \frac{e^{\theta_c^T x_j}}{\sum_{k=1}^{C} e^{\theta_k^T x_j}} \right)

Оператор одного шага градиентного спуска:

A_t: \Theta \times \Lambda \to \Theta

\tilde{\theta}{t}^{i} = A_t(\theta{t-1}^{i}, \lambda_{t-1}^{i}) = \theta_{t-1}^{i} - \eta_{t}^{i} \nabla L_{\text{CE}}(\theta_{t-1}^{i}, \mathcal{D}_{\text{train}}), где \eta_{t}^{i}- скорость обучения (компонента вектора \lambda_{t-1}^{i}), а D_{\text{train}} - обучающая выборка

Функция пригодности (fitness) для классификации

M: \Theta \to [0,1]

M(\theta) = \frac{1}{|\mathcal{D}_{\text{val}}|} \sum_{(x_j, y_j) \in \mathcal{D}_{\text{val}}} \mathbb{I}[f_{\theta}(x_j) = y_j]


\mathcal{D}_{\text{val}}- валидационная выборка
f_{\theta}(x) = \arg\max_{c \in \mathcal{Y}} p_{\theta}(c \mid x)- предсказанный класс
I[⋅]- индикаторная функция (1 если условие истинно, 0 иначе)

Для бинарной классификации с порогом 0.5:

f_{\theta}(x) = \begin{cases} 1, & \text{если } \sigma(\theta^T x) > 0.5 \\ 0, & \text{иначе} \end{cases}

Сортировка (ранжирование по качеству)

Для классификации сортируем по убыванию accuracy (больше - лучше):
Определим биекцию \sigma_t: \{1, \ldots, N\} \to \{1, \ldots, N\}- перестановку индексов, такую что:
M(\tilde{\theta}{t}^{\sigma_t(1)}) \geq M(\tilde{\theta}{t}^{\sigma_t(2)}) \geq \cdots \geq M(\tilde{\theta}_{t}^{\sigma_t(N)})
В случае равенства может использоваться дополнительный критерий (например, порядковый номер).
Обозначим:

\mathcal{E}_t = \{\sigma_t(1), \ldots, \sigma_t(K)\}- элита
W_t = \{\sigma_t(K+1), \ldots, \sigma_t(N)\}- худшие , где K=⌊αN⌋, α∈(0,1)α∈(0,1) - доля элиты

Элитизм

Для всех i \in \mathcal{E}_t: \theta_{t}^{i} = \tilde{\theta}_{t}^{i}
\lambda_{t}^{i} = \lambda_{t-1}^{i}
Элите переходит в следующее поколение без изменений.

Оператор селекции (Exploit)

Для каждого j \in W_tвыбираем донора d(j)  случайным образом из элиты: d(j) \sim \text{Uniform}(\mathcal{E}_t)
Копируем параметры (веса) от донора: \hat{\theta}{t}^{j} = \tilde{\theta}{t}^{d(j)}

Оператор мутации (Explore)

Гиперпараметры мутируют, чтобы внести разнообразие

Вариант 1: Адаптивное возмущение

\lambda_{t}^{j} = \lambda_{t-1}^{d(j)} + \varepsilon, \quad \varepsilon \sim \mathcal{N}(0, \Sigma)

Вариант 2: Мультипликативное возмущение
\lambda_{t}^{j} = \lambda_{t-1}^{d(j)} \cdot \exp(\varepsilon), \quad \varepsilon \sim N(0, \sigma^2)

Вариант 3: Перезагрузка
\lambda_{t}^{j} \sim p_{\Lambda}(\lambda)

Итоговое правило обновления
Для каждого i∈P:

(\theta_{t}^{i}, \lambda_{t}^{i}) = \begin{cases} (\tilde{\theta}{t}^{i}, \lambda{t-1}^{i}), & \text{если } i \in \mathcal{E}t \text{ (элитизм)} \ (\tilde{\theta}{t}^{d(i)}, \\lambda_{t-1}^{d(i)} + \varepsilon_i), & \text{если } i \in \mathcal{W}_t \text{ (эксплуатация + мутация)} \end{cases}

где d(i) \sim \text{Uniform}(\mathcal{E}_t), \quad \varepsilon_i - независимые реализации шума (может быть аддитивным, мультипликативным или заменой)

Скрытый текст

Для понимания того как это все работает иногда хочется потыкаться в это и посмотреть на числах, так что если вам тоже такое интересно, то прошу

Исходные данные

Θ⊆R^2- веса модели (2 параметра)
Λ⊆R  - скорость обучения
N=3 - размер популяции
K=1- размер элиты
T = 2- поколения
Задача - бинарная классификация (0/1)

Поколение 0 (инициализация)

\theta_{0}^{1} = \begin{bmatrix} 1.0 \\ 0.5 \end{bmatrix}, \quad \lambda_{0}^{1} = 0.1\theta_{0}^{2} = \begin{bmatrix} 0.5 \\ 1.0 \end{bmatrix}, \quad \lambda_{0}^{2} = 0.01\theta_{0}^{3} = \begin{bmatrix} 0.0 \\ 0.0 \end{bmatrix}, \quad \lambda_{0}^{3} = 0.5

Поколение 1 (t =1)

тренировка \tilde{\theta}{1}^{i} = A{1}(\theta_{0}^{i}, \lambda_{0}^{i})Для классификации используем кросс-энтропию:
L(\theta) = -\frac{1}{n} \sum_{j=1}^{n} \left[ y_j \log(\hat{y}_j) + (1 - y_j) \log(1 - \hat{y}_j) \right], где

 \hat{y}_j = \sigma(\theta^T x_j), \quad \sigma(z) = \frac{1}{1 + e^{-z}}

Пусть градиент для каждой модели:

\nabla L(\theta_{0}^{1}) = \begin{bmatrix} 0.3 \\ -0.2 \end{bmatrix}, \quad \nabla L(\theta_{0}^{2}) = \begin{bmatrix} -0.4 \\ 0.2 \end{bmatrix}, \quad \nabla L(\theta_{0}^{3}) = \begin{bmatrix} 0.8 \\ 0.5 \end{bmatrix}

Тогда:

\tilde{\theta}_{1}^{1} = \begin{bmatrix} 1.0 \\ 0.5 \end{bmatrix} - 0.1 \cdot \begin{bmatrix} 0.3 \\ -0.2 \end{bmatrix} = \begin{bmatrix} 0.97 \\ 0.52 \end{bmatrix}\tilde{\theta}_{1}^{2} = \begin{bmatrix} 0.5 \\ 1.0 \end{bmatrix} - 0.01 \cdot \begin{bmatrix} -0.4 \\ 0.2 \end{bmatrix} = \begin{bmatrix} 0.504 \\ 0.998 \end{bmatrix}\tilde{\theta}_{1}^{3} = \begin{bmatrix} 0.0 \\ 0.0 \end{bmatrix} - 0.5 \cdot \begin{bmatrix} 0.8 \\ 0.5 \end{bmatrix} = \begin{bmatrix} -0.4 \\ -0.25 \end{bmatrix}

Оценка качества m_i = M(\tilde{\theta}_{1}^{i})

Для классификации M - accuracy на валидационной выборке из 5 примеров:
Валидационные данные:

X_{\text{val}} =  \begin{bmatrix} 1.0 & 0.2 \\ 0.5 & 1.5 \\ -0.5 & -0.3 \\ 0.8 & 0.8 \\ 0.1 & 1.0 \end{bmatrix}, \quad y_{\text{val}} =  \begin{bmatrix} 1 \\ 0 \\ 0 \\ 1 \\ 0 \end{bmatrix}

Считаем предсказания: \hat{y} = \sigma(\theta^T x), порог 0.5
Для модели 1 (θ=[0.97,0.52]θ=[0.97,0.52]):

\theta^T x_1 = 0.97 \cdot 1.0 + 0.52 \cdot 0.2 = 1.074, \quad \sigma = 0.745 \rightarrow \hat{y} = 1 \theta^T x_2 = 0.97 \cdot 0.5 + 0.52 \cdot 1.5 = 1.265, \quad \sigma = 0.780 \rightarrow \hat{y} = 1\theta^T x_3 = 0.97 \cdot (-0.5) + 0.52 \cdot (-0.3) = -0.641, \quad \sigma = 0.345 \rightarrow \hat{y} = 0\theta^T x_4 = 0.97 \cdot 0.8 + 0.52 \cdot 0.8 = 1.192, \quad \sigma = 0.767 \rightarrow \hat{y} = 1\theta^T x_5 = 0.97 \cdot 0.1 + 0.52 \cdot 1.0 = 0.617, \quad \sigma = 0.649 \rightarrow \hat{y} = 1

Совпало 3 из 5: m_1​=0.6

Для модели 2 (θ=[0.504,0.998]θ=[0.504,0.998]):

\theta^T x_1 = 0.504 \cdot 1.0 + 0.998 \cdot 0.2 = 0.7036, \quad \sigma = 0.669 \rightarrow \hat{y} = 1\theta^T x_2 = 0.504 \cdot 0.5 + 0.998 \cdot 1.5 = 1.749, \quad \sigma = 0.852 \rightarrow \hat{y} = 1\theta^T x_3 = 0.504 \cdot (-0.5) + 0.998 \cdot (-0.3) = -0.5514, \quad \sigma = 0.366 \rightarrow \hat{y} = 0\theta^T x_4 = 0.504 \cdot 0.8 + 0.998 \cdot 0.8 = 1.2016, \quad \sigma = 0.769 \rightarrow \hat{y} = 1\theta^T x_5 = 0.504 \cdot 0.1 + 0.998 \cdot 1.0 = 1.0484, \quad \sigma = 0.740 \rightarrow \hat{y} = 1

Совпало 3 из 5: m_2 ​=0.6

Для модели 3 (θ=[−0.4,−0.25]θ=[−0.4,−0.25]):

\theta^T x_1 = -0.4 \cdot 1.0 + (-0.25) \cdot 0.2 = -0.45, \quad \sigma = 0.389 \rightarrow \hat{y} = 0\theta^T x_2 = -0.4 \cdot 0.5 + (-0.25) \cdot 1.5 = -0.575, \quad \sigma = 0.360 \rightarrow \hat{y} = 0\theta^T x_3 = -0.4 \cdot (-0.5) + (-0.25) \cdot (-0.3) = 0.275, \quad \sigma = 0.568 \rightarrow \hat{y} = 1\theta^T x_4 = -0.4 \cdot 0.8 + (-0.25) \cdot 0.8 = -0.52, \quad \sigma = 0.373 \rightarrow \hat{y} = 0\theta^T x_5 = -0.4 \cdot 0.1 + (-0.25) \cdot 1.0 = -0.29, \quad \sigma = 0.428 \rightarrow \hat{y} = 0

Совпало 2 из 5 m_3 = 0.4

Сортировка σ1​

Для классификации

m_1 = 0.6, \quad m_2 = 0.6, \quad m_3 = 0.4

При равенстве берем по индексу:

\sigma_1(1) = 1 \text{ (лучший)}, \quad \sigma_1(2) = 2, \quad \sigma_1(3) = 3

Элита: \mathcal{E}_1 = {1}
Худшие: \mathcal{W}_1 = {2, 3}

Элитизм для i = 1

\theta_{1}^{1} = \tilde{\theta}_{1}^{1} = \begin{bmatrix} 0.97 \\ 0.52 \end{bmatrix}\lambda_{1}^{1} = \lambda_{0}^{1} = 0.1

Эксплуатация и мутация для худших

для i = 2:
донор d = 1
Копирование параметров:

\theta_{1}^{2} = \tilde{\theta}_{1}^{1} = \begin{bmatrix} 0.97 \\ 0.52 \end{bmatrix}

Мутация гиперпараметра (мультипликативная):
Пусть \varepsilon_2 = -0.693

λ  1 2 ​  =λ  0 1 ​  ⋅exp(ε  2 ​  )=0.1⋅0.5=0.05

Для i = 3
Донор d = 1

\theta_{1}^{3} = \tilde{\theta}_{1}^{1} = \begin{bmatrix} 0.97 \\ 0.52 \end{bmatrix}

\varepsilon_3 = 0.693

\lambda_{1}^{3} = 0.1 \cdot 2 = 0.2

Состояние после поколения 1

S_1 = \{ (\theta_{1}^{1}, 0.1), (\theta_{1}^{2}, 0.05), (\theta_{1}^{3}, 0.2) \}\theta_{1}^{1} = \theta_{1}^{2} = \theta_{1}^{3} = \begin{bmatrix} 0.97 \\ 0.52 \end{bmatrix}

Поколение 2 (t=2)
\tilde{\theta}{2}^{i} = A{2}(\theta_{1}^{i}, \lambda_{1}^{i})
Пусть на этом шаге градиент для всех одинаков (так как веса одинаковые):

\nabla L(\theta_{1}^{i}) = \begin{bmatrix} 0.2 \\ 0.1 \end{bmatrix}, \quad \forall i

Тогда:

\tilde{\theta}_{2}^{1} = \begin{bmatrix} 0.97 \\ 0.52 \end{bmatrix} - 0.1 \cdot \begin{bmatrix} 0.2 \\ 0.1 \end{bmatrix} = \begin{bmatrix} 0.95 \\ 0.51 \end{bmatrix}\tilde{\theta}_{2}^{2} = \begin{bmatrix} 0.97 \\ 0.52 \end{bmatrix} - 0.05 \cdot \begin{bmatrix} 0.2 \\ 0.1 \end{bmatrix} = \begin{bmatrix} 0.96 \\ 0.515 \end{bmatrix}\tilde{\theta}_{2}^{3} = \begin{bmatrix} 0.97 \\ 0.52 \end{bmatrix} - 0.2 \cdot \begin{bmatrix} 0.2 \\ 0.1 \end{bmatrix} = \begin{bmatrix} 0.93 \\ 0.50 \end{bmatrix}

Оценка accuracy на тех же валидационных данных

Для модели 1 (θ=[0.95,0.51]θ=[0.95,0.51]):

\theta^T x_1 = 0.95 \cdot 1.0 + 0.51 \cdot 0.2 = 1.052, \quad \sigma = 0.741 \rightarrow 1\theta^T x_2 = 0.95 \cdot 0.5 + 0.51 \cdot 1.5 = 1.24, \quad \sigma = 0.776 \rightarrow 1\theta^T x_3 = 0.95 \cdot (-0.5) + 0.51 \cdot (-0.3) = -0.628, \quad \sigma = 0.348 \rightarrow 0\theta^T x_4 = 0.95 \cdot 0.8 + 0.51 \cdot 0.8 = 1.168, \quad \sigma = 0.763 \rightarrow 1\theta^T x_5 = 0.95 \cdot 0.1 + 0.51 \cdot 1.0 = 0.605, \quad \sigma = 0.647 \rightarrow 1

Совпало 3 из 5: m_1 = 0.6

Для модели 2 (θ=[0.96,0.515]θ=[0.96,0.515]):

\theta^T x_1 = 0.96 \cdot 1.0 + 0.515 \cdot 0.2 = 1.063, \quad \sigma = 0.743 \rightarrow 1\theta^T x_2 = 0.96 \cdot 0.5 + 0.515 \cdot 1.5 = 1.2525, \quad \sigma = 0.778 \rightarrow 1\theta^T x_3 = 0.96 \cdot (-0.5) + 0.515 \cdot (-0.3) = -0.6345, \quad \sigma = 0.346 \rightarrow 0\theta^T x_4 = 0.96 \cdot 0.8 + 0.515 \cdot 0.8 = 1.18, \quad \sigma = 0.765 \rightarrow 1\theta^T x_5 = 0.96 \cdot 0.1 + 0.515 \cdot 1.0 = 0.611, \quad \sigma = 0.648 \rightarrow 1

Совпало 3 из 5: m_2 = 0.6

Для модели 3 (θ=[0.93,0.50]θ=[0.93,0.50]):

\theta^T x_1 = 0.93 \cdot 1.0 + 0.50 \cdot 0.2 = 1.03, \quad \sigma = 0.737 \rightarrow 1\theta^T x_2 = 0.93 \cdot 0.5 + 0.50 \cdot 1.5 = 1.215, \quad \sigma = 0.771 \rightarrow 1\theta^T x_3 = 0.93 \cdot (-0.5) + 0.50 \cdot (-0.3) = -0.615, \quad \sigma = 0.351 \rightarrow 0\theta^T x_4 = 0.93 \cdot 0.8 + 0.50 \cdot 0.8 = 1.144, \quad \sigma = 0.758 \rightarrow 1\theta^T x_5 = 0.93 \cdot 0.1 + 0.50 \cdot 1.0 = 0.593, \quad \sigma = 0.644 \rightarrow 1

Совпало 3 и 5: m_3 = 0.6

у всех accuracy = 0,6 При равных значениях оставляем порядок без изменений:
\sigma_2(1) = 1, \quad \sigma_2(2) = 2, \quad \sigma_2(3) = 3
Элита: \mathcal{E}_2 = {1}
Худшие: \mathcal{W}_2 = {2, 3}

Элитизм для i = 1

\theta_{2}^{1} = \tilde{\theta}_{2}^{1} = \begin{bmatrix} 0.95 \\ 0.51 \end{bmatrix}\lambda_{2}^{1} = \lambda_{1}^{1} = 0.1

Мутация: \varepsilon_2 = -0.357 \lambda_{2}^{2} = 0.1 \cdot 0.7 = 0.07
для i = 3
донор d = 1

\theta_{2}^{3} = \tilde{\theta}_{2}^{1} = \begin{bmatrix} 0.95 \\ 0.51 \end{bmatrix}

Мутация \varepsilon_3 = 0.262

\lambda_{2}^{3} = 0.1 \cdot 1.3 = 0.13

Финальное состояние

S_2 = \{ (\theta_{2}^{1}, 0.1), (\theta_{2}^{2}, 0.07), (\theta_{2}^{3}, 0.13) \}\theta_{2}^{1} = \theta_{2}^{2} = \theta_{2}^{3} = \begin{bmatrix} 0.95 \\ 0.51 \end{bmatrix}

С математикой мы разобрались, так что перейдем к реализации метода на коде

P.S. так как задача по растениям была абстрактной, так что данных у меня нет - взял вместо MNIST


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import random
from collections import defaultdict

# ========== 1. КОНФИГУРАЦИЯ ==========
class Config:
    # Популяция
    population_size = 4  # 4 воркера
    
    # Данные - MNIST!
    dataset = "MNIST"
    batch_size = 64
    
    # Обучение
    num_epochs = 5  # 5 эпох
    ready_interval = 100  # проверка PBT каждые 100 шагов
    
    # Гиперпараметры
    lr_min = 1e-4
    lr_max = 1e-2
    momentum_min = 0.5
    momentum_max = 0.99
    
    # Мутация
    mutate_prob = 0.5
    mutate_factor_low = 0.7
    mutate_factor_high = 1.5
    
    # Технические
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    seed = 42
    checkpoint_dir = "./pbt_mnist_demo"
    
    # Для печати
    verbose = True


config = Config()
os.makedirs(config.checkpoint_dir, exist_ok=True)

# Устанавливаем seed
torch.manual_seed(config.seed)
np.random.seed(config.seed)
random.seed(config.seed)

print("="*60)
print("PBT ДЕМОНСТРАЦИЯ НА MNIST")
print("="*60)
print(f"Устройство: {config.device}")
print(f"Размер популяции: {config.population_size}")
print(f"Эпох: {config.num_epochs}")
print(f"Данные: {config.dataset}")
print("="*60)


# ========== 2. ЗАГРУЗКА MNIST ==========
def load_mnist(config):
    """Загружает MNIST"""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # Загружаем
    train_dataset = datasets.MNIST(
        './data', train=True, download=True, transform=transform
    )
    test_dataset = datasets.MNIST(
        './data', train=False, download=True, transform=transform
    )
    
    # Берем подмножество для скорости
    train_dataset = torch.utils.data.Subset(train_dataset, range(5000))
    test_dataset = torch.utils.data.Subset(test_dataset, range(1000))
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config.batch_size, 
        shuffle=True,
        num_workers=0
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=config.batch_size, 
        shuffle=False,
        num_workers=0
    )
    
    print(f"MNIST загружен:")
    print(f"  Train: {len(train_dataset)} изображений")
    print(f"  Test: {len(test_dataset)} изображений")
    print(f"  Размер: 28x28, классов: 10")
    
    return train_loader, test_loader


# ========== 3. МОДЕЛЬ ДЛЯ MNIST (ИСПРАВЛЕННАЯ) ==========
class MNISTNet(nn.Module):
    """
    Исправленная модель для MNIST
    Правильно считаем размеры после сверток
    """
    def __init__(self, dropout_rate=0.3):
        super().__init__()
        
        # Сверточная часть
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # 28x28 -> 28x28
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # 28x28 -> 28x28
        self.pool = nn.MaxPool2d(2, 2)  # 28x28 -> 14x14
        
        # После двух сверток и пулинга: 64 канала * 7 * 7 = 3136
        # Потому что: 28 -> pool -> 14 -> conv -> 14 -> pool -> 7
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        
        self.dropout1 = nn.Dropout2d(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)
        
    def forward(self, x):
        # Первый блок
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.pool(x)  # 28 -> 14
        
        # Второй блок
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = self.pool(x)  # 14 -> 7
        
        x = self.dropout1(x)
        
        # Переход к полносвязным слоям
        x = torch.flatten(x, 1)  # [batch, 64*7*7]
        
        # Полносвязные слои
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        
        return x


# ========== 4. ОЦЕНКА МОДЕЛИ ==========
@torch.no_grad()
def evaluate(model, test_loader, device, max_batches=10):
    """Оценка модели"""
    model.eval()
    correct = 0
    total = 0
    
    for i, (data, target) in enumerate(test_loader):
        if i >= max_batches:
            break
        data, target = data.to(device), target.to(device)
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        total += target.size(0)
    
    return correct / total if total > 0 else 0.0


# ========== 5. ХРАНИЛИЩЕ ==========
class Store:
    def __init__(self, checkpoint_dir):
        self.checkpoint_dir = checkpoint_dir
        self.stats = defaultdict(list)
        
    def save(self, worker_id, model, optimizer, step, lr, momentum, accuracy):
        path = f"{self.checkpoint_dir}/worker_{worker_id}.pth"
        torch.save({
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'step': step,
            'lr': lr,
            'momentum': momentum,
            'accuracy': accuracy
        }, path)
        
        self.stats[worker_id].append({
            'step': step,
            'accuracy': accuracy,
            'lr': lr,
            'momentum': momentum
        })
    
    def load(self, worker_id, device):
        path = f"{self.checkpoint_dir}/worker_{worker_id}.pth"
        if os.path.exists(path):
            return torch.load(path, map_location=device)
        return None
    
    def get_all_accuracies(self):
        accs = {}
        for worker_id in self.stats:
            if self.stats[worker_id]:
                accs[worker_id] = self.stats[worker_id][-1]['accuracy']
        return accs


# ========== 6. EXPLOIT (ОТБОР УСЕЧЕНИЕМ) ==========
def truncation_selection(worker_id, current_acc, all_accs, top_ratio=0.25):
    """Отбор усечением"""
    if len(all_accs) < config.population_size:
        return None
    
    sorted_workers = sorted(all_accs.items(), key=lambda x: x[1])
    n_top = max(1, int(len(sorted_workers) * top_ratio))
    
    top_workers = [w for w, acc in sorted_workers[-n_top:]]
    bottom_workers = [w for w, acc in sorted_workers[:n_top]]
    
    if worker_id in bottom_workers:
        donor_id = random.choice(top_workers)
        return donor_id
    
    return None


# ========== 7. EXPLORE (МУТАЦИЯ) ==========
def mutate(lr, momentum):
    """Мутация гиперпараметров"""
    new_lr = lr
    new_momentum = momentum
    
    if random.random() < config.mutate_prob:
        factor = random.choice([config.mutate_factor_low, config.mutate_factor_high])
        new_lr = lr * factor
        new_lr = max(config.lr_min, min(config.lr_max, new_lr))
    
    if random.random() < config.mutate_prob:
        factor = random.choice([config.mutate_factor_low, config.mutate_factor_high])
        new_momentum = momentum * factor
        new_momentum = max(config.momentum_min, min(config.momentum_max, new_momentum))
    
    return new_lr, new_momentum


# ========== 8. ОБУЧЕНИЕ ВОРКЕРА ==========
def train_worker(worker_id, train_loader, test_loader, store, config):
    """Обучение одного воркера с PBT"""
    device = config.device
    
    # Инициализация гиперпараметров
    lr = 10 ** np.random.uniform(np.log10(config.lr_min), np.log10(config.lr_max))
    momentum = np.random.uniform(config.momentum_min, config.momentum_max)
    
    print(f"\n--- Worker {worker_id} ---")
    print(f"  Начальные: LR={lr:.6f}, Momentum={momentum:.3f}")
    
    # Модель и оптимизатор
    model = MNISTNet().to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    
    step = 0
    best_acc = 0.0
    
    for epoch in range(config.num_epochs):
        model.train()
        epoch_loss = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            # Обучение
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            step += 1
            
            # PBT - проверка
            if step % config.ready_interval == 0:
                # Оценка
                acc = evaluate(model, test_loader, device)
                
                if acc > best_acc:
                    best_acc = acc
                
                print(f"  Worker {worker_id}, эпоха {epoch+1}, шаг {step}, acc={acc:.3f}")
                
                # Сохраняем
                store.save(worker_id, model, optimizer, step, lr, momentum, acc)
                
                # Получаем все accuracy
                all_accs = store.get_all_accuracies()
                
                # EXPLOIT
                if len(all_accs) >= config.population_size:
                    donor_id = truncation_selection(worker_id, acc, all_accs)
                    
                    if donor_id is not None and donor_id != worker_id:
                        donor_acc = all_accs.get(donor_id, 0)
                        print(f"  >> Worker {worker_id} (acc={acc:.3f}) копирует Worker {donor_id} (acc={donor_acc:.3f})")
                        
                        # Загружаем донора
                        donor_checkpoint = store.load(donor_id, device)
                        if donor_checkpoint:
                            # Копируем веса
                            model.load_state_dict(donor_checkpoint['model'])
                            optimizer.load_state_dict(donor_checkpoint['optimizer'])
                            
                            # EXPLORE - мутация
                            old_lr, old_momentum = lr, momentum
                            lr, momentum = mutate(
                                donor_checkpoint['lr'],
                                donor_checkpoint['momentum']
                            )
                            
                            # Обновляем оптимизатор
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = lr
                                param_group['momentum'] = momentum
                            
                            print(f"     Мутация: LR {old_lr:.6f}->{lr:.6f}, Momentum {old_momentum:.3f}->{momentum:.3f}")
                            
                            # Оценка после копирования
                            new_acc = evaluate(model, test_loader, device)
                            print(f"     Accuracy после копирования: {new_acc:.3f}")
                            
                            # Сохраняем обновленного
                            store.save(worker_id, model, optimizer, step, lr, momentum, new_acc)
        
        # Конец эпохи
        avg_loss = epoch_loss / len(train_loader)
        if config.verbose:
            print(f"  Worker {worker_id}, эпоха {epoch+1} завершена, loss={avg_loss:.4f}")
    
    print(f"--- Worker {worker_id} финал, лучшая acc={best_acc:.3f} ---")
    return {'worker_id': worker_id, 'best_acc': best_acc}


# ========== 9. ВИЗУАЛИЗАЦИЯ ==========
def visualize_results(results, store, config):
    """Визуализация результатов"""
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # 1. Динамика accuracy
    ax = axes[0, 0]
    for worker_id in range(config.population_size):
        if worker_id in store.stats:
            steps = [s['step'] for s in store.stats[worker_id]]
            accs = [s['accuracy'] for s in store.stats[worker_id]]
            ax.plot(steps, accs, 'o-', label=f'Worker {worker_id}', linewidth=2, markersize=4)
    
    ax.set_xlabel('Шаг')
    ax.set_ylabel('Accuracy')
    ax.set_title('Динамика accuracy')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_ylim([0, 1])
    
    # 2. Финальные accuracy
    ax = axes[0, 1]
    workers = []
    final_accs = []
    colors = []
    
    sorted_results = sorted(results, key=lambda x: x['best_acc'], reverse=True)
    
    for i, r in enumerate(sorted_results):
        workers.append(f'W{r["worker_id"]}')
        final_accs.append(r['best_acc'])
        colors.append('green' if i == 0 else 'skyblue')
    
    bars = ax.bar(workers, final_accs, color=colors)
    ax.set_xlabel('Воркер')
    ax.set_ylabel('Финальная accuracy')
    ax.set_title('Итоговое качество')
    ax.set_ylim([0, 1])
    
    for bar, acc in zip(bars, final_accs):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
                f'{acc:.3f}', ha='center', fontsize=11)
    
    # 3. Эволюция learning rate
    ax = axes[1, 0]
    for worker_id in range(config.population_size):
        if worker_id in store.stats:
            steps = [s['step'] for s in store.stats[worker_id]]
            lrs = [s['lr'] for s in store.stats[worker_id]]
            ax.plot(steps, lrs, 'o-', label=f'Worker {worker_id}', linewidth=2, markersize=4)
    
    ax.set_xlabel('Шаг')
    ax.set_ylabel('Learning Rate')
    ax.set_title('Эволюция learning rate')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_yscale('log')
    
    # 4. Эволюция momentum
    ax = axes[1, 1]
    for worker_id in range(config.population_size):
        if worker_id in store.stats:
            steps = [s['step'] for s in store.stats[worker_id]]
            momentums = [s['momentum'] for s in store.stats[worker_id]]
            ax.plot(steps, momentums, 'o-', label=f'Worker {worker_id}', linewidth=2, markersize=4)
    
    ax.set_xlabel('Шаг')
    ax.set_ylabel('Momentum')
    ax.set_title('Эволюция momentum')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('pbt_mnist_results.png', dpi=150)
    plt.show()
    
    print("\nГрафики сохранены в 'pbt_mnist_results.png'")


# ========== 10. ОСНОВНАЯ ФУНКЦИЯ ==========
def run_pbt_demo():
    """Запуск демонстрации PBT"""
    print("\n[1] Загрузка данных MNIST...")
    train_loader, test_loader = load_mnist(config)
    
    print("\n[2] Инициализация хранилища...")
    store = Store(config.checkpoint_dir)
    
    print("\n[3] Запуск популяции воркеров...")
    results = []
    start_time = time.time()
    
    for worker_id in range(config.population_size):
        result = train_worker(worker_id, train_loader, test_loader, store, config)
        results.append(result)
    
    elapsed = time.time() - start_time
    
    print("\n" + "="*60)
    print("РЕЗУЛЬТАТЫ")
    print("="*60)
    
    sorted_results = sorted(results, key=lambda x: x['best_acc'], reverse=True)
    
    for i, r in enumerate(sorted_results):
        print(f"{i+1}. Worker {r['worker_id']}: accuracy = {r['best_acc']:.3f}")
    
    best = sorted_results[0]
    print("-"*40)
    print(f"ПОБЕДИТЕЛЬ: Worker {best['worker_id']} с accuracy {best['best_acc']:.3f}")
    print(f"Время выполнения: {elapsed:.1f} секунд ({elapsed/60:.1f} минут)")
    
    print("\n[4] Визуализация...")
    visualize_results(results, store, config)
    
    print("\n" + "="*60)
    print("ДЕМОНСТРАЦИЯ PBT ЗАВЕРШЕНА")
    print("="*60)


# ========== ЗАПУСК ==========
if __name__ == "__main__":
    run_pbt_demo()

Если честно с кодом я сильно не заморачивался, но благо в 2026 году с ним прекрасно помогает всезнающий, задача была разобраться с методом. Ниже результаты кода

Скрытый текст

============================================================ PBT ДЕМОНСТРАЦИЯ НА MNIST ============================================================ Устройство: cpu Размер популяции: 4 Эпох: 5 Данные: MNIST ============================================================ [1] Загрузка данных MNIST... MNIST загружен: Train: 5000 изображений Test: 1000 изображений Размер: 28x28, классов: 10 [2] Инициализация хранилища... [3] Запуск популяции воркеров... --- Worker 0 --- Начальные: LR=0.000561, Momentum=0.966 Worker 0, эпоха 1 завершена, loss=2.2183 Worker 0, эпоха 2, шаг 100, acc=0.591 Worker 0, эпоха 2 завершена, loss=1.3404 Worker 0, эпоха 3, шаг 200, acc=0.805 Worker 0, эпоха 3 завершена, loss=0.6215 Worker 0, эпоха 4, шаг 300, acc=0.853 Worker 0, эпоха 4 завершена, loss=0.4694 Worker 0, эпоха 5 завершена, loss=0.4072 --- Worker 0 финал, лучшая acc=0.853 --- --- Worker 1 --- Начальные: LR=0.002911, Momentum=0.793 Worker 1, эпоха 1 завершена, loss=2.0550 Worker 1, эпоха 2, шаг 100, acc=0.752 Worker 1, эпоха 2 завершена, loss=0.8283 Worker 1, эпоха 3, шаг 200, acc=0.850 Worker 1, эпоха 3 завершена, loss=0.5120 Worker 1, эпоха 4, шаг 300, acc=0.878 Worker 1, эпоха 4 завершена, loss=0.4440 Worker 1, эпоха 5 завершена, loss=0.3914 --- Worker 1 финал, лучшая acc=0.878 --- --- Worker 2 --- Начальные: LR=0.000205, Momentum=0.576 Worker 2, эпоха 1 завершена, loss=2.2981 Worker 2, эпоха 2, шаг 100, acc=0.170 Worker 2, эпоха 2 завершена, loss=2.2754 Worker 2, эпоха 3, шаг 200, acc=0.258 Worker 2, эпоха 3 завершена, loss=2.2600 Worker 2, эпоха 4, шаг 300, acc=0.384 Worker 2, эпоха 4 завершена, loss=2.2427 Worker 2, эпоха 5 завершена, loss=2.2252 --- Worker 2 финал, лучшая acc=0.384 --- --- Worker 3 --- Начальные: LR=0.000131, Momentum=0.924 Worker 3, эпоха 1 завершена, loss=2.3044 Worker 3, эпоха 2, шаг 100, acc=0.336 >> Worker 3 (acc=0.336) копирует Worker 1 (acc=0.878) Мутация: LR 0.000131->0.004366, Momentum 0.924->0.555 Accuracy после копирования: 0.878 Worker 3, эпоха 2 завершена, loss=0.8226 Worker 3, эпоха 3, шаг 200, acc=0.877 Worker 3, эпоха 3 завершена, loss=0.3259 Worker 3, эпоха 4, шаг 300, acc=0.902 Worker 3, эпоха 4 завершена, loss=0.3378 Worker 3, эпоха 5 завершена, loss=0.3315 --- Worker 3 финал, лучшая acc=0.902 --- ============================================================ РЕЗУЛЬТАТЫ ============================================================ 1. Worker 3: accuracy = 0.902 2. Worker 1: accuracy = 0.878 3. Worker 0: accuracy = 0.853 4. Worker 2: accuracy = 0.384 ---------------------------------------- ПОБЕДИТЕЛЬ: Worker 3 с accuracy 0.902 Время выполнения: 80.3 секунд (1.3 минут)