Продолжение
темы изучения питорча: сделал модель бинарного классификатора для определения синонимичности двух фраз с использованием multilinual BERT:
гистВ этой модельке оба входных предложения обрабатываются одним BERT'ом, затем средние векторы с последнего скрытого слоя объединяются и обрабатываются одним полносвязным слоем. Я пробовал добавлять на вход полносвязного слоя также скалярное произведение и разность BERT-векторов, но значимого улучшения это не дало.
Обучения модели с BERT идет достаточно тяжело, так как ограниченный размер памяти GPU вынуждает
сокращать размер минибатчей, это пропорционально увеличивает время каждой эпохи. Чтобы уложиться в разумное время обучения, я ограничил размер датасета 100,000 сэмплов из полного набора в 800,000. Это следует учитывать при сравнении результатов, так как остальные модели учились на кратно большем количестве примеров.
Результат для модели на базе BERT (pytorch, transformers): f1=0.9384 precision@1=0.9558 mrr=0.794
Тут precision@1 - доля тестовых случаев, когда правильная пара оказывается на первом посте после ранжирования для группы из 100 пар, и mrr - средний обратный ранг (mean reciprocal rank). F1 дает некоторую оценку качества бинарной классификации. Для чатбота в некоторых случаях информативнее mrr, в других - F1.
Другие модели
Исходник моей старой реализации похожей модели на Keras без finetuning'а BERT'а можно
посмотреть тут.
Для сравнения,
модель bi-siamese LSTM на pytorch дает после обучения на полном датасете f1=0.970 precision@1=0.977 mrr=0.863. Аналогичная
модель на Keras, в которой слова представляются сочетанием word2vector-векторов и wordchar2vector-векторов, дает f1=0.738.
Модель бинарного классификатора на Keras, тренируемая с метрикой triple loss (
подробности тут), дает оценку precision@1=0.476 mrr=0.544
Модели на базе
XGBRanker и
LGBMRanker реализуют подход через ранжирование, в качестве фич используются символьные шинглы. Для них получаются метрики:
XGBRanker precision@1=0.992 mrr=0.995
LGBRanker precision@1=0.990 mrr=0.994
Исходный текст
тутПодробности
тут В
чатботе используется
модель бинарного классификатор на LightGBM с символьными шинглами в качестве фич.
Она выдает f1=0.984 precision@1=0.988 mrr=0.994. Ее преимущество, кроме достаточно высокого качества, заключается в легкости перехода, например, на работу с фонетической транскрипцией фраз вместо их каноничного текстового представления. Это позволяет обучаться на корпусе из транскрибированных фраз и затем более качественно обрабатывать фразы из ASR.