Compute-Aware Hybrid Attention Architecture Search:
Selective Linear Attention Replacement via Layer-wise Knowledge Distillation
Videet Mehta Sanjith Udupa Vineet Sharma
Final project for 6.7960, MIT  |  December 2025
Outline

Introduction & Motivation

Related Work

Linear Attention

Stage 0 Training

Stage 1 Training

Results

Limitations & Future

Conclusion

References

Introduction and Motivation

The computational demands of large language models continue to grow, driven largely by the quadratic complexity of self-attention. For a sequence of length \(n\), standard attention requires \(O(n^2)\) time and memory. This makes long-context inference expensive and limits practical deployment.

Linear attention mechanisms reduce this complexity to \(O(n)\), but replacing all attention layers with linear variants typically degrades model quality. This raises a question: are all attention layers equally important, or can we selectively replace only those layers where linear attention performs well?

Our goal was to develop a method for identifying which layers in a transformer can be safely replaced with linear attention, and which layers should retain full softmax attention.

Hypothesis: Transformer layers exhibit varying sensitivity to attention mechanism replacement. By measuring how well a linear attention module can replicate each layer's behavior through distillation loss, we can identify an optimal hybrid architecture that improves efficiency while preserving model quality.

We developed a two-stage approach. In Stage 0, we train linear attention replacements for each layer independently using knowledge distillation. This gives us a measure of how "replaceable" each layer is. In Stage 1, we assemble a hybrid model by selecting the best-performing layers for replacement and fine-tune the complete model using knowledge distillation with a decoupled top-k KL divergence objective.

We conducted experiments using pretrained layers from Qwen3-1.7B [8] (28 layers, 2048 hidden size) to train layer replacements for two different linear attention mechanisms: Gated Linear Attention (GLA) and RWKV7 from the Flash Linear Attention library in Stage 0. We then constructed a hybrid model picking between these trained layer replacements and used Qwen3-8B as the teacher for Stage 1 distillation. We evaluated a hybrid model that used GLA. Our results show clear patterns in layer sensitivity, with middle layers being significantly harder to replace than early or late layers.







The quadratic scaling becomes particularly problematic for applications like document understanding and multi-turn dialogue that require long context windows.













Our approach treats architecture search as a measurement problem rather than a search problem.

Linear Attention Mechanisms

We investigated two linear attention mechanisms as candidates for layer replacement: Gated Linear Attention (GLA) and RWKV7.

Gated Linear Attention (GLA)

Gated Linear Attention (GLA) [4] replaces softmax attention with a gated recurrent formulation. The key insight is that attention can be computed as a weighted sum over a recurrent state, with data-dependent gates controlling information flow.

For each position \(t\), GLA maintains a state matrix \(S_t\) and computes:

\[S_t = G_t \odot S_{t-1} + k_t^T v_t\] \[o_t = q_t S_t\]

where \(G_t\) is a learned gate matrix that controls how much of the previous state to retain. This gating mechanism provides expressivity beyond simple linear attention by allowing the model to selectively forget or retain information from previous positions.

The recurrent formulation enables \(O(n)\) complexity since each position only needs to update and query a fixed-size state matrix. GLA also supports efficient parallel training through a chunked computation that processes multiple positions simultaneously.

RWKV7

RWKV7 [7] is a linear attention variant that combines aspects of RNNs and transformers. It uses a time-mixing mechanism with learned decay rates to control information flow across positions. RWKV7 maintains a recurrent state that is updated at each position:

\[s_t = \text{diag}(w) \cdot s_{t-1} + k_t^T v_t\] \[o_t = \sigma(r_t) \odot (q_t \cdot s_t)\]

where \(w\) is a learned decay vector, \(r_t\) is a receptance gate, and \(\sigma\) is a nonlinearity. The decay mechanism allows RWKV7 to model different time scales of information retention.

We use implementations from the Flash Linear Attention (FLA) library, which provides optimized CUDA kernels for both GLA and RWKV7 for training and inference.











