Back to Blog Oryx:在序列维度上自由切换 Attention 和线性循环

Oryx:在序列维度上自由切换 Attention 和线性循环

Paper
核心贡献:提出序列轴混合(sequence-axis hybridization)——在序列的不同片段灵活切换 softmax attention 和 linear RNN(Mamba-2 / Gated DeltaNet)。Tied K/V 投影共享 90%+ 参数,一套表示同时更新 KV cache 和 recurrent state。linear prefill + attention generation 即可追平 Transformer 的检索能力,算力省得多。

核心问题

现有混合架构都是静态的——层间交替或层内融合,每个 token 走的 mixer 在架构定义时就定了。Oryx 问的是:能不能在推理中动态切换?长上下文用 linear RNN 省算力,精确检索时切 attention。

设计:共享 K/V,独立 Query

从 associative memory 统一视角出发:attention、Mamba-2、Gated DeltaNet 底层都是 query-key-value 关联记忆查询。Oryx 绑定 K/V 投影,一套表示同时更新 attention KV cache 和 RNN state。Query 不共享(实验证明共享掉点)。

Chunked Mixed-Mode Training

序列切成 128 token chunk,每个 chunk 随机分配模式(1:3 attention:linear 最佳)。Chunk-level 比 sequence-level 训练切换更稳定,表示兼容性更强。

x1.png
图1: 三种混合架构对比 — (a) 层间交替、(b) 层内融合、(c) Oryx 的序列轴混合:序列内不同片段可用不同 mixer
x2.png
图2: Oryx block 结构 — tied K/V 投影共享表示,各自维护 KV cache 和 RNN state,灵活切换模式
x3.png
图3: 切换 mixer 时的 token-level perplexity — 切换后 perplexity 迅速收敛到 no-switch baseline
x4.png
图4: chunk vs sequence 训练的对比 — chunk 训练使表示兼容性更强,模式切换更稳定
x5.png
图5: 各规模模型的语言建模性能对比 — 1.4B 时 Oryx 各模式均超越对应单 mixer baseline

关键实验结果

语言建模:1.4B 时 Oryx 两种模式单独跑,平均准确率都比对应单 mixer baseline 高 ≥0.7pp。

检索:linear prefill + attention generation 的混合模式,real-world 检索比 Mamba-2 baseline 高 13.5pp,NIAH 高 38.6pp,追平 Transformer。用不到 10% 的 attention token 量就达到 Transformer 级别的检索。

工程意义

  • 长文档 QA:linear 处理 prefill,attention 处理 generation
  • CoT 推理:中间步骤 linear(快),最终验证 attention(准)
  • Token 预算管理:attention 预算花在最需要的片段

下一步:learned routing(RL 训练动态 router,按 FLOPs 节省做奖励),Oryx 的共享设计天然适配。

Tags: #CMU#Google Research