Обучаем нейросеть распознавать геометрические фигуры
Простой
Ожидает приглашения
Возникла задача обучить нейросеть на распознавание геометрических фигур - самых простейших: квадрата и круга.
1) Подготовка датасета
Для обучения нейросети нужны данные, поэтому самое простое - это сгенирировать их самому:
Генерация квадратов
def draw_box_random_centered(w, h, fc, mfs):
img = Image.new("RGB", (w, h))
draw = ImageDraw.Draw(img)
d0 = random.randint(0, min(w, h) // 2 - 1)
if d0 < mfs // 2: d0 = mfs // 2 - 1
x1 = w // 2 - d0
y1 = h // 2 - d0
x2 = w // 2 + d0
y2 = h // 2 + d0
draw.rectangle((x1, y1, x2, y2), fill=random.choice(fc), outline=(0, 0, 0))
return img
Генерация кругов
def draw_circle_random_centered(w, h, fc, mfs):
img = Image.new("RGB", (w, h))
draw = ImageDraw.Draw(img)
r0 = random.randint(0, min(w, h) // 2 - 1)
if r0 < mfs: r0 = mfs // 2 - 1
x1 = w // 2 - r0
y1 = h // 2 - r0
x2 = w // 2 + r0
y2 = h // 2 + r0
draw.ellipse((x1, y1, x2, y2), fill=random.choice(fc), outline=(0, 0, 0))
return img
Сохраняем датасет в каталоги
_path = 'training_dataset'
if not os.path.exists(_path): os.makedirs(_path)
_folder = os.path.join(_path, "boxes")
if not os.path.exists(_folder): os.makedirs(_folder)
_folder = os.path.join(_path, "circles")
if not os.path.exists(_folder): os.makedirs(_folder)
for i in range(num_samples):
img = draw_box_random_centered(w=w, h=h, fc=fill_colors, mfs=min_fig_size)
img.save(os.path.join(os.path.join(_path, "boxes"), f"box-{i}.png"))
img = draw_circle_random_centered(w=w, h=h, fc=fill_colors, mfs=min_fig_size)
img.save(os.path.join(os.path.join(_path, "circles"), f"circle-{i}.png"))
2) Обучение нейросети
Определяем параметры для обучения нейросети
cur_run_folder = os.path.abspath(os.getcwd()) # текущий каталог
data_dir = os.path.join(cur_run_folder, "training_dataset") # каталог с данными
num_classes = 2 # всего классов
epochs = 40 # Количество эпох
batch_size = 10 # Размер мини-выборки
img_height, img_width = w, h # размер картинок
input_shape = (img_height, img_width, 3) # размерность картинки
_path = '_models'
if not os.path.exists(_path): os.makedirs(_path)
Увеличиваем число примеров для обучения нейросети
keras.layers.RandomRotation(factor=(-0.2, 0.3)),
keras.layers.RandomContrast(factor=0.2),
Создаем модель нейросети
model = keras.Sequential([
keras.layers.Rescaling(1. / 255, input_shape=input_shape),
keras.layers.RandomRotation(factor=(-0.2, 0.3)),
keras.layers.RandomContrast(factor=0.2),
keras.layers.Conv2D(16, 3, padding='same', activation='relu'),
keras.layers.MaxPooling2D(),
keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
keras.layers.MaxPooling2D(),
keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
keras.layers.MaxPooling2D(),
keras.layers.Flatten(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(num_classes)
])
Компилируем модель
model.compile(
optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
Создаем тренировочный и проверочный наборы
# тренировочный набор
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
# shuffle=False,
image_size=(img_height, img_width),
batch_size=batch_size)
# набор для валидации
val_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
# shuffle=False,
image_size=(img_height, img_width),
batch_size=batch_size)
Сохранение обученных моделей нейросети в каталоге _models
# для записи моделей
callbacks = [ModelCheckpoint(os.path.join("_models", 'cnn_Open{epoch:1d}.hdf5')),
# keras.callbacks.EarlyStopping(monitor='loss', patience=10),
]
Запуск обучения нейросети
# запуск процесса обучения
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs,
callbacks=callbacks
)
Результат обучения нейросети
Num GPUs Available: 1
Found 200 files belonging to 2 classes.
Using 160 files for training.
Found 200 files belonging to 2 classes.
Using 40 files for validation.
Epoch 1/40
16/16 [==============================] - 6s 52ms/step - loss: 0.7398 - accuracy: 0.5938 - val_loss: 0.6609 - val_accuracy: 0.6250
Epoch 2/40
16/16 [==============================] - 1s 46ms/step - loss: 0.5615 - accuracy: 0.7000 - val_loss: 0.6362 - val_accuracy: 0.5750
Epoch 3/40
16/16 [==============================] - 1s 46ms/step - loss: 0.5037 - accuracy: 0.7812 - val_loss: 0.5680 - val_accuracy: 0.7500
Epoch 4/40
16/16 [==============================] - 1s 52ms/step - loss: 0.4575 - accuracy: 0.8313 - val_loss: 0.4621 - val_accuracy: 0.7750
Epoch 5/40
16/16 [==============================] - 1s 50ms/step - loss: 0.3441 - accuracy: 0.8562 - val_loss: 0.2830 - val_accuracy: 0.9000
Epoch 6/40
16/16 [==============================] - 1s 51ms/step - loss: 0.2102 - accuracy: 0.9250 - val_loss: 0.1800 - val_accuracy: 0.9000
Epoch 7/40
16/16 [==============================] - 1s 50ms/step - loss: 0.1450 - accuracy: 0.9312 - val_loss: 0.1132 - val_accuracy: 0.9500
Epoch 8/40
16/16 [==============================] - 1s 50ms/step - loss: 0.1089 - accuracy: 0.9688 - val_loss: 0.1758 - val_accuracy: 0.9500
Epoch 9/40
16/16 [==============================] - 1s 50ms/step - loss: 0.0933 - accuracy: 0.9625 - val_loss: 0.0992 - val_accuracy: 0.9500
Epoch 10/40
16/16 [==============================] - 1s 51ms/step - loss: 0.0818 - accuracy: 0.9812 - val_loss: 0.0829 - val_accuracy: 0.9750
Epoch 11/40
16/16 [==============================] - 1s 50ms/step - loss: 0.0664 - accuracy: 0.9812 - val_loss: 0.0631 - val_accuracy: 0.9750
Epoch 12/40
16/16 [==============================] - 1s 51ms/step - loss: 0.0478 - accuracy: 0.9875 - val_loss: 0.0209 - val_accuracy: 1.0000
Epoch 13/40
16/16 [==============================] - 1s 53ms/step - loss: 0.0182 - accuracy: 1.0000 - val_loss: 0.0345 - val_accuracy: 1.0000
Epoch 14/40
16/16 [==============================] - 1s 47ms/step - loss: 0.0274 - accuracy: 0.9875 - val_loss: 0.0622 - val_accuracy: 0.9750
Epoch 15/40
16/16 [==============================] - 1s 43ms/step - loss: 0.0923 - accuracy: 0.9812 - val_loss: 0.0316 - val_accuracy: 0.9750
Epoch 16/40
16/16 [==============================] - 1s 46ms/step - loss: 0.0970 - accuracy: 0.9625 - val_loss: 0.0400 - val_accuracy: 0.9750
Epoch 17/40
16/16 [==============================] - 1s 46ms/step - loss: 0.0214 - accuracy: 1.0000 - val_loss: 0.0403 - val_accuracy: 1.0000
Epoch 18/40
16/16 [==============================] - 1s 46ms/step - loss: 0.0315 - accuracy: 0.9875 - val_loss: 0.0470 - val_accuracy: 0.9750
Epoch 19/40
16/16 [==============================] - 1s 45ms/step - loss: 0.0399 - accuracy: 0.9875 - val_loss: 0.0282 - val_accuracy: 0.9750
Epoch 20/40
16/16 [==============================] - 1s 52ms/step - loss: 0.0061 - accuracy: 1.0000 - val_loss: 0.0269 - val_accuracy: 0.9750
Epoch 21/40
16/16 [==============================] - 1s 54ms/step - loss: 0.0059 - accuracy: 1.0000 - val_loss: 0.0232 - val_accuracy: 1.0000
Epoch 22/40
16/16 [==============================] - 1s 53ms/step - loss: 0.0029 - accuracy: 1.0000 - val_loss: 0.0286 - val_accuracy: 0.9750
Epoch 23/40
16/16 [==============================] - 1s 53ms/step - loss: 0.0015 - accuracy: 1.0000 - val_loss: 0.0260 - val_accuracy: 0.9750
Epoch 24/40
16/16 [==============================] - 1s 55ms/step - loss: 0.0012 - accuracy: 1.0000 - val_loss: 0.0239 - val_accuracy: 0.9750
Epoch 25/40
16/16 [==============================] - 1s 54ms/step - loss: 9.7747e-04 - accuracy: 1.0000 - val_loss: 0.0279 - val_accuracy: 0.9750
Epoch 26/40
16/16 [==============================] - 1s 55ms/step - loss: 0.0018 - accuracy: 1.0000 - val_loss: 0.0237 - val_accuracy: 0.9750
Epoch 27/40
16/16 [==============================] - 1s 55ms/step - loss: 0.0011 - accuracy: 1.0000 - val_loss: 0.0239 - val_accuracy: 0.9750
Epoch 28/40
16/16 [==============================] - 1s 55ms/step - loss: 4.9770e-04 - accuracy: 1.0000 - val_loss: 0.0244 - val_accuracy: 0.9750
Epoch 29/40
16/16 [==============================] - 1s 53ms/step - loss: 8.6882e-04 - accuracy: 1.0000 - val_loss: 0.0192 - val_accuracy: 0.9750
Epoch 30/40
16/16 [==============================] - 1s 54ms/step - loss: 0.0018 - accuracy: 1.0000 - val_loss: 0.0157 - val_accuracy: 1.0000
Epoch 31/40
16/16 [==============================] - 1s 54ms/step - loss: 8.1033e-04 - accuracy: 1.0000 - val_loss: 0.0166 - val_accuracy: 1.0000
Epoch 32/40
16/16 [==============================] - 1s 55ms/step - loss: 3.2232e-04 - accuracy: 1.0000 - val_loss: 0.0179 - val_accuracy: 1.0000
Epoch 33/40
16/16 [==============================] - 1s 55ms/step - loss: 2.4855e-04 - accuracy: 1.0000 - val_loss: 0.0187 - val_accuracy: 0.9750
Epoch 34/40
16/16 [==============================] - 1s 56ms/step - loss: 2.1718e-04 - accuracy: 1.0000 - val_loss: 0.0191 - val_accuracy: 0.9750
Epoch 35/40
16/16 [==============================] - 1s 55ms/step - loss: 2.4915e-04 - accuracy: 1.0000 - val_loss: 0.0183 - val_accuracy: 0.9750
Epoch 36/40
16/16 [==============================] - 1s 54ms/step - loss: 2.4807e-04 - accuracy: 1.0000 - val_loss: 0.0197 - val_accuracy: 0.9750
Epoch 37/40
16/16 [==============================] - 1s 56ms/step - loss: 1.6868e-04 - accuracy: 1.0000 - val_loss: 0.0199 - val_accuracy: 0.9750
Epoch 38/40
16/16 [==============================] - 1s 54ms/step - loss: 3.7411e-04 - accuracy: 1.0000 - val_loss: 0.0211 - val_accuracy: 0.9750
Epoch 39/40
16/16 [==============================] - 1s 54ms/step - loss: 2.0273e-04 - accuracy: 1.0000 - val_loss: 0.0228 - val_accuracy: 0.9750
Epoch 40/40
16/16 [==============================] - 1s 54ms/step - loss: 2.7570e-04 - accuracy: 1.0000 - val_loss: 0.0249 - val_accuracy: 0.9750
3) Выбор обученной модели нейросети
Графики потерь и точности на обучающих и проверочных наборах

