ViT Vision Transformer Image Recognition at Scale

Task: Image Classification / Feature Backbone
Method: Vision Transformer, Patch Embedding
Venue: ICLR
Year: 2021
Paper: https://arxiv.org/abs/2010.11929
Code: https://github.com/google-research/vision_transformer

摘要

Transformer 架构已成为自然语言处理的事实标准,但在计算机视觉中仍受限于与卷积网络结合或替换卷积的局部组件。本文表明这种对 CNN 的依赖并非必要:将纯 Transformer 直接应用于图像 patch 序列即可在分类任务上取得优异性能。在大规模数据上预训练后迁移到多个中小规模基准(ImageNet, CIFAR-100, VTAB 等),Vision Transformer (ViT) 达到了与 SOTA 卷积网络相当甚至更好的结果,同时所需预训练计算资源大幅减少。最佳模型 ViT-H/14 在 ImageNet 上达到 88.55% top-1 准确率,CIFAR-100 上达到 94.55%。

核心论点:视觉识别不需要卷积——将图像切割为 16×16 的 patch 序列,直接输入标准 NLP Transformer,在足够大的数据集(14M-300M 图像)预训练下即可超越 CNN,证明了”大数据 + 大模型”的 scaling law 在视觉领域同样成立。

问题与动机

2020 年前,计算机视觉中 Transformer 的应用面临根本性矛盾:

方案 代表作 核心问题
局部注意力 Image Transformer (Parmar 2018) 仅在局部邻域计算注意力,丧失全局建模能力
稀疏注意力 Sparse Transformer (Child 2019) 需要特殊注意力模式,硬件效率低
CNN + Self-Attention Non-local Nets / DETR 仍依赖 CNN 特征提取,Transformer 仅做后处理
像素级全注意力 理论方案 针对像素的二次复杂度 $O(n^2)$,224×224 图像就有 50176 个 token,完全不可行

当时的认知是:视觉任务需要 CNN 的归纳偏置(locality、translation equivariance、2D 结构),纯 Transformer 在视觉上无法工作。

核心痛点:能否像 NLP 一样,用最简单的标准 Transformer 直接处理视觉数据?CNN 的归纳偏置在大规模预训练下是否真的必要?

核心洞察

洞察 1:图像 = Patch 序列,无需卷积

传统做法:图像 → CNN 提取特征 → 可能用 Transformer 做后处理。

本文做法:将图像 $\mathbf{x} \in \mathbb{R}^{H \times W \times C}$ 切割为 $N$ 个固定大小的 patch $\mathbf{x}_p \in \mathbb{R}^{N \times (P^2 \cdot C)}$,其中 $N = HW/P^2$,每个 patch 通过一个线性投影映射为 token embedding:

$$\mathbf{z}_0 = [\mathbf{x}_{\text{class}};\, \mathbf{x}_p^1 \mathbf{E};\, \mathbf{x}_p^2 \mathbf{E};\, \cdots;\, \mathbf{x}_p^N \mathbf{E}] + \mathbf{E}_{pos}$$

其中 $\mathbf{E} \in \mathbb{R}^{(P^2 \cdot C) \times D}$ 是可学习的线性投影,$\mathbf{E}_{pos} \in \mathbb{R}^{(N+1) \times D}$ 是可学习的 1D 位置嵌入。对于 $P=16$, $224 \times 224$ 图像仅产生 $14 \times 14 = 196$ 个 token,计算完全可行。

洞察 2:大数据弥补归纳偏置缺失

传统认知:CNN 的局部性(locality)和平移等变性(translation equivariance)是视觉任务不可或缺的归纳偏置。

本文发现:ViT 几乎不含视觉归纳偏置——仅在 patch 切割时使用了 2D 结构信息。在小数据集(ImageNet-1k)上,ViT 确实不如同等规模的 ResNet。但在大数据集上预训练后,情况逆转:

  • ImageNet-1k 预训练:ViT-B/16 仅 77.91%,不如 ResNet
  • ImageNet-21k 预训练(14M):ViT-L/16 达到 85.15%
  • JFT-300M 预训练(303M):ViT-H/14 达到 88.55%,超越所有 CNN

数据量越大,ViT 的优势越明显。论文在 JFT 子集实验(9M/30M/90M)中证明:CNN 在小数据集上更好但很快饱和,ViT 随数据量增长持续提升。

洞察 3:Transformer 在视觉上的 Scaling 更高效

传统做法:增大 CNN(ResNet-200x3 等)可以提升性能,但计算代价急剧增长。

本文做法:ViT 在 performance/compute trade-off 上显著优于 ResNet。相同精度下,ViT 使用约 2-4× 更少的计算:

