Как стать автором
Обновить

Как сделать проект по распознаванию рукописных цифр с дообучением онлайн. Гайд для не совсем начинающих

Время на прочтение 57 мин
Количество просмотров 34K
Всего голосов 27: ↑26 и ↓1 +25
Комментарии 9

Комментарии 9

Спасибо за статью!
А как вы забирали с Amazon S3 все то множество файлов (около 800)? Их же надо было забрать себе на локальную машину для проведения обучения. Возможно решение довольно простое (и для вас очевидное), но на сайте Amazon S3 скачивается лишь по одному файлу.
Хороший вопрос, наверное, стоило его осветить. Это делается в 2 этапа с помощью библиотеки boto3: вначале с помощью функции list_objects получаем список объектов, потом в цикле их скачиваем. Важно, что Амазон ограничивает «размеры» запросов, так что взять больше 1000 объектов за раз не получился. Есть 2 варианта для скачивания больше 1000 объектов: либо указывать параметры запроса и с помощью этого выбирать объекты (не пробовал), либо после каждого скачивания перемещать/удалять объекты в корзине.

Мой код для скачивания картинок выглядит так:
s3 = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY)
for obj in s3.list_objects(Bucket=BUCKET)['Contents']:
    filename = obj['Key']
    if 'digit' in filename:
        # The local directory must exist.
        localfilename = os.path.join('my_images/', filename)
        s3.download_file(BUCKET, filename, localfilename)
    else:
        pass
Спасибо, статья более чем достойна.
Вечер добрый, в процессе реализации описанного вами проекта возник вопрос.
После разработки первой версии CNN требуется первый раз обучить сеть и для этого

… модель предполагает другие измерения у данных


и далее следует код, в котором проводится reshape X_train, X_val, y_train, y_val

trX = X_train.reshape(-1, 28, 28, 1) # 28x28x1
teX = X_val.reshape(-1, 28, 28, 1)
enc = OneHotEncoder()
enc.fit(y.reshape(-1, 1), 10).toarray() # 10x1
trY = enc.fit_transform(y_train.reshape(-1, 1)).toarray()
teY = enc.fit_transform(y_val.reshape(-1, 1)).toarray()


Непонятна одна переменная «y». Что она собой представляет? Какие в ней данные?
Доброе утро.

Действительно, упустил этот момент. В данном случае y — все лейблы для исходных данных. Вообще говоря, это нужно только для того, чтобы OneHotEncoder превращал вектор с 10 классами в матрицу с 10 столбцами. Можно использовать любой вектор из имеющихся (y_train, y_val или какой-то другой), главное, чтобы в нём были все 10 классов.
Добрый вечер. У меня еще вопрос или предложение. В пункте дообучения есть обновленный код CNN, и в нем в методе 'train' я вижу как сохраняются файлы с новыми весами в папку tmp

# Save updated weights
all_saver = tf.train.Saver()
all_saver.save(sess, './tmp/data-all_2_updated.chkp')

Но не вижу где эти файлы заливаются на Amazon, при этом в след методе происходит скачивание этих обновленных весов с сервера Amazon. Но мы же туда их не заливали, чтобы скачивать.
Возможно я чего то не заметил. Спасибо.
Добрый вечер.

Для этого надо смотреть в сам код (строки 182, 183, 187, 188, 189): github.com/Erlemar/digit-draw-recognize/blob/master/functions.py#L182

cnn = CNN()
cnn.train(X, y)
		
response = self.save_weights_amazon('data-all_2_updated.chkp.meta', './tmp/data-
                                    all_2_updated.chkp')
response = self.save_weights_amazon('data-all_2_updated.chkp.index', './tmp/data-
                                    all_2_updated.chkp')
response = self.save_weights_amazon('data-all_2_updated.chkp.data-00000-of-00001', 
                                    './tmp/data-all_2_updated.chkp')


Что здесь происходит:
  1. Модель инициализируется и тренируется;
  2. В результате тренировки обновлённые веса сохраняются локально в папке tmp (на Heroku), это 3 отдельных файла (так работает tensorflow);
  3. А затем используется метод save_weights_amazon для заливки обновлённых файлов на Amazon;


Возможно есть более элегантные способы делать это, но у меня получилось вот так.
Понял, спасибо. Сожалею, что сразу не заметил.
Зарегистрируйтесь на Хабре , чтобы оставить комментарий