--- name: code-generation description: Use when generating PyTorch code from paper analysis to ensure correct mapping from paper to code --- # Code Generation from Papers ## Overview Guidelines for translating paper descriptions into working PyTorch code. **Announce at start:** "I'm using the code-generation skill to ensure accurate paper-to-code translation." ## Core Principles 1. **Traceability**: Every code block should reference paper section/equation 2. **Testability**: Write code that can be unit tested 3. **Readability**: Prefer clarity over cleverness 4. **Modularity**: One component per file ## Paper-to-Code Mapping ### Architecture Diagrams → nn.Module | Diagram Element | PyTorch Equivalent | |-----------------|-------------------| | Box/Block | nn.Module subclass | | Arrow | forward() call chain | | Split | Multiple outputs / tuple | | Merge | torch.cat / torch.add | | Skip connection | Residual addition | ### Equations → Tensor Operations | Notation | PyTorch | |----------|---------| | $Wx + b$ | `nn.Linear(in, out)` | | $\sigma(x)$ | `torch.sigmoid(x)` or `nn.Sigmoid()` | | $\text{softmax}(x)$ | `F.softmax(x, dim=-1)` | | $\|x\|_2$ | `torch.norm(x, p=2)` | | $x \odot y$ | `x * y` (element-wise) | | $x^T y$ | `torch.matmul(x.T, y)` or `x.T @ y` | | $\sum_i$ | `torch.sum(x, dim=i)` | | $\mathbb{E}[x]$ | `torch.mean(x)` | ### Loss Functions | Paper Description | PyTorch | |-------------------|---------| | Cross-entropy | `nn.CrossEntropyLoss()` | | MSE / L2 | `nn.MSELoss()` | | L1 | `nn.L1Loss()` | | BCE | `nn.BCEWithLogitsLoss()` | | KL divergence | `nn.KLDivLoss()` | | Custom | Subclass or functional | ## Code Structure Template ```python """ {component_name}.py Implements {what} from "{paper_title}" ({year}) Paper Reference: - Section: {section_number} - Equation: ({equation_number}) - Figure: {figure_number} Author: Auto-generated for paper replication """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, List class {ComponentName}(nn.Module): """ {One-line description} From paper: "{exact quote or paraphrase}" Args: {param1}: {description} (paper: {where specified}) {param2}: {description} Shape: - Input: {shape description} - Output: {shape description} Example: >>> layer = {ComponentName}(dim=512) >>> x = torch.randn(32, 100, 512) >>> out = layer(x) >>> out.shape torch.Size([32, 100, 512]) """ def __init__( self, {param1}: {type}, {param2}: {type} = {default}, ): super().__init__() # Paper Section X.Y: "{description}" self.layer1 = nn.Linear(...) # Equation (N): ... self.layer2 = nn.LayerNorm(...) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass implementing Equation (N). Args: x: Input tensor of shape (batch, seq, dim) Returns: Output tensor of shape (batch, seq, dim) """ # Step 1: ... (Eq. N, first term) h = self.layer1(x) # Step 2: ... (Eq. N, second term) out = self.layer2(h) return out ``` ## Common Patterns ### Residual Connection ```python # Paper: "We add a residual connection" out = self.sublayer(x) + x ``` ### Layer Normalization ```python # Paper: "Pre-LN Transformer" x = self.norm(x) x = self.attention(x) # Paper: "Post-LN Transformer" x = x + self.attention(x) x = self.norm(x) ``` ### Multi-Head Attention ```python # Paper: "Standard multi-head attention with h heads" self.attention = nn.MultiheadAttention( embed_dim=d_model, num_heads=h, dropout=dropout, batch_first=True, ) ``` ### Custom Activation ```python # Paper: "We use GELU activation" x = F.gelu(x) # Paper: "We use Swish/SiLU activation" x = F.silu(x) ``` ## Handling Ambiguity When paper is unclear: 1. **Check code repository** if available 2. **Follow common practice** for the architecture type 3. **Document assumption** in code comment 4. **Add TODO** for verification ```python # TODO: Paper unclear on initialization. Using PyTorch default. # See: https://github.com/paper/repo for reference implementation self.linear = nn.Linear(in_dim, out_dim) ``` ## Verification Checklist Before completing a module: - [ ] All equations implemented - [ ] Shapes documented and verified - [ ] Paper references in comments - [ ] Type hints complete - [ ] Example in docstring works - [ ] No hardcoded dimensions (use params) - [ ] Gradient flow verified (no in-place ops breaking autograd)