The key difference from standard attention is that these linear mechanisms never materialize the full \(n \times n\) attention matrix.









GLA's gating mechanism makes it more expressive than simple linear attention while maintaining linear complexity.









RWKV7's decay mechanism provides a different inductive bias, allowing explicit control over how quickly information fades.

Stage 0: Layer-wise Linear Attention Training

The first stage trains linear attention replacements for each layer independently. We trained both GLA and RWKV7 variants to measure how well each layer's behavior can be approximated by linear attention, which tells us which layers are good candidates for replacement.

Training Setup

We modified the Qwen3 architecture to create a Qwen3ForNAS class that outputs the hidden states before and after any specified layer. For each layer \(l\), we train separate GLA and RWKV7 blocks to match the teacher's behavior by minimizing MSE of the output hidden states of the student and teacher models:

\[\mathcal{L}_{\text{MSE}}^{(l)} = \mathbb{E}_{x \sim \mathcal{D}} \left[ \left\| f_{\text{student}}^{(l)}(h^{(l-1)}) - f_{\text{teacher}}^{(l)}(h^{(l-1)}) \right\|_2^2 \right]\]

Training Configuration

We trained each layer on approximately 40 million tokens from the DCLM-baseline-1.0 dataset with batch size 32, sequence length 1024, learning rate 1e-3 with cosine decay to 1e-5, and 10% warmup. To prevent early training divergence, we apply gradient clipping. All training was performed on NVIDIA A6000 GPUs. We trained GLA for layers 1-10 and independently trained RWKV7 for layers 1-28. Training time scaled superlinearly with layer depth.

Normalized Loss

After training the individual layers, we needed a way of comparing the student layers' performance to reproduce their teacher model's hidden states. However, different layers have different activation magnitudes, making the raw MSE values incomparable. We compute a normalized loss that measures the student's error relative to the magnitude of the teacher's update:

\[\mathcal{L}_{\text{norm}}^{(l)} = \frac{\text{MSE}(f_{\text{student}}^{(l)}(h^{(l-1)}), h^{(l)})}{\text{MSE}(h^{(l-1)}, h^{(l)})}\]

A normalized loss of 0.0 means perfect reproduction; 1.0 means the student is no better than passing through the input unchanged.

Layer-wise Training Results

GLA student layer training curves showing normalized MSE loss over training

Figure 1: Training curves for GLA student layers showing normalized MSE loss over training.

Training curves for RWKV7 student layers showing normalized MSE loss over training

Figure 2: Training curves for RWKV7 student layers showing normalized MSE loss over training.

RWKV7 Issues

We encountered persistent numerical instability with RWKV7 during Stage 1 training (full model distillation), stemming from mixed precision mismatches between the FLA library's RWKV7 implementation and our teacher model's precision requirements. These floating point arithmetic errors proved difficult to resolve within our compute constraints. As a result, we were unable to successfully complete Stage 1 distillation for any hybrid model configurations that included RWKV7 layers. We therefore proceeded with only GLA layers for the final hybrid model configuration used in Stage 1.

Identifying the Best Layers

After training each layer, we ranked them by their final normalized loss. Based on our experiments, we constrained GLA to layers 1-10. For RWKV7, we trained layers 1-28 and observed similar Stage 0 training results with comparable normalized loss values to GLA, indicating that RWKV7 layers were also capable of approximating the teacher model's behavior during Stage 0. However, we ultimately did not use RWKV7 layers in our Stage 1 hybrid model because Stage 1 distillation consistently failed with RWKV7 layers due to the numerical instability issues described above. The results for both GLA and RWKV7 in Stage 0 revealed a clear pattern: middle layers (roughly layers 8-14) have significantly higher normalized loss than early or late layers, suggesting this pattern holds across different linear attention mechanisms.

