梯度下降详解
What is Gradient Descent
Gradient Descent is the most fundamental optimization algorithm in machine learning and deep learning. The intuition is beautifully simple: imagine you are standing on a mountain, surrounded by thick fog so you cannot see the landscape. However, you can feel the slope of the ground beneath your feet. To reach the valley (the minimum of the loss function), you take a small step in the direction of the steepest downhill slope at your current position, then repeat. This is the essence of gradient descent -- iteratively computing the gradient (directional derivative) of the loss function with respect to the parameters, and updating the parameters in the opposite direction to gradually minimize the loss.
Virtually all neural network training relies on gradient descent and its variants. Understanding the math behind gradients, the differences between optimizer variants, and learning rate scheduling strategies is essential for any deep learning practitioner. This guide covers everything from mathematical foundations to PyTorch implementation.
Mathematical Foundation
The Gradient
The gradient is a vector of partial derivatives of a multivariate function with respect to each variable. It points in the direction of the steepest increase of the function. For a function f(x1, x2, ..., xn), the gradient is defined as:
The gradient points in the direction of steepest ascent. The negative gradient points in the direction of steepest descent -- and that is exactly the direction gradient descent follows to minimize the loss function.
The Update Rule
The core update formula of gradient descent is elegantly simple:
Here, θ represents the model parameters, α is the learning rate, J(θ) is the loss function, and ∇J(θ) is the gradient of the loss with respect to the parameters. Each iteration updates the parameters by subtracting the gradient scaled by the learning rate.
Impact of Learning Rate α
The learning rate is the single most critical hyperparameter in gradient descent:
Types of Gradient Descent
Batch Gradient Descent
Uses the entire training dataset to compute the gradient at each step. The gradient estimate is accurate and updates are stable, but the computational cost per step is very high for large datasets. It cannot perform online learning and may get stuck at saddle points.
Stochastic Gradient Descent (SGD)
Uses a single sample to compute the gradient at each step. Updates are extremely frequent, which helps escape local minima and supports online learning. However, gradient estimates are noisy, the loss curve oscillates heavily, and convergence is erratic.
Mini-batch Gradient Descent
The standard approach in practice. Uses a small batch (typically 32-256 samples) to compute the gradient. It combines the stability of batch methods with the efficiency of stochastic methods and takes full advantage of GPU parallelism. This is the default training mode in PyTorch and TensorFlow.
Comparison Table
| Method | Speed | Stability | Memory | Best For |
|---|---|---|---|---|
| Batch GD | Slow (high per-step cost) | High (accurate gradient) | High (all data in memory) | Small datasets, convex problems |
| Stochastic GD | Fast (low per-step cost) | Low (noisy) | Low (single sample) | Online learning, streaming data |
| Mini-batch GD | Optimal (GPU parallel) | Medium (balanced) | Medium (one batch) | Standard deep learning |
Optimizers Explained
Vanilla SGD suffers from slow convergence and oscillation issues. Researchers have developed improved optimizers that introduce momentum, adaptive learning rates, and other mechanisms to accelerate and stabilize training. Here are the most widely used optimizers.
SGD with Momentum
Momentum borrows from physics: the update depends not only on the current gradient but also accumulates previous update directions. This accelerates movement in consistent gradient directions and dampens oscillation, similar to a ball rolling downhill gaining speed. The momentum coefficient β is typically set to 0.9.
RMSprop (Root Mean Square Propagation)
RMSprop maintains an exponential moving average of squared gradients for each parameter, adaptively scaling the learning rate. Parameters with large gradients get a smaller effective learning rate; parameters with small gradients get a larger one. This solves the monotonically decreasing learning rate problem of Adagrad.
Adam (Adaptive Moment Estimation)
Adam is the most popular optimizer today, combining the benefits of Momentum (first moment estimate) and RMSprop (second moment estimate). It maintains both the mean (momentum) and variance (adaptive rate) of gradients for each parameter, with bias correction to eliminate initialization bias. The defaults (β₁=0.9, β₂=0.999, ε=1e-8) work well in most cases.
AdamW (Decoupled Weight Decay)
AdamW fixes Adam's incorrect implementation of L2 regularization. In standard Adam, weight decay is added to the gradient before adaptive scaling, weakening the regularization effect. AdamW decouples weight decay from the gradient update, applying it directly to the parameters for better regularization. AdamW has become the standard choice for training Transformers and large language models.
Optimizer Comparison
| Optimizer | Adaptive LR | Momentum | Weight Decay | Typical Use |
|---|---|---|---|---|
| SGD | No | No | L2 | Convex optimization baseline |
| SGD + Momentum | No | Yes | L2 | CNN training (ResNet, etc.) |
| RMSprop | Yes | No | L2 | RNN / non-stationary objectives |
| Adam | Yes | Yes | L2 (coupled) | General-purpose default |
| AdamW | Yes | Yes | Decoupled | Transformer / LLM training |
Learning Rate Strategies
A fixed learning rate is rarely optimal. Early in training, a larger learning rate enables fast exploration; later, a smaller learning rate allows fine-tuning. Here are the most common learning rate scheduling strategies.
Fixed Learning Rate (Constant)
The simplest strategy: use the same learning rate throughout training. Suitable for small models and simple tasks, but rarely optimal for complex problems.
Step Decay
Multiply the learning rate by a decay factor (e.g., 0.1) every fixed number of epochs. Simple and intuitive, widely used in CNN training (e.g., ResNet decays lr at epochs 30, 60, 90).
Cosine Annealing
The learning rate follows a cosine curve from the initial value smoothly down to a minimum (near 0). The deceleration is gradual in the later stages of training. This is one of the most popular scheduling strategies today.
Warmup
Start with a very small learning rate and linearly increase it to the target learning rate over the first few epochs, then begin decaying. Warmup prevents unstable gradient updates from randomly initialized parameters early in training. Transformer training almost always uses warmup -- skipping it often causes training to completely fail.
OneCycleLR
A super-convergence strategy: the learning rate rises from a small value to a maximum, then decays back to a very small value, all within one training cycle. Proposed by Leslie Smith, it allows using learning rates up to 10x larger than conventional methods, significantly speeding up convergence.
Gradient Descent from Scratch (Linear Regression)
A complete gradient descent implementation for linear regression using pure NumPy, to help you understand the underlying mechanics:
Common Problems and Pitfalls
Vanishing Gradients
In deep networks, gradients are multiplied layer by layer during backpropagation via the chain rule. If gradients at each layer are less than 1, they shrink exponentially over dozens of layers, approaching zero. Early layers barely update, preventing the network from learning deep features. Common with sigmoid/tanh activations in deep networks. Solutions include: ReLU activation, BatchNorm, residual connections (ResNet), and proper weight initialization (He/Xavier).
Exploding Gradients
The opposite of vanishing gradients: if gradients at each layer are greater than 1, they grow exponentially during backpropagation, causing enormous parameter updates and NaN loss values. Common in RNNs processing long sequences. Solutions include: gradient clipping, using LSTM/GRU instead of vanilla RNN, proper weight initialization, and reducing the learning rate.
Saddle Points
Saddle points have zero gradient but are neither minima nor maxima -- they are minima in some directions and maxima in others. In high-dimensional spaces, saddle points vastly outnumber local minima. The stochasticity of SGD and momentum mechanisms help escape saddle points, which is one reason stochastic methods outperform batch methods in practice.
Local Minima
Non-convex loss functions may have multiple local minima, and gradient descent may converge to a suboptimal one. However, recent research shows that in high-dimensional deep learning, most local minima have loss values very close to the global minimum, making local minima less of a practical concern than previously thought. Saddle points and flat regions are typically the bigger challenge.
Learning Rate Too High / Too Low
Too high: Loss oscillates wildly from the start or immediately shoots up to NaN. The parameter updates overshoot the optimum and may leave the reasonable region of the loss landscape entirely. When this happens, reduce the learning rate by 10x immediately.
Too low: Loss decreases extremely slowly, remaining high even after hundreds of epochs. The model crawls through the search space and may need thousands of epochs to converge. Increase the learning rate by 3-10x, or use a warmup strategy.
Practical Tuning Tips
Related Guides
FAQ
For fast convergence and easy tuning, choose Adam/AdamW. For maximum accuracy with patience for tuning, choose SGD + Momentum + lr scheduling. In NLP/Transformer tasks, AdamW is essentially the only choice. In CV/CNN tasks, SGD + Momentum remains the go-to for competition-winning solutions.
Very important for large models and large batch sizes. Randomly initialized parameters produce unstable gradients early in training -- using a large learning rate immediately can cause the model to diverge. Warmup lets the model "warm up" with a small lr until parameters reach a reasonable range. For Transformers, skipping warmup frequently causes complete training failure.
Not necessarily. Larger batch sizes better utilize GPU parallelism and process more data per unit time, but excessively large batches may hurt generalization (the sharp minima problem). Common range is 32-512; very large batch training requires special lr strategies (e.g., LARS, LAMB). When limited by GPU memory, use gradient accumulation to simulate larger batches.
Debug in this order: 1) Check if the learning rate is appropriate (try 1e-3 first); 2) Verify data and labels are correct (overfit on a tiny dataset as a sanity check); 3) Ensure the loss function matches the task (CrossEntropy for classification, MSE for regression); 4) Check that optimizer.zero_grad() is being called; 5) Verify gradient norms are normal (non-zero, non-NaN); 6) Simplify the model, confirm the basic training pipeline works, then add complexity.
Standard gradient descent requires the loss to be differentiable with respect to the parameters. For non-differentiable operations (argmax, discrete sampling), alternatives exist: Straight-Through Estimator, Gumbel-Softmax reparameterization, REINFORCE policy gradient, etc. In practice, ReLU is not differentiable at 0, but PyTorch defaults its gradient to 0 there, which works fine in training.