引言
模型量化的工程目标很简单:用 INT8 替代 FP32,换取更小的模型体积和更快的推理速度。但几乎每一个量化项目都会遭遇同一个问题——精度回退。
精度为什么会掉?表面原因各不相同,但追溯到底层,回退的根源可以归纳为三层:
- 量化本身的信息损失:将连续浮点值映射到离散整数,必然引入截断与舍入误差。
- 训练与部署的计算图不一致:训练阶段的计算图与部署阶段的执行图存在结构差异,导致“训练时没问题、部署后掉点”。
- 校准与量化初始化:QAT 的起点需要通过 calibration 获得合理的量化范围初始值,否则训练不稳定。
QAT(Quantization Aware Training)的价值正在于:它让模型在训练阶段就提前感知量化误差,从而在三个层面上同时进行修正。本文围绕这三层根因,梳理 PyTorch FX Graph Mode 下 QAT 的核心原理与工程实践。
1. 第一层:量化的信息损失
1.1 均匀量化的误差来源
均匀量化将浮点数 $x$ 映射为整数 $x_q$:
$$x_q = \text{clamp}\!\left(\left\lfloor \frac{x}{s} \right\rceil + z,\; 0,\; 2^b - 1\right)$$其中 $s$ 是 scale,$z$ 是 zero point,$b$ 是位宽(通常为 8)。反量化得到 $\hat{x} = s \cdot (x_q - z)$。
误差 $\epsilon = x - \hat{x}$ 来自两个环节:
- 舍入误差:$\lfloor \cdot \rceil$ 将连续值离散化,最大误差为 $s/2$。
- 截断误差:超出 $[0, 2^b - 1]$ 范围的值被 clamp,信息完全丢失。
scale $s$ 的选择决定了两种误差的权衡:$s$ 越小,舍入精度越高,但截断范围越窄;$s$ 越大,覆盖范围越广,但每个量化区间的精度越差。这就是 observer(观测器,负责统计数据分布并计算 scale/zero_point 的组件,常见策略有 MinMax / Histogram / Percentile)要解决的核心问题。
1.2 伪量化(Fake Quant)与 STE
QAT 的核心机制是 伪量化(Fake Quantization):前向传播中执行量化→反量化(模拟部署时的精度损失,但实际计算仍用浮点数),反向传播中用 直通估计器(Straight-Through Estimator, STE) 直接传递梯度。
$$\text{forward: } \hat{x} = s \cdot \left(\text{clamp}\!\left(\left\lfloor \frac{x}{s} \right\rceil + z,\; 0,\; 2^b-1\right) - z\right)$$$$\text{backward: } \frac{\partial \mathcal{L}}{\partial x} \approx \frac{\partial \mathcal{L}}{\partial \hat{x}}$$
STE 是一个近似方法(它假设量化操作的梯度为 1,这显然不精确):梯度在截断区域为零,这意味着被 clamp 掉的值无法通过梯度修正自身。这就是为什么 calibration 阶段的 scale/zero_point 初值如此重要——如果初始量化范围偏差过大,大量值落入截断区域,训练早期的梯度信号就会严重失真。
1.3 QAT 如何修正这层损失
QAT 让模型参数在感知到量化误差的前提下继续优化。经过若干轮 fine-tune,权重会自适应地调整到量化友好的分布——峰值更集中、长尾被压缩、关键通道的数值范围更紧凑。这是 QAT 相比 PTQ(Post-Training Quantization)精度更高的核心原因。
2. 第二层:训练与部署的计算图不一致
2.1 为什么训练和部署会不一致
这里的“一致性”指:训练时伪量化模拟的行为,必须与部署时 INT8 算子的实际计算完全一致。但在工程中,这种对齐很容易被破坏:
- 算子融合差异:训练时
Conv + BN + ReLU是三个独立算子(BN 在 train 模式下需要统计均值方差),部署时被融合为一个ConvBnReLU或进一步折叠 BN 到 Conv 权重中。如果伪量化节点的插入位置没有考虑融合后的语义,量化范围就会偏移。 - 后端配置不匹配:训练用
fbgemm(x86),部署用qnnpack(ARM),两者的量化 scheme 可能不同(对称 vs 非对称)。 - 图结构遗漏:手工在代码中插入伪量化节点时容易遗漏中间节点,导致部分路径在训练时未被量化但部署时被量化。
2.2 FX 如何解决这个问题
FX Graph Mode 的核心价值是把量化从“手工在代码中插入量化节点”提升为自动化的计算图变换:
- 自动算子融合:自动识别
Conv-BN-ReLU等组合并在正确位置插入伪量化节点。 - 自动观测器插入:根据量化配置规则(qconfig mapping),确保每个需要量化的节点都被覆盖。
- 转换一致性:
convert_fx将伪量化图直接转换为 INT8 执行图,训练和部署使用同一张图的不同形态。
这使得计算图一致性从“人工保证”变为“框架保证”。
2.3 FX-QAT 标准流程
FP32 Checkpoint
|
v
Load Float Model ---------------+
| |
v |
Build QAT Graph (FX) |
| |
v |
Load Calibration Checkpoint <---+
|
v
QAT Fine-tune (FakeQuant Active)
|
v
convert_fx
|
v
INT8 Deployable Model
2.4 关键代码
标准流程(单 checkpoint + 在线 calibration):
import torch |
双 checkpoint 变体(生产环境优化):
# 建图后加载预先保存的 calibration state,跳过在线 calibration |
3. 第三层:校准与量化初始化
3.1 标准流程:一个 checkpoint + 在线 calibration
标准 QAT 只需要一个 FP32 预训练 checkpoint。calibration 是建图后的在线步骤,不需要额外的 checkpoint 文件:
- 加载 FP32 预训练权重
prepare_qat_fx建图(插入 observer + fake quant)- 在线 calibration:用代表性数据跑若干轮前向推理,observer 自动统计激活值的 min/max 分布,计算初始 scale/zero_point
- QAT fine-tune
convert_fx→ INT8
calibration 的本质是让 observer “看到”真实数据分布,从而给 fake quant 一个合理的初始量化范围。如果跳过 calibration 直接训练,observer 的 scale/zero_point 处于默认值,fake quant 会在第一次前向传播中引入极大噪声,导致训练震荡或发散。
3.2 工程优化:双 checkpoint 模式
在生产环境(如 OpenMMLab 等框架)中,calibration 步骤常被独立出来并保存为第二个 checkpoint,形成”双 checkpoint”模式:
| checkpoint | 内容 | 用途 |
|---|---|---|
| float checkpoint | FP32 训练收敛的权重 | 提供参数起点 |
| calibration checkpoint | 模型参数 + observer 统计量 + scale/zero_point | 提供量化初始范围 |
这样做的好处是:多次 QAT 实验(调超参、换数据集)不需要每次重新跑 calibration,节省时间。加载顺序为:先 prepare_qat_fx 建图,再 load_state_dict 加载 calibration state(因为建图后 state_dict 才包含 observer 和 fake quant 的 key)。
双 checkpoint 是工程最佳实践,不是 QAT 的前置必要条件。 标准流程中,在线 calibration 完全可以替代第二个 checkpoint。
4. PyTorch 量化模式定位
PyTorch 的执行/导出模式体系可以用一条主轴来理解:
动态性主轴:Eager -> FX -> torch.export -> Backend IR,动态性逐步降低,结构确定性逐步提升。
在量化场景中,常用组合是 Eager -> FX(QAT) -> convert_fx -> 部署;跨框架交付时补充 ONNX 导出。选型规则:
| 阶段 | 推荐模式 | 原因 |
|---|---|---|
| 模型开发与训练 | Eager | 完整 Python 语义,调试友好 |
| 量化改图与插入量化/反量化节点 | FX | 自动算子匹配,计算图一致性 |
| 自动编译与后端优化 | torch.compile | 算子融合,硬件适配 |
| 跨框架模型交付 | ONNX | 标准交换格式 |
从生态趋势看,torch.export + pt2e(PyTorch 2 Export Quantization,基于 torch.export 的新一代量化框架)路线正在成为官方推荐的新量化路径,但短期内 FX 仍是大量存量项目的核心。
5. 工程决策与风险
5.1 消融映射
| 设置 | float ckpt | calibration | 现象 | 对应根因 |
|---|---|---|---|---|
| A | ✓ | ✓(在线或 ckpt) | 收敛稳定 | 三层均受控 |
| B | ✓ | ✗(跳过) | 训练震荡 | 第一层:scale 偏差导致 STE 梯度失真 |
| C | ✗ | ✓ | 收敛慢,上限受限 | 第三层:参数起点不在可优化区域 |
| D | Eager 手工 | N/A | 可用但维护成本高 | 第二层:计算图一致性靠人保证 |
5.2 失效场景
- 动态图路径复杂:模型含大量 if/else 分支,FX 跟踪时只能覆盖其中一条路径,导致量化节点插入不完整 → 第二层问题。
- 后端配置不一致:训练用 fbgemm、部署用 qnnpack,量化方案不对齐(如对称 vs 非对称量化) → 第二层问题。
- 校准数据失配:calibration 用的数据分布与实际推理数据差异大,scale 偏移 → 第一层问题。
5.3 决策树
目标是量化部署? |
5.4 复现要点
- 加载顺序:先 float checkpoint →
prepare_qat_fx建图 → 在线 calibration(或加载 calibration state)。 - 后端一致:训练、convert、部署三阶段的
torch.backends.quantized.engine必须相同。 - 分段验证:FP32 / fake-quant / INT8 三段精度都要可观测,定位回退发生在哪一层。
- BN 冻结:QAT 后期建议冻结 BN 统计量,避免小 batch 下均值方差抖动干扰量化范围。