This finding aligns with prior work on transformer interpretability. Clark et al. [5] showed that early layers primarily detect local syntactic patterns while middle layers handle more complex semantic relationships. Tenney et al. [6] demonstrated that the "classical NLP pipeline" emerges across layers, with syntax in early layers and semantics in middle layers. Our results suggest that the complex semantic operations in middle layers fundamentally require the expressivity of softmax attention, while the more local computations in early layers can be approximated by linear attention.

Rank Layer Normalized Loss
1 3 0.0003
2 9 0.1059
3 10 0.1186
4 7 0.2129
5 8 0.2396
6 1 0.3459
7 2 0.3585
... ... ...
Worst 6 0.4688

Table 1: GLA layers ranked by normalized distillation loss. Lower loss indicates better suitability for replacement.

The standout result is layer 3, which achieves a normalized loss of just 0.0003. This layer appears to perform computations that are nearly linear in nature, making it an ideal candidate for replacement.































The normalized loss is crucial for fair comparison. Without it, we would incorrectly conclude that layers with smaller activations are easier to replace.





















The pattern of middle layers being hardest to replace aligns with findings from interpretability research showing that middle layers handle more abstract reasoning.

Stage 1: Full Model Distillation with Decoupled Top-k KL Divergence

After identifying which layers to replace, we assemble the hybrid model and fine-tune it using knowledge distillation from Qwen3-8B. We use a two-GPU configuration with NVIDIA A6000 GPUs: the 8B teacher runs on one GPU while the 1.7B hybrid student runs on another. The precision for each model was chosen to fit within GPU memory constraints—the teacher uses FP16 and the student uses BF16. We specifically chose to distill from a larger model than the base architecture to get the best possible downstream accuracy, while still being able to maintain efficient training with our distributed training approach.

Hybrid Model Configuration

Based on our Stage 0 results, we construct our hybrid models by selectively choosing the top layers within a certain percentile based on their Stage 0 training losses. Our configuration generator takes the Stage 0 distillation losses for each layer type (GLA and RWKV7), filters layers by allowed ranges (GLA: layers 1-10, RWKV7: layers 11-22), sorts them by loss, and selects the top N% of layers for replacement, where N is the target percentile. Below is a list of several hybrid model configurations generated with percentile-based layer selection of the top 10%, 25%, and 50% of layers respectively. We also had a modified version of the top10 configuration where we only used GLA layers (called top10_gla) and this is what we ultimately used for our Stage 1 distillation. We will discuss later why we specifically chose to use GLA only, but given resource constraints for training it was only possible to move forward with one configuration. However, we still ran performance analysis on all the configurations to get a sense of how the different model configurations would perform downstream in terms of inference speed metrics like time to first token and token throughput, but only measured accuracy for the final top10_gla configuration because we deemed it would be representative of the process as a whole. The top10_gla configuration uses layers 1, 2, 3, 7, and 9 with GLA, keeping the remaining 23 layers as full attention. This configuration was chosen because these layers showed the lowest normalized distillation loss, indicating they could be well-approximated by linear attention.

Configuration GLA Layers RWKV7 Layers # Full Attention Layers
control none none 28
top10 [3] [21] 26
top10_gla (Final) [1, 2, 3, 7, 9] none 23
top25 [3, 9] [20, 21, 22] 23
top50 [3, 7, 8, 9, 10] [17, 18, 19, 20, 21, 22] 17

Table 2: Final hybrid model configuration.

Decoupled Top-k KL Divergence

Standard KL divergence over the full vocabulary (150,000+ tokens) is dominated by low-probability tokens in the tail of the distribution. To address this, we use a decoupled top-k KL divergence that focuses on the most important tokens, building on the idea of top-k distillation from previous work [3].

Let \(p_T\) and \(p_S\) denote the teacher and student probability distributions over the vocabulary \(\mathcal{V}\). For a given top-k set \(\mathcal{T}_k\) containing the k tokens with highest teacher probability, we define:

\[\rho_T = \sum_{v \in \mathcal{T}_k} p_T(v), \quad \rho_S = \sum_{v \in \mathcal{T}_k} p_S(v)\]

