|
Compute-Aware Hybrid Attention Architecture Search: Selective Linear Attention Replacement via Layer-wise Knowledge Distillation |
|||
| Videet Mehta | Sanjith Udupa | Vineet Sharma | |
| Final project for 6.7960, MIT | December 2025 | |||
|
Compute-Aware Hybrid Attention Architecture Search: Selective Linear Attention Replacement via Layer-wise Knowledge Distillation |
|||
| Videet Mehta | Sanjith Udupa | Vineet Sharma | |
| Final project for 6.7960, MIT | December 2025 | |||
The computational demands of large language models continue to grow, driven largely by the quadratic complexity of self-attention. For a sequence of length \(n\), standard attention requires \(O(n^2)\) time and memory. This makes long-context inference expensive and limits practical deployment.
Linear attention mechanisms reduce this complexity to \(O(n)\), but replacing all attention layers with linear variants typically degrades model quality. This raises a question: are all attention layers equally important, or can we selectively replace only those layers where linear attention performs well?
Our goal was to develop a method for identifying which layers in a transformer can be safely replaced with linear attention, and which layers should retain full softmax attention.
We developed a two-stage approach. In Stage 0, we train linear attention replacements for each layer independently using knowledge distillation. This gives us a measure of how "replaceable" each layer is. In Stage 1, we assemble a hybrid model by selecting the best-performing layers for replacement and fine-tune the complete model using knowledge distillation with a decoupled top-k KL divergence objective.
We conducted experiments using pretrained layers from Qwen3-1.7B [8] (28 layers, 2048 hidden size) to train layer replacements for two different linear attention mechanisms: Gated Linear Attention (GLA) and RWKV7 from the Flash Linear Attention library in Stage 0. We then constructed a hybrid model picking between these trained layer replacements and used Qwen3-8B as the teacher for Stage 1 distillation. We evaluated a hybrid model that used GLA. Our results show clear patterns in layer sensitivity, with middle layers being significantly harder to replace than early or late layers.
We investigated two linear attention mechanisms as candidates for layer replacement: Gated Linear Attention (GLA) and RWKV7.
Gated Linear Attention (GLA) [4] replaces softmax attention with a gated recurrent formulation. The key insight is that attention can be computed as a weighted sum over a recurrent state, with data-dependent gates controlling information flow.
For each position \(t\), GLA maintains a state matrix \(S_t\) and computes:
\[S_t = G_t \odot S_{t-1} + k_t^T v_t\] \[o_t = q_t S_t\]
where \(G_t\) is a learned gate matrix that controls how much of the previous state to retain. This gating mechanism provides expressivity beyond simple linear attention by allowing the model to selectively forget or retain information from previous positions.
The recurrent formulation enables \(O(n)\) complexity since each position only needs to update and query a fixed-size state matrix. GLA also supports efficient parallel training through a chunked computation that processes multiple positions simultaneously.
RWKV7 [7] is a linear attention variant that combines aspects of RNNs and transformers. It uses a time-mixing mechanism with learned decay rates to control information flow across positions. RWKV7 maintains a recurrent state that is updated at each position:
\[s_t = \text{diag}(w) \cdot s_{t-1} + k_t^T v_t\] \[o_t = \sigma(r_t) \odot (q_t \cdot s_t)\]
where \(w\) is a learned decay vector, \(r_t\) is a receptance gate, and \(\sigma\) is a nonlinearity. The decay mechanism allows RWKV7 to model different time scales of information retention.
We use implementations from the Flash Linear Attention (FLA) library, which provides optimized CUDA kernels for both GLA and RWKV7 for training and inference.
The first stage trains linear attention replacements for each layer independently. We trained both GLA and RWKV7 variants to measure how well each layer's behavior can be approximated by linear attention, which tells us which layers are good candidates for replacement.
We modified the Qwen3 architecture to create a
Qwen3ForNAS class that outputs the hidden states before
and after any specified layer. For each layer \(l\), we train separate
GLA and RWKV7 blocks to match the teacher's behavior by minimizing MSE
of the output hidden states of the student and teacher models:
\[\mathcal{L}_{\text{MSE}}^{(l)} = \mathbb{E}_{x \sim \mathcal{D}} \left[ \left\| f_{\text{student}}^{(l)}(h^{(l-1)}) - f_{\text{teacher}}^{(l)}(h^{(l-1)}) \right\|_2^2 \right]\]
We trained each layer on approximately 40 million tokens from the DCLM-baseline-1.0 dataset with batch size 32, sequence length 1024, learning rate 1e-3 with cosine decay to 1e-5, and 10% warmup. To prevent early training divergence, we apply gradient clipping. All training was performed on NVIDIA A6000 GPUs. We trained GLA for layers 1-10 and independently trained RWKV7 for layers 1-28. Training time scaled superlinearly with layer depth.
After training the individual layers, we needed a way of comparing the student layers' performance to reproduce their teacher model's hidden states. However, different layers have different activation magnitudes, making the raw MSE values incomparable. We compute a normalized loss that measures the student's error relative to the magnitude of the teacher's update:
\[\mathcal{L}_{\text{norm}}^{(l)} = \frac{\text{MSE}(f_{\text{student}}^{(l)}(h^{(l-1)}), h^{(l)})}{\text{MSE}(h^{(l-1)}, h^{(l)})}\]
A normalized loss of 0.0 means perfect reproduction; 1.0 means the student is no better than passing through the input unchanged.
Figure 1: Training curves for GLA student layers showing normalized MSE loss over training.
Figure 2: Training curves for RWKV7 student layers showing normalized MSE loss over training.
We encountered persistent numerical instability with RWKV7 during Stage 1 training (full model distillation), stemming from mixed precision mismatches between the FLA library's RWKV7 implementation and our teacher model's precision requirements. These floating point arithmetic errors proved difficult to resolve within our compute constraints. As a result, we were unable to successfully complete Stage 1 distillation for any hybrid model configurations that included RWKV7 layers. We therefore proceeded with only GLA layers for the final hybrid model configuration used in Stage 1.
After training each layer, we ranked them by their final normalized loss. Based on our experiments, we constrained GLA to layers 1-10. For RWKV7, we trained layers 1-28 and observed similar Stage 0 training results with comparable normalized loss values to GLA, indicating that RWKV7 layers were also capable of approximating the teacher model's behavior during Stage 0. However, we ultimately did not use RWKV7 layers in our Stage 1 hybrid model because Stage 1 distillation consistently failed with RWKV7 layers due to the numerical instability issues described above. The results for both GLA and RWKV7 in Stage 0 revealed a clear pattern: middle layers (roughly layers 8-14) have significantly higher normalized loss than early or late layers, suggesting this pattern holds across different linear attention mechanisms.
This finding aligns with prior work on transformer interpretability. Clark et al. [5] showed that early layers primarily detect local syntactic patterns while middle layers handle more complex semantic relationships. Tenney et al. [6] demonstrated that the "classical NLP pipeline" emerges across layers, with syntax in early layers and semantics in middle layers. Our results suggest that the complex semantic operations in middle layers fundamentally require the expressivity of softmax attention, while the more local computations in early layers can be approximated by linear attention.
| Rank | Layer | Normalized Loss |
|---|---|---|
| 1 | 3 | 0.0003 |
| 2 | 9 | 0.1059 |
| 3 | 10 | 0.1186 |
| 4 | 7 | 0.2129 |
| 5 | 8 | 0.2396 |
| 6 | 1 | 0.3459 |
| 7 | 2 | 0.3585 |
| ... | ... | ... |
| Worst | 6 | 0.4688 |
Table 1: GLA layers ranked by normalized distillation loss. Lower loss indicates better suitability for replacement.
The standout result is layer 3, which achieves a normalized loss of just 0.0003. This layer appears to perform computations that are nearly linear in nature, making it an ideal candidate for replacement.
After identifying which layers to replace, we assemble the hybrid model and fine-tune it using knowledge distillation from Qwen3-8B. We use a two-GPU configuration with NVIDIA A6000 GPUs: the 8B teacher runs on one GPU while the 1.7B hybrid student runs on another. The precision for each model was chosen to fit within GPU memory constraints—the teacher uses FP16 and the student uses BF16. We specifically chose to distill from a larger model than the base architecture to get the best possible downstream accuracy, while still being able to maintain efficient training with our distributed training approach.
Based on our Stage 0 results, we construct our hybrid models by selectively choosing the top layers within a certain percentile based on their Stage 0 training losses. Our configuration generator takes the Stage 0 distillation losses for each layer type (GLA and RWKV7), filters layers by allowed ranges (GLA: layers 1-10, RWKV7: layers 11-22), sorts them by loss, and selects the top N% of layers for replacement, where N is the target percentile. Below is a list of several hybrid model configurations generated with percentile-based layer selection of the top 10%, 25%, and 50% of layers respectively. We also had a modified version of the top10 configuration where we only used GLA layers (called top10_gla) and this is what we ultimately used for our Stage 1 distillation. We will discuss later why we specifically chose to use GLA only, but given resource constraints for training it was only possible to move forward with one configuration. However, we still ran performance analysis on all the configurations to get a sense of how the different model configurations would perform downstream in terms of inference speed metrics like time to first token and token throughput, but only measured accuracy for the final top10_gla configuration because we deemed it would be representative of the process as a whole. The top10_gla configuration uses layers 1, 2, 3, 7, and 9 with GLA, keeping the remaining 23 layers as full attention. This configuration was chosen because these layers showed the lowest normalized distillation loss, indicating they could be well-approximated by linear attention.
| Configuration | GLA Layers | RWKV7 Layers | # Full Attention Layers |
|---|---|---|---|
| control | none | none | 28 |
| top10 | [3] | [21] | 26 |
| top10_gla (Final) | [1, 2, 3, 7, 9] | none | 23 |
| top25 | [3, 9] | [20, 21, 22] | 23 |
| top50 | [3, 7, 8, 9, 10] | [17, 18, 19, 20, 21, 22] | 17 |
Table 2: Final hybrid model configuration.
Standard KL divergence over the full vocabulary (150,000+ tokens) is dominated by low-probability tokens in the tail of the distribution. To address this, we use a decoupled top-k KL divergence that focuses on the most important tokens, building on the idea of top-k distillation from previous work [3].
Let \(p_T\) and \(p_S\) denote the teacher and student probability distributions over the vocabulary \(\mathcal{V}\). For a given top-k set \(\mathcal{T}_k\) containing the k tokens with highest teacher probability, we define:
\[\rho_T = \sum_{v \in \mathcal{T}_k} p_T(v), \quad \rho_S = \sum_{v \in \mathcal{T}_k} p_S(v)\]
These represent the total probability mass that each distribution assigns to the top-k tokens.
The decoupled KL divergence consists of two components. First, a Bernoulli KL term that measures how well the student matches the teacher's allocation of probability mass between the top-k set and the rest:
\[\mathcal{L}_{\text{Bern}} = \rho_T \log\frac{\rho_T}{\rho_S} + (1-\rho_T) \log\frac{1-\rho_T}{1-\rho_S}\]
Second, a categorical KL term computed only over the top-k tokens. We renormalize the distributions within the top-k set and compute KL divergence with temperature scaling:
\[\tilde{p}_T(v) = \frac{p_T(v)}{\rho_T}, \quad \tilde{p}_S(v) = \frac{p_S(v)}{\rho_S} \quad \text{for } v \in \mathcal{T}_k\]
\[\mathcal{L}_{\text{top-k}} = T^2 \cdot \rho_T \cdot \text{KL}(\tilde{p}_T \| \tilde{p}_S)\]
where \(T\) is the temperature parameter. The \(\rho_T\) factor weights this term by how much probability mass is in the top-k, and \(T^2\) compensates for gradient magnitude reduction from temperature scaling.
The final KL loss combines both components:
\[\mathcal{L}_{\text{KL}} = \mathcal{L}_{\text{top-k}} + \mathcal{L}_{\text{Bern}}\]
The total distillation loss is:
\[\mathcal{L}_{\text{total}} = \alpha \cdot \mathcal{L}_{\text{CE}} + (1-\alpha) \cdot \mathcal{L}_{\text{KL}}\]
We use \(\alpha = 0.5\), \(T = 4.0\), and \(k = 32\) tokens.
Stage 1 distillation proceeded in two phases. First, we trained for 3,000 steps with batch size 8 and sequence length 512. Then, because linear attention can struggle with longer contexts, we continued training for another 5,000 steps with sequence length 1,024 and batch size 4. This totals approximately 8,000 steps and 32.8 million tokens (3,000 × 8 × 512 + 5,000 × 4 × 1,024 = 32,768,000). We used learning rate 5e-5 with cosine annealing, 200 warmup steps, BF16 precision, gradient checkpointing, and gradient clipping to prevent early training divergence.
We compared training with full KL divergence versus our decoupled top-k KL divergence approach.
Figure 3: Training loss curves comparing full KL divergence vs top-k KL divergence.
Important: The total loss and KL loss values cannot be directly compared between the two methods because they operate on different scales. Full KL divergence sums over 150,000+ vocabulary items, while top-k KL operates on only 32 tokens. What we can compare is how well the student learns the language modeling task, as measured by the cross-entropy component. We observe that top-k KL divergence leads to better CE loss convergence, suggesting that focusing on high-probability tokens provides a cleaner learning signal.
Our Stage 0 results show that middle layers are significantly harder to replace than early or late layers. We interpret this as reflecting the functional organization of transformers:
Figure 4: Throughput comparison across model configurations.
Figure 5: Time to First Token comparison across model configurations (log scale).
| Configuration | GLA Layers | RWKV7 Layers | Time to First Token (s) | Throughput (tok/s) | Throughput Speedup | TTFT Speedup |
|---|---|---|---|---|---|---|
| control | none | none | 1.31 | 5.44 | 1.00x | 1.00x |
| top10_gla | [1, 2, 3, 7, 9] | none | 0.06 | 21.09 | 3.88x | 22.64x |
| top25 | [3, 9] | [20, 21, 22] | 0.05 | 25.58 | 4.70x | 24.49x |
| top50 | [3, 7, 8, 9, 10] | [17, 18, 19, 20, 21, 22] | 0.05 | 28.40 | 5.22x | 27.68x |
Table 3: Model configurations with throughput and loss metrics.
To verify that the hybrid model maintains coherent generation:
Unfortunately, our hybrid model did not achieve good performance in text prediction during qualitative evaluation. The generated outputs were often incoherent or failed to follow the prompt context appropriately.
Our work has two primary limitations:
Limited training compute: Our knowledge distillation training used only approximately 33 million tokens (3,000 × 8 × 512 + 5,000 × 4 × 1,024 ≈ 32,768,000 tokens). This is significantly less than the billions of tokens typically used for pretraining, and also far below the 350-700 million tokens used by RADLADS [9] for their successful linear attention conversions. Additionally, we observed that training time for GLA layers scaled superlinearly with layer depth, which limited our ability to train all layers within the available compute budget. With more compute and training tokens, the hybrid model would likely achieve better performance and more closely match the teacher model's capabilities.
KV cache incompatibility: Standard transformer inference relies heavily on key-value (KV) caching to avoid recomputing attention for previous tokens. However, linear attention mechanisms like GLA use a fundamentally different computation paradigm based on recurrent state updates rather than explicit key-value storage. This means the standard KV cache approach does not directly apply to hybrid models. To achieve strong inference performance, we may need to implement a separate caching mechanism specifically designed for linear attention layers, which adds engineering complexity and may not provide the same speedups as traditional KV caching.
Several directions could extend this work:
We presented a compute-aware architecture search method for constructing hybrid attention transformers. By training linear attention replacements for each layer independently and measuring normalized distillation loss, we can identify which layers are good candidates for replacement.
Our key finding is that layer position strongly predicts amenability to linear attention replacement. Early layers can often be replaced with minimal quality loss, while middle layers are significantly harder to approximate. This suggests that middle layers perform more complex computations that fundamentally require the expressivity of softmax attention.
We also adopted decoupled top-k KL divergence as an improved objective for knowledge distillation, focusing on high-probability tokens rather than matching the full distribution over 150,000+ vocabulary items.
As language models continue to grow, efficient attention mechanisms become increasingly important. Our work provides a principled methodology for navigating the efficiency-quality tradeoff and contributes toward making large language models more practical to deploy.
Code available at: [GitHub Repository]