PaperTool/.opencode/skills/pytorch-patterns/SKILL.md

14 KiB

name description
pytorch-patterns Use when writing PyTorch code to follow best practices and common patterns

PyTorch Patterns Skill

Overview

This skill provides best practices and common patterns for writing PyTorch code. Use this when implementing neural networks, training loops, data loading, and related deep learning infrastructure.

Announce: "I'm using the pytorch-patterns skill for best practice code."


Model Definition

Basic nn.Module Pattern

from typing import NamedTuple
import torch
import torch.nn as nn
import torch.nn.functional as F


class ModelConfig:
    """Configuration for the model."""
    def __init__(
        self,
        input_dim: int = 768,
        hidden_dim: int = 512,
        output_dim: int = 10,
        dropout: float = 0.1,
        num_layers: int = 2,
    ):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.dropout = dropout
        self.num_layers = num_layers


class ModelOutput(NamedTuple):
    """Typed output container for model forward pass."""
    logits: torch.Tensor
    hidden_states: torch.Tensor
    attention_weights: torch.Tensor | None = None


class MyModel(nn.Module):
    """Example model following best practices."""
    
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        # Define layers
        self.input_proj = nn.Linear(config.input_dim, config.hidden_dim)
        self.layers = nn.ModuleList([
            nn.Linear(config.hidden_dim, config.hidden_dim)
            for _ in range(config.num_layers)
        ])
        self.output_proj = nn.Linear(config.hidden_dim, config.output_dim)
        self.dropout = nn.Dropout(config.dropout)
        self.layer_norm = nn.LayerNorm(config.hidden_dim)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize model weights with appropriate schemes."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)
    
    @property
    def device(self) -> torch.device:
        """Get the device of model parameters."""
        return next(self.parameters()).device
    
    def forward(self, x: torch.Tensor) -> ModelOutput:
        """Forward pass with typed output."""
        # Input projection
        hidden = self.input_proj(x)
        hidden = self.layer_norm(hidden)
        hidden = F.gelu(hidden)
        
        # Process through layers
        for layer in self.layers:
            residual = hidden
            hidden = layer(hidden)
            hidden = self.dropout(hidden)
            hidden = F.gelu(hidden)
            hidden = hidden + residual  # Residual connection
        
        # Output projection
        logits = self.output_proj(hidden)
        
        return ModelOutput(
            logits=logits,
            hidden_states=hidden,
            attention_weights=None,
        )

Device Management

Device Property Pattern

class DeviceAwareModule(nn.Module):
    """Module with device awareness."""
    
    @property
    def device(self) -> torch.device:
        """Infer device from model parameters."""
        return next(self.parameters()).device
    
    @property
    def dtype(self) -> torch.dtype:
        """Infer dtype from model parameters."""
        return next(self.parameters()).dtype

Training Script Device Setup

def get_device() -> torch.device:
    """Get the best available device."""
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using CUDA: {torch.cuda.get_device_name(0)}")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        print("Using Apple MPS")
    else:
        device = torch.device("cpu")
        print("Using CPU")
    return device


# Usage
device = get_device()
model = MyModel(config).to(device)

Training Loop

Standard Training Epoch

def train_epoch(
    model: nn.Module,
    dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
    max_grad_norm: float | None = 1.0,
) -> dict[str, float]:
    """Train for one epoch."""
    model.train()
    total_loss = 0.0
    num_batches = 0
    
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        # Move to device
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.logits, targets)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping (optional but recommended)
        if max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        
        # Update weights
        optimizer.step()
        
        # Accumulate metrics (use .item() to prevent memory leak)
        total_loss += loss.item()
        num_batches += 1
    
    return {
        "train_loss": total_loss / num_batches,
    }

Evaluation Function

@torch.no_grad()
def evaluate(
    model: nn.Module,
    dataloader: torch.utils.data.DataLoader,
    criterion: nn.Module,
    device: torch.device,
) -> dict[str, float]:
    """Evaluate the model."""
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, targets in dataloader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs.logits, targets)
        
        total_loss += loss.item()
        
        # Calculate accuracy
        predictions = outputs.logits.argmax(dim=-1)
        correct += (predictions == targets).sum().item()
        total += targets.size(0)
    
    return {
        "eval_loss": total_loss / len(dataloader),
        "accuracy": correct / total,
    }

Data Loading

Custom Dataset Pattern

from torch.utils.data import Dataset, DataLoader


class CustomDataset(Dataset):
    """Example custom dataset."""
    
    def __init__(self, data: list, transform=None):
        self.data = data
        self.transform = transform
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        item = self.data[idx]
        features = torch.tensor(item["features"], dtype=torch.float32)
        label = torch.tensor(item["label"], dtype=torch.long)
        
        if self.transform:
            features = self.transform(features)
        
        return features, label

DataLoader Helper