These represent the total probability mass that each distribution assigns to the top-k tokens.

The decoupled KL divergence consists of two components. First, a Bernoulli KL term that measures how well the student matches the teacher's allocation of probability mass between the top-k set and the rest:

\[\mathcal{L}_{\text{Bern}} = \rho_T \log\frac{\rho_T}{\rho_S} + (1-\rho_T) \log\frac{1-\rho_T}{1-\rho_S}\]

Second, a categorical KL term computed only over the top-k tokens. We renormalize the distributions within the top-k set and compute KL divergence with temperature scaling:

\[\tilde{p}_T(v) = \frac{p_T(v)}{\rho_T}, \quad \tilde{p}_S(v) = \frac{p_S(v)}{\rho_S} \quad \text{for } v \in \mathcal{T}_k\]

\[\mathcal{L}_{\text{top-k}} = T^2 \cdot \rho_T \cdot \text{KL}(\tilde{p}_T \| \tilde{p}_S)\]

where \(T\) is the temperature parameter. The \(\rho_T\) factor weights this term by how much probability mass is in the top-k, and \(T^2\) compensates for gradient magnitude reduction from temperature scaling.

The final KL loss combines both components:

\[\mathcal{L}_{\text{KL}} = \mathcal{L}_{\text{top-k}} + \mathcal{L}_{\text{Bern}}\]

The total distillation loss is:

\[\mathcal{L}_{\text{total}} = \alpha \cdot \mathcal{L}_{\text{CE}} + (1-\alpha) \cdot \mathcal{L}_{\text{KL}}\]

We use \(\alpha = 0.5\), \(T = 4.0\), and \(k = 32\) tokens.

Training Configuration

Stage 1 distillation proceeded in two phases. First, we trained for 3,000 steps with batch size 8 and sequence length 512. Then, because linear attention can struggle with longer contexts, we continued training for another 5,000 steps with sequence length 1,024 and batch size 4. This totals approximately 8,000 steps and 32.8 million tokens (3,000 × 8 × 512 + 5,000 × 4 × 1,024 = 32,768,000). We used learning rate 5e-5 with cosine annealing, 200 warmup steps, BF16 precision, gradient checkpointing, and gradient clipping to prevent early training divergence.





























































The Bernoulli term ensures the student allocates similar probability mass to the top-k set as the teacher. The categorical term ensures the relative ordering within top-k is preserved.

Results

Training Dynamics: Full KL vs Top-k KL

We compared training with full KL divergence versus our decoupled top-k KL divergence approach.

Training loss curves comparing full KL divergence vs top-k KL divergence

Figure 3: Training loss curves comparing full KL divergence vs top-k KL divergence.

Important: The total loss and KL loss values cannot be directly compared between the two methods because they operate on different scales. Full KL divergence sums over 150,000+ vocabulary items, while top-k KL operates on only 32 tokens. What we can compare is how well the student learns the language modeling task, as measured by the cross-entropy component. We observe that top-k KL divergence leads to better CE loss convergence, suggesting that focusing on high-probability tokens provides a cleaner learning signal.

Why Middle Layers Have the Highest Normalized Loss

Our Stage 0 results show that middle layers are significantly harder to replace than early or late layers. We interpret this as reflecting the functional organization of transformers:

Throughput Comparison

Throughput comparison chart

Figure 4: Throughput comparison across model configurations.

Time to First Token Comparison

Time to First Token comparison chart (log scale)

Figure 5: Time to First Token comparison across model configurations (log scale).

Configuration GLA Layers RWKV7 Layers Time to First Token (s) Throughput (tok/s) Throughput Speedup TTFT Speedup
control none none 1.31 5.44 1.00x 1.00x
top10_gla [1, 2, 3, 7, 9] none 0.06 21.09 3.88x 22.64x
top25 [3, 9] [20, 21, 22] 0.05 25.58 4.70x 24.49x
top50 [3, 7, 8, 9, 10] [17, 18, 19, 20, 21, 22] 0.05 28.40 5.22x 27.68x

