Beyond Heuristic Algorithm: dLLM Decoding Optimization

金义杰

2026-03-17

Learning to Parallel: Accelerating Diffusion Large Language Models via Learnable Parallel Decoding

Learn2PD: 数据构建 (Data Construction)

核心目标

学会判断「当前 Token 是否已经稳了」

数据来源

  • FLAN 数据集:66 类任务 × 40 样本 = 2,640 个训练样本
  • 上帝视角解码:Prompt + 参考答案,用 EGP 并行解码 跑完整轨迹

特征与标签

  • 特征:每步各 Token 的 置信度分布
  • 标签:比对参考答案,对记 1(Unmask),错记 0(继续 Mask)

收集成本

4×A6000,约 3 小时 完成全部数据收集

Learn2PD: 训练策略 (Training Strategy)

核心思想

不动 dLLM,只训练一个 2 层 MLP 作为过滤器

模型架构 (Filter Model \(f_\theta\))

  • 输入:Token 的置信度
  • 输出:该 Token「已正确」的概率 Logit
  • 通过 Sigmoid 得到 \(\sigma(z_i)\)

损失函数 (BCE Loss)

\(L_{\text{BCE}} = -\frac{1}{m} \sum_{i=1}^m \left[ y_i \log \sigma(z_i) + (1 - y_i) \log(1 - \sigma(z_i)) \right]\)

  • \(y_i\):数据标签(1=保留,0=重塑)
  • \(z_i\):MLP 输出的 Logit

训练开销(极小)

  • Diffusion LLM 参数冻结,仅更新 MLP
  • Block Size 32 时,可训练参数仅约 2,112 个
  • 单张 T4:6 分钟 / 5000 Epoch 即可收敛

一句话总结

把复杂的解码控制压缩成几千参数的 二分类小头,便宜又好训。

Learn2PD: 推理执行 (Inference Process)

无参考答案场景

训练好的 MLP 过滤器 \(f_\theta\) 充当「裁判」

执行流程

  1. 并行生成:Diffusion LLM 生成整块 Token 及其置信度
  2. 快速裁决:置信度输入 2 层 MLP,零额外延迟 得到 \(\sigma(z_i)\)
  3. 阈值判定\(\tau = 0.96\)):
    • \(\sigma(z_i) > \tau\)Unmask,后续不再改动
    • \(\sigma(z_i) \le \tau\):置为 [MASK],继续 refinement
  4. Block 终止:所有 Token 都 Unmask 时,该 Block 解码完成

核心优势

高优策略造数据 → 超小 MLP 训裁判 → 毫秒级裁决,避免无意义反复解码

Efficient Diffusion LLMs via Temporal-Spatial Parallel Decoding and Confidence Extrapolation

1. 时空并行解码 (TSPD)

  • 痛点:传统方法只看「当前一步」的置信度是否超阈值,单步、静态判断脆弱,忽略了时间轨迹和位置差异(靠后的 token 稳定更晚)。
  • 机制:lightweight 序列控制器(2 层 LSTM)综合时间轨迹特征(置信度、熵、动量)+ 空间相对位置,每步直接输出二元决策(固定/继续降噪)。
  • 优势:决策更鲁棒,准确识别已收敛或延迟稳定的 token,减少重复计算。

2. 置信度外推 (CE)

  • 痛点:现有加速多是「被动等待」,置信度不够就只能继续降噪。
  • 机制:Training-free,利用状态空间模型(类似卡尔曼滤波)预测 token 未来几步置信度走向及不确定性。
  • 风险控制 (Risk-aware Horizon):根据左侧上下文完成度及预测不确定性,动态决定可预测多远,历史可靠时才启用外推。
  • 优势:化被动为主动,趋势稳定即可提前固定(Look-ahead),砍掉多余等待步数。

现有启发式方案的不足

被动等待

  • 不要死等:趋势一致时置信度可预测,外推法能提前锁定;“Potential to save”即死等 0.9 的浪费
  • 并非个例:约 44.9% 的步骤里 Token 已定型,模型仍在反复确认

其他关键不足

  1. 太死板(阈值脆弱):固定 \(\tau\) 缺乏输入自适应性,A 任务好用换 B 任务可能过激或过保守
  2. 没耐心(延迟稳定):无法区分“大器晚成”与“彻底错误”,低置信度直接重算导致不必要重复掩码
  3. 不合群(系统兼容):与 KV Cache 结合差,打破缓存连续性反而增加单步延迟

总结

现有方案将扩散解码视为独立阈值测试,而非动态控制问题;在处理非对齐置信度与位置异质性时乏力。

TSPD (Temporal and Spatial Parallel Decoding)

