|
| 1 | +# Learnable Attention Head Routing Extension |
| 2 | + |
| 3 | +This extension implements a **Learnable Attention Head Routing** mechanism for NanoGPT that dynamically routes input tokens to a subset of attention heads per forward pass, inspired by Mixture-of-Experts (MoE) architectures. |
| 4 | + |
| 5 | +## Quick Demo: Text Generation Comparison |
| 6 | + |
| 7 | +Here's a side-by-side comparison of text generation from both models: |
| 8 | + |
| 9 | +**Prompt**: "He was a just a phd who" |
| 10 | + |
| 11 | +**Baseline Model** (90.4M parameters): |
| 12 | +``` |
| 13 | +[First sample] |
| 14 | +He was a just a phd who characterised [[poaching]] of [[Michael Blaus]], and was reprinted as a result of an [[understanding]] of [[Archerian literature]] that the phantoms had no coverage of the [[Professional world]]. |
| 15 | +
|
| 16 | +[Second sample] |
| 17 | +He was a just a phd who was the good of excessive and close to the [[Earth]]. He is the predominant movement who lived as a certain nation of the [[artificial intelligence]] and [[film production]], and the predominant stat |
| 18 | +``` |
| 19 | + |
| 20 | +**Routing (Our) Model** (94.0M parameters): |
| 21 | +``` |
| 22 | +[First sample] |
| 23 | +He was a just a phd who chanted his life to approach better awareness than the passion, and presented his life to the motion picture in a video game introduced by [[Ben Konrad]]. He is the man who lived in the [[United Sta |
| 24 | +
|
| 25 | +[Second sample] |
| 26 | +He was a just a phd who destroyed the Council of Europe. The neck attempted to include a coup attempt by stopping with the government in a way that was present at the first Commission on European [[August 19]], [[1905]]. Th |
| 27 | +``` |
| 28 | + |
| 29 | + |
| 30 | +## Overview |
| 31 | + |
| 32 | +The routing mechanism allows the model to learn which attention heads are most relevant for processing each token, potentially improving efficiency and performance by: |
| 33 | +- Reducing computational overhead by using only a subset of heads per token |
| 34 | +- Encouraging specialization among attention heads |
| 35 | +- Maintaining the causal structure of the original transformer |
| 36 | + |
| 37 | +## Implementation Details |
| 38 | + |
| 39 | +### Core Changes |
| 40 | + |
| 41 | +The implementation is contained in `model_novel.py` and includes the following key modifications: |
| 42 | + |
| 43 | +#### 1. GPTConfig Extensions |
| 44 | + |
| 45 | +```python |
| 46 | +class GPTConfig: |
| 47 | + # ... existing parameters ... |
| 48 | + use_routing: bool = False |
| 49 | + top_k_heads: int = None # Number of heads to select (k out of n_head) |
| 50 | + entropy_reg_coef: float = 0.01 # Entropy regularization coefficient |
| 51 | +``` |
| 52 | + |
| 53 | +#### 2. CausalSelfAttention with Routing |
| 54 | + |
| 55 | +```python |
| 56 | +class CausalSelfAttention(nn.Module): |
| 57 | + def __init__(self, config): |
| 58 | + # ... existing initialization ... |
| 59 | + |
| 60 | + # Add gating network for routing |
| 61 | + if config.use_routing: |
| 62 | + self.gate_net = nn.Sequential( |
| 63 | + nn.Linear(config.n_embd, config.n_embd // 2), |
| 64 | + nn.GELU(), |
| 65 | + nn.Linear(config.n_embd // 2, config.n_head) |
| 66 | + ) |
| 67 | + self.top_k_heads = config.top_k_heads |
| 68 | + self.entropy_reg_coef = config.entropy_reg_coef |
| 69 | + |
| 70 | + def forward(self, x): |
| 71 | + # ... existing attention computation ... |
| 72 | + |
| 73 | + if self.use_routing and self.training: |
| 74 | + # Compute routing gates |
| 75 | + gate_logits = self.gate_net(x) |
| 76 | + |
| 77 | + # Apply top-k selection if specified |
| 78 | + if self.top_k_heads is not None: |
| 79 | + top_k_values, top_k_indices = torch.topk(gate_logits, self.top_k_heads, dim=-1) |
| 80 | + mask = torch.zeros_like(gate_logits).scatter_(-1, top_k_indices, 1) |
| 81 | + gate_logits = gate_logits * mask |
| 82 | + |
| 83 | + # Apply softmax to get routing weights |
| 84 | + routing_gates = F.softmax(gate_logits, dim=-1) |
| 85 | + |
| 86 | + # Apply gates to attention output |
| 87 | + y = y * routing_gates.transpose(1, 2).unsqueeze(-1) |
| 88 | + |
| 89 | + # Calculate entropy regularization loss |
| 90 | + entropy_loss = -(routing_gates * torch.log(routing_gates + 1e-8)).sum(dim=-1).mean() * self.entropy_reg_coef |
| 91 | + |
| 92 | + return y, entropy_loss |
| 93 | + else: |
| 94 | + return y |
| 95 | +``` |
| 96 | + |
| 97 | +#### 3. Block and GPT Modifications |
| 98 | + |
| 99 | +The `Block` and `GPT` classes are modified to propagate the entropy loss from the attention layers and add it to the main loss function. |
| 100 | + |
| 101 | +### Training Configuration |
| 102 | + |
| 103 | +Two training configurations are provided: |
| 104 | + |
| 105 | +#### Routing Model (`config/train_enwik8_routing.py`) |
| 106 | +```python |
| 107 | +# Routing-specific parameters |
| 108 | +use_routing = True |
| 109 | +top_k_heads = 6 # Use 6 out of 12 heads |
| 110 | +entropy_reg_coef = 0.01 |
| 111 | +max_iters = 200000 |
| 112 | +``` |
| 113 | + |
| 114 | +#### Baseline Model (`config/train_enwik8_routing_ablation.py`) |
| 115 | +```python |
| 116 | +# Baseline parameters (routing disabled) |
| 117 | +use_routing = False |
| 118 | +top_k_heads = None |
| 119 | +entropy_reg_coef = 0.0 |
| 120 | +max_iters = 200000 |
| 121 | +``` |
| 122 | + |
| 123 | +## Training Setup |
| 124 | + |
| 125 | +### Dataset |
| 126 | +- **Dataset**: enwik8 (first 100M characters of Wikipedia) |
| 127 | +- **Splits**: 90M train / 5M validation / 5M test |
| 128 | +- **Format**: Binary tokenized data |
| 129 | + |
| 130 | +### Model Architecture |
| 131 | +- **Base Model**: GPT-2 style transformer |
| 132 | +- **Parameters**: ~90M parameters |
| 133 | +- **Attention Heads**: 12 heads |
| 134 | +- **Routing**: Top-6 head selection (6 out of 12 heads) |
| 135 | + |
| 136 | +### Training Details |
| 137 | +- **Iterations**: 200,000 iterations for both models |
| 138 | +- **Learning Rate**: Cosine decay schedule |
| 139 | +- **Batch Size**: Optimized for available GPU memory |
| 140 | +- **Entropy Regularization**: 0.01 coefficient to encourage specialization |
| 141 | + |
| 142 | +## Results |
| 143 | + |
| 144 | +### Performance Comparison |
| 145 | + |
| 146 | +| Model | Loss (Nats) | BPC (Base 2) | Parameters<sup>†</sup> | Improvement | |
| 147 | +|----------|-------------|--------------|------------------------|-------------| |
| 148 | +| Baseline | 0.873679 | 1.26045 | 90.397M | - | |
| 149 | +| Routing | 0.844327 | 1.218107 | 93.997M | **3.36%** | |
| 150 | + |
| 151 | +<sup>†</sup>Parameter counts include positional embeddings. |
| 152 | + |
| 153 | +### Key Findings |
| 154 | + |
| 155 | +1. **Performance Improvement**: The routing model achieves a 3.36% improvement in BPC (1.218 vs 1.260) |
| 156 | +2. **Parameter Efficiency**: Only 4.0% parameter increase (3.6M additional parameters) |
| 157 | +3. **Computational Efficiency**: Uses only 6 out of 12 attention heads per token |
| 158 | +4. **Convergence**: Both models trained for 200K iterations with stable convergence |
| 159 | + |
| 160 | +### Training Curves |
| 161 | + |
| 162 | +The training progress is logged and visualized through: |
| 163 | + |
| 164 | +- **Loss Curves**: Training and validation loss over iterations |
| 165 | +- **Learning Rate**: Cosine decay schedule visualization |
| 166 | +- **BPC Metrics**: Bits-per-character in base 2 |
| 167 | +- **Overfitting Analysis**: Training vs validation gap |
| 168 | + |
| 169 | +#### Individual Model Curves |
| 170 | + |
| 171 | +**Routing Model Training Curves:** |
| 172 | + |
| 173 | + |
| 174 | +**Baseline Model Training Curves:** |
| 175 | + |
| 176 | + |
| 177 | +#### Model Comparison |
| 178 | + |
| 179 | +The comparison shows the routing model achieving better validation BPC throughout training: |
| 180 | + |
| 181 | + |
| 182 | + |
| 183 | +### Training Results |
| 184 | + |
| 185 | +After 200,000 iterations of training, the models achieved the following final validation performance: |
| 186 | + |
| 187 | +| Model | Final Validation Loss (Nats) | Final Validation BPC (Base 2) | Parameters | |
| 188 | +|-------|------------------------------|-------------------------------|------------| |
| 189 | +| Baseline | 0.862398 | 1.244177 | 90.397M | |
| 190 | +| Routing | 0.844890 | 1.218919 | 93.997M | |
| 191 | + |
| 192 | +**Key Results:** |
| 193 | +- **BPC Improvement**: 2.03% improvement in validation BPC (1.244 → 1.219) |
| 194 | +- **Parameter Overhead**: 3.6M additional parameters (4.0% increase) |
| 195 | +- **Training Stability**: Both models converged smoothly over 200K iterations |
| 196 | +- **Routing Efficiency**: Uses only 6 out of 12 attention heads per token |
| 197 | + |
| 198 | +### Important Notes |
| 199 | + |
| 200 | +**Educational Project Limitations:** |
| 201 | +This repository was implemented as an educational side project to explore learnable attention head routing mechanisms. Due to limited computational resources and time constraints, no hyperparameter tuning was carried out. The results presented here are from a single training run with default configurations. |
| 202 | + |
| 203 | +**Evaluation Results:** |
| 204 | +The evaluation results shown are based on batch sampling rather than full test set evaluation due to computational time constraints. Full test set evaluation would provide more accurate results but requires significantly more time. The batch results may be slightly better or worse than actual full test set performance. |
| 205 | + |
| 206 | +**Potential Improvements:** |
| 207 | +- Hyperparameter optimization (learning rate, batch size, model architecture) |
| 208 | +- Different routing configurations (varying top-k values, entropy regularization) |
| 209 | +- Longer training runs with more iterations |
| 210 | +- Full test set evaluation for more accurate metrics |
| 211 | +- Ablation studies on different routing mechanisms |
| 212 | + |
| 213 | +## Usage |
| 214 | + |
| 215 | +### Training the Routing Model |
| 216 | + |
| 217 | +```bash |
| 218 | +# Train the routing model |
| 219 | +python train_novel.py config/train_enwik8_routing.py |
| 220 | + |
| 221 | +# Train the baseline model (ablation study) |
| 222 | +python train_novel.py config/train_enwik8_routing_ablation.py |
| 223 | +``` |
| 224 | + |
| 225 | +### Evaluation |
| 226 | + |
| 227 | +```bash |
| 228 | +# Evaluate both models on test data |
| 229 | +python evaluate_models_binary.py \ |
| 230 | + --routing_checkpoint out-enwik8-routing/ckpt.pt \ |
| 231 | + --baseline_checkpoint out-enwik8-routing-ablation/ckpt.pt \ |
| 232 | + --test_data data/enwik8/test.bin |
| 233 | +``` |
| 234 | + |
| 235 | +### Visualization |
| 236 | + |
| 237 | +```bash |
| 238 | +# Auto-detect and plot both models (recommended) |
| 239 | +python plot_training_curves.py |
| 240 | + |
| 241 | +# Plot individual models |
| 242 | +python plot_training_curves.py --log_file out-enwik8-routing/training_log.json |
| 243 | +python plot_training_curves.py --log_file out-enwik8-routing-ablation/training_log.json |
| 244 | + |
| 245 | +# Manual comparison |
| 246 | +python plot_training_curves.py |
| 247 | + |
| 248 | +### Text Generation |
| 249 | + |
| 250 | +Generate text samples from both models for comparison: |
| 251 | + |
| 252 | +```bash |
| 253 | +# Basic text generation |
| 254 | +python sample_both_models.py --start "The quick brown fox" |
| 255 | +
|
| 256 | +# Generate longer text with different settings |
| 257 | +python sample_both_models.py --start "He was a just a phd who" --max_new_tokens 300 --temperature 0.8 |
| 258 | +
|
| 259 | +# Generate multiple samples |
| 260 | +python sample_both_models.py --start "In the beginning" --num_samples 5 --max_new_tokens 200 |
| 261 | +
|
| 262 | +# Use different generation parameters |
| 263 | +python sample_both_models.py --start "The future of AI" --temperature 1.2 --top_k 100 |
| 264 | +``` |
| 265 | + |
| 266 | +**Generation Parameters:** |
| 267 | +- `--start`: Starting text or prompt |
| 268 | +- `--max_new_tokens`: Number of tokens to generate (default: 200) |
| 269 | +- `--temperature`: Controls randomness (0.8 = balanced, 1.2 = more random) |
| 270 | +- `--top_k`: Number of top tokens to consider (default: 200) |
| 271 | +- `--num_samples`: Number of samples to generate (default: 3) |
| 272 | + |
| 273 | +## Technical Details |
| 274 | + |
| 275 | +### Routing Mechanism |
| 276 | + |
| 277 | +1. **Gate Network**: Lightweight MLP that predicts routing weights for each token |
| 278 | +2. **Top-K Selection**: Optional mechanism to select only the k most relevant heads |
| 279 | +3. **Entropy Regularization**: Encourages specialization by penalizing uniform distributions |
| 280 | +4. **Causal Preservation**: Maintains the autoregressive structure of the transformer |
| 281 | + |
| 282 | +### Efficiency Analysis |
| 283 | + |
| 284 | +- **Memory**: Minimal overhead from gate network parameters |
| 285 | +- **Computation**: Reduced attention computation when using top-k selection |
| 286 | +- **Scalability**: Mechanism scales with model size and can be applied to larger models |
| 287 | + |
| 288 | +### Compatibility |
| 289 | + |
| 290 | +- **Backward Compatible**: Falls back to standard multi-head attention when routing is disabled |
| 291 | +- **Training Loop**: Compatible with existing NanoGPT training infrastructure |
| 292 | +- **Checkpointing**: Supports model checkpointing and resuming |
| 293 | + |
| 294 | + |
| 295 | + |
| 296 | +## Conclusion |
| 297 | + |
| 298 | +The learnable attention head routing extension demonstrates promising results with a 3.36% improvement in BPC while maintaining computational efficiency through top-k head selection. The mechanism successfully encourages specialization among attention heads while preserving the causal structure of the transformer architecture. |
| 299 | + |
| 300 | +python train_novel.py --config=config/train_enwik8_routing.py |
| 301 | +Traceback (most recent call last): |
| 302 | + File "/gpfs/work4/0/tdse0635/nanoGPT/train_novel.py", line 70, in <module> |
| 303 | + exec(open('configurator.py').read()) # overrides from command line or config file |
| 304 | + File "<string>", line 47, in <module> |
| 305 | +ValueError: Unknown config key: config |
0 commit comments