模型 预训练 exaFLOPs ImageNet top-1
ResNet152x2 (14ep) 1126 85.56%
ViT-L/16 (14ep) 1567 87.12%
ResNet200x3 (14ep) 3306 87.22%
ViT-H/14 (14ep) 4262 88.08%

ViT-L/16 用不到 ResNet200x3 一半的计算就接近其精度。且 ViT 在测试范围内未出现饱和趋势,暗示继续 scaling 还能获益。

要记住的 3 个数字

  • **88.55%**:ViT-H/14 在 ImageNet 上的 top-1 准确率(JFT-300M 预训练),刷新 SOTA
  • 16×16:标准 patch 大小,将 224×224 图像转化为 196 个 token
  • 2-4×:ViT 相比 ResNet 达到相同精度所需的更少计算量

方法设计

4.1 整体架构

核心流程:输入图像 → Patch 切割 + 线性投影 → 加位置嵌入 + class token → Transformer Encoder → class token 输出 → MLP 分类头

ViT 架构图
  输入 (224×224×3)
       │
  ┌────┴─────────┐
  │  Patch 切割   │  P=16 → 14×14 = 196 个 patch
  │  + 线性投影   │  (P²·C) × D 映射
  └────┬─────────┘
       │  (196 × D)
  ┌────┴─────────┐
  │  [CLS] token │  前置可学习分类 token
  │ + Pos Embed  │  (N+1) × D 可学习位置嵌入
  └────┬─────────┘
       │  (197 × D)
  ┌────┴─────────────────┐
  │  Transformer Encoder │  L 层
  │  ┌─────────────────┐ │
  │  │ LN → MSA → +res │ │  多头自注意力
  │  │ LN → MLP → +res │ │  2 层 MLP (GELU)
  │  └─────────────────┘ │
  │  × L 层              │
  └────┬─────────────────┘
       │
  ┌────┴────┐
  │ LN      │  取 [CLS] token 输出
  │ MLP Head│  → 分类概率
  └─────────┘

4.2 关键组件

ViT 模型变体

变体 层数 L 隐藏维度 D MLP 维度 注意力头数 参数量
ViT-Base 12 768 3072 12 86M
ViT-Large 24 1024 4096 16 307M
ViT-Huge 32 1280 5120 16 632M

Transformer Encoder 各层组件

组件 功能 公式
Layer Norm 每个子层前归一化 LN(x)
Multi-Head Self-Attention 全局 token 交互 $\text{MSA}(\text{LN}(\mathbf{z}{\ell-1})) + \mathbf{z}{\ell-1}$
MLP 逐 token 非线性变换 $\text{MLP}(\text{LN}(\mathbf{z}’{\ell})) + \mathbf{z}’{\ell}$
[CLS] token 聚合全局信息用于分类 $\mathbf{y} = \text{LN}(\mathbf{z}_L^0)$

4.3 关键代码

以下为 Google 官方 JAX/Flax 实现中的核心组件。

