Variational Autoencoder (VAE) Training Process

Overview

A Variational Autoencoder is a generative model that learns to encode data into a latent space and decode it back. Unlike standard autoencoders, VAEs learn a probability distribution over the latent space, enabling generation of new samples.

Architecture Components

Encoder \( q_\phi(z|x) \): Maps input \( x \) to latent distribution parameters \( \mu \) and \( \sigma \)
Decoder \( p_\theta(x|z) \): Maps latent vector \( z \) back to reconstruction \( \hat{x} \)
Prior \( p(z) \): Typically \( \mathcal{N}(0, I) \) — standard normal distribution

Training Process

Step 1: Forward Pass — Encoding

  1. Input: Data sample \( x \) (e.g., image, sequence)
  2. Encode to distribution parameters:
    \[ \mu, \log\sigma^2 = \text{Encoder}_\phi(x) \]

    The encoder outputs mean \( \mu \) and log-variance \( \log\sigma^2 \) (log-variance for numerical stability).

  3. Reparameterization Trick: Sample \( z \) from \( q_\phi(z|x) \)
    \[ z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) \]

    This allows gradients to flow through the sampling operation. \( \epsilon \) is sampled from a standard normal, then scaled and shifted.

Step 2: Forward Pass — Decoding

  1. Decode latent vector:
    \[ \hat{x} = \text{Decoder}_\theta(z) \]

    The decoder reconstructs the input from the latent representation.

Step 3: Loss Computation (ELBO)

The loss function is the negative Evidence Lower Bound (ELBO), which consists of two terms:

\[ \mathcal{L}(\phi, \theta; x) = \underbrace{\text{Reconstruction Loss}}_{\text{forces } \hat{x} \approx x} + \underbrace{\beta \cdot \text{KL Divergence}}_{\text{regularizes latent space}} \]

Term 1: Reconstruction Loss

Measures how well the decoder reconstructs the input:

\[ \mathcal{L}_{\text{recon}} = -\mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] \]

Common choices:

  • Binary data: Binary cross-entropy (BCE)
  • Continuous data: Mean squared error (MSE) or Gaussian negative log-likelihood
\[ \mathcal{L}_{\text{recon}} = \text{MSE}(x, \hat{x}) \quad \text{or} \quad \text{BCE}(x, \hat{x}) \]

Term 2: KL Divergence

Regularizes the latent space to match the prior distribution:

\[ \mathcal{L}_{\text{KL}} = D_{KL}(q_\phi(z|x) \| p(z)) \]

For Gaussian encoder \( q_\phi(z|x) = \mathcal{N}(\mu, \sigma^2 I) \) and standard normal prior \( p(z) = \mathcal{N}(0, I) \):

\[ D_{KL} = \frac{1}{2} \sum_{j=1}^{J} \left( \mu_j^2 + \sigma_j^2 - \log\sigma_j^2 - 1 \right) \]

where \( J \) is the dimensionality of the latent space.

Beta-VAE Weighting

The \( \beta \) parameter balances reconstruction quality vs. latent space structure:

  • \( \beta = 1 \): Standard VAE
  • \( \beta > 1 \): Stronger regularization, more disentangled representations
  • \( \beta < 1 \): Better reconstructions, less structured latent space

Step 4: Backpropagation

  1. Compute gradients:
    \[ \frac{\partial \mathcal{L}}{\partial \theta}, \quad \frac{\partial \mathcal{L}}{\partial \phi} \]
  2. Update parameters:
    \[ \theta \leftarrow \theta - \alpha \cdot \frac{\partial \mathcal{L}}{\partial \theta} \] \[ \phi \leftarrow \phi - \alpha \cdot \frac{\partial \mathcal{L}}{\partial \phi} \]

    where \( \alpha \) is the learning rate.

Training Loop Summary

for epoch in epochs:
    for batch x in dataloader:
        # 1. Encode
        μ, log_σ² = encoder(x)
        σ = exp(0.5 * log_σ²)
        
        # 2. Reparameterization trick
        ε ~ N(0, I)
        z = μ + σ ⊙ ε
        
        # 3. Decode
        x̂ = decoder(z)
        
        # 4. Compute loss
        recon_loss = reconstruction_loss(x, x̂)
        kl_loss = -0.5 * sum(1 + log_σ² - μ² - σ²)
        loss = recon_loss + β * kl_loss
        
        # 5. Backprop and update
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Key Insights

  • Reparameterization trick is crucial: Without it, you can't backpropagate through stochastic sampling
  • KL divergence acts as a regularizer: Prevents the encoder from learning arbitrary encodings
  • Trade-off between reconstruction and regularization: Controlled by \( \beta \)
  • Latent space becomes continuous and structured: Enabling smooth interpolation and sampling

Common Issues & Solutions

Problem: KL divergence collapses to zero (posterior collapse)
Solution: KL annealing — gradually increase \( \beta \) during training
Problem: Blurry reconstructions
Solution: Use more powerful decoder, reduce \( \beta \), or use perceptual losses
Problem: Mode collapse in generation
Solution: Ensure sufficient latent dimensionality, tune \( \beta \)

References

  • Kingma & Welling (2013): "Auto-Encoding Variational Bayes"
  • Higgins et al. (2017): "β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework"
← Back to Model Training