Part 1 时空特征向量 \(r_i^{(t)} = [p_i^{(t)}, H_i^{(t)}, \Delta \bar{p}_i^{(t)}, \phi(i)]\) - 置信度、熵、动量/趋势、相对位置,综合时间轨迹与空间异质性

Part 2 序列感知控制器 \(h_i^{(t)} = f_{\psi}(h_i^{(t-1)}, r_i^{(t)})\), \(z_i^{(t)} = W h_i^{(t)} + b\) - 2 层 LSTM(~2k 参数)记忆历史轨迹,输出 Logit

Part 3 动作决策 \(a_i^{(t)} = I(\sigma(z_i^{(t)}) \geq 0.5)\) - \(a_i=1\) 锁定;\(a_i=0\) 继续去噪;用 STE 解决离散可微

训练 \(L_{gate} = \text{BCEWithLogits}(z_i^{(t)}, y_i^{(t)})\)\(y_i^{(t)}\) 为 Oracle 标签(当前预测与最终一致则为 1)

对比:传统方案用 \(p_i^{(t)} > \tau\) 静态阈值,单步快照、易误判;TSPD 用 \(f_{\psi}(\text{Trace}, \text{Position}) \to \{0,1\}\) 动态控制,看轨迹、更鲁棒

亮点

  • 极轻量:额外开销约 0.3%
  • 高性能:GSM8K 上 5.0x(无缓存)~ 11.2x(有缓存)
  • 通用性:LLaDA、Dream-7B 等架构验证

5k epoch:损失由 ~0.7 速降至 ~0.28 后平稳收敛;验证略高于训练且无分叉,表明轻量控制器有效收敛、泛化偏差小、未过拟合。

TSPD 输入信号 \(r_i^{(t)}\) 详解

原始版 (4 维) 扩展版 (6 维)
目的 验证有效性 生产用稳健版
差异 基础置信度+位置 + \(\bar{p}\) 平滑 + \(u\) 不确定性

A. 基础瞬时

  1. \(p_i^{(t)} = \max_{v \in [V]} p_\theta(v \mid c, x^{(t)})\) — 置信度,模型当前步的“自信”
  2. \(H_i^{(t)} = -\sum_v p_i^{(t)}(v) \log p_i^{(t)}(v)\) — 熵,低熵=集中、高熵=犹豫

B. 时间趋势

  1. \(\bar{p}_i^{(t)} = \alpha p_i^{(t)} + (1-\alpha)\bar{p}_i^{(t+1)}\)\(\alpha=0.25\))— EWMA 平滑去噪
  2. \(\Delta \bar{p}_i^{(t)} = \bar{p}_i^{(t)} - \bar{p}_i^{(t+1)}\) — 动量,正值=收敛、负值=分歧

C. 空间与不确定性

  1. \(\phi(i) = \frac{i}{L-1}\) — 相对位置,右侧 Token 需更久稳定(Fig. 2)
  2. \(u_i^{(t)} = \text{Var}(\delta_{i,t+h})\) — CE 预测方差,大则谨慎锁定

Confidence Extrapolation (CE):风险感知的预测性加速

1. 动机:被动等待 vs 主动预测;用历史轨迹预测未来稳定态,实现「抢跑」

2. 状态空间模型 - 状态 \(x_i^{(t)} = [\delta_i^{(t)}, \dot{\delta}_i^{(t)}]^\top\)(边际 \(\delta\) + 变化率) - 转移 \(x_i^{(t-1)} = A x_i^{(t)} + \epsilon\)\(A = \begin{bmatrix} 1 & 1 \\ 0 & 1 \end{bmatrix}\)(常速) - 观测 \(o_i^{(t)} = \delta_i^{(t)} + \eta\);Kalman 在线估计

3. \(h\) 步预测 - \(\hat{x}_{t+h|t} = A^h \hat{x}_t\)\(P_{t+h|t} = A^h P_t (A^h)^\top + Q_h\) - 步数越远,\(\sigma^2_{t+h}\) 越大

4. 风险感知决策 - 保守下界 \(\text{LB}_{t+h} = \hat{\ell}_{t+h} - z \cdot \sigma_{t+h}\)\(z\) 风险系数) - 动态地平线 \(h_{\text{pot}} = H \cdot r_t(j)\)\(r_t\) 由 left_coverage 决定

5. 实现与价值 - Training-free;输入 \(\delta_i^{(t)}\)(Top-1 与 Top-2 边际);开销 0.13% - 加速比再提升 0.5x~1.5x(GSM8K: 4.4x→5.0x)

Algorithm 1 推理工作流解读

实验结果 (一)

实验结果 (二)

Q & A