Table 3: Model configurations with throughput and loss metrics.

Qualitative Evaluation

To verify that the hybrid model maintains coherent generation:

Unfortunately, our hybrid model did not achieve good performance in text prediction during qualitative evaluation. The generated outputs were often incoherent or failed to follow the prompt context appropriately.

















The scale difference between full and top-k KL makes direct loss comparison meaningless. The CE component is the fair comparison.























The middle-layer difficulty pattern was consistent across experiments, suggesting it reflects fundamental properties of the transformer architecture.

Limitations and Future Work

Limitations

Our work has two primary limitations:

Limited training compute: Our knowledge distillation training used only approximately 33 million tokens (3,000 × 8 × 512 + 5,000 × 4 × 1,024 ≈ 32,768,000 tokens). This is significantly less than the billions of tokens typically used for pretraining, and also far below the 350-700 million tokens used by RADLADS [9] for their successful linear attention conversions. Additionally, we observed that training time for GLA layers scaled superlinearly with layer depth, which limited our ability to train all layers within the available compute budget. With more compute and training tokens, the hybrid model would likely achieve better performance and more closely match the teacher model's capabilities.

KV cache incompatibility: Standard transformer inference relies heavily on key-value (KV) caching to avoid recomputing attention for previous tokens. However, linear attention mechanisms like GLA use a fundamentally different computation paradigm based on recurrent state updates rather than explicit key-value storage. This means the standard KV cache approach does not directly apply to hybrid models. To achieve strong inference performance, we may need to implement a separate caching mechanism specifically designed for linear attention layers, which adds engineering complexity and may not provide the same speedups as traditional KV caching.

Future Work

Several directions could extend this work:

Conclusion

We presented a compute-aware architecture search method for constructing hybrid attention transformers. By training linear attention replacements for each layer independently and measuring normalized distillation loss, we can identify which layers are good candidates for replacement.

Our key finding is that layer position strongly predicts amenability to linear attention replacement. Early layers can often be replaced with minimal quality loss, while middle layers are significantly harder to approximate. This suggests that middle layers perform more complex computations that fundamentally require the expressivity of softmax attention.

We also adopted decoupled top-k KL divergence as an improved objective for knowledge distillation, focusing on high-probability tokens rather than matching the full distribution over 150,000+ vocabulary items.

As language models continue to grow, efficient attention mechanisms become increasingly important. Our work provides a principled methodology for navigating the efficiency-quality tradeoff and contributes toward making large language models more practical to deploy.

































The KV cache limitation is a significant practical concern for deploying hybrid models in production settings.

References

[1] Gu, Y., et al. (2025). Jet-Nemotron: Efficient Language Model with Post Neural Architecture Search. arXiv preprint.

[2] Mercat, J., Vasiljevic, I., Keh, S., et al. (2024). Linearizing Large Language Models. arXiv preprint.

[3] Liquid AI. (2024). LFM2 Technical Report. arXiv preprint.

[4] Yang, S., Wang, B., Shen, Y., et al. (2024). Gated Linear Attention Transformers with Hardware-Efficient Training. ICML 2024.

[5] Clark, K., Khandelwal, U., Levy, O., & Manning, C. D. (2019). What Does BERT Look At? An Analysis of BERT's Attention. BlackboxNLP 2019.

[6] Tenney, I., Das, D., & Pavlick, E. (2019). BERT Rediscovers the Classical NLP Pipeline. ACL 2019.

[7] Peng, B., et al. (2024). Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence. ACL 2019.

[8] Qwen Team. (2024). Qwen2.5 Technical Report. arXiv preprint.

[9] Goldstein, D., Alcaide, E., Lu, J., & Cheah, E. (2025). RADLADS: Rapid Attention Distillation to Linear Attention Decoders at Scale. arXiv preprint.


Code available at: [GitHub Repository]