FlashAttention 简化版:Tiling + Online Softmax
本文手写 FlashAttention 的 forward pass,聚焦两个核心算法思想:online softmax(用动态 rescaling 单遍计算 softmax)和 tiling(分块避免 O(N²) 显存)。目标是理解算法原理——online softmax 中 correction = exp(m_old - m_new) 这一行代码为什么能替代标准 softmax 的两遍遍历,以及分块后显存如何从 130 MB 降到 1 MB。全程用纯 Python/PyTorch 实现,不涉及 CUDA kernel。
阅读建议:先通读 §1 理解 O(N²) 瓶颈的来源,再逐段读 §2 的 online softmax 推导(这是全文最核心的数学),最后对照 §3 的伪代码和 §4 的实测数据验证理解。
1. 背景
1.1 Attention 的 O(N²) 显存问题
标准 self-attention 的计算流程:
S = Q @ K^T / √d [N, N] ← 注意力分数矩阵
P = softmax(S) [N, N] ← 注意力概率矩阵
O = P @ V [N, d] ← 输出
其中 S 和 P 都是 [N, N] 的矩阵。当序列长度 N = 4096、FP32 精度时:
- S 占用:4096² × 4 bytes = 64 MB
- P 占用:4096² × 4 bytes = 64 MB
- 合计:128 MB 仅为了存储中间结果
对于长序列(N = 16384),两个矩阵各占 1 GB,GPT-3 级别的 2048 序列长度下仅 attention 中间结果就需要 ~16 GB。这就是 attention 的 O(N²) 显存瓶颈。
1.2 FlashAttention 的核心思想
FlashAttention (Dao et al., 2022) 提出:不把 N×N 的矩阵写回 HBM,在 on-chip SRAM 中分块计算、当场消费。
两个关键技术:
- Tiling(分块):将 Q、K、V 切成小块,每次只加载一块到 SRAM,计算完立刻释放
- Online Softmax(在线 softmax):不存储完整 softmax,用动态更新的 running max 和 running sum 增量计算
结果:显存从 O(N²) 降到 O(N),同时因为减少 HBM 读写(SRAM 带宽远高于 HBM),实际速度反而更快。
1.3 本次实现范围
只做 forward pass,用纯 Python/PyTorch 实现 tiling + online softmax 算法。目标是理解算法原理,不追求生产级性能(需要 CUDA kernel 级别的 SRAM 管理)。Backward 涉及重计算 + 分段梯度,复杂度是 forward 的 3-5 倍,超出本次学习范围。
2. Online Softmax 的数学原理
2.1 标准 Softmax 需要两遍
标准的 numerically stable softmax:
第一遍: 找最大值
m = max(x₁, x₂, ..., xₙ)
第二遍: 计算 exp 并求和
numerator_i = exp(x_i - m)
l = Σ numerator_i
输出: softmax(x)_i = numerator_i / l
为什么需要两遍?因为 softmax 中的 exp(x_i) 很容易溢出(x_i = 100 → exp(100) ≈ 2.7e43)。减去最大值 m 后,所有 x_i - m ≤ 0,exp 值在 (0, 1] 之间,数值稳定。
但这个算法要求先遍历一遍找到 m,再遍历一遍算 exp,无法与分块计算兼容。
2.2 Online Softmax:增量更新
假设数据分两块到达:block₁ = {x₁, x₂, x₃},block₂ = {x₄, x₅, x₆}。
处理 block₁:
m₁ = max(x₁, x₂, x₃) # 当前最大值
l₁ = Σ exp(x_i - m₁) # 当前 sum(exp)
处理 block₂——block₂ 中有更大的值怎么办?
m₂ = max(m₁, max(x₄, x₅, x₆)) # 新最大值(可能更大)
如果 m₂ > m₁:
# block₁ 的结果需要"打折"——因为之前减的是 m₁,
# 现在应该减更大的 m₂
correction = exp(m₁ - m₂) # ≤ 1,因为 m₂ ≥ m₁
l₁_corrected = correction * l₁
# block₂ 的新增部分
l₂_new = Σ exp(x_j - m₂) # j = 4,5,6
# 合并
l_total = correction * l₁ + l₂_new
核心洞察: exp(x_i - m_new) = exp(x_i - m_old) * exp(m_old - m_new)。这个 exp(m_old - m_new) 就是 correction 因子——之前的结果不需要重新计算,只需要乘以这个因子进行 rescale。
2.3 应用到 Attention
在 FlashAttention 中,online softmax 需要同时维护三个运行状态:
对于每个 Q block(外层循环):
m = -∞ ← 运行 max
l = 0 ← 运行 sum(exp)
O = 0 ← 运行 softmax-weighted V
对于每个 K/V block(内层循环):
S_block = Q_block @ K_block^T / √d
m_new = max(m_old, row_max(S_block))
correction = exp(m_old - m_new) # ≤ 1
P_block = exp(S_block - m_new) # 当前 block 的 softmax 分子
l_new = correction * l_old + row_sum(P_block)
O_new = correction * O_old + P_block @ V_block
m = m_new; l = l_new; O = O_new
最终: O_output = O / l ← 除以 sum(exp) 完成归一化
这就是全文最核心的 5 行伪代码。correction 项的引入使得我们可以在不知道全局 max 的情况下增量计算 softmax,从而分块处理、不需要存储 N×N 矩阵。
3. Tiling 策略
§2 解决了”怎么增量计算 softmax”,本节解决”怎么分块”——把 Q、K、V 切成小块,确保每块的计算都在 on-chip 内存中完成。
3.1 分块伪代码
Q [N, d] → Tr 块,每块 Br 行
K [N, d] → Tc 块,每块 Bc 行
V [N, d] → Tc 块,每块 Bc 行
伪代码:
for i in 0..Tr-1: (外层: Q blocks)
Q_block = Q[i*Br : (i+1)*Br, :]
m, l, O = -∞, 0, 0 (重置 online softmax 状态)
for j in 0..Tc-1: (内层: K/V blocks)
K_block = K[j*Bc : (j+1)*Bc, :]
V_block = V[j*Bc : (j+1)*Bc, :]
S_block = Q_block @ K_block^T / √d [Br, Bc]
# online softmax 更新 (见 §2.3)
m_new = max(m, max of S_block)
correction = exp(m - m_new)
P_block = exp(S_block - m_new)
l = correction * l + sum(P_block)
O = correction * O + P_block @ V_block
m = m_new
O[i*Br : (i+1)*Br, :] = O / l
两层循环的计算总量与标准 attention 完全相同(每个 Q_block 都会与所有 K/V blocks 计算一次)。区别在于:标准 attention 一次性算出完整 N×N 矩阵再 softmax,FlashAttention 在每个 block 内完成 S→P→O 的完整计算,S_block 和 P_block 只在 block 范围内存在,不写回 HBM。
3.2 显存分析
| 方法 | 最大单次分配 | 总中间结果 |
|---|---|---|
| 标准 Attention | N×N 矩阵 (128 MB @ N=4096) | ~256 MB (S + P + intermediates) |
| FlashAttention | Br×Bc 矩阵 (0.016 MB @ 64×64) | ~0.03 MB (仅 S_block + P_block) |
标准 attention 的 128 MB 来自两个 N×N 矩阵(S 和 P),且它们是同时存在的——S 算完后 P 覆盖其上,但 FP32 下峰值仍为 128 MB。FlashAttention 的 0.016 MB 来自一个 Br×Bc = 4096 个元素的 FP32 矩阵,配合三个 running states(m、l、O,合计约 Br×d × 3 ≈ 12KB),总峰值 < 1 MB。矩阵尺寸的节省比 = (N² / (Br×Bc)),当 N=4096、Br=Bc=64 时理论值为 4096×。注意这是纯矩阵元素数量的比值——实测 HBM 峰值节省约 130×(130 MB → 1 MB,见 §4.2),差距来自 running states(m, l, O)和框架内存开销。随着 N 增大,这些固定开销占比下降,实际节省比会趋近理论值。
3.3 为什么 Python 实现不加速?
真正的 FlashAttention 把 tiling 写进 CUDA kernel,在 GPU/NPU 的 SRAM 中完成所有计算和中间存储。Python 的 for 循环开销(每个 iteration 都是一次 Python→C++→NPU kernel 调用)远大于节省的 HBM 访存时间。但显存节省是真实的——我们的实现确实避免了分配 N×N 矩阵。
从 Python 到生产级需要做什么: 将两层 for 循环融合为一个 CUDA/Triton kernel;手动管理 SRAM 的加载和驱逐(double buffering);利用 warp-level primitives 做线程间通信。这些优化需要 C++/Triton 级别的编程,超出了本次学习范围,但核心算法逻辑(online softmax + tiling)完全一致。
4. 测试结果
测试环境:Ascend 910B3, CANN 8.0.1, NPU 7。
4.1 精度验证
| 配置 (B, N, d) | max_diff | 结果 |
|---|---|---|
| (1, 256, 64) | 1.19e-07 | ✓ |
| (1, 512, 64) | 1.34e-07 | ✓ |
| (1, 1024, 64) | 1.79e-07 | ✓ |
| (4, 512, 64) | 2.09e-07 | ✓ |
所有配置的 max_diff < 1e-6,远低于 1e-3 的目标。online softmax 的数值精度与标准两遍 softmax 完全等价(差异仅来自浮点舍入误差的顺序不同)。
4.2 显存对比 (N=4096, d=64)
| 方法 | 峰值 HBM | 说明 |
|---|---|---|
| 标准 Attention | 130 MB | 接近理论值 128 MB(S:64MB + P:64MB) |
| FlashAttention | 1 MB | 仅存 Br×Bc=64×64 的 block + running states |
| 节省 | 95% | O(N²) → O(N) 的显存优势充分体现 |
4.3 速度对比
Python 实现的 FlashAttention 比标准 Attention 慢(341ms vs 0.1ms @ N=2048),这是预期行为。真正的 FlashAttention 通过以下工程优化获得加速:
- Kernel Fusion:整个 attention 计算融合为一个 CUDA kernel,消除 kernel launch 开销
- SRAM 管理:手动管理 on-chip SRAM 的加载/驱逐,最大化重用
- Warp-level 优化:利用 GPU 的 warp 调度减少同步开销
这些优化需要在 CUDA C++ 或 Triton 层面实现,Python 无法做到。但 Python 实现的显存节省是真实的,而且算法逻辑(online softmax + tiling)与生产级实现完全一致。
5. 代码结构
flash_attention.py 约 190 行,分为三组:核心算法(standard attention + flash attention forward)、验证工具(精度/显存/速度对比)、CLI 入口。
09_flash_attention/
└── flash_attention.py # FlashAttention 简化版(~190 行)
├── standard_attention() — 标准 PyTorch attention (baseline)
├── flash_attention_forward() — tiled + online softmax forward
├── compare_and_verify() — 数值精度对比
├── profile_memory() — HBM 峰值对比
└── benchmark_speed() — 执行速度对比
6. 与之前 phase 的联系
本 phase 的 FlashAttention 不是孤立的算法实验——它与之前多个 phase 直接关联:
| Phase | 关联 |
|---|---|
| Phase 8 (Mini-GPT) | CausalSelfAttention 中的标准 attention 是本 phase 的 baseline |
| Phase 7 (Profiling) | profiling 方法用于验证 HBM 占用差异 |
| Phase 1 (Hello NPU) | Q@K^T 矩阵乘法的性能决定了 attention 的效率 |
7. 后续扩展
本实验只实现了 FlashAttention 的 forward pass,以下四个方向是最自然的延伸:
- Backward pass:实现重计算 + 分段梯度,理解训练时的显存节省(代码量约为 forward 的 3 倍)
- Causal Mask 集成:在 tiling 过程中融入 causal mask,用于 GPT 等 decoder 模型
- Triton 实现:用 Triton 语言写 NPU kernel,获得接近原生 CUDA 的性能
- Multi-Query Attention (MQA) / Grouped-Query Attention (GQA):进一步减少 K/V 的显存开销