TechLead
Intermediate
20 min
Full Guide

Model Training & Optimization

Master batch size, learning rate scheduling, optimizers like Adam and SGD, regularization, and techniques to fix overfitting

The Training Loop Fundamentals

Training a neural network involves repeatedly showing it data in batches, computing the loss, calculating gradients via backpropagation, and updating weights. Getting the hyperparameters right is often the difference between a model that works and one that does not.

Training Terminology:

Epoch: one full pass through all training data. Batch: subset of data processed together. Iteration: one weight update (one batch).

Batch Size: Impact and Trade-offs

Small Batch (8-32)

  • Noisy gradients (acts as regularization)
  • Better generalization
  • Lower GPU memory usage
  • Slower training (more updates per epoch)

Medium Batch (64-256)

  • Good balance of speed and quality
  • Most common in practice
  • Stable training dynamics
  • Standard starting point

Large Batch (512+)

  • More stable gradients
  • Faster per epoch (fewer updates)
  • May converge to sharper minima
  • Needs higher learning rate

Optimizers Compared

import torch
import torch.optim as optim

model = torch.nn.Linear(100, 10)  # simple model for demo

# 1. SGD: Simple, needs careful LR tuning, good with momentum
optimizer_sgd = optim.SGD(
    model.parameters(),
    lr=0.01,
    momentum=0.9,        # accelerates convergence
    weight_decay=1e-4     # L2 regularization
)

# 2. Adam: Adaptive learning rates, great default choice
optimizer_adam = optim.Adam(
    model.parameters(),
    lr=0.001,            # typical: 1e-3 to 1e-4
    betas=(0.9, 0.999),  # momentum parameters
    weight_decay=1e-4
)

# 3. AdamW: Adam with decoupled weight decay (preferred)
optimizer_adamw = optim.AdamW(
    model.parameters(),
    lr=0.001,
    weight_decay=0.01     # decoupled from gradient
)

# 4. Different LR for different layers (fine-tuning)
optimizer_finetune = optim.AdamW([
    {'params': model.weight, 'lr': 1e-5},   # pretrained layers: low LR
    {'params': model.bias,   'lr': 1e-3},   # new layers: higher LR
])

# When to use what:
# Adam/AdamW: Default choice, works well out of the box
# SGD + Momentum: Often reaches better final accuracy (with tuning)
# AdamW: Best for transformers and fine-tuning pretrained models
print("AdamW with lr=1e-3 to 1e-4 is the safest starting point")

Learning Rate Scheduling

import torch.optim.lr_scheduler as lr_scheduler

model = torch.nn.Linear(100, 10)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

# 1. StepLR: Decay by factor every N epochs
scheduler1 = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
# LR: 0.001 -> 0.0005 (epoch 10) -> 0.00025 (epoch 20)

# 2. CosineAnnealingLR: Smooth cosine decay (very popular)
scheduler2 = lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
# Smoothly decays from initial LR to ~0 over 50 epochs

# 3. ReduceLROnPlateau: Reduce when metric stops improving
scheduler3 = lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5
)
# If val_loss doesn't improve for 5 epochs, halve the LR

# 4. OneCycleLR: Start low, ramp up, then decay (fast convergence)
scheduler4 = lr_scheduler.OneCycleLR(
    optimizer, max_lr=0.01, total_steps=1000
)

# 5. Warmup + Cosine Decay (standard for transformers)
def warmup_cosine_schedule(optimizer, warmup_steps, total_steps):
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps  # linear warmup
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return 0.5 * (1 + np.cos(np.pi * progress))  # cosine decay
    return lr_scheduler.LambdaLR(optimizer, lr_lambda)

# Usage in training loop
import numpy as np
scheduler = warmup_cosine_schedule(optimizer, warmup_steps=100, total_steps=1000)
for step in range(1000):
    # ... training step ...
    optimizer.step()
    scheduler.step()

print("Cosine annealing with warmup is the standard for modern training")

Regularization Techniques

import torch
import torch.nn as nn

class RegularizedModel(nn.Module):
    def __init__(self):
        super().__init__()
        # 1. Dropout: randomly zero out neurons during training
        self.layer1 = nn.Linear(784, 256)
        self.dropout1 = nn.Dropout(p=0.5)  # 50% dropout rate

        # 2. Batch Normalization: normalize activations
        self.bn1 = nn.BatchNorm1d(256)

        self.layer2 = nn.Linear(256, 128)
        self.dropout2 = nn.Dropout(p=0.3)  # lighter dropout
        self.bn2 = nn.BatchNorm1d(128)

        self.output = nn.Linear(128, 10)

    def forward(self, x):
        x = self.layer1(x)
        x = self.bn1(x)       # batch norm before activation
        x = torch.relu(x)
        x = self.dropout1(x)  # dropout after activation

        x = self.layer2(x)
        x = self.bn2(x)
        x = torch.relu(x)
        x = self.dropout2(x)

        return self.output(x)

# 3. L2 Regularization (weight decay in optimizer)
model = RegularizedModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

# 4. Early Stopping (manual implementation)
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.should_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
                print(f"Early stopping! No improvement for {self.patience} epochs")
        else:
            self.best_loss = val_loss
            self.counter = 0

early_stopping = EarlyStopping(patience=10)
# In training loop: early_stopping(val_loss)
# if early_stopping.should_stop: break

Diagnosing Overfitting vs Underfitting

Overfitting (High Variance)

Train accuracy: 99%, Val accuracy: 70%

  • Fix: More data, data augmentation
  • Fix: Add dropout, increase weight decay
  • Fix: Reduce model size
  • Fix: Early stopping
  • Fix: Use pretrained models

Underfitting (High Bias)

Train accuracy: 60%, Val accuracy: 58%

  • Fix: Increase model capacity (more layers)
  • Fix: Train longer (more epochs)
  • Fix: Reduce regularization
  • Fix: Better features / data quality
  • Fix: Try a different architecture

Key Takeaways

  • Start with AdamW (lr=1e-3), batch_size=64, and cosine annealing schedule
  • Use dropout (0.1-0.5) and weight decay (0.01) as default regularization
  • Cosine annealing with warmup is the standard LR schedule for transformers
  • If train_acc >> val_acc, you are overfitting; add regularization or more data
  • Early stopping prevents wasted compute and overfitting simultaneously

Continue Learning