Variational Autoencoder (VAE) Training Process
Overview
A Variational Autoencoder is a generative model that learns to encode data into a latent space and decode it back. Unlike standard autoencoders, VAEs learn a probability distribution over the latent space, enabling generation of new samples.
Architecture Components
Training Process
Step 1: Forward Pass — Encoding
- Input: Data sample \( x \) (e.g., image, sequence)
-
Encode to distribution parameters:
\[ \mu, \log\sigma^2 = \text{Encoder}_\phi(x) \]
The encoder outputs mean \( \mu \) and log-variance \( \log\sigma^2 \) (log-variance for numerical stability).
-
Reparameterization Trick: Sample \( z \) from \( q_\phi(z|x) \)
\[ z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) \]
This allows gradients to flow through the sampling operation. \( \epsilon \) is sampled from a standard normal, then scaled and shifted.
Step 2: Forward Pass — Decoding
-
Decode latent vector:
\[ \hat{x} = \text{Decoder}_\theta(z) \]
The decoder reconstructs the input from the latent representation.
Step 3: Loss Computation (ELBO)
The loss function is the negative Evidence Lower Bound (ELBO), which consists of two terms:
Term 1: Reconstruction Loss
Measures how well the decoder reconstructs the input:
Common choices:
- Binary data: Binary cross-entropy (BCE)
- Continuous data: Mean squared error (MSE) or Gaussian negative log-likelihood
Term 2: KL Divergence
Regularizes the latent space to match the prior distribution:
For Gaussian encoder \( q_\phi(z|x) = \mathcal{N}(\mu, \sigma^2 I) \) and standard normal prior \( p(z) = \mathcal{N}(0, I) \):
where \( J \) is the dimensionality of the latent space.
Beta-VAE Weighting
The \( \beta \) parameter balances reconstruction quality vs. latent space structure:
- \( \beta = 1 \): Standard VAE
- \( \beta > 1 \): Stronger regularization, more disentangled representations
- \( \beta < 1 \): Better reconstructions, less structured latent space
Step 4: Backpropagation
-
Compute gradients:
\[ \frac{\partial \mathcal{L}}{\partial \theta}, \quad \frac{\partial \mathcal{L}}{\partial \phi} \]
-
Update parameters:
\[ \theta \leftarrow \theta - \alpha \cdot \frac{\partial \mathcal{L}}{\partial \theta} \] \[ \phi \leftarrow \phi - \alpha \cdot \frac{\partial \mathcal{L}}{\partial \phi} \]
where \( \alpha \) is the learning rate.
Training Loop Summary
for epoch in epochs:
for batch x in dataloader:
# 1. Encode
μ, log_σ² = encoder(x)
σ = exp(0.5 * log_σ²)
# 2. Reparameterization trick
ε ~ N(0, I)
z = μ + σ ⊙ ε
# 3. Decode
x̂ = decoder(z)
# 4. Compute loss
recon_loss = reconstruction_loss(x, x̂)
kl_loss = -0.5 * sum(1 + log_σ² - μ² - σ²)
loss = recon_loss + β * kl_loss
# 5. Backprop and update
loss.backward()
optimizer.step()
optimizer.zero_grad()
Key Insights
- Reparameterization trick is crucial: Without it, you can't backpropagate through stochastic sampling
- KL divergence acts as a regularizer: Prevents the encoder from learning arbitrary encodings
- Trade-off between reconstruction and regularization: Controlled by \( \beta \)
- Latent space becomes continuous and structured: Enabling smooth interpolation and sampling
Common Issues & Solutions
Solution: KL annealing — gradually increase \( \beta \) during training
Solution: Use more powerful decoder, reduce \( \beta \), or use perceptual losses
Solution: Ensure sufficient latent dimensionality, tune \( \beta \)
References
- Kingma & Welling (2013): "Auto-Encoding Variational Bayes"
- Higgins et al. (2017): "β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework"