PROGRESSIVE DISTILLATION FOR FAST SAMPLING OF DIFFUSION MODELS

https://readpaper.com/pdf-annotate/note?pdfId=4667185955594059777

核心问题是什么?

DDPM的生成需要T次采样,T通常很大,生成速度太慢。

相关工作(Chapter 6)

perform distillation of DDIM teacher models into one-step student models

  • Eric Luhman and Troy Luhman. Knowledge distillation in iterative generative models for improved sampling speed. arXiv preprint arXiv:2101.02388, 2021.

few-step sampling, as was the probability flow sampler

  • DDIM (Song et al., 2021a) was originally shown to be effective for few-step sampling, as was the probability flow sampler (Song et al., 2021c).
  • Jolicoeur-Martineau et al. (2021) study fast SDE integrators for reverse diffusion processes
  • Tzen & Raginsky (2019b) study unbiased samplers which may be useful for fast, high quality sampling as well.

Other work on fast sampling can be viewed as manual or automated methods to adjust samplers or diffusion processes for fast generation.

  • Nichol & Dhariwal (2021); Kong & Ping (2021) describe methods to adjust a discrete time diffusion model trained on many timesteps into models that can sample in few timesteps.
  • Watson et al. (2021) describe a dynamic programming algorithm to reduce the number of timesteps for a diffusion model in a way that is optimal for log likelihood.
  • Chen et al. (2021); Saharia et al. (2021); Ho et al. (2021) train diffusion models over continuous noise levels and tune samplers post training by adjusting the noise levels of a few-step discrete time reverse diffusion process.
  • Their method is effective in highly conditioned settings such as text-to-speech and image super-resolution. San-Roman et al. (2021) train a new network to estimate the noise level of noisy data and show how to use this estimate to speed up sampling.

Alternative specifications of the diffusion model can also lend themselves to fast sampling,

  • modified forward and reverse processes (Nachmani et al., 2021; Lam et al., 2021)
  • training diffusion models in latent space (Vahdat et al., 2021).

核心贡献是什么?

加速DDPM的生成过程。

  1. 提出了新的扩散模型参数化方式,在使用少量采样步骤时可以提供更高的稳定性。
  2. 提出了一种知识蒸馏的方法,可以把更高的迭代次数优化为更低的迭代次数。

大致方法是什么?

d2a026851eebdc98f0b12efdc0b53a2c_1_Figure_1.png

  1. Distill a deterministic ODE sampler to the same model architecture.
  2. At each stage, a “student” model is learned to distill two adjacent sampling steps of the “teacher” model to one sampling step.
  3. At next stage, the “student” model from previous stage will serve as the new “teacher” model.

✅ 假设有一个 solver,可以根据$x_t$ 预测$x_{t-1}$.
✅ 调用两次 solver,可以从 $x_t$ 得到$x_{t-2}$,学习这个过程,可以直接得到 2 step 的 solver.
✅ 前一个 solver 称为 teacher,后一个称为 student.
✅ student 成为新的 teacher,训练新的 student.

有效性

On standard image generation benchmarks like CIFAR-10, ImageNet, and LSUN, we start out with state-of-the-art samplers taking as many as 8192 steps, and are able to distill down to models taking as few as 4 steps without losing much perceptual quality; achieving, for example, a FID of 3.0 on CIFAR-10 in 4 steps.

缺陷

局限改进点
In the current work we limited ourselves to setups where the student model has the same architecture and number of parameters as the teacher model:in future work we hope to relax this constraint and explore settings where the student model is smaller, potentially enabling further gains in test time computational requirements.
In addition, we hope to move past the generation of images and also explore progressive distillation of diffusion models for different data modalities such as e.g. audio (Chen et al., 2021).
In addition to the proposed distillation procedure, some of our progress was realized through different parameterizations of the diffusion model and its training loss. We expect to see more progress in this direction as the community further explores this model class.

验证

启发

  1. The resulting target value $\tilde x(z_t)$ is fully determined given the teacher model and starting point $z_t$, which allows the student model to make a sharp prediction when evaluated at$z_t$. In contrast, the original data point x is not fully determined given $z_t$, since multiple different data points x can produce the same noisy data $z_t$: this means that the original denoising model is predicting a weighted average of possible x values, which produces a blurry prediction.
  2. 对噪声求L2 loss可以看作是加权平均的重建L2 loss,推导过程见公式9。但在distillation过程中,不适合预测噪声,而应该重建。
  3. In practice, the choice of loss weighting also has to take into account how αt, σt are sampled during training, as this sampling distribution strongly determines the weight the expected loss gives to each signal-to-noise ratio.

遗留问题

  1. 很多细节看不懂。比如预测x与预测噪声的关系。怎么定义weight?parameterizations of the denoising diffusion model?DDIM?
  2. https://caterpillarstudygroup.github.io/ImportantArticles/diffusion-tutorial-part/diffusiontutorialpart1.html