Как стать автором
Обновить

Обучаем нейросеть распознавать геометрические фигуры

Уровень сложностиПростой

Возникла задача обучить нейросеть на распознавание геометрических фигур - самых простейших: квадрата и круга.

1) Подготовка датасета

Для обучения нейросети нужны данные, поэтому самое простое - это сгенирировать их самому:

Генерация квадратов

def draw_box_random_centered(w, h, fc, mfs):
    img = Image.new("RGB", (w, h))
    draw = ImageDraw.Draw(img)

    d0 = random.randint(0, min(w, h) // 2 - 1)
    if d0 < mfs // 2: d0 = mfs // 2 - 1
    x1 = w // 2 - d0
    y1 = h // 2 - d0
    x2 = w // 2 + d0
    y2 = h // 2 + d0

    draw.rectangle((x1, y1, x2, y2), fill=random.choice(fc), outline=(0, 0, 0))
    return img

Генерация кругов

def draw_circle_random_centered(w, h, fc, mfs):
    img = Image.new("RGB", (w, h))
    draw = ImageDraw.Draw(img)

    r0 = random.randint(0, min(w, h) // 2 - 1)
    if r0 < mfs: r0 = mfs // 2 - 1
    x1 = w // 2 - r0
    y1 = h // 2 - r0
    x2 = w // 2 + r0
    y2 = h // 2 + r0


    draw.ellipse((x1, y1, x2, y2), fill=random.choice(fc), outline=(0, 0, 0))
    return img

Сохраняем датасет в каталоги

_path = 'training_dataset'
if not os.path.exists(_path): os.makedirs(_path)
_folder = os.path.join(_path, "boxes")
if not os.path.exists(_folder): os.makedirs(_folder)
_folder = os.path.join(_path, "circles")
if not os.path.exists(_folder): os.makedirs(_folder)

for i in range(num_samples):
    img = draw_box_random_centered(w=w, h=h, fc=fill_colors, mfs=min_fig_size)
    img.save(os.path.join(os.path.join(_path, "boxes"), f"box-{i}.png"))
    img = draw_circle_random_centered(w=w, h=h, fc=fill_colors, mfs=min_fig_size)
    img.save(os.path.join(os.path.join(_path, "circles"), f"circle-{i}.png"))

2) Обучение нейросети

Определяем параметры для обучения нейросети

cur_run_folder = os.path.abspath(os.getcwd())  # текущий каталог
data_dir = os.path.join(cur_run_folder, "training_dataset")  # каталог с данными
num_classes = 2  # всего классов
epochs = 40  # Количество эпох
batch_size = 10  # Размер мини-выборки
img_height, img_width = w, h  # размер картинок
input_shape = (img_height, img_width, 3)  # размерность картинки

_path = '_models'
if not os.path.exists(_path): os.makedirs(_path)

Увеличиваем число примеров для обучения нейросети

keras.layers.RandomRotation(factor=(-0.2, 0.3)),
keras.layers.RandomContrast(factor=0.2),

Создаем модель нейросети

model = keras.Sequential([
    keras.layers.Rescaling(1. / 255, input_shape=input_shape),
    keras.layers.RandomRotation(factor=(-0.2, 0.3)),
    keras.layers.RandomContrast(factor=0.2),
    keras.layers.Conv2D(16, 3, padding='same', activation='relu'),
    keras.layers.MaxPooling2D(),
    keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
    keras.layers.MaxPooling2D(),
    keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
    keras.layers.MaxPooling2D(),
    keras.layers.Flatten(),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(num_classes)
])

Компилируем модель

model.compile(
    optimizer='adam',
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])

Создаем тренировочный и проверочный наборы

# тренировочный набор
train_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    # shuffle=False,
    image_size=(img_height, img_width),
    batch_size=batch_size)

# набор для валидации
val_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    # shuffle=False,
    image_size=(img_height, img_width),
    batch_size=batch_size)

Сохранение обученных моделей нейросети в каталоге _models

# для записи моделей
callbacks = [ModelCheckpoint(os.path.join("_models", 'cnn_Open{epoch:1d}.hdf5')),
             # keras.callbacks.EarlyStopping(monitor='loss', patience=10),
             ]

Запуск обучения нейросети

# запуск процесса обучения
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs,
    callbacks=callbacks
)

Результат обучения нейросети