Transformer Encoder Block(来源:vit_jax/models_vit.py

📄 点击展开 Encoder1DBlock 代码
class Encoder1DBlock(nn.Module):
"""Transformer encoder layer."""
mlp_dim: int
num_heads: int
dtype: Dtype = jnp.float32
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1

@nn.compact
def __call__(self, inputs, *, deterministic):
# inputs.shape = (batch, seq_len, hidden_dim)
assert inputs.ndim == 3

# 注意力子层:LN → MSA → Dropout → 残差
x = nn.LayerNorm(dtype=self.dtype)(inputs)
x = nn.MultiHeadDotProductAttention(
dtype=self.dtype,
kernel_init=nn.initializers.xavier_uniform(),
broadcast_dropout=False,
deterministic=deterministic,
dropout_rate=self.attention_dropout_rate,
num_heads=self.num_heads)(x, x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
x = x + inputs # 残差连接

# MLP 子层:LN → MLP(GELU) → Dropout → 残差
y = nn.LayerNorm(dtype=self.dtype)(x)
y = MlpBlock(
mlp_dim=self.mlp_dim,
dtype=self.dtype,
dropout_rate=self.dropout_rate)(y, deterministic=deterministic)

return x + y # 残差连接

VisionTransformer 主体(来源:vit_jax/models_vit.py

📄 点击展开 VisionTransformer 代码
class VisionTransformer(nn.Module):
"""VisionTransformer."""
num_classes: int
patches: Any # patch 大小配置
transformer: Any # Transformer 超参数
hidden_size: int
representation_size: Optional[int] = None
classifier: str = 'token' # 'token' 使用 [CLS],'gap' 使用全局平均池化

@nn.compact
def __call__(self, inputs, *, train):
x = inputs
n, h, w, c = x.shape

# Patch Embedding:等价于一个大 kernel 的卷积
# kernel_size = stride = patch_size,将每个 patch 线性投影到 D 维
x = nn.Conv(
features=self.hidden_size,
kernel_size=self.patches.size, # (16, 16)
strides=self.patches.size, # (16, 16)
padding='VALID',
name='embedding')(x)

n, h, w, c = x.shape
x = jnp.reshape(x, [n, h * w, c]) # (B, N, D)

# 前置 [CLS] token
if self.classifier in ['token', 'token_unpooled']:
cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
cls = jnp.tile(cls, [n, 1, 1])
x = jnp.concatenate([cls, x], axis=1) # (B, N+1, D)

# Transformer Encoder(内含位置嵌入添加)
x = self.encoder(name='Transformer', **self.transformer)(
x, train=train)

# 分类:取 [CLS] token
if self.classifier == 'token':
x = x[:, 0]
elif self.classifier == 'gap':
x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))

# 分类头
if self.representation_size is not None:
x = nn.Dense(features=self.representation_size,
name='pre_logits')(x)
x = nn.tanh(x)

if self.num_classes:
x = nn.Dense(
features=self.num_classes,
name='head',
kernel_init=nn.initializers.zeros)(x)
return x

实验与分析

5.1 主要结果

与 SOTA 对比(JFT-300M 预训练,多数据集迁移):

模型 ImageNet ImageNet ReaL CIFAR-100 VTAB (19 tasks) 预训练计算 (TPUv3-core-days)
BiT-L (ResNet152x4) 87.54 90.54 93.51 76.3 9.9k
Noisy Student (EfficientNet-L2) 88.4 90.55 - - 12.3k
ViT-L/16 87.76 90.54 93.90 76.28 0.68k
ViT-H/14 88.55 90.72 94.55 77.63 2.5k

关键发现

  1. ViT-H/14 以 2.5k TPUv3-core-days 的预训练成本超越了 BiT-L 的 9.9k 和 Noisy Student 的 12.3k,计算效率优势显著
  2. 即使用开源 ImageNet-21k 预训练的 ViT-L/16 也表现出色,仅需 8 个 TPU 核约 30 天即可训练
  3. ViT 在 VTAB 的 Natural 和 Structured 任务组上优于 BiT,在 Specialized 上持平

5.2 消融实验:验证三个洞察

数据规模实验(验证洞察 2:大数据弥补归纳偏置):

预训练数据 ViT-B/16 ViT-L/16 趋势 验证洞察
ImageNet-1k 77.91 76.53 ViT-L 不如 ViT-B(过拟合) 洞察 2
ImageNet-21k 83.97 85.15 ViT-L 开始超越 ViT-B 洞察 2
JFT-300M 84.15 87.12 ViT-L 大幅领先 洞察 2

位置嵌入消融(验证洞察 1:序列化足够):

位置编码方式 ImageNet 5-shot linear
无位置嵌入 0.61382
1D 位置嵌入 0.64206
2D 位置嵌入 0.64001
相对位置嵌入 0.64032

注:有无位置嵌入差距大,但 1D/2D/相对之间差异极小,说明 patch 级别输入的空间分辨率已经足够低(14×14),不同编码方式同样容易学习。

Scaling 实验(验证洞察 3:计算效率优势):

模型 预训练 exaFLOPs ImageNet top-1 验证洞察
ResNet50x1 50 77.54 洞察 3
ViT-B/32 55 80.73 洞察 3
ResNet152x2 563 84.97 洞察 3
ViT-L/16 783 86.30 洞察 3
R50+ViT-B/16 (混合) 274 85.58 洞察 3
  • 洞察 1 验证:仅用简单 1D 位置嵌入即可有效编码空间信息,2D 感知编码无额外收益
  • 洞察 2 验证:ViT-L 在 ImageNet-1k 上不如 ViT-B,但在 JFT-300M 上大幅领先——数据越多,ViT 越强
  • 洞察 3 验证:相同 exaFLOPs 预算下,ViT 系统性地优于 ResNet;混合模型在小规模时有优势,大规模时差距消失

5.3 性能瓶颈分析

  • 在 ImageNet-1k(无额外数据)上训练 ViT-B/16,即使加入 dropout 和 label smoothing,也不如 ResNet-50(归纳偏置在小数据下仍重要)
  • 自监督预训练(masked patch prediction)的 ViT-B/16 达到 79.9%,虽然比从头训练好 2%,但仍落后有监督预训练 4%
  • 注意力距离分析表明:低层部分 head 已有全局注意力(替代了 CNN 早期的局部接收域),但也有 head 保持局部关注,说明网络自适应地学习了局部和全局模式

5.4 失效场景分析

  • 数据不足场景:在 ImageNet-1k(1.3M 图像)上从头训练,ViT 显著不如同计算量的 ResNet,缺乏归纳偏置导致需要更大数据量
  • 小 patch 计算爆炸:patch 从 16 减小到 8 时,序列长度增长 4 倍,计算量显著上升
  • 自监督预训练差距:masked patch prediction 仍比有监督预训练落后 4%(当时 MAE 等方案尚未提出)

工程实践

6.1 训练配置

Backbone:      ViT-B/16, ViT-L/16, ViT-H/14
Input: 224×224 预训练, 384/512 微调
Batch Size: 4096 (预训练), 512 (微调)
Optimizer: Adam (β1=0.9, β2=0.999) 预训练; SGD+momentum 微调
LR Schedule: 线性 warmup (10k steps) + 线性/cosine 衰减
Weight Decay: 0.1 (预训练)
Augmentation: 无特殊增强 (仅标准裁剪翻转)
Dropout: 仅 ImageNet-from-scratch 使用 (rate=0.1)
Training Time: ViT-L/16 ImageNet-21k 约 30 天 (8 TPUv3 cores)

6.2 复现要点

  1. 预训练数据必须足够大:ViT 在 ImageNet-1k 上从头训练效果差,至少需 ImageNet-21k(14M)预训练才能发挥优势
  2. Adam 优于 SGD 做预训练:论文发现 Adam 在大规模预训练中优于 SGD,即使对 ResNet 也是如此
  3. 高分辨率微调:预训练用 224×224,微调时用 384 或更高分辨率,位置嵌入通过 2D 插值适配
  4. Weight Decay 设为 0.1:远高于 CNN 常用的 0.0001,对 ViT 的迁移性能至关重要
  5. 微调时替换分类头:移除预训练的 MLP 头,替换为零初始化的单层线性层
  6. [CLS] 与 GAP 效果相同:论文发现二者区别仅在于最佳学习率不同,性能无实质差异

6.3 性能优化方向

精度提升

  • DeiT(Touvron et al., 2021)引入知识蒸馏和更强数据增强,使 ViT 在 ImageNet-1k 上无需大规模预训练也能达到竞争力
  • MAE(He et al., 2022)引入 masked autoencoder 自监督预训练,大幅缩小自监督与有监督预训练的差距

速度优化

  • 使用窗口注意力(Swin Transformer)将 $O(n^2)$ 降为 $O(n)$,适配高分辨率密集预测任务
  • FlashAttention 等硬件友好的注意力实现可在不改变模型的前提下加速 2-3 倍
  • 轻量级变体(MobileViT 等)面向移动端部署

研究启示

7.1 可迁移的思想

  • “图像 = 序列”的统一表示:这一洞察使视觉模型可以直接复用 NLP Transformer 的架构、训练技巧和工程基础设施,催生了多模态统一模型(CLIP, GPT-4V 等)
  • Scaling Law 从 NLP 到 Vision 的验证:证明了”大数据 + 大模型”在视觉中同样适用,归纳偏置并非不可替代,而是数据不足时的一种补偿
  • Patch Embedding 是一种通用序列化手段:不仅适用于图像,也被扩展到视频(ViViT)、点云(PCT)、音频(AST)等模态
  • 预训练-微调范式的视觉扩展:借鉴 BERT 的训练范式,大模型预训练 + 下游任务微调成为视觉领域的标准流程

7.2 方法局限

  • 在小数据集上需要大量正则化才能避免过拟合,丧失了 CNN 天然的局部归纳偏置优势
  • 标准 ViT 无法高效处理高分辨率图像($O(n^2)$ 注意力),不适合密集预测任务(检测/分割)
  • 依赖大规模私有数据集(JFT-300M)预训练,公开可用数据下的复现存在差距

7.3 技术影响

  • 开启了 Vision Transformer 时代:直接催生了 DeiT、Swin Transformer、BEiT、MAE 等一系列后续工作,Transformer 迅速取代 CNN 成为视觉主流架构
  • 统一了 NLP 和 Vision 的技术栈:使得 CLIP、Flamingo、GPT-4V 等多模态大模型成为可能
  • 推动了自动驾驶感知架构变革:BEVFormer、PETR 等 BEV 感知方法的 backbone 从 ResNet 切换为 Swin/ViT,获得显著性能提升
  • 改变了学术界对归纳偏置的认知:从”必须精心设计任务特定的归纳偏置”转向”让模型从数据中学习”,推动了 foundation model 范式