1. Проблема Когда мы обучаем модели машинного обучения, почти всегда возникает один и тот же вопрос:Что именно происходит во время обучения? Обычно мы смотрим н1. Проблема Когда мы обучаем модели машинного обучения, почти всегда возникает один и тот же вопрос:Что именно происходит во время обучения? Обычно мы смотрим н

Мета-модель для диагностики обучения нейросетей

2026/03/15 23:15
4м. чтение
Для обратной связи или замечаний по поводу данного контента, свяжитесь с нами по адресу [email protected]

1. Проблема

Когда мы обучаем модели машинного обучения, почти всегда возникает один и тот же вопрос:

Обычно мы смотрим на графики метрик и пытаемся вручную интерпретировать происходящее:

  • Модель недообучена

  • Модель переобучена

  • Имбаланс датасета.

  • Сильно шумные данные.

Можно посмотреть на learning curves и понять, что происходит:

График с рутиной ML)
График с рутиной ML)

Но этот анализ почти всегда выполняется вручную или с помощью простейших эвристических правил. А ведь сколько времени, сил и нервов можно было бы сэкономить, если обучить до 100 эпохи а не до 500 (см картинка выше) :-(

Но можно задать интересный вопрос:

А можно ли автоматически определить состояние обучения модели?

2. Идея

А что если научить отдельную модель, которая будет автоматически определять состояние обучения?

То есть вместо ручного анализа мы обучаем модель, которая делает это автоматически. Но насколько это эффективно, на проде, интерпретируемы ли эти результаты для разных типов задач и т. д.

3. Генерация датасета

Чтобы обучить такой классификатор, нужен датасет с различными сценариями обучения.

Я решил сгенерировать его программно.

В качестве базового датасета использовался MNIST — классический набор изображений рукописных цифр.

Эксперименты проводились с несколькими типами моделей:

  • logistic regression

  • небольшой MLP

  • большой MLP

  • маленькая CNN

  • большая CNN

Для каждого эксперимента варьировались параметры:

  • размер обучающей выборки

  • случайный seed

  • наличие дисбаланса классов

  • тип сдвига данных

По итогу я обучил 270 моделей и посмотрел их после 1, 5, 6, 11,16,21,26 эпох. По каждой записи были сохранены:

Столбец

Тип

Описание

model

str

Название модели, использованной для обучения (logreg, mlp_small, mlp_large, cnn_small, cnn_large).

train_size

int

Размер выборки для обучения в конкретном эксперименте.

seed

int

Значение random seed для воспроизводимости случайной выборки.

imbalance

bool

Флаг, указывающий, использовался ли искусственный дисбаланс классов (True) или нет (False).

shift_type

str

Тип сдвига данных на тестовой выборке (none, noise, invert).

train_acc

float

Точность модели на тренировочной выборке после текущей эпохи.

val_acc

float

Точность модели на валидационной выборке после текущей эпохи.

test_acc

float

Точность модели на тестовой выборке (с учетом возможного сдвига данных).

gap

float

Разница между тренировочной и валидационной точностью (train_acc - val_acc). Используется для диагностики переобучения.

epochs

int

Количество эпох обучения (для функции train_and_evaluate) — либо номер эпохи в train_with_history.

val_curve

list of list

История точности на валидационной выборке по эпохам до текущей.

epoch

int

Номер текущей эпохи обучения (используется при пошаговом train_with_history).

underfitting

int (0/1)

Диагностический флаг: модель недообучена, если train_acc < 0.7.

overfitting

int (0/1)

Диагностический флаг: модель переобучена, если gap > 0.15.

dataset_shift

int (0/1)

Диагностический флаг: есть смещение тестовых данных, если val_acc - test_acc > 0.15.

С мериками получилось сложно, нельзя точно сказать, что при val_acс 0.9 нет переобучения, однако, в рамках работы я просто тестил всё на test_dataset и ставил метки по нему. правила для меток:

def diagnose(metrics): return { "underfitting": int(metrics["train_acc"] < 0.7), "overfitting": int(metrics["gap"] > 0.15), "dataset_shift": int(metrics["val_acc"] - metrics["test_acc"] > 0.15) }

В итоге в датасете я получил:

Кол-во меток в  датасете
Кол-во меток в датасете

Касаемо качества датасета, меня устаивает, есть как и ужасные модели, так и неплохие, acc достиг 0.9.

dc4719b7e5f95345f2fa687b39b9b032.png

4. Признаки для мета-классификатора

Одним из самых интересных источников информации является форма learning curve. Я вытащил из него много признаков, все признаки на которых я делал метрики (подразумеваются как недоступные я удалил из обучения)

df["curve_start"] = df["val_curve"].apply(lambda x: x[0]) df["curve_mid"] = df["val_curve"].apply(lambda x: x[len(x)//2]) df["curve_end"] = df["val_curve"].apply(lambda x: x[-1]) df["curve_growth"] = df["curve_end"] - df["curve_start"] df["curve_stability"] = df["val_curve"].apply(np.std)

5. Обучение моделей

Для классификации были протестированы несколько алгоритмов:

  • Random Forest

  • XGBoost

  • Logistic Regression

  • ансамбль моделей

Поскольку задача имеет несколько независимых меток, использовался MultiOutputClassifier.

rf = RandomForestClassifier( n_estimators=200, random_state=42 ) model = MultiOutputClassifier(rf) model.fit(X_train, y_train) pred = model.predict(X_test)

Итоги после обучения:

precision recall f1-score support 0 0.94 0.89 0.91 177 1 0.96 0.97 0.96 593 2 0.97 0.88 0.92 233 3 0.75 0.73 0.74 419 micro avg 0.90 0.87 0.89 1422 macro avg 0.90 0.87 0.88 1422 weighted avg 0.90 0.87 0.88 1422 samples avg 0.86 0.84 0.83 1422Важность признаков в случайном лесе.

Важность признаков в случайном лесе.

Лучшие результаты показал Random Forest.

Он хорошо определял:

  • underfitting

  • dataset shift

Логистическая регрессия показала более низкое качество — что ожидаемо, так как она является линейным классификатором. Ансамбль моделей практически не улучшил результат.

6. Результаты

Этот подход можно использовать в ML-pipeline.

Это может позволить:

  • автоматически выявлять переобучение

  • обнаруживать проблемы с данными

  • останавливать обучение раньше

  • экономить вычислительные ресурсы

На этом всё, если кому-то интересно потрогать руками напишите в комментах, на гитхабе надо убраться, буду рад критике, есть что добавить и так, планирую дописать 2 часть. Вот мой гитхаб, в целом там иногда есть что-то интересное.

Спасибо, всем хорошего дня.

Источник

Возможности рынка
Логотип Ucan fix life in1day
Ucan fix life in1day Курс (1)
$0.0003693
$0.0003693$0.0003693
+0.84%
USD
График цены Ucan fix life in1day (1) в реальном времени
Отказ от ответственности: Статьи, размещенные на этом веб-сайте, взяты из общедоступных источников и предоставляются исключительно в информационных целях. Они не обязательно отражают точку зрения MEXC. Все права принадлежат первоисточникам. Если вы считаете, что какой-либо контент нарушает права третьих лиц, пожалуйста, обратитесь по адресу [email protected] для его удаления. MEXC не дает никаких гарантий в отношении точности, полноты или своевременности контента и не несет ответственности за любые действия, предпринятые на основе предоставленной информации. Контент не является финансовой, юридической или иной профессиональной консультацией и не должен рассматриваться как рекомендация или одобрение со стороны MEXC.