def get_dataloader(
    dataset: Dataset,
    batch_size: int = 32,
    shuffle: bool = True,
    num_workers: int = 4,
    pin_memory: bool = True,
    drop_last: bool = False,
) -> DataLoader:
    """Create a DataLoader with sensible defaults."""
    # Adjust for Windows/macOS compatibility
    if num_workers > 0:
        import platform
        if platform.system() == "Windows":
            num_workers = 0  # Windows has issues with multiprocessing
    
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory and torch.cuda.is_available(),
        drop_last=drop_last,
        persistent_workers=num_workers > 0,
    )

Checkpointing

Save Checkpoint

def save_checkpoint(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    loss: float,
    path: str,
    scheduler: torch.optim.lr_scheduler._LRScheduler | None = None,
    **kwargs,
) -> None:
    """Save a training checkpoint."""
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss,
    }
    
    if scheduler is not None:
        checkpoint["scheduler_state_dict"] = scheduler.state_dict()
    
    # Add any additional data
    checkpoint.update(kwargs)
    
    torch.save(checkpoint, path)
    print(f"Checkpoint saved to {path}")

Load Checkpoint

def load_checkpoint(
    path: str,
    model: nn.Module,
    optimizer: torch.optim.Optimizer | None = None,
    scheduler: torch.optim.lr_scheduler._LRScheduler | None = None,
    device: torch.device | None = None,
) -> dict:
    """Load a training checkpoint."""
    # Use weights_only=True for security (prevents arbitrary code execution)
    checkpoint = torch.load(path, map_location=device, weights_only=True)
    
    model.load_state_dict(checkpoint["model_state_dict"])
    
    if optimizer is not None and "optimizer_state_dict" in checkpoint:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    
    if scheduler is not None and "scheduler_state_dict" in checkpoint:
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    
    print(f"Loaded checkpoint from {path} (epoch {checkpoint.get('epoch', 'unknown')})")
    return checkpoint

Reproducibility

Set Seed Function

import random
import numpy as np
import torch


def set_seed(seed: int = 42, deterministic: bool = True) -> None:
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU
    
    if deterministic:
        # Makes operations deterministic but may reduce performance
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        # Better performance but non-deterministic
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True


# Usage at start of training script
set_seed(42)

Common Gotchas

1. In-Place Operations Breaking Autograd

# BAD - In-place operation can break autograd graph
def bad_forward(x):
    x += 1  # In-place modification
    return x * 2

# GOOD - Create new tensor
def good_forward(x):
    x = x + 1  # Creates new tensor
    return x * 2

2. Memory Leaks from Not Detaching

# BAD - Keeps computation graph in memory
losses = []
for batch in dataloader:
    loss = model(batch)
    losses.append(loss)  # Holds entire graph!

# GOOD - Detach with .item() for scalars
losses = []
for batch in dataloader:
    loss = model(batch)
    losses.append(loss.item())  # Just the number, no graph

# GOOD - Detach for non-scalar tensors
features = []
for batch in dataloader:
    feat = model.encode(batch)
    features.append(feat.detach().cpu())  # Detached, moved to CPU

3. Mixed Precision Training

from torch.cuda.amp import autocast, GradScaler


def train_with_amp(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
) -> float:
    """Training with Automatic Mixed Precision."""
    model.train()
    scaler = GradScaler()
    total_loss = 0.0
    
    for inputs, targets in dataloader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass with autocast
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs.logits, targets)
        
        # Backward pass with scaled gradients
        scaler.scale(loss).backward()
        
        # Unscale before clipping
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        # Step with scaler
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

4. Forgetting model.eval() and model.train()

# BAD - Dropout/BatchNorm behave incorrectly
def evaluate_bad(model, dataloader):
    # model.train() is still active!
    for batch in dataloader:
        output = model(batch)

# GOOD - Always set mode explicitly
def evaluate_good(model, dataloader):
    model.eval()  # Disable dropout, use running stats for BatchNorm
    with torch.no_grad():
        for batch in dataloader:
            output = model(batch)
    model.train()  # Restore training mode if continuing

5. Proper Tensor Creation on Device

# BAD - Creates on CPU then moves (slow)
tensor = torch.zeros(100, 100).to(device)

# GOOD - Create directly on device
tensor = torch.zeros(100, 100, device=device)

# GOOD - Create with same device/dtype as reference
tensor = torch.zeros_like(reference_tensor)
tensor = torch.empty(100, 100, device=model.device, dtype=model.dtype)

Quick Reference Checklist

  • Use nn.Module with proper __init__ and forward
  • Initialize weights with _init_weights method
  • Use @property device for device inference
  • Always use .item() when logging scalar losses
  • Use @torch.no_grad() decorator for evaluation
  • Call model.train() and model.eval() explicitly
  • Use weights_only=True when loading checkpoints
  • Set seeds for reproducibility at script start
  • Avoid in-place operations in forward pass
  • Use mixed precision for faster training on GPU
  • Clip gradients to prevent exploding gradients
  • Use pin_memory=True with CUDA for faster data transfer