Num GPUs Available:  1
Found 200 files belonging to 2 classes.
Using 160 files for training.
Found 200 files belonging to 2 classes.
Using 40 files for validation.
Epoch 1/40
16/16 [==============================] - 6s 52ms/step - loss: 0.7398 - accuracy: 0.5938 - val_loss: 0.6609 - val_accuracy: 0.6250
Epoch 2/40
16/16 [==============================] - 1s 46ms/step - loss: 0.5615 - accuracy: 0.7000 - val_loss: 0.6362 - val_accuracy: 0.5750
Epoch 3/40
16/16 [==============================] - 1s 46ms/step - loss: 0.5037 - accuracy: 0.7812 - val_loss: 0.5680 - val_accuracy: 0.7500
Epoch 4/40
16/16 [==============================] - 1s 52ms/step - loss: 0.4575 - accuracy: 0.8313 - val_loss: 0.4621 - val_accuracy: 0.7750
Epoch 5/40
16/16 [==============================] - 1s 50ms/step - loss: 0.3441 - accuracy: 0.8562 - val_loss: 0.2830 - val_accuracy: 0.9000
Epoch 6/40
16/16 [==============================] - 1s 51ms/step - loss: 0.2102 - accuracy: 0.9250 - val_loss: 0.1800 - val_accuracy: 0.9000
Epoch 7/40
16/16 [==============================] - 1s 50ms/step - loss: 0.1450 - accuracy: 0.9312 - val_loss: 0.1132 - val_accuracy: 0.9500
Epoch 8/40
16/16 [==============================] - 1s 50ms/step - loss: 0.1089 - accuracy: 0.9688 - val_loss: 0.1758 - val_accuracy: 0.9500
Epoch 9/40
16/16 [==============================] - 1s 50ms/step - loss: 0.0933 - accuracy: 0.9625 - val_loss: 0.0992 - val_accuracy: 0.9500
Epoch 10/40
16/16 [==============================] - 1s 51ms/step - loss: 0.0818 - accuracy: 0.9812 - val_loss: 0.0829 - val_accuracy: 0.9750
Epoch 11/40
16/16 [==============================] - 1s 50ms/step - loss: 0.0664 - accuracy: 0.9812 - val_loss: 0.0631 - val_accuracy: 0.9750
Epoch 12/40
16/16 [==============================] - 1s 51ms/step - loss: 0.0478 - accuracy: 0.9875 - val_loss: 0.0209 - val_accuracy: 1.0000
Epoch 13/40
16/16 [==============================] - 1s 53ms/step - loss: 0.0182 - accuracy: 1.0000 - val_loss: 0.0345 - val_accuracy: 1.0000
Epoch 14/40
16/16 [==============================] - 1s 47ms/step - loss: 0.0274 - accuracy: 0.9875 - val_loss: 0.0622 - val_accuracy: 0.9750
Epoch 15/40
16/16 [==============================] - 1s 43ms/step - loss: 0.0923 - accuracy: 0.9812 - val_loss: 0.0316 - val_accuracy: 0.9750
Epoch 16/40
16/16 [==============================] - 1s 46ms/step - loss: 0.0970 - accuracy: 0.9625 - val_loss: 0.0400 - val_accuracy: 0.9750
Epoch 17/40
16/16 [==============================] - 1s 46ms/step - loss: 0.0214 - accuracy: 1.0000 - val_loss: 0.0403 - val_accuracy: 1.0000
Epoch 18/40
16/16 [==============================] - 1s 46ms/step - loss: 0.0315 - accuracy: 0.9875 - val_loss: 0.0470 - val_accuracy: 0.9750
Epoch 19/40
16/16 [==============================] - 1s 45ms/step - loss: 0.0399 - accuracy: 0.9875 - val_loss: 0.0282 - val_accuracy: 0.9750
Epoch 20/40
16/16 [==============================] - 1s 52ms/step - loss: 0.0061 - accuracy: 1.0000 - val_loss: 0.0269 - val_accuracy: 0.9750
Epoch 21/40
16/16 [==============================] - 1s 54ms/step - loss: 0.0059 - accuracy: 1.0000 - val_loss: 0.0232 - val_accuracy: 1.0000
Epoch 22/40
16/16 [==============================] - 1s 53ms/step - loss: 0.0029 - accuracy: 1.0000 - val_loss: 0.0286 - val_accuracy: 0.9750
Epoch 23/40
16/16 [==============================] - 1s 53ms/step - loss: 0.0015 - accuracy: 1.0000 - val_loss: 0.0260 - val_accuracy: 0.9750
Epoch 24/40
16/16 [==============================] - 1s 55ms/step - loss: 0.0012 - accuracy: 1.0000 - val_loss: 0.0239 - val_accuracy: 0.9750
Epoch 25/40
16/16 [==============================] - 1s 54ms/step - loss: 9.7747e-04 - accuracy: 1.0000 - val_loss: 0.0279 - val_accuracy: 0.9750
Epoch 26/40
16/16 [==============================] - 1s 55ms/step - loss: 0.0018 - accuracy: 1.0000 - val_loss: 0.0237 - val_accuracy: 0.9750
Epoch 27/40
16/16 [==============================] - 1s 55ms/step - loss: 0.0011 - accuracy: 1.0000 - val_loss: 0.0239 - val_accuracy: 0.9750
Epoch 28/40
16/16 [==============================] - 1s 55ms/step - loss: 4.9770e-04 - accuracy: 1.0000 - val_loss: 0.0244 - val_accuracy: 0.9750
Epoch 29/40
16/16 [==============================] - 1s 53ms/step - loss: 8.6882e-04 - accuracy: 1.0000 - val_loss: 0.0192 - val_accuracy: 0.9750
Epoch 30/40
16/16 [==============================] - 1s 54ms/step - loss: 0.0018 - accuracy: 1.0000 - val_loss: 0.0157 - val_accuracy: 1.0000
Epoch 31/40
16/16 [==============================] - 1s 54ms/step - loss: 8.1033e-04 - accuracy: 1.0000 - val_loss: 0.0166 - val_accuracy: 1.0000
Epoch 32/40
16/16 [==============================] - 1s 55ms/step - loss: 3.2232e-04 - accuracy: 1.0000 - val_loss: 0.0179 - val_accuracy: 1.0000
Epoch 33/40
16/16 [==============================] - 1s 55ms/step - loss: 2.4855e-04 - accuracy: 1.0000 - val_loss: 0.0187 - val_accuracy: 0.9750
Epoch 34/40
16/16 [==============================] - 1s 56ms/step - loss: 2.1718e-04 - accuracy: 1.0000 - val_loss: 0.0191 - val_accuracy: 0.9750
Epoch 35/40
16/16 [==============================] - 1s 55ms/step - loss: 2.4915e-04 - accuracy: 1.0000 - val_loss: 0.0183 - val_accuracy: 0.9750
Epoch 36/40
16/16 [==============================] - 1s 54ms/step - loss: 2.4807e-04 - accuracy: 1.0000 - val_loss: 0.0197 - val_accuracy: 0.9750
Epoch 37/40
16/16 [==============================] - 1s 56ms/step - loss: 1.6868e-04 - accuracy: 1.0000 - val_loss: 0.0199 - val_accuracy: 0.9750
Epoch 38/40
16/16 [==============================] - 1s 54ms/step - loss: 3.7411e-04 - accuracy: 1.0000 - val_loss: 0.0211 - val_accuracy: 0.9750
Epoch 39/40
16/16 [==============================] - 1s 54ms/step - loss: 2.0273e-04 - accuracy: 1.0000 - val_loss: 0.0228 - val_accuracy: 0.9750
Epoch 40/40
16/16 [==============================] - 1s 54ms/step - loss: 2.7570e-04 - accuracy: 1.0000 - val_loss: 0.0249 - val_accuracy: 0.9750

