type
Post
status
Published
date
Mar 26, 2026
slug
Knowledge_1
summary
tags
数据结构
LLMs
category
数据结构与算法
icon
password
1.常见的激活函数
1. GELU (Gaussian Error Linear Unit)
数学定义:
GELU(x) = x Φ(x),其中 Φ(x) 是标准正态分布的累积分布函数。
在工程中常使用 Tanh 或代数近似来加速计算。
优点:
- 平滑性:在所有点上处处可导,且在负数区域是非单调的,这使得它比 ReLU 能提供更丰富的梯度信息,有助于更深层网络的优化。
- 融合了随机正则化:它将 Dropout、Zoneout 和 ReLU 的思想结合在一起(可以理解为对输入 x 乘以一个服从伯努利分布的 0 或 1,而这个概率由 x 的高斯分布决定)。
缺点:
- 计算复杂度相对较高(涉及误差函数 erf 或指数计算),在推理时会有一定的性能开销。
使用场景:
- BERT、GPT-2、GPT-3、RoBERTa 等经典大模型
- ViT 的 FFN(前馈神经网络)层
2. Swish / SiLU (Sigmoid Linear Unit)
数学定义:
Swish(x) = x · σ(βx),当 β = 1 时称为 SiLU。
优点:
- 无上界有下界:避免梯度爆炸,同时下界引入微小正则化效果。
- 平滑且非单调:在 x < 0 的某一区域内输出为负值且存在极小值,增强模型表达能力。
缺点:
- 计算成本高于 ReLU(需要计算 Sigmoid)。
使用场景:
- 现代大模型中广泛使用
- GLU(门控线性单元)的变体基础
3. SwiGLU (Swish Gated Linear Unit) —— 当前大模型主流
数学定义:
SwiGLU(x, W, V) = Swish(xW) ⊗ (xV)
(本质是一种门控结构,而非单一激活函数)
优点:
- 性能上限高:在 PaLM、LLaMA 等论文中证明,相同参数量下可显著降低困惑度并提升性能。
- 动态信息流:门控机制可控制信息流动。
缺点:
- 参数量增加:通常将隐藏层从 4d 调整为 8/3 d 以保持参数量不变。
- 计算量高于 GELU。
使用场景:
- LLaMA 系列(1/2/3)
- PaLM、Qwen、Mistral 等主流大模型
4. ReLU (Rectified Linear Unit)
数学定义:
ReLU(x) = max(0, x)
优点:
- 计算高效:仅需阈值比较
- 缓解梯度消失:正区间梯度恒为 1
- 引入稀疏性:部分神经元输出为 0
缺点:
- Dying ReLU 问题:如果学习率设置过大,导致大梯度更新后权重变为负数,输入恒小于 0,则该神经元的梯度永远为 0,无法再更新。
- 非零中心化:可能导致优化震荡(Zig-Zag)
使用场景:
- CNN(如 ResNet)
- 传统 MLP
- 边缘设备模型
5. Leaky ReLU / PReLU
数学定义:
- x ≥ 0:f(x) = x
- x < 0:f(x) = αx (Leaky ReLU 中 α 为常数,PReLU 中 α 可学习)
优点:
- 缓解 Dying ReLU 问题
缺点:
- 性能提升不稳定
- PReLU 增加参数量
使用场景:
- GAN 判别器
- ReLU 失效时的替代方案
6. Sigmoid
数学定义:
σ(x) = 1 / (1 + e^{-x})
优点:
- 输出范围 (0,1),具有概率意义
缺点:
- 梯度消失严重(两端饱和)
- 计算慢(指数运算)
- 非零中心化
使用场景:
- 二分类输出层
- LSTM / GRU 门控机制
7. Tanh (Hyperbolic Tangent)
数学定义:
Tanh(x) = (e^x - e^{-x}) / (e^x + e^{-x})
优点:
- 输出范围 (-1,1)
- 零中心化,优化更稳定
缺点:
- 仍存在梯度消失问题
- 包含指数计算
使用场景:
- RNN / LSTM 隐藏状态
- 早期生成模型
二、 常见的优化器
在预训练和 SFT(监督微调)阶段,由于Transformer对超参数极其敏感,且参数量巨大,我们对优化器的选择主要在收敛稳定性和显存开销之间权衡。
1. AdamW (Adam with Decoupled Weight Decay) —— 目前 LLM 的绝对标配
- 原理:Adam 的改进版。传统 Adam 将 L2 正则化等效为权重衰减,这在自适应学习率下是错误的。AdamW 将权重衰减(Weight Decay)与梯度更新解耦,直接作用于权重本身。
- 优点:
- 泛化能力更强:解耦权重衰减后,能更有效地限制权重膨胀,提升模型在复杂任务(如语言建模)上的泛化性能。
- 收敛极快且稳定:结合了一阶动量(平滑梯度)和二阶动量(自适应学习率),非常适合 Transformer 这种复杂的非凸优化面。
- 缺点:
- 极其消耗显存:对于每个参数,除了梯度外,还需要额外保存一阶动量 $m$ (FP32, 4 bytes) 和二阶动量 $v$ (FP32, 4 bytes)。这意味着在混合精度训练中,优化器状态需要占用 $8 \times$ 模型参数量的显存。
- 使用场景:GPT 系列、LLaMA 系列、Qwen 等几乎所有主流大模型的预训练与微调。
2. Adafactor
- 原理:为了解决 Adam 显存占用过大的问题,Adafactor 对二阶动量矩阵进行了低秩分解(近似),并且默认去除了一阶动量。
- 优点:
- 显存极其友好:将二阶动量的空间复杂度从 $O(n \times m)$ 降到了 $O(n + m)$,大大减少了 Optimizer State 的显存占用。
- 缺点:
- 收敛速度/稳定性妥协:由于丢弃了一阶动量并做了矩阵近似,在某些任务上收敛速度比 AdamW 慢,且对学习率调度器(LR Scheduler)的要求更为苛刻。
- 使用场景:Google 的 T5、PaLM 等超大规模模型的预训练(当算力集群显存受限,且不用 ZeRO-3 时的高效平替)。
3. 8-bit Adam / Paged AdamW (结合量化与系统工程的优化器)
- 原理:并非全新的数学算法,而是工程上的极致优化。由 bitsandbytes 库提出,将 Adam 的一阶和二阶动量从 32-bit 浮点数量化到 8-bit(非线性量化)。Paged 版本则结合了统一内存(Unified Memory)机制。
- 优点:
- 打破显存瓶颈:将优化器状态显存占用降低了约 75%,使得单卡(如 24G RTX3090/4090)微调较大模型(如 7B/13B)成为可能。
- Paged 版本能在显存不足时将优化器状态 Offload 到 CPU 内存,防止 OOM。
- 缺点:在量化和反量化过程中会引入极小规模的额外计算开销(通常可忽略);极端情况下可能带来微小的精度损失。
- 使用场景:QLoRA 等参数高效微调(PEFT)场景的主力优化器。
二、 经典基础优化器(深度学习通用)
1. SGD (Stochastic Gradient Descent) 及其动量变体 (Momentum)
- 原理:最基础的梯度下降,每次使用一个 Batch 的数据计算梯度。结合 Momentum 后,会累积之前的梯度方向(模拟惯性)。
- 优点:
- 显存占用极小(标准 SGD 几乎无额外状态,带动量版只需存一个动量)。
- 理论上在长序列训练后,SGD 往往能找到更平缓的局部最优点,泛化性能上限高。
- 缺点:
- 所有参数共享同一个全局学习率,对稀疏数据极不友好。
- 遇到鞍点(Saddle Point)或病态曲率(Hessian 矩阵条件数大)时,收敛极慢。
- 使用场景:经典的图像分类 CNN 网络(如 ResNet 在 ImageNet 上的训练)。
2. AdaGrad
- 原理:为每个参数分配独立的学习率。通过累积历史梯度的平方和,作为学习率的分母。
- 优点:适合处理稀疏特征(如早期的 NLP 词向量任务),高频特征更新步长小,低频特征更新步长大。
- 缺点:学习率单调递减,随着训练推进,分母越来越大,学习率趋近于 0,导致网络提前停止学习。
- 使用场景:早期的稀疏数据推荐系统,目前已极少在深层网络中使用。
3. RMSProp
- 原理:为了解决 AdaGrad 学习率递减的问题,采用指数移动平均(EMA)来计算历史梯度平方,而不是简单累加。
- 优点:缓解了学习率下降过快的问题,非常适合处理非平稳目标(RNN 等存在梯度剧烈波动的网络)。
- 缺点:依然缺乏一阶动量的加持,在复杂面上的寻优速度不及 Adam。
- 使用场景:在 Transformer 出现前的 RNN/LSTM 时代是标配,目前也是部分强化学习算法(如 A2C)的默认选择。
解析 RoPE (Rotary Position Embedding 旋转位置编码)
RoPE(由苏剑林等人在 RoFormer 中提出)是目前大模型(如 LLaMA, Qwen, ChatGLM, PaLM 等)最为主流的位置编码方案。
1. 核心思想与数学直觉
在 RoPE 之前,绝对位置编码(如 Sinusoidal、可学习位置编码)无法很好地表达 Token 之间的相对距离;而传统的相对位置编码(如 T5, ALiBi)计算复杂度较高或需要在 Attention 矩阵上做加法,破坏了 FlashAttention 等底层算子的优化。
RoPE 的极其优雅之处在于:通过绝对位置的旋转操作,自然地实现了相对位置的表达。
假设有查询向量 $q$ 在位置 $m$,键向量 $k$ 在位置 $n$:
- 绝对位置编码:RoPE 将位置 $m$ 映射为一个旋转矩阵 $R_m$(在复平面上相当于乘以 $e^{im\theta}$),然后与向量 $q$ 相乘,即 $q_m = R_m q$。
- 相对位置浮现:在计算 Attention Score 时,需要求内积。神奇的数学性质出现了: $$q_m^T k_n = (R_m q)^T (R_n k) = q^T R_m^T R_n k = q^T R_{n-m} k$$ 这意味着,两个经过 RoPE 处理的向量的内积,仅仅依赖于它们之间的相对距离 $(n-m)$,而不依赖于它们的绝对位置。
2. RoPE 的优点
- 融合了绝对与相对优势:既保留了绝对位置的顺序信息,又在 Attention 计算时精确体现了相对距离。
- 远程衰减特性:随着相对距离 $(n-m)$ 的增加,RoPE 控制的内积期望值会自然震荡并呈衰减趋势,这符合自然语言中“距离越远,关联度通常越弱”的直觉。
- 极其强大的长度外推潜力(Length Extrapolation):这是 RoPE 在大模型时代封神的根本原因。当我们需要将模型从 4K 上下文扩展到 32K 甚至 128K 时,基于 RoPE 衍生出了一系列免训练或少训练的扩展算法,如 位置插值(Linear Scaling)、NTK-aware Scaling、YaRN 等。通过修改旋转频率 $\theta$ 的 base 值,可以极为平滑地扩展模型的上下文窗口。
3. RoPE 的缺点(工程实现代价)
- 计算量相比于不加位置编码或标量相加(如 ALiBi)略大,需要在 Q、K 乘积前进行复数乘法或分块旋转计算。不过在现代 GPU 上,通过编写融合算子(Fused Rotary Embedding Kernel),这部分开销已被大幅抹平。
常见的 Normalization 方式及其优缺点
在深度学习中,Normalization 主要是为了缓解内部协变量偏移(Internal Covariate Shift),平滑 Loss 空间,防止梯度消失或爆炸。
1. BatchNorm (BN - 批归一化)
- 原理:在 Batch 维度上对每一个特征通道(Channel)求均值和方差进行归一化。
- 优点:
- 在 CNN(如 ResNet)中表现极佳,能加速收敛,并带有一定的正则化效果(类似于 Dropout)。
- 缺点:
- 严重依赖 Batch Size:当 Batch Size 太小(如显存受限时)或分布极不稳定时,统计的均值方差偏差极大,导致模型崩溃。
- 极不适合 NLP / 变长序列:由于同一个 Batch 内句子的长度不同,Padding 部分会严重干扰统计量,且无法有效处理测试时出现比训练集更长序列的情况。大模型中已完全淘汰。
2. LayerNorm (LN - 层归一化)
- 原理:在 Hidden 维度(即每个 Token 的特征向量维度)上独立求均值和方差进行归一化。
- 优点:
- 摆脱 Batch Size 依赖:每次归一化只针对当前样本的单个 Token,非常适合动态 Batch Size 和变长文本序列。
- Transformer 的经典标配:原版 Attention is All you Need 就采用这种方式。
- 缺点:
- 需要同时计算均值(Mean)和方差(Variance),其中减去均值的操作(Mean-Centering)在分布式计算中需要一定的同步开销,计算复杂度相对较高。
3. RMSNorm (Root Mean Square Normalization - 均方根归一化) —— 当前 LLM 的绝对主流
- 原理:LayerNorm 的“极简降本版”。作者发现,LayerNorm 中“减去均值将数据中心化”这一步其实对模型训练的增益微乎其微,真正起作用的是“除以方差进行缩放”。因此,RMSNorm 直接舍弃了求均值操作,只用均方根(RMS)进行归一化: $$RMS(a) = \sqrt{\frac{1}{d} \sum_{i=1}^{d} a_i^2 + \epsilon} \quad \Rightarrow \quad \bar{a} = \frac{a}{RMS(a)} \odot g$$
- 优点:
- 计算极其高效:省去了均值的计算,减少了访存和计算量,整体速度比 LayerNorm 快 10% ~ 40%(具体取决于硬件和 Kernel 实现)。
- 性能不降反升:在极大规模语料的预训练中,其实际效果与 LayerNorm 几乎无异,甚至在部分场景下收敛更为稳定。
- 缺点:理论上打破了平移不变性(Shift Invariance),但在大语言模型的高维稀疏空间中,这并未成为实际的负面影响。
- 使用场景:LLaMA 1/2/3、Qwen、Gemma、Mistral 等几乎所有现代主流开源大模型。
GRPO
GRPO(组相对策略优化) 是由 DeepSeek 团队(在 DeepSeekMath 中首次提出,并在轰动业界的 DeepSeek-V3 和 DeepSeek-R1 中大放异彩)提出的一种极其优雅且高效的强化学习对齐算法。
1. 为什么需要 GRPO?(解决 PPO 的痛点)
在传统的 PPO(Proximal Policy Optimization)训练中,为了稳定更新,我们需要计算优势函数(Advantage)。这通常需要一个 Critic 模型(价值模型) 来预估当前状态的基线价值(Baseline)。
在 LLM 时代,这意味着我们在显存中需要同时驻留 4个大模型:
- Actor (策略模型,需更新)
- Critic (价值模型,需更新,通常与Actor同等参数量)
- Reference (参考模型,冻结)
- Reward (奖励模型,冻结)
痛点:Critic 模型极其消耗显存,导致在有限算力下,很难对超大参数模型(如 70B 甚至千亿参数)进行完整的 PPO 训练。
2. GRPO 的核心机制("杀掉" Critic)
GRPO 的核心思想是:用组内相对评分,替代绝对的 Critic 基线预估。
- 采样生成(Rollout):对于同一个 Prompt $q$,Actor 模型生成一组(Group)共 $G$ 个不同的回答(例如 $G=4$ 或 $8$)。
- 奖励打分:使用 Reward 模型(或在数学/代码任务中直接使用规则校验器 Rule-based Verifier)对这 $G$ 个回答打分,得到 $\{r_1, r_2, ..., r_G\}$。
- 组内标准化(计算 Advantage):直接计算这 $G$ 个分数的均值 $\mu$ 和标准差 $\sigma$。每个回答的优势函数即为: $$A_i = \frac{r_i - \mu}{\sigma}$$
- 策略更新:将上述计算出的 $A_i$ 代入类似 PPO 的裁剪损失函数(Clipping Loss)中进行梯度回传,同时加入 KL 散度惩罚以防止偏离 Reference 模型过远。
3. GRPO 的优缺点
- 优点:
- 极简且省显存:彻底砍掉了庞大的 Critic 模型,节省了整整一个 LLM 的显存开销(以及其对应的优化器状态),极大地降低了 RLHF 的硬件门槛。
- 天然的 Prompt 级别基线:同一组生成的回答面对的是同一个 Prompt 的难度,组内归一化天然消除了不同 Prompt 之间难度不同带来的奖励方差问题,训练非常稳定。
- 极度契合逻辑推理任务(R1 的成功秘诀):配合基于规则的验证器(代码编译结果、数学答案比对),完美实现了零额外显存成本的超长思维链(CoT)强化学习强化。
- 缺点:
- 生成开销大:在 Rollout 阶段,每个 Prompt 需要生成 $G$ 个回答,这对推理系统的吞吐量(尤其是长文本生成)提出了极高要求(需要极致的 vLLM / 融合算子支持)。
重要性采样
一、 数学直觉:什么是重要性采样?
重要性采样的核心目的,是用一个分布(行为分布 $q$)中采样得到的数据,去估计另一个分布(目标分布 $p$)的期望。
假设我们想计算函数 $f(x)$ 在分布 $p(x)$ 下的期望 $\mathbb{E}{x \sim p}[f(x)]$,但从 $p(x)$ 中采样很难,或者我们手里只有从另一个分布 $q(x)$ 采样的历史数据。我们可以做如下恒等变换:
$$ \mathbb{E}{x \sim p}[f(x)] = \int p(x)f(x)dx = \int q(x) \frac{p(x)}{q(x)} f(x)dx = \mathbb{E}_{x \sim q} \left[ \frac{p(x)}{q(x)} f(x) \right] $$
这里,$\rho(x) = \frac{p(x)}{q(x)}$ 就是“重要性权重(Importance Weight)”。它在修正概率分布的差异:如果某个样本在 $p$ 中概率大但在 $q$ 中概率小,它的权重就会被放大;反之亦然。
二、 在强化学习中的绝妙推导:消去环境动态概率
在强化学习中,我们的目标是最大化策略 $\pi_\theta$ 生成的轨迹 $\tau = (s_0, a_0, s_1, a_1...)$ 的期望总奖励 $R(\tau)$。
如果在每一次更新策略(变成 $\pi_{new}$)后,都要重新去环境中交互采样,那是极其昂贵的(On-policy 的痛点,比如早期的 REINFORCE 算法)。为了提高样本利用率(Sample Efficiency),我们希望复用旧策略 $\pi_{old}$ 收集的数据来更新当前的 $\pi_\theta$。
这就用到了重要性采样。我们将目标函数改写为:
$$ J(\theta) = \mathbb{E}{\tau \sim \pi{old}} \left[ \frac{P(\tau | \pi_\theta)}{P(\tau | \pi_{old})} R(\tau) \right] $$
这里最美妙的数学抵消出现了!一条轨迹的概率 $P(\tau | \pi)$ 是由“策略执行动作的概率”和“环境状态转移的概率(Environment Dynamics)”共同决定的:
$$ P(\tau | \pi) = P(s_0) \prod_{t=0}^T \pi(a_t|s_t) P(s_{t+1}|s_t, a_t) $$
当我们计算重要性权重 $\frac{P(\tau | \pi_\theta)}{P(\tau | \pi_{old})}$ 时,与环境相关的转移概率 $P(s_{t+1}|s_t, a_t)$ 和初始状态概率 $P(s_0)$ 刚好在分子分母中被完美约掉了!
最终只剩下动作概率的比值:
$$ \rho_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{old}(a_t|s_t)} $$
核心意义:这意味着我们不需要知道环境的模型(Model-free),仅仅依靠比较新旧策略对历史动作的概率输出比值,就能对策略进行梯度更新。
三、 致命的缺陷:方差爆炸 (Variance Explosion)
理论很完美,但现实很骨感。重要性采样在实际工程中有一个致命弱点:极高的方差。
- 连乘效应:在完整的轨迹推导中,权重是 $\prod_{t} \frac{\pi_\theta}{\pi_{old}}$。如果轨迹很长,且 $\pi_\theta$ 与 $\pi_{old}$ 差异变大,这个连乘积会迅速趋于 $0$ 或 $\infty$。
- 训练崩溃:在某一刻,如果重要性权重 $\rho$ 变得极大,就会产生一个极大的梯度,直接把神经网络更新“飞”了,导致模型参数崩塌(Loss Spike)。
四、 大模型算法工程师视角的破局:PPO 的裁剪机制
为了解决方差爆炸,大模型时代 RLHF 的标准算法 PPO(Proximal Policy Optimization) 对重要性采样做了一个天才般的工程改造。
PPO 的核心不是抛弃重要性采样,而是约束重要性权重,不让新旧策略偏离太远。这就是著名的 Clipping(裁剪)目标函数:
$$ L^{CLIP}(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta)\hat{A}t, \, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat{A}t \right) \right] $$
其中,$r_t(\theta) = \frac{\pi\theta(a_t|s_t)}{\pi{old}(a_t|s_t)}$ 就是重要性权重,$\hat{A}_t$ 是优势函数(Advantage),$\epsilon$ 通常设为 $0.1$ 或 $0.2$。
工程师视角的原理解析:
- 复用数据:大模型生成(Rollout)回答非常消耗显存和时间。通过重要性采样 $r_t(\theta)$,我们收集到一批回答后,可以进行多次(比如 4个 Epoch)的梯度下降。
- 安全更新:
- 如果某个动作表现好($\hat{A}_t > 0$),我们会增大 $r_t(\theta)$。但如果它大过 $1+\epsilon$,clip 函数就会将其截断,防止单次更新迈的步子过大。
- 如果表现差($\hat{A}_t < 0$),我们会减小 $r_t(\theta)$。如果小到低于 $1-\epsilon$,同样被截断。
- 信任域(Trust Region)思想:这确保了 $\pi_\theta$ 在更新时,始终保持在 $\pi_{old}$ 的附近,从根本上解决了重要性采样的方差爆炸问题。
常见的损失函数
按照任务类型,我们可以将常见损失函数分为三大类:
1. 分类任务 (Classification)
- Cross-Entropy Loss (交叉熵损失 CE):
- 原理:衡量模型预测的概率分布与真实标签分布(通常是 One-hot 编码)之间的差异。数学本质是最小化两个分布的 KL 散度。
- 优点:梯度计算稳定,配合 Softmax 能有效拉大不同类别之间的概率差距。
- 场景:几乎所有的多分类任务,包括语言模型的词表预测。
- Focal Loss:
- 原理:在交叉熵的基础上增加了权重因子 $(1 - p_t)^\gamma$。
- 优点:专门解决类别极度不平衡的问题,让模型自动降低“容易分类样本”的权重,将注意力集中在“困难样本”上。
- 场景:目标检测(如 RetinaNet)、长尾分布的 NLP 分类任务。
2. 回归与排序任务 (Regression & Ranking)
- MSE (均方误差 / L2 Loss) & MAE (平均绝对误差 / L1 Loss):
- 原理:衡量预测连续值与真实值之间的距离。MSE 对异常值敏感(惩罚大),MAE 对异常值鲁棒但梯度恒定。
- 场景:传统的连续值预测。在 LLM 领域,**Reward Model(奖励模型)的打分头(Value Head)**初始化时有时会参考这类连续值回归的思想。
- Bradley-Terry (BT) Loss / Ranking Loss:
- 原理:在大模型的 RLHF (PPO/DPO) 阶段,奖励模型并不关注绝对分数,而是关注“偏好对(Chosen vs Rejected)”的相对排序。
- 公式:$-\log(\sigma(r_{chosen} - r_{rejected}))$
- 场景:专门用于训练大模型的 Reward Model 或者是直接用于 DPO 算法中。
3. 表征与对比学习 (Representation & Contrastive)
- InfoNCE Loss (对比损失):
- 原理:拉近正样本对(如同一句话的不同视角)在潜空间的距离,推开负样本(Batch 内的其他句子)的距离。
- 场景:训练向量数据库依赖的 Embedding 模型(如 BGE、SimCSE),以及多模态对齐(如 CLIP)。
大模型 SFT(监督微调)中使用的损失函数
在 LLM 的 SFT 阶段,最核心的损失函数依然是 Cross-Entropy Loss (交叉熵损失),但它在具体的任务构造和工程实现上,有着非常严格且巧妙的设计。我们称之为 Causal Language Modeling (CLM) Loss 或 Next-Token Prediction Loss。
1. 数学表达
对于一段长度为 $N$ 的文本序列 $X = (x_1, x_2, ..., x_N)$,大模型是自回归(Auto-regressive)的,即通过前 $i-1$ 个 Token 预测第 $i$ 个 Token。其标准损失函数为:
$$ \mathcal{L} = -\frac{1}{N} \sum_{i=1}^{N} \log P(x_i | x_{1}, ..., x_{i-1}; \Theta) $$
模型输出 logits 后,经过 Softmax 转化为概率 $P$,然后与真实的下一个词的 One-hot 标签计算交叉熵。
2. SFT 阶段的致命细节:Loss Masking (掩码机制) —— 面试核心考点
虽然数学公式很简单,但在 SFT 工程实践中(如使用 PyTorch/HuggingFace),绝对不能对整段文本(Prompt + Response)计算 Loss。
- 为什么要 Mask Prompt? SFT 的目的是让模型学习“如何根据人类指令回答问题”,而不是让模型去“死记硬背人类的提问”。如果对 Prompt 部分也计算 Loss,模型会把算力浪费在预测用户输入上,甚至导致对话时出现严重的“自言自语”或“复读机”现象。
- 工程实现:
在构建数据时,我们会计算出 Prompt 部分的长度。在 PyTorch 的
CrossEntropyLoss中,有一个关键参数ignore_index(默认值为100)。 我们会把 Prompt 部分和 Padding 部分的标签(Labels)全部替换为100。
3. SFT 损失函数的前沿优化与系统挑战
作为大模型工程师,不仅要会写
-100,还要关注损失函数在分布式训练中的表现:- 通信瓶颈 (Megatron-LM 视角的优化):
大模型的词表通常很大(如 LLaMA 的 32k,Qwen 的 152k)。在计算最后一层 Logits 并求 Softmax 交叉熵时,显存占用极大。我们通常会使用 张量并行 (Tensor Parallelism) 把词表维度切分到不同的 GPU 上,这被称为
VocabParallelCrossEntropy。它能避免把庞大的 Logits 汇聚到一张卡上再算 Loss,极大节约显存。
- Flash Cross Entropy: 为了进一步提高计算速度和降低显存峰值,现代框架(如 Unsloth, Flash-Attention 生态)融合了 Softmax 和 Cross-Entropy 算子(Fused Kernel),避免了中间庞大矩阵的读写操作(Memory-Bound 优化)。
- Z-Loss 正则化 (PaLM/Gemma 等模型的秘籍): 在计算 Softmax 交叉熵时,如果没有约束,Logits 的绝对值可能会漂移得非常大(比如达到几十甚至上百),这在 BF16/FP16 混合精度训练下极易引发数值溢出(NaN)。 因此,Google 在训练中会在标准交叉熵中加入一个辅助损失 Z-Loss: $$ \mathcal{L}_{aux} = \alpha \cdot \log^2 Z $$ 其中 $Z = \sum e^{logits}$ 是 Softmax 的分母。这会迫使 Logits 保持在 $0$ 附近,极大提升了模型在大 Batch Size SFT 时的训练稳定性。