Tianzhu Ye, Li Dong, Yutao Sun, Furu Wei
Notion Link (for better readability)
We compare DIFF V2 with DIFF V1 below:
(For simplicity, we omit the batch dimension and assume that both the input and output of the following flash_attn_func are three-dimensional tensors (tokens, heads, head dimension). Heads belonging to the same GQA group are arranged contiguously in the output)
Note DIFF V2 subtracts two heads that are in the same GQA group, which means they share the same key and value. This is crucial to performance. See design ablations section and Github code.
def DiffAttnV1(
layer_index, q1, q2, k1, k2, v,
lam_q1, lam_k1, lam_q2, lam_k2,
):
"""
q1, q2: (N, h/2, d)
k1, k2: (N, h_kv/2, d)
v: (N, h_kv/2, 2d)
lam_*: (d,)
"""
attn1 = flash_attn_func(q1, k1, v)
attn2 = flash_attn_func(q2, k2, v)
lam_init = 0.8 - 0.6 * \
exp(-0.3 * layer_index)
lam1 = exp(sum(lam_q1 * lam_k1)
lam2 = exp(sum(lam_q2 * lam_k2)
lam = lam1 - lam2 + lam_init
attn = attn1 - lam * attn2
attn = rmsnorm(attn)
attn = attn * (1 - lam_init)
return attn
def DiffAttnV2(
q, k, v, lam
):
"""
q: (N, 2h, d)
k: (N, h_kv, d)
v: (N, h_kv, d)
lam: (N, h, 1)
"""
attn = flash_attn_func(q, k, v)
attn1, attn2 = (attn[:, 0::2],
attn[:, 1::2])
lam_val = sigmoid(lam)
attn = attn1 - lam_val * attn2
return attn
Full code at: unilm/Diff-Transformer/Diff-Transformer-V2 at master · microsoft/unilm In the script, h represents number of query heads, h_kv represents number of key-value heads, and d means head dimension. The λ\lambda in DIFF V2 is projected from XX for each token each head.
DIFF V2 doubles number of query heads while maintaining number of key value heads, and the extra dimension is reduced back to h*d after the differential operation so the WOW_O projection remains the same as baseline Transformer.
DIFF V2 introduces additional query heads compared to the baseline Transformer, but does not increase the number of key-value (KV) heads. Since LLM decoding is typically memory-bound, this design allows DIFF V2 to achieve decoding speeds on par with standard Transformer. Besides, since head dimension is aligned between query, key and value, there is no need for custom attention kernels for DIFF V2. In contrast, DIFF V1 can be slower during decoding because the value cache must be loaded twice, and a custom attention kernel is needed. DIFF V2 can also increase the arithmetic intensity of the attention module during decoding.
During pretraining, when using cutting-edge FlashAttention kernels on H-series and B-series GPUs, the throughput reduction introduced by DIFF V2 is negligible. For long-sequence prefilling, we recommend combining DIFF V2 with techniques such as YOCO (also used in Gemma 3n), which already reduces prefilling complexity to linear time with respect to sequence length.
An alternative perspective is to compare DIFF V2 with a Transformer that has the same query dimension 2h*d. Under this comparison, both models exhibit same attention kernel speed, while DIFF V2 has less parameters and flops in output projection.
In the standard Scaled Dot-Product Attention (SDPA), let Q,K,V∈Rn×dQ, K, V \in \mathbb{R}^{n \times d} be the queries, keys, and values. The context vector CC is defined as:
C=Softmax(QKTd)V=AV C = \text{Softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V = AV
Where A∈Rn×nA \in \mathbb{R}^{n \times n} is the attention weight matrix. Let's focus on a single row of CC, denoted as ci\mathbf{c}_i, which is a weighted sum of value vectors vj\mathbf{v}_j:
ci=∑j=1naijvj \mathbf{c}_i = \sum_{j=1}^{n} a_{ij} \mathbf{v}_j
We define the Context RMS (Root Mean Square) to represent the magnitude of this output:
RMS(ci)=1d∥ci∥2 \text{RMS}(\mathbf{c}_i) = \sqrt{\frac{1}{d} \|\mathbf{c}_i\|^2}
The weights aija_{ij} are non-negative and sum to 1 ( ∑j=1naij=1\sum_{j=1}^{n} a_{ij} = 1 ). Assume the value vectors vj\mathbf{v}_j are uncorrelated and have an RMS of 1, the Context RMS is strictly bounded in range [1n,1)[\frac{1}{\sqrt{n}},1) however the attention distribution changes:
In DIFF V1 we add a per-head RMSNorm on context vectors:
c^i=ciRMS(ci) \mathbf{\hat{c}}_i = \frac{\mathbf{c}_i}{\text{RMS}(\mathbf{c}_i)}
If the model learns a uniform attention distribution in a head, the Context RMS is approximately 1/n1/\sqrt{n}. To normalize this back to 11, RMSNorm must multiply the vector by a scale of n\sqrt{n}. For n=8192n = 8192, n≈90.5\sqrt{n} \approx 90.5. This means the RMSNorm layer applies a 100x magnification to the output. In large-scale pretraining, we find this leads to massive gradients and numerical instability.
A typical phenomenon is that when DIFF V1 is pre-trained at a large learning rate, the gradient norm experiences a larger increase compared to Transformer in the later stages, along with higher variance. In DIFF V2, after removing the per-head RMSNorm, the gradient norm scale becomes comparable to that of Transformer, and the gradient norm spike is reduced (will be discussed further below).
We adopted the per-head RMSNorm design in DIFF V1 primarily because of the doubled value head dimension and the globally shared λ\lambda across all tokens. Given the modifications made to these two aspects in DIFF V2, we found that removing RMSNorm is now safe.
We demonstrate DIFF V2 can overcome the constraint of Softmax mentioned above. It can also help eliminate attention sinks.
aij=Softmax(zij)=exp(zij)∑k=1nexp(zik)ci=∑j=1naijvj=∑j=1nSoftmax(zij)vjRMS(ci)∈[1n,1) a_{ij} = \text{Softmax}(z_{ij}) = \frac{\exp(z_{ij})}{\sum_{k=1}^{n} \exp(z_{ik})} \\ \mathbf{c}_i = \sum_{j=1}^{n} a_{ij} \mathbf{v}_j = \sum_{j=1}^{n} \text{Softmax}(z_{ij}) \mathbf{v}_j \\ \text{RMS}(\mathbf{c}_i) \in \left[\frac{1}{\sqrt{n}},1\right)
ci=∑j=1n(Softmax(zij1)−sigmoid(λi)⋅Softmax(zij2))vjRMS(ci)∈(0,2) \mathbf{c}_i = \sum_{j=1}^{n} \left( \text{Softmax}(z_{ij}^\text{1}) - \text{sigmoid}(\lambda_i) \cdot \text{Softmax}(z_{ij}^\text{2}) \right) \mathbf{v}_j \\ \text{RMS}(\mathbf{c}_i) \in \left(0, \sqrt{2}\right)
The projected λi\lambda_i helps to control the context RMS. We observe that lowering the lower bound of the context RMS to zero is particularly important. It can help eliminate attention sinks and improve training stability. The upper bound only needs to remain bounded.
Note that our analysis here consider RMS before output projection WOW_O. Although the RMS can be recovered and adjusted after the output projection, the lack of freedom at Softmax still affects the learning performance.
Other recent works alleviate this constraint as well:
aijoff=exp(zij)1+∑k=1nexp(zik) ci=∑j=1naijoffvj=∑k=1nexp(zik)1+∑k=1nexp(zik)∑j=1nSoftmax(zij)vj RMS(ci)∈(0,1) a_{ij}^{\text{off}} = \frac{\exp(z_{ij})}{1 + \sum_{k=1}^{n} \exp(z_{ik})} \\ \ \\ \mathbf{c}_i = \sum_{j=1}^{n} a_{ij}^{\text{off}} \mathbf{v}_j = \frac{\sum_{k=1}^{n} \exp(z_{ik})}{1 + \sum_{k=1}^{n} \exp(z_{ik})} \sum_{j=1}^{n} \text{Softmax}(z_{ij}) \mathbf{v}_j \\ \ \\ \text{RMS}(\mathbf{c}_i) \in \left(0, 1\right)
aijoss=exp(zij)exp(s)+∑k=1nexp(zik) ci=∑j=1naijossvj=∑k=1nexp(zik)exp(s)+∑k=1nexp(zik)∑j=1nSoftmax(zij)vj RMS(ci)∈(0,1) a_{ij}^{\text{oss}} = \frac{\exp(z_{ij})}{\exp(s) + \sum_{k=1}^{n} \exp(z_{ik})} \\ \ \\ \mathbf{c}_i = \sum_{j=1}^{n} a_{ij}^{\text{oss}} \mathbf{v}_j = \frac{\sum_{k=1}^{n} \exp(z_{ik})}{\exp(s) + \sum_{k=1}^{n} \exp(z_{ik})} \sum_{j=1}^{n} \text{Softmax}(z_{ij}) \mathbf{v}_j \\ \ \\ \text{RMS}(\mathbf{c}_i) \in \left(0, 1\right)
ci=sigmoid(gi)⊙∑j=1nSoftmax(zij)vjRMS(ci)∈(0,1) \mathbf{c}_i = \text{sigmoid} (\mathbf{g}_i) \odot \sum_{j=1}^{n} \text{Softmax}(z_{ij}) \mathbf{v}_j \\ \text{RMS}(\mathbf{c}_i) \in \left(0, 1\right)
We conduct pretraining experiments on production-scale LLMs, including dense models and a 30A3 MoE on trillions of tokens using large learning rate of 6e-4 to 1e-3.
The experiments are still running. What we have observed now:
We expect to explore in later stages of training:
In theory, a standard Transformer with 2h2h attention heads can learn the differential operation by learning WO2i=−WO2i+1,i=0,1,…,h−1W_O^{2i}=-W_O^{2i+1}, i=0,1,\ldots,h-1, where WOiW_O^{i} denotes the output projection of head ii, and head 2i2i and 2i+12i+1 belong to the same GQA group.
Assumption 1. In practice, such a solution is difficult to learn through optimization, as it requires two sets of parameters to converge to exact negatives of each other.
Assumption 2. The differential operation can be learned by the model and the model chooses to learn it in the training. Then explicitly constructing it before the output projection as in DIFF V2 can save half of the WOW_O parameters. The number of saved parameters is also non-trivial. Under the current GQA setting, the parameters in the attention module are dominated by WQW_Q and WOW_O; Therefore, approximately 25% of the attention-module parameters can be saved. The saved parameter budget can then be reallocated to other parts of the model.
Even if DIFF V2, after reallocating parameters, does not achieve a lower loss than the baseline but merely matches it, the method is still worthwhile if it provides additional benefits such as improved training stability, better control of outliers, or higher training efficiency. This is analogous to GQA, which matches the loss of MHA while reducing KV-cache as an additional benefit. So the key question becomes empirical performance.
(For simplicity, we omit the batch dimension and assume that both the input and output of the following flash_attn_func are three-dimensional tensors (tokens, heads, head dimension). Heads belonging to the same GQA group are arranged contiguously in the output)
# Ablation 1
# ❌ Wrong Implementation of DIFF V2!
...
attn = flash_attn_func(q, k, v)
nh = attn.size(1)
attn1, attn2 = (attn[:, :nh//2],
attn[:, nh//2:])
...
# DIFF V2
# ✅ Correct Implementation of DIFF V2
...
attn = flash_attn_func(q, k, v)
attn1, attn2 = (attn[:, 0::2],
attn[:, 1::2])
...
In our large learning rate setting, the ablation 1 setting exhibits obvious training instability (much more loss and gradient spikes) and higher loss comparing to DIFF V2. The value should be shared in the two subtraction heads to construct differential operation, as discussed in DIFF V1 paper.
attn1 - attn2 instead of attn1 - lam_val * attn2. This results in an excessively small context RMS at initialization.sigmoid operation. The context RMS is unbounded from above.Both ablation 2 and ablation 3 lead to higher language modeling loss than DIFF V2. Ablation 2 maintains training stability similar to DIFF V2, whereas ablation 3 is less stable (still more stable than ablation 1).
1.5*h heads which aligns parameter with DIFF V2.Ablation 4 also has higher training loss comparing to DIFF V2.