背景与动机
大模型的训练受制于一个根本瓶颈:端到端反向传播需要保存所有层的中间激活值。网络越深,显存消耗越大。这是阻碍模型规模扩展的核心因素。
逐块训练(block-wise training)一直是理论上诱人的方案——把网络切成 B 块,每次只训练一块,显存降到 1/B,还可以并行。但现实是骨感的:Hinton 的 Forward-Forward 用对比学习目标,准确率只有 7.85%(vs 端到端 60%);其他方法也缺乏理论支撑,基本只能在分类任务上勉强跑。
Sakana AI 这篇论文的突破在于:找到了一个有严格理论基础的方法,而且适用于分类、生成、文本等多种任务。
核心方法
为什么残差连接 = 扩散步?
残差网络的更新公式是 z_l = z_(l-1) + f(z_(l-1))。扩散模型概率流 ODE 的欧拉离散化形式几乎一致——残差块就是去噪器,中间层输出就是不同噪声级别的状态。
三步转换
Step 1 — 分块:把 L 层网络分成 B 个块,每块包含若干连续层。
Step 2 — 分配噪声范围:用 log-normal 噪声分布,按等概率划分策略切分噪声范围为 B 个区间。这不是均匀切噪声值,而是让每块覆盖相同的概率质量——中间噪声级别(图像结构涌现的区域)分到更窄的区间,因为那里去噪最困难。
Step 3 — 加噪声条件:扩展输入为 (x, z_σ),通过 AdaLN 注入噪声级别 σ,让块知道自己在扩散过程的哪个位置。
关键实验结果
图像分类 — ViT / CIFAR-100
| 方法 | 准确率 | 同时训练层数 |
|---|---|---|
| ViT(端到端) | 60.25% | 12 |
| + Forward-Forward | 7.85% | 4 |
| + DiffusionBlocks | 59.30% | 4 |
在 ViT 上,DiffusionBlocks 以 1/3 的显存代价达到了几乎相同的准确率。Forward-Forward 的对比学习目标完全崩了。
图像生成 — DiT
| 数据集 | 方法 | FID(train / test) |
|---|---|---|
| CIFAR-10 | DiT | 32.84 / 39.83 |
| CIFAR-10 | + DiffusionBlocks | 30.59 / 37.20 |
| ImageNet | DiT | 9.01 / 12.09 |
| ImageNet | + DiffusionBlocks | 9.00 / 10.63 |
不仅在持平的精度下实现了 3x 显存节省,推理时每个去噪步骤只需激活对应块,也有额外加速。
自回归文本生成 — Llama-2
| 数据集 | 方法 | MAUVE ↑ | PPL (Llama-2) ↓ |
|---|---|---|---|
| LM1B | AR | 0.50 | 14.58 |
| LM1B | + DiffusionBlocks | 0.71 | 12.32 |
| OWT | AR | 0.85 | 15.05 |
| OWT | + DiffusionBlocks | 0.82 | 14.99 |
即使不是扩散原生架构,自回归 Transformer 用 DiffusionBlocks 训练后性能反而提升了。MAUVE 从 0.50 跳到 0.71。
循环深度模型 — Huginn
| 方法 | MAUVE ↑ | PPL (GPT2-XL) ↓ |
|---|---|---|
| Huginn(8-step BPTT) | 0.49 | 46.73 |
| + DiffusionBlocks | 0.70 | 42.43 |
把 32 次迭代训练变成单次前向传播,性能还更好了。对于循环深度模型,DiffusionBlocks 不仅是内存优化,而是根本性地改变了训练范式。
消融:等概率划分 vs 均匀划分
| 划分策略 | 层分布 | FID ↓ |
|---|---|---|
| 均匀划分 | [4,4,4] | 43.53 |
| 等概率划分 | [4,4,4] | 38.03 |
等概率划分比均匀划分好 5.5 个 FID 点,验证了噪声分布自适应的重要性。
局限性与未来方向
维度匹配限制:目前要求输入输出维度相同,U-Net 这类变维度架构暂不支持。
规模验证不足:最大实验是 DiT-L/2 (ImageNet 256x256),离千卡工业级大模型还有距离。
推理开销:虽然每步只激活一个块,但扩散采样需要多步,总计算量未必比标准推理少。
总结
DiffusionBlocks 的思路极其优雅——不需要新架构、不需要特殊损失函数,只利用"残差连接 = 扩散步"这个数学等价关系,就把反向传播从必需品变成了可选项。等概率划分的设计也巧妙地解决了块间负载均衡的问题。
虽然目前规模还不够大,但作为方向验证非常扎实。如果能在更大模型上复现,这对"想训大模型但卡不够"的场景是根本性的突破。代码已开源:GitHub。