Dropout, Batch Normalization - в поисках серебряной пули. Как я научился искать размытые фотографии

Aug 01, 2021 17:40

Понадобилось мне сделать набор резких картинок (т.е. из большого набора выкинуть все нерезкие фотографии, как на примере ниже)



Сначала я попробовал старые классические методы (сразу же в поисковике выскакивает Laplacian - простенький сверточный фильтр). Для отсеивания явного брака может сгодится, но много ложных срабатываний. К примеру размытие в движении (motion blur), шумные картинки и т.п. ему не по зубам. Попадались и более сложные реализации (с пирамидальным масштабированием и вычислением частот). Но все встреченные реализации часто давали ошибки и (зачастую) медленно работали.

Пойдем путем ИИ. Дальше будет описание проблем построения и обучения нейросети. (Забегая вперед - аккуратность получилась лучше 96%.)


Беглый поиск существующих решений быстро выдал работу "DeepFocus: Detection of out-of-focus regions in whole slide digital images using deep learning" (https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6201886/).

Схематически архитектура DeepFocus:


Вдохновившись архитектурами U-Net и ResNet я решил добавить "пропущенное" соединение. В интернете я не нашел устоявшегося русского перевода и считаю, что такой тип соединения лучше называть проброшенным соединением. Суть заключается в том, что честь фильтров предыдущего слоя (до сверточных фильтров) мы без преобразования добавляем к текущему слою, после чего применяем свертки.

На схеме ниже эти проброшенные связи обозначены зелеными стрелками.


Сверточные слои имеют на входе 3 канала, далее идут свертки с числом каналов 8, 10, 28 (после каждого слоя MaxPool2d), 24, 22, AdaptiveMaxPool2d, после чего полносвязные слои 108, 64, 16 нейронов в слое и 2 выходных значения.

Можно видеть, что структура сети достаточно простая, поэтому эти дополнительные проброшенные связи, призванные ускорить обучение, почти не играют роли.

Отличительной чертой является использование AdaptiveMaxPool2d (на схеме оранжевый слой), который позволяет тренировать подобные сети на картинках разного размера (в пределах одного минибатча размер должен быть одинаковый!). Это позволяет, к примеру, начать тренировку с маленьких картинок, а позже плавно перейти к бОльшим без полного обучения с нуля. Я первоначально обучал на размере 256*256, финальное обучение проводил с img_tile_size = 328. Такое увеличение позволило улучшить точность на несколько процентов. Так же я протестировал модель, обученную на 256*256 картинках получить предсказание для картинки 512 - точность ожидаемо упала, но не очень сильно (примерно на 5%).

По AdaptiveMaxPool2d в сети мало информации (порядка 1 тыс. результатов, значительная часть из которых являются копиями других страниц).

Первоначально у меня была намного сложнее сеть и обучение не шло. Возможно из-за маленького обучающего набора (возможно из-за какой-то мелкой ошибки). В этот момент я решил поэкспериментировать с Dropout и Batch Normalization.

Batch Normalization несколько ускорял обучение, но выяснилась неприятная особенность реализации в PyTorch - нормализация обучается (обновление статистики) не только во время обучения, но и во время тестирования. На эту тему много сообщений в сети [1] [2] и др. Суть предлагаемых решений: использовать momentum=0.9999 или делать такую структуру сети, которая в режиме train ставит нормальный momentum, а в режиме eval ставит его равным 1 (или 0). Я просто отказался (не заметил выигрыша).

Dropout у меня повел себя вообще странно. Даже относительно небольшие значения (0.2) приводили к тому, что буквально первые эпохи давали аккуратность на проверочном наборе порядка 0.75, но потом аккуратность падала и держалась в районе 0,63-0,74 (при этом на тренировочном наборе точность порядка 0.9). Примерно как на графике ниже (надписи игнорируйте, картинка из интернета):


У меня даже закралось сомнение, а действительно ли в режиме eval сеть отключает Dropout - да, отключает. Причем если не переводить сеть в режим eval, то точность на проверочном наборе возрастает до 0.85. Судя по всему имеет место банальное переобучение.

Совместное использование Dropout и Batch Normalization

Т.к. простой Dropout у меня не завелся, то и совместное использование не дало результатов. В интернете сходятся к мнению, что совместно использовать нужно с осторожностью и важна последовательность применения.

Наиболее часто рекомендуют использовать Conv - BatchNorm - Activation - DropOut - (Pool or Conv).

Но встречается и вариант Conv - DropOut - BatchNorm - Activation - Pool [https://stackoverflow.com/questions/39691902/ordering-of-batch-normalization-and-dropout]

Скрипт обучения

Реализация основывается на разработке взятой из https://github.com/zhangrong1722/CheXNet-Pytorch. Это простой скрипт для бинарной классификации. Модуль нейросети был выкинут (densenet121 явно избыточен для этой задачи) и вставлен мой.

Из интересных моментов был подсмотрен focal loss (функция потерь для обучения на несбалансированных данных) и data sampler - https://github.com/ufoym/imbalanced-dataset-sampler предназначенный для автоматического выравнивания несбалансированности (чаще берет угнетенный класс). Я предпочел imbalanced-dataset-sampler.

Был исследован preprocessing.py - скрипт для аугментации и ничего особого интересного не почерпнул. Вместо неё следует брать что-то более мощное, например imgaug (https://imgaug.readthedocs.io/en/latest/). Для моей задачи были задействованы GaussianBlur, AverageBlur, MedianBlur, Resize, AdditiveGaussianNoise, MotionBlur, Sharpen, ElasticTransformation).

Эксперименты

Было проведено обучение с несколько отличающимися настройками сети.
Первый столбец - по сравнению с обычной сеткой число сверточных столбцов увеличено вдвое (full_conv_block=True).
Второй столбец - full_conv_block=True + extra_cat_lines=6 (число пробрасываемых каналов увеличено)
Третий столбец - компактная сеть

Таблица точности в зависимости от эпохи на проверочном наборе картинок:


Можно видеть, что более сложная сеть обучается чуть быстрее (красные клетки появляются уже на 12-й эпохе), но потом обучение замедляется.
Сеть, с дополнительными пробросами обучается ещё быстрее (красные клетки уже на 8-й эпохе!).
Маленькой сети обучатся сложнее, но процесс обучения более равномерный и финальный результат лучше.

Так же в графическом виде:


Видно, что процесс обучения (в данном случае) довольно сильно скачет. И это при том, что скорость обучения была выставлена небольшой (learning rate 5e-4; Adamax). Видно, что финальную точность на тестовой выборке удалось получить порядка 96.5%. При этом на обучающей выборке точность была почти 100% -- такая разница получилась, видимо, из-за того, что эти два набора создавались несколько по разному. В обучающем наборе широко применялась аугментация: для имитации blur были использованы цифровые фильтры), для получения резких картинок было произведено уменьшение в несколько раз итак резких фотографий, снятых на современные цифровые фотоаппараты. При составлении тестовой выборки цифровые фильтры я старался использовать как можно меньше.

С кодом можно ознакомится тут: https://github.com/Imageman/blur_detect_small

P.S. За бортом описания осталось довольно много мелких деталей, экспериментов. На все эксперименты ушло три с половиной недели (в свободное от работы время, и не каждый день). Эксперименты проводились как на чистом CPU, так и на простенькой видеокарте GTX 960 с 2 ГБ памяти (даже такая простая карта дает многократное ускорение).

python, программирование

Previous post Next post
Up