Реализация модели детектора перефразировок с BERT в pytorch transformers

Mar 13, 2020 17:43

Продолжение темы изучения питорча: сделал модель бинарного классификатора для определения синонимичности двух фраз с использованием multilinual BERT: гист

В этой модельке оба входных предложения обрабатываются одним BERT'ом, затем средние векторы с последнего скрытого слоя объединяются и обрабатываются одним полносвязным слоем. Я пробовал добавлять на вход полносвязного слоя также скалярное произведение и разность BERT-векторов, но значимого улучшения это не дало.

Обучения модели с BERT идет достаточно тяжело, так как ограниченный размер памяти GPU вынуждает
сокращать размер минибатчей, это пропорционально увеличивает время каждой эпохи. Чтобы уложиться в разумное время обучения, я ограничил размер датасета 100,000 сэмплов из полного набора в 800,000. Это следует учитывать при сравнении результатов, так как остальные модели учились на кратно большем количестве примеров.

Результат для модели на базе BERT (pytorch, transformers): f1=0.9384  precision@1=0.9558  mrr=0.794

Тут precision@1 - доля тестовых случаев, когда правильная пара оказывается на первом посте после ранжирования для группы из 100 пар, и mrr - средний обратный ранг (mean reciprocal rank). F1 дает некоторую оценку качества бинарной классификации. Для чатбота в некоторых случаях информативнее mrr, в других - F1.

Другие модели

Исходник моей старой реализации похожей модели на Keras без finetuning'а BERT'а можно посмотреть тут.

Для сравнения, модель bi-siamese LSTM на pytorch дает после обучения на полном датасете f1=0.970 precision@1=0.977 mrr=0.863. Аналогичная модель на Keras, в которой слова представляются сочетанием word2vector-векторов и wordchar2vector-векторов, дает f1=0.738.

Модель бинарного классификатора на Keras, тренируемая с метрикой triple loss ( подробности тут), дает оценку precision@1=0.476 mrr=0.544

Модели на базе XGBRanker и LGBMRanker реализуют подход через ранжирование, в качестве фич используются символьные шинглы. Для них получаются метрики:
XGBRanker precision@1=0.992 mrr=0.995
LGBRanker precision@1=0.990 mrr=0.994
Исходный текст тут
Подробности тут

В чатботе используется модель бинарного классификатор на LightGBM с символьными шинглами в качестве фич.
Она выдает f1=0.984 precision@1=0.988 mrr=0.994. Ее преимущество, кроме достаточно высокого качества, заключается в легкости перехода, например, на работу с фонетической транскрипцией фраз вместо их каноничного текстового представления. Это позволяет обучаться на корпусе из транскрибированных фраз и затем более качественно обрабатывать фразы из ASR.

transformer, bert, перефразировки, синонимы, pytorch

Previous post Next post
Up