3) Выбор обученной модели нейросети

Графики потерь и точности на обучающих и проверочных наборах

4) Проверяем нейросеть на правильность распознавания

For image: ..\training_dataset\boxes\box-0.png Predicted: [[ 0.9033287  -0.50563365]] => class=0
For image: ..\training_dataset\boxes\box-1.png Predicted: [[ 3.7889903 -3.240302 ]] => class=0
For image: ..\training_dataset\boxes\box-10.png Predicted: [[ 1.3527468  -0.91766155]] => class=0
For image: ..\training_dataset\boxes\box-11.png Predicted: [[ 8.017973 -5.261616]] => class=0
For image: ..\training_dataset\boxes\box-12.png Predicted: [[ 3.5476727 -2.392489 ]] => class=0
For image: ..\training_dataset\boxes\box-13.png Predicted: [[ 2.9426796 -2.0551767]] => class=0
For image: ..\training_dataset\boxes\box-14.png Predicted: [[ 6.810717  -5.7654223]] => class=0
For image: ..\training_dataset\boxes\box-15.png Predicted: [[ 2.3655586 -1.4439728]] => class=0
For image: ..\training_dataset\boxes\box-16.png Predicted: [[ 3.3409414 -2.8813415]] => class=0
For image: ..\training_dataset\boxes\box-17.png Predicted: [[ 2.3670661 -1.098555 ]] => class=0

For image: ..\training_dataset\circles\circle-0.png Predicted: [[-0.5515442  3.1467297]] => class=1
For image: ..\training_dataset\circles\circle-1.png Predicted: [[-1.5983998  3.7142878]] => class=1
For image: ..\training_dataset\circles\circle-10.png Predicted: [[0.01349133 1.4574348 ]] => class=1
For image: ..\training_dataset\circles\circle-11.png Predicted: [[-3.249526   5.4390597]] => class=1
For image: ..\training_dataset\circles\circle-12.png Predicted: [[-0.86239564  3.5299587 ]] => class=1
For image: ..\training_dataset\circles\circle-13.png Predicted: [[-0.4071996  0.6257615]] => class=1
For image: ..\training_dataset\circles\circle-14.png Predicted: [[-2.268906  3.950712]] => class=1
For image: ..\training_dataset\circles\circle-15.png Predicted: [[-2.268906  3.950712]] => class=1
For image: ..\training_dataset\circles\circle-16.png Predicted: [[-1.2579566  6.733812 ]] => class=1
For image: ..\training_dataset\circles\circle-17.png Predicted: [[-3.249526   5.4390597]] => class=1

Как мы видим, получились хорошие результаты распознавания квадратов и кругов.

Теги:
Хабы:
Данная статья не подлежит комментированию, поскольку её автор ещё не является полноправным участником сообщества. Вы сможете связаться с автором только после того, как он получит приглашение от кого-либо из участников сообщества. До этого момента его username будет скрыт псевдонимом.