Модуль IV·Статья II·~3 мин чтения
Стохастическая оптимизация и современные методы
Выпуклая оптимизация для ML
Превратить статью в подкаст
Выберите голоса, формат и длину — AI запишет аудио
Стохастическая оптимизация в глубоком обучении
Современные нейронные сети обучаются на миллиардах примеров и имеют миллиарды параметров. Вычисление точного градиента невозможно — нужны стохастические методы. Понимание их теоретических свойств позволяет настраивать обучение и диагностировать проблемы.
Стохастический градиентный спуск: теория
Постановка: min_θ f(θ) = (1/n) Σᵢ fᵢ(θ). Стохастический градиент: gₜ = ∇fᵢₜ(θₜ), где iₜ выбирается случайно. Ключевые свойства: E[gₜ] = ∇f(θₜ) (несмещённость), Var[gₜ] = σ² (конечная дисперсия).
Оптимальные learning rate schedules:
- Убывающий: αₜ = α₀/√t → сходимость O(σ/√T) (конвексный случай)
- Постоянный: αₜ = α → сходимость к окрестности, но не к оптимуму
- Убывающий для SC: αₜ = 2/(μ(t+1)) → O(σ²/(μT)) (сильно выпуклый)
Mini-batch: gₜ = (1/|B|)Σᵢ∈Bₜ ∇fᵢ(θₜ). Дисперсия уменьшается: Var[gₜ] = σ²/|B|. Линейное ускорение до критического размера батча B_crit ≈ σ²/||∇f||² — дальше параллелизм помогает только по времени, не по итерациям.
Adam: теоретический анализ
Adam (Kingma & Ba, 2014) — де-факто стандарт для обучения нейронных сетей:
mₜ = β₁ mₜ₋₁ + (1−β₁) gₜ (сглаженное среднее градиента) vₜ = β₂ vₜ₋₁ + (1−β₂) gₜ² (сглаженное среднее квадрата) m̂ₜ = mₜ/(1−β₁ᵗ), v̂ₜ = vₜ/(1−β₂ᵗ) (коррекция смещения) θₜ₊₁ = θₜ − α · m̂ₜ/(√v̂ₜ + ε)
Расшифровка: m̂ₜ/√v̂ₜ ≈ sgn(gₜ) при стационарном режиме — Adam делает шаги фиксированного размера в направлении знака градиента, адаптируя lr по каждому параметру. Параметры с историей большого градиента получают меньший lr.
Теоретические проблемы: Reddi et al. (2018) построили пример, где Adam не сходится даже для выпуклых функций. Причина: v̂ₜ может «забыть» информацию о прошлых больших градиентах.
AMSGrad (Reddi, 2018): использует максимум v̂: v̂ₜᵐᵃˣ = max(v̂ₜ₋₁ᵐᵃˣ, v̂ₜ), обновляет θ через v̂ᵐᵃˣ. Гарантирована сходимость.
AdamW (Loshchilov & Hutter, 2019): Adam + правильная weight decay. Стандартный Adam применяет L2-регуляризацию к градиенту (через m̂/√v̂), что отличается от weight decay. AdamW: θ ← θ(1−αλ) − α·m̂/√v̂. Де-факто стандарт для трансформеров.
Variance Reduction: SVRG и SARAH
Проблема SGD: дисперсия σ² не стремится к нулю вблизи оптимума → oscillation, нельзя брать большой lr.
SVRG (Johnson & Zhang, 2013): периодически (каждые m шагов) вычисляем полный градиент ∇f(x̃). Уточнённый стохастический градиент:
gₜ = ∇fᵢ(xₜ) − ∇fᵢ(x̃) + ∇f(x̃)
Дисперсия → 0 при xₜ → x* (обе части сходятся к одному значению). Результат: линейная сходимость O(exp(−t)) для L-гладкой μ-SC задачи — как у детерминированного GD!
SARAH (Nguyen et al., 2017): рекурсивное variance reduction: gₜ = ∇fᵢ(xₜ) − ∇fᵢ(xₜ₋₁) + gₜ₋₁. Теоретически ещё лучше SVRG.
Практика: SVRG/SARAH эффективнее SGD для выпуклых задач с множеством сумм (logistic regression, SVM). Для нейронных сетей Adam с lr-scheduling практичнее (нелинейность нарушает теоретические гарантии).
Федеративное обучение
Мотивация: данные на устройствах пользователей нельзя централизовать (privacy). Хотим обучить глобальную модель без доступа к сырым данным.
FedAvg (McMahan et al., 2017):
- Сервер рассылает модель θ клиентам K
- Каждый клиент k обучает E эпох на локальных данных: θₖ ← θ − α·∇L_k(θ)
- Сервер усредняет: θ ← (1/K) Σₖ θₖ
Коммуникационная эффективность: E локальных эпох вместо 1 → уменьшаем число раундов в E раз.
Проблемы: Data heterogeneity (non-IID): если у клиентов разные распределения данных — FedAvg расходится. FedProx добавляет регуляризацию: клиент минимизирует L_k(θ) + μ/2||θ−θ_global||². Дифференциальная privacy: добавляем шум Гаусса к градиентам перед отправкой → (ε,δ)-DP гарантии.
Численный пример
Обучение BERT-base (110M параметров) на A100 GPU (FP16):
- Batch size = 256, lr = 2e-4, warmup = 10000 шагов
- Adam: β₁=0.9, β₂=0.999, ε=1e-8, weight decay=0.01
- После 1M шагов (~10 дней на 8×A100): val perplexity = 3.8
При увеличении batch size до 2048 (linear scaling rule: lr = 8·2e-4 = 1.6e-3): те же результаты за меньшее число итераций (в 8×), ускорение ≈4× (не 8× из-за коммуникации).
Задание: Сравните SGD, Adam, SVRG на MNIST с логистической регрессией (10 классов, L2 λ=0.001). Для SGD и Adam: найдите оптимальный lr через grid search. Для SVRG: m=n/10 (частота полного градиента). Постройте: val accuracy vs число gradient evaluations (честное сравнение). Реализуйте FedAvg для MNIST: 10 клиентов с non-IID data (каждый видит только 2 класса). Как деградирует качество по сравнению с centralized обучением?
§ Акт · что дальше