Noise vs. Clean Data: What Should a Diffusion Model Learn? (Part 1)
By Alexander (Hongyi) Huang
Overview: This blog post explains the empirical distinction between two seemingly equivalent training objectives. This is part one of a two-part series; here I cover experimental results and interpretations, and Part 2 dives into information theoretic perspectives.
At high dimension, clean data prediction outperforms noise prediction. At low dimension, noise prediction outperforms clean data prediction.
Introduction
The landmark DDPM paper established -prediction as the standard training objective for diffusion models. Recently, the "Back to Basics" paper by Kaiming He revisited this choice and argued that directly predicting clean data (-prediction) is a more efficient and natural alternative.
At first glance, this debate seems puzzling. Predicting noise and predicting clean data are equivalent in the sense that knowing one uniquely determines the other via .
If both objectives contain the same information, why do they lead to such different training behavior in practice?
This post argues that the difference is due to the distributional simplicity and dimensional tractability trade-off.
- -prediction
- Benefits from distributional simplicity (Gaussian targets are easy to regress, so it outperforms at low dimension)
- Suffers from signal spreading across untractable high dimensions (Gaussian noise has the same dimension as ambient space, so it underperforms at high dimension).
- -prediction
- Benefits from signal compactness (target lies on a tractable low dimensional manifold, so it outperforms at high dimension)
- Suffers from manifold complexity (target can be jagged/discontinuous and consequently minimizing the MSE can result in off-manifold predictions, so it underperforms at low dimension).
Setup: Two Equivalent Objectives
Diffusion models learn by first gradually adding noise to clean data (the forward process), then training a neural network to reverse this corruption and recover the original data (the backward process).
Diffusion model training process. The model learns to reverse the forward process that iteratively adds noise to clean data.
In -prediction, the network is trained to predict the clean data:
In -prediction, the network predicts the noise:
Although these losses are equivalent under a linear change of variables, their optimization objectives are different.
Why Does x₀-Prediction Outperform at High Dimensions?
Let's consider the simplest example of one point in 32768 dimensional space. We train a 5-layer diffusion MLP with hidden dimension of 256 on this dataset of only one point by predicting clean data and noise respectively. The loss curve of the -prediction model doesn't decrease, but the loss curve of the -prediction model instantly drops to almost 0. Therefore, the -prediction model immediately learns the underlying distribution, but -prediction model fails to learn anything.
Loss comparison between -prediction and -prediction. -prediction quickly learns the distribution, but -prediction fails to learn anything.
When sampling points with low noise level in training, this noisy point would be very close to the clean data point. Therefore, the -prediction model can use these low-noise points to quickly learn the distribution of the clean data. Unlike clean data, the noise is always Gaussian distributed in the 32768-dimensional space, so we don't have the trivial solution of consistently predicting the same one point as in -prediction models.
The figure below shows the result for a more general dataset of stars. Unsurprisingly, -prediction model fails to learn anything, but the -prediction model still performs well at high-dimensions.
-prediction model reverse sampling trajectory. -prediction model learns the STAR dataset embedded in 512 dimensional space.
-prediction model reverse sampling trajectory. -prediction model fails to learn the STAR dataset embedded in 512 dimensional space.
Interestingly, the loss curve of -prediction model decreases rapidly at the beginning and plateaus soon after. Similar to the one-point dataset, the -prediction model learns to accurately predict the clean data from low-noise samples very quickly. However, predicting clean data from very noisy samples is an ill-conditioned problem, so the -prediction model fails to optimize efficiently at later stages.
Loss comparison for STAR dataset embedded in 512 dimensional space. -prediction model fails to learn much, -prediction model learns quickly first then plateaus.
Taking a step back, what is the fundamental difference between noise and clean data? Clean datas are points, but noises are directions. Points live on a low-dimensional surface, but directions live in the surrounding high-dimensional space. Because of their difference in dimensionality, points are constrained but directions change freely in space. This intrinsic difference in degree of freedom causes the performance gap between these two seemingly equivalent formulations.
Why Does x₀-Prediction Underperform at Low Dimensions?
At low dimensionality, -prediction performs badly because of the complex and non-convex distribution of clean data. It can be mathematically shown that the -prediction model minimizing the MSE loss () predicts the average position of clean data conditioned on the current noisy input (). Therefore, the -prediction model predicts clean data off the manifold at non-convex "V-shape" locations. As seen in the dinosaur example, the sampled points travel towards the "mean" of the clean data and get stuck in some non-convex "V-shape" as in the area between the jaw and hands.
-prediction model reverse sampling trajectory. -prediction model fails to learn the DINO dataset embedded in 2-dimensional space accurately at concave locations (hands and jaws).
In contrast, -prediction performs well in low dimensions. The noise follows a simple, smooth Gaussian distribution that does not depend on the complexity of the data distribution. Instead of reconstructing intricate geometric structures, the model only needs to regress a well-behaved target with uniform variance in all directions.
-prediction model reverse sampling trajectory. -prediction model learns the DINO dataset embedded in 2-dimensional space accurately.
Because the noise does not suffer from the curse of dimensionality as before, the diffusion model can learn useful information at all noise levels as well. This is confirmed by experimental data because the loss curve continues to decrease even after long epochs.
Loss curve of -prediction model trained on DINO dataset in 2-dimensional space. The loss keeps decreasing after 5000 epochs.
Conclusion
The choice between predicting noise or predicting clean data reflects a fundamental trade-off between distributional simplicity and dimensional tractability. Noise is a regular Gaussian distribution but untraceable when the embedding space is big. Clean data is tractable even when the embedding space is high but could be irregular. As the embedding dimension increases, the learning bottleneck shifts from the complexity of the data distribution to the intractability of high-dimensional space.
One interesting question that we can ask is: how to balance this trade-off when training models at high dimensions? Perhaps we can weight predicting clean data and error to learn useful information at both high and low noise levels.
Special thanks to Mason Wang, Gino Chiaranaipanich, Matthew Noto, and Iris Cai for reviewing the post.
Appendix: Experimental Details
We embed the 2-d synthetic dataset of stars and dinosaur patterns, each consisting of around 2000 points into a higher dimensional space (2, 8, 16, 32, 128, and 512) with a projection matrix P. Then, we train a -prediction and a -prediction diffusion model of 5-layer MLP with hidden dimension of 256 using the exact same computational budget and initialization.
From these figures, it is apparent that -prediction underperforms -prediction in low-dimensional settings. However, the opposite holds in high dimensions: -prediction significantly outperforms -prediction under the same compute budget. In fact, -prediction cannot learn any meaningful structure at extremely high dimensions, even with maximal compute.
Reverse sampled output with -prediction and -prediction models. Comparison of -prediction vs. -prediction reverse sampling in low-dimensional and high-dimensional embeddings of DINO dataset. (Clean data is shown on the left.)
Reverse sampled output with -prediction and -prediction models. Comparison of -prediction vs. -prediction reverse sampling in low-dimensional and high-dimensional embeddings of STAR dataset. (Clean data is shown on the left.)
The reverse sampled ouput of -prediction model after 100k epochs The -prediction model fails to learn anything.
The loss curve of -prediction model training on STAR dataset embedded in 512 dimensional space. training loss curve of -prediction plateaus at 0.8 after 100k epochs.