Модуль IV·Статья I·~3 мин чтения
Генеративно-состязательные сети (GAN)
Генеративные модели
Превратить статью в подкаст
Выберите голоса, формат и длину — AI запишет аудио
Генеративно-состязательные сети (GAN)
GAN (Goodfellow et al., 2014) — один из наиболее творческих архитектурных паттернов в истории машинного обучения. Ян Лекун назвал их «самой интересной идеей в машинном обучении за последние 20 лет». Они породили целую эпоху синтетических медиа: реалистичные лица, deepfakes, AI-художники.
Принцип GAN: игровая постановка
Два агента: Generator G: z → x̂ (генерирует синтетические данные из шума z). Discriminator D: x → [0,1] (отличает реальные данные от синтетических).
Минимаксная игра:
min_G max_D V(D,G) = E_{xp_data}[log D(x)] + E_{zp_z}[log(1 − D(G(z)))]
Расшифровка: D максимизирует вероятность правильной классификации реальных (D(x)→1) и фейковых (D(G(z))→0). G минимизирует вероятность, что D обнаружит подделку (D(G(z))→1).
Оптимальное D при фиксированном G: D*(x) = p_data(x)/(p_data(x) + p_G(x)). Подставляя: max_D V = −log(4) + 2·JSD(p_data || p_G). JSD — Jensen-Shannon дивергенция. В равновесии: p_G = p_data → D* = 1/2 → V = −log(4).
Теорема Nash-равновесия GAN: Единственная точка Nash-равновесия: G воспроизводит p_data, D не может отличить (D*(x) = 0.5 всюду).
Проблемы обучения GAN
Mode collapse: G «находит» несколько успешных режимов и генерирует только их. Вместо разнообразных лиц — несколько «безопасных» вариантов. Причина: G не штрафуется за отсутствие разнообразия напрямую.
Нестабильность: Тонкий баланс — если D слишком хорош, gradient для G ≈ 0 (D насыщается). Если D слишком плох — G не получает полезного сигнала. Обучение нестабильно: G и D «гоняются» друг за другом.
Vanishing gradients при насыщении D: log(1 − D(G(z))) → 0 при D(G(z)) → 0. Практическое решение: перефразировать как min_G −E[log D(G(z))] (non-saturating loss). Теперь градиент не исчезает даже при сильном D.
Улучшения: DCGAN, WGAN, StyleGAN
DCGAN (Radford et al., 2015): Рекомендации для стабильного обучения: заменить pooling на strided convolutions (в G — transposed, в D — strided). Batch Normalization везде (кроме выходного слоя G и входного D). LeakyReLU в D (α=0.2). tanh в G-выходе. Убрать полносвязные слои.
WGAN (Arjovsky et al., 2017): Заменить JS-дивергенцию на расстояние Вассерштейна (Earth Mover's Distance): W₁(p,q) = inf_{γ∈Π(p,q)} E_{(x,y)~γ}[||x−y||]. Теорема Канторовича-Рубинштейна: W₁(p,q) = sup_{||f||_L≤1} [E_p[f] − E_q[f]]. Критик f (не дискриминатор) должен быть 1-Lipschitz → gradient clipping (||W|| ≤ с) или gradient penalty. Обучение значительно стабильнее, осмысленная функция потерь.
StyleGAN (Karras et al., 2018/2019): «Стиль» вводится в G через Adaptive Instance Normalization (AdaIN) на каждом уровне сети: AdaIN(x, y) = y_s (x − μ(x))/σ(x) + y_b. y_s и y_b — масштаб и сдвиг, вычисленные из latent-кода w (через mapping network z → w → y). Разные уровни контролируют разные масштабы: грубые слои (поза, форма) → средние (черты) → тонкие (текстура, цвет). Thispersondoesnotexist.com — пример.
FID (Fréchet Inception Distance): Стандартная метрика качества GAN. Используем Inception-v3 как feature extractor. Реальные и синтетические изображения → feature vectors → аппроксимируем Гауссианами (μ_r, Σ_r) и (μ_g, Σ_g). FID = ||μ_r − μ_g||² + Tr(Σ_r + Σ_g − 2(Σ_r Σ_g)^{1/2}). Меньше FID → лучше. StyleGAN2: FID=2.8 на FFHQ (vs 35 для первого GAN).
Численный пример
DCGAN для генерации MNIST (28×28, binary): Архитектура G: FC(100→256·7·7) → BN → Reshape(256,7,7) → ConvTranspose(128,4,2) → BN → ConvTranspose(1,4,2) → tanh. Архитектура D: Conv(64,4,2) → LeakyReLU → Conv(128,4,2) → BN → LeakyReLU → Flatten → FC → sigmoid.
Обучение 50 эпох (batch=64, Adam lr=0.0002): после 5 эпох — размытые цифры. После 20 — узнаваемые. После 50 — качественные образцы. FID(50 эпох) ≈ 12 (близко к реальным данным). Mode collapse не произошёл благодаря label smoothing (реальные → 0.9, не 1.0).
Задание: Реализуйте DCGAN для MNIST. (1) G: Linear(100)→Unflatten→ConvTranspose×3→Tanh. D: Conv×3→LeakyReLU→Flatten→Linear→Sigmoid. (2) Обучите 50 эпох. Визуализируйте прогресс каждые 5 эпох (grid 8×8). (3) Вычислите FID на 1000 generated samples. (4) Реализуйте interpolation в z-пространстве: линейно интерполируйте между z₁ и z₂ (8 шагов) → визуализируйте «плавный» переход между цифрами.
§ Акт · что дальше