Back to Blog Untitled

Untitled

Paper
TL;DR:Transformer 的自注意力在序列维度上是"自由选择",但在深度维度上只是粗暴的残差累加。Depth-Attention 在注意力模块内部增加了一步"沿深度方向的注意力",让每层可以从前面的层中选择性复用 value 信息。零参数增加、零额外 KV cache、不到 0.01% 额外 FLOPs——在 1.5B/3B Qwen3 风格解码器上全面最优,优于 vanilla Transformer 和所有跨层 baseline。

背景:残差连接是深度维度的信息瓶颈

Transformer 的自注意力机制让每个 token 可以在序列维度上自由选择信息。但跨层信息流动极其粗暴——每一层的输出通过残差连接直接加到 hidden state 上,所有前层信息被压缩成一个向量。后面层想回溯某一特定层的表示?做不到。

近年的跨层方法试图改善这一状况。DenseFormer 用学到的权重平均各层输出;Attention Residuals 将固定残差累加换成 softmax 注意力;Hyper-Connections 维护多条交互流。它们都有效,但有一个共同的代价:需要在 KV cache 之外维护额外的隐藏状态

在现代 LLM 大量使用 GQA 和 MLA 压缩 KV cache 的背景下,这个额外存储的代价越来越不可忽视。DenseFormer 在 128K context 下需要额外 24GB 内存。

核心方法:在注意力内部做深度选择

Depth-Attention 的思路:把跨层选择放在注意力机制内部,而不是残差路径上。

在每一层执行正常的序列注意力之前,当前层的 query 沿着深度方向对前面若干层的 key 做一次 softmax 注意力,然后把它们的 value 混合成一个"深度混合 value"。后续的序列注意力直接读这个混合后的 value,query、key 和因果掩码不变。

Depth-Attention 概览
图 1:Depth-Attention 概览。标准自注意力在序列维度混合信息,Depth-Attention 在深度维度做类似的操作。对 token t 在层 ℓ,当前 query 同时关注当前层和前面间隔 s 的层的 key,用深度注意力权重混合 value。

关键设计决策

只改 value,不动 key 和 query。消融实验(Table 6)表明:只混合 value 效果最好(val loss 2.2115),混合 key-value 次之(2.2198),只改 key 基本无效(2.2354)。直觉上,key 编码"我在找什么"(当前层的任务特征),value 编码"我找到了什么"(语义信息)——混合 value 是在融合语义,混合 key 反而模糊了查询目标。

稀疏深度源(stride L/2)。不是每层都回溯所有前层,而是用步长 s 只取部分层。消融实验显示 s = L/2 最优——太稀疏(只用第 1 层)不够,太密集(quarter-depth 或全连接)反而引入冗余噪声。

复用 KV cache 的 V 插槽。深度混合后的 value 直接替换原 value 存回同一位置,KV cache 大小不变。在 GQA 设置下,depth-attention 在 key-value head 分辨率上运行(而非 query head 分辨率),进一步压缩开销。

效率分析:几乎免费的午餐

方法额外 FLOPs训练开销 (1.5B/3B)128K 预填额外内存
Depth-Attention0.004%+9.1% / +11.2%0 GiB
DenseFormer~0.1%+50.6% / +47.9%24.0 GiB
Attention Residuals~0.09%+47.8% / +75.1%4.0 GiB
mHC~0.9%+321% / +308%1.5 GiB

Depth-Attention 是唯一不增加额外持久推理状态的方法。推理延迟仅 +1.18%,预填内存零增长。训练开销是所有跨层方法中最小的(+9-11%),而且这是未做算子融合的结果——专用 kernel 可以进一步缩小差距。

主要实验结果

在 Qwen3 风格的 1.5B 和 3B 解码器上,Depth-Attention 在困惑度和 8 个下游任务的平均准确率上全面最优,超越了 vanilla Transformer 和所有跨层 baseline。

方法 (3B)PPL ↓Arc-C ↓HellaSwag ↓WinoGrande ↓MMLU ↓Avg ↑
Vanilla7.2553.2355.0191.2061.5354.16
mHC-56.7658.6493.5062.5855.65
Attention Residuals-55.7357.4692.7064.3155.92
DenseFormer-54.1259.2791.9063.8055.59
Depth-Attention最低57.4059.1992.8063.2256.17

在 1.5B 模型上类似:Depth-Attention 的平均准确率 56.17 vs vanilla 54.16,提升 2 个百分点。五 shot 设置下优势保持一致。

Scaling 行为

Scaling 行为
图 2:360M 到 3B 的 scaling 实验。Depth-Attention 在所有尺度上一致降低验证 loss,虚线段表示 vanilla baseline 需要增加多少参数才能匹配 1.5B Depth-Attention 的性能。

从 360M 到 3B,Depth-Attention 一致地将验证 loss 曲线向下推移。另一个有趣的发现:scaling 拟合趋势暗示 vanilla Transformer 需要更多参数才能达到同等的 Depth-Attention 性能。

注意力权重可视化

深度注意力权重热力图
图 3:3B 模型的深度注意力权重热力图(对所有 token、head 和样本取平均)。对角线是当前层 value 的权重,离对角线越远表示对更浅层的回溯。可以看到模型确实在利用中间层的 value 信息。

热力图清楚地表明模型没有退化为"只看自己"——深层 target 层对中间层的 value 分配了显著的注意力权重。这说明深度方向的 value 混合确实被学到了、用到了。

消融实验

Stride 消融
图 4:深度源列表密度的消融。half-depth stride 最优,first-layer-only 略差,quarter-depth 和 full-source 反而更差——说明适度的稀疏性最有利于捕获有用的中间层信息。

三组消融的结论都很清晰:

Stride:L/2 最优。太稀疏(只用第 1 层)不够,太密集引入冗余。但所有 Depth-Attention 变体都优于 vanilla(loss 2.2348)。

混合规则:Softmax 混合(2.2115)优于均匀混合(2.2153)优于 vanilla,说明自适应深度选择确实比固定权重更好。

混合对象:Value-only 最好,改 key 有害。直觉:当前层的 key 编码了当前任务"需要找什么",被浅层 key 污染后反而模糊了查询意图。

Looped Transformer 上的验证

Depth-Attention 可以直接应用于循环 Transformer(参数共享的循环层)。在 500M 3-loop 模型上,验证 loss 从 2.208 降至 2.194,说明深度方向的 value 混合在循环产生的"虚拟深度"中同样有效。

局限性与评价

论文坦承三个局限:只验证到 3B 参数和 32B token(算力限制而非方法本身约束);当前实现未做算子融合(实测开销高于理论值);深度源选择用的是固定 stride(可学习的选择策略留作未来工作)。

但核心思路非常扎实。把跨层信息复用嵌入 attention 内部、完全复用 KV cache 结构、零额外状态——这是一个几乎无脑可加的架构改进。尤其适合已采用 GQA/MLA 的现代 LLM 架构,因为那些架构的 cache 压缩趋势让基于 hidden state 的跨层方法成本越来越高。

从工程角度看,+9% 训练开销换跨层信息选择性,零推理内存增长,这个 trade-off 非常划算。如果能进一步做算子融合把训练开销压到 +5% 以内,基本就是免费的能力提升。

参考资源

论文原文:arXiv 2606.05014

代码仓库:github.com/LUMIA-Group/Depth-Attention

Tags: #Paper