Intermediate
20 min
Full Guide
Model Training & Optimization
Master batch size, learning rate scheduling, optimizers like Adam and SGD, regularization, and techniques to fix overfitting
The Training Loop Fundamentals
Training a neural network involves repeatedly showing it data in batches, computing the loss, calculating gradients via backpropagation, and updating weights. Getting the hyperparameters right is often the difference between a model that works and one that does not.
Training Terminology:
Epoch: one full pass through all training data. Batch: subset of data processed together. Iteration: one weight update (one batch).
Batch Size: Impact and Trade-offs
Small Batch (8-32)
- Noisy gradients (acts as regularization)
- Better generalization
- Lower GPU memory usage
- Slower training (more updates per epoch)
Medium Batch (64-256)
- Good balance of speed and quality
- Most common in practice
- Stable training dynamics
- Standard starting point
Large Batch (512+)
- More stable gradients
- Faster per epoch (fewer updates)
- May converge to sharper minima
- Needs higher learning rate
Optimizers Compared
import torch
import torch.optim as optim
model = torch.nn.Linear(100, 10) # simple model for demo
# 1. SGD: Simple, needs careful LR tuning, good with momentum
optimizer_sgd = optim.SGD(
model.parameters(),
lr=0.01,
momentum=0.9, # accelerates convergence
weight_decay=1e-4 # L2 regularization
)
# 2. Adam: Adaptive learning rates, great default choice
optimizer_adam = optim.Adam(
model.parameters(),
lr=0.001, # typical: 1e-3 to 1e-4
betas=(0.9, 0.999), # momentum parameters
weight_decay=1e-4
)
# 3. AdamW: Adam with decoupled weight decay (preferred)
optimizer_adamw = optim.AdamW(
model.parameters(),
lr=0.001,
weight_decay=0.01 # decoupled from gradient
)
# 4. Different LR for different layers (fine-tuning)
optimizer_finetune = optim.AdamW([
{'params': model.weight, 'lr': 1e-5}, # pretrained layers: low LR
{'params': model.bias, 'lr': 1e-3}, # new layers: higher LR
])
# When to use what:
# Adam/AdamW: Default choice, works well out of the box
# SGD + Momentum: Often reaches better final accuracy (with tuning)
# AdamW: Best for transformers and fine-tuning pretrained models
print("AdamW with lr=1e-3 to 1e-4 is the safest starting point")
Learning Rate Scheduling
import torch.optim.lr_scheduler as lr_scheduler
model = torch.nn.Linear(100, 10)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
# 1. StepLR: Decay by factor every N epochs
scheduler1 = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
# LR: 0.001 -> 0.0005 (epoch 10) -> 0.00025 (epoch 20)
# 2. CosineAnnealingLR: Smooth cosine decay (very popular)
scheduler2 = lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
# Smoothly decays from initial LR to ~0 over 50 epochs
# 3. ReduceLROnPlateau: Reduce when metric stops improving
scheduler3 = lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5
)
# If val_loss doesn't improve for 5 epochs, halve the LR
# 4. OneCycleLR: Start low, ramp up, then decay (fast convergence)
scheduler4 = lr_scheduler.OneCycleLR(
optimizer, max_lr=0.01, total_steps=1000
)
# 5. Warmup + Cosine Decay (standard for transformers)
def warmup_cosine_schedule(optimizer, warmup_steps, total_steps):
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps # linear warmup
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1 + np.cos(np.pi * progress)) # cosine decay
return lr_scheduler.LambdaLR(optimizer, lr_lambda)
# Usage in training loop
import numpy as np
scheduler = warmup_cosine_schedule(optimizer, warmup_steps=100, total_steps=1000)
for step in range(1000):
# ... training step ...
optimizer.step()
scheduler.step()
print("Cosine annealing with warmup is the standard for modern training")
Regularization Techniques
import torch
import torch.nn as nn
class RegularizedModel(nn.Module):
def __init__(self):
super().__init__()
# 1. Dropout: randomly zero out neurons during training
self.layer1 = nn.Linear(784, 256)
self.dropout1 = nn.Dropout(p=0.5) # 50% dropout rate
# 2. Batch Normalization: normalize activations
self.bn1 = nn.BatchNorm1d(256)
self.layer2 = nn.Linear(256, 128)
self.dropout2 = nn.Dropout(p=0.3) # lighter dropout
self.bn2 = nn.BatchNorm1d(128)
self.output = nn.Linear(128, 10)
def forward(self, x):
x = self.layer1(x)
x = self.bn1(x) # batch norm before activation
x = torch.relu(x)
x = self.dropout1(x) # dropout after activation
x = self.layer2(x)
x = self.bn2(x)
x = torch.relu(x)
x = self.dropout2(x)
return self.output(x)
# 3. L2 Regularization (weight decay in optimizer)
model = RegularizedModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
# 4. Early Stopping (manual implementation)
class EarlyStopping:
def __init__(self, patience=10, min_delta=0.001):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = None
self.should_stop = False
def __call__(self, val_loss):
if self.best_loss is None:
self.best_loss = val_loss
elif val_loss > self.best_loss - self.min_delta:
self.counter += 1
if self.counter >= self.patience:
self.should_stop = True
print(f"Early stopping! No improvement for {self.patience} epochs")
else:
self.best_loss = val_loss
self.counter = 0
early_stopping = EarlyStopping(patience=10)
# In training loop: early_stopping(val_loss)
# if early_stopping.should_stop: break
Diagnosing Overfitting vs Underfitting
Overfitting (High Variance)
Train accuracy: 99%, Val accuracy: 70%
- Fix: More data, data augmentation
- Fix: Add dropout, increase weight decay
- Fix: Reduce model size
- Fix: Early stopping
- Fix: Use pretrained models
Underfitting (High Bias)
Train accuracy: 60%, Val accuracy: 58%
- Fix: Increase model capacity (more layers)
- Fix: Train longer (more epochs)
- Fix: Reduce regularization
- Fix: Better features / data quality
- Fix: Try a different architecture
Key Takeaways
- Start with AdamW (lr=1e-3), batch_size=64, and cosine annealing schedule
- Use dropout (0.1-0.5) and weight decay (0.01) as default regularization
- Cosine annealing with warmup is the standard LR schedule for transformers
- If train_acc >> val_acc, you are overfitting; add regularization or more data
- Early stopping prevents wasted compute and overfitting simultaneously