4) Проверяем нейросеть на правильность распознавания
For image: ..\training_dataset\boxes\box-0.png Predicted: [[ 0.9033287 -0.50563365]] => class=0
For image: ..\training_dataset\boxes\box-1.png Predicted: [[ 3.7889903 -3.240302 ]] => class=0
For image: ..\training_dataset\boxes\box-10.png Predicted: [[ 1.3527468 -0.91766155]] => class=0
For image: ..\training_dataset\boxes\box-11.png Predicted: [[ 8.017973 -5.261616]] => class=0
For image: ..\training_dataset\boxes\box-12.png Predicted: [[ 3.5476727 -2.392489 ]] => class=0
For image: ..\training_dataset\boxes\box-13.png Predicted: [[ 2.9426796 -2.0551767]] => class=0
For image: ..\training_dataset\boxes\box-14.png Predicted: [[ 6.810717 -5.7654223]] => class=0
For image: ..\training_dataset\boxes\box-15.png Predicted: [[ 2.3655586 -1.4439728]] => class=0
For image: ..\training_dataset\boxes\box-16.png Predicted: [[ 3.3409414 -2.8813415]] => class=0
For image: ..\training_dataset\boxes\box-17.png Predicted: [[ 2.3670661 -1.098555 ]] => class=0
For image: ..\training_dataset\circles\circle-0.png Predicted: [[-0.5515442 3.1467297]] => class=1
For image: ..\training_dataset\circles\circle-1.png Predicted: [[-1.5983998 3.7142878]] => class=1
For image: ..\training_dataset\circles\circle-10.png Predicted: [[0.01349133 1.4574348 ]] => class=1
For image: ..\training_dataset\circles\circle-11.png Predicted: [[-3.249526 5.4390597]] => class=1
For image: ..\training_dataset\circles\circle-12.png Predicted: [[-0.86239564 3.5299587 ]] => class=1
For image: ..\training_dataset\circles\circle-13.png Predicted: [[-0.4071996 0.6257615]] => class=1
For image: ..\training_dataset\circles\circle-14.png Predicted: [[-2.268906 3.950712]] => class=1
For image: ..\training_dataset\circles\circle-15.png Predicted: [[-2.268906 3.950712]] => class=1
For image: ..\training_dataset\circles\circle-16.png Predicted: [[-1.2579566 6.733812 ]] => class=1
For image: ..\training_dataset\circles\circle-17.png Predicted: [[-3.249526 5.4390597]] => class=1
Как мы видим, получились хорошие результаты распознавания квадратов и кругов.