Back to Blog Untitled

Untitled

Paper
核心洞察:残差连接天然等价于扩散去噪过程的欧拉离散化。利用这一等价关系,可以将任意 Transformer 网络拆成独立块,每个块只负责特定噪声范围的去噪任务,训练时显存需求降低 B 倍。

背景与动机

大模型的训练受制于一个根本瓶颈:端到端反向传播需要保存所有层的中间激活值。网络越深,显存消耗越大。这是阻碍模型规模扩展的核心因素。

逐块训练(block-wise training)一直是理论上诱人的方案——把网络切成 B 块,每次只训练一块,显存降到 1/B,还可以并行。但现实是骨感的:Hinton 的 Forward-Forward 用对比学习目标,准确率只有 7.85%(vs 端到端 60%);其他方法也缺乏理论支撑,基本只能在分类任务上勉强跑。

Sakana AI 这篇论文的突破在于:找到了一个有严格理论基础的方法,而且适用于分类、生成、文本等多种任务

核心方法

为什么残差连接 = 扩散步?

残差网络的更新公式是 z_l = z_(l-1) + f(z_(l-1))。扩散模型概率流 ODE 的欧拉离散化形式几乎一致——残差块就是去噪器,中间层输出就是不同噪声级别的状态。

Figure 1
图 1:DiffusionBlocks 概览。左:标准网络需要全层反向传播;中:DiffusionBlocks 将网络分块,每块在分配的噪声范围内独立去噪训练;右:应用场景。
Figure 2
图 2:三步转换流程。Step 1:分块;Step 2:定义噪声分布并划分范围;Step 3:给每个块加噪声条件。

三步转换

Step 1 — 分块:把 L 层网络分成 B 个块,每块包含若干连续层。

Step 2 — 分配噪声范围:用 log-normal 噪声分布,按等概率划分策略切分噪声范围为 B 个区间。这不是均匀切噪声值,而是让每块覆盖相同的概率质量——中间噪声级别(图像结构涌现的区域)分到更窄的区间,因为那里去噪最困难。

Step 3 — 加噪声条件:扩展输入为 (x, z_σ),通过 AdaLN 注入噪声级别 σ,让块知道自己在扩散过程的哪个位置。

Figure 4
图 4:等概率划分示意(B=3)。橙色边界按等概率质量切分,灰色为均匀切分。中间噪声级别区间更窄,分配更多计算资源。
Figure 3
图 3:标准网络(左)vs DiffusionBlocks(右)的训练和推理伪代码。DiffusionBlocks 训练时只需对单块计算梯度。

关键实验结果

图像分类 — ViT / CIFAR-100

方法准确率同时训练层数
ViT(端到端)60.25%12
+ Forward-Forward7.85%4
+ DiffusionBlocks59.30%4

在 ViT 上,DiffusionBlocks 以 1/3 的显存代价达到了几乎相同的准确率。Forward-Forward 的对比学习目标完全崩了。

图像生成 — DiT

数据集方法FID(train / test)
CIFAR-10DiT32.84 / 39.83
CIFAR-10+ DiffusionBlocks30.59 / 37.20
ImageNetDiT9.01 / 12.09
ImageNet+ DiffusionBlocks9.00 / 10.63

不仅在持平的精度下实现了 3x 显存节省,推理时每个去噪步骤只需激活对应块,也有额外加速。

自回归文本生成 — Llama-2

数据集方法MAUVE ↑PPL (Llama-2) ↓
LM1BAR0.5014.58
LM1B+ DiffusionBlocks0.7112.32
OWTAR0.8515.05
OWT+ DiffusionBlocks0.8214.99

即使不是扩散原生架构,自回归 Transformer 用 DiffusionBlocks 训练后性能反而提升了。MAUVE 从 0.50 跳到 0.71。

循环深度模型 — Huginn

方法MAUVE ↑PPL (GPT2-XL) ↓
Huginn(8-step BPTT)0.4946.73
+ DiffusionBlocks0.7042.43

把 32 次迭代训练变成单次前向传播,性能还更好了。对于循环深度模型,DiffusionBlocks 不仅是内存优化,而是根本性地改变了训练范式。

消融:等概率划分 vs 均匀划分

划分策略层分布FID ↓
均匀划分[4,4,4]43.53
等概率划分[4,4,4]38.03

等概率划分比均匀划分好 5.5 个 FID 点,验证了噪声分布自适应的重要性。

局限性与未来方向

维度匹配限制:目前要求输入输出维度相同,U-Net 这类变维度架构暂不支持。

规模验证不足:最大实验是 DiT-L/2 (ImageNet 256x256),离千卡工业级大模型还有距离。

推理开销:虽然每步只激活一个块,但扩散采样需要多步,总计算量未必比标准推理少。

总结

DiffusionBlocks 的思路极其优雅——不需要新架构、不需要特殊损失函数,只利用"残差连接 = 扩散步"这个数学等价关系,就把反向传播从必需品变成了可选项。等概率划分的设计也巧妙地解决了块间负载均衡的问题。

虽然目前规模还不够大,但作为方向验证非常扎实。如果能在更大模型上复现,这对"想训大模型但卡不够"的场景是根本性的突破。代码已开源:GitHub

Figure 5
图 5:块数量对 ImageNet FID 的影响。B=2 时 FID 优于端到端训练(B=1),说明适度分块可以通过专业化提升性能。
Tags: #Paper