成像和视觉扩散模型教程

陈士丹利111普渡大学电气与计算机工程学院,西拉斐特,IN 47907。 邮箱: stanchan@purdue.edu

摘要 近年来,生成工具的惊人增长为文本到图像生成和文本到视频生成领域的许多令人兴奋的应用提供了支持。 这些生成工具背后的基本原理是扩散的概念,这是一种特殊的采样机制,克服了先前方法中被认为难以解决的某些缺点。 本教程的目标是讨论扩散模型的基本思想。 本教程的目标受众包括对扩散模型研究或应用这些模型解决其他问题感兴趣的本科生和研究生。

1 基础知识:变分自动编码器 (VAE)

1.1 VAE设置

很久以前,在一个遥远的星系中,我们想要构建一个生成器,从潜在代码生成图像。 最简单(也许是最经典的方法之一)的方法是考虑如下所示的编码器-解码器对。 这称为变分自动编码器 (VAE) [1, 2, 3]

[Uncaptioned image]

自动编码器有一个输入变量 𝐱 和一个潜在变量 𝐳 为了理解这个主题,我们将 𝐱 视为美丽的图像,将 𝐳 视为存在于某些高维空间中的某种向量。

示例 获得图像的潜在表示并不是一件陌生的事情。 回到 JPEG 压缩时代(可以说是一种恐龙),我们使用离散余弦变换 (DCT) 基础 𝝋n 来对图像的底层图像/补丁进行编码。 系数向量𝐳=[z1,,zN]T是通过将面片𝐱投影到基zn=𝝋n,𝐱所跨越的空间上来获得的。 因此,如果您给我们一个图像 𝐱,我们将返回一个系数向量 𝐳 𝐳我们可以进行逆变换来恢复(即解码)图像。 因此,系数向量𝐳是潜在代码。 编码器是DCT变换,解码器是DCT逆变换。[Uncaptioned image]

“变分”这个名字来源于我们使用概率分布来描述𝐱𝐳的因素。 我们并不想采用将𝐱转换为𝐳的确定性程序,而是更感兴趣的是确保分布p(𝐱)可以映射到所需分布p(𝐳),并反向回到p(𝐱) 由于分布设置,我们需要考虑一些分布。

  • p(𝐱)𝐱的分布。 这是永远不知道的。 如果我们知道这一点,我们就会成为亿万富翁。 整个扩散模型家族都是为了找到从p(𝐱)中抽取样本的方法。

  • p(𝐳): 潜变量的分布。 因为我们都很懒,所以让我们把它做成一个零均值单位方差高斯 p(𝐳)=𝒩(0,𝐈)

  • p(𝐳|𝐱): 与编码器相关的条件分布,它告诉我们在给定𝐱𝐳的可能性。 我们无法访问它。 p(𝐳|𝐱)本身不是编码器,但编码器必须做一些事情,使其行为与p(𝐳|𝐱)一致。

  • p(𝐱|𝐳): 与解码器相关的条件分布,它告诉我们在给定𝐳时获得𝐱的后验概率。 同样,我们无法访问它。

上面的四个分布并不算太神秘。 这是一个有点琐碎但有教育意义的例子,可以说明这个想法。 示例 考虑一个随机变量𝐗,它根据高斯混合模型分布,其中潜变量z{1,,K}表示聚类标识,使得pZ(k)=[Z=k]=πk对于k=1,,K 我们假设k=1Kπk=1 然后,如果我们被告知我们只需要查看第 k 个簇,则给定 Z𝐗 的条件分布为 p𝐗|Z(𝐱|k)=𝒩(𝐱|𝝁k,σk2𝐈). 𝐱 的边际分布可以使用总概率定律找到,给我们 p𝐗(𝐱)=k=1Kp𝐗|Z(𝐱|k)pZ(k)=k=1Kπk𝒩(𝐱|𝝁k,σk2𝐈). (1) 因此,如果我们从p𝐗(𝐱)开始,编码器的设计问题是构建一个神奇的编码器,使得对于每个样本𝐱p𝐗(𝐱),潜码将是z{1,,K},其分布为zpZ(k) 为了说明编码器和解码器的工作原理,我们假设均值和方差已知并且是固定的。 否则,我们需要通过 EM 算法来估计均值和方差。 这是可行的,但繁琐的方程将违背本例的目的。 Encoder:我们如何从𝐱获取z 这很简单,因为在编码器中,我们知道 p𝐗(𝐱)pZ(k) 想象一下,您只有两个类 z{1,2} 实际上,您只是对样本 𝐱 应该属于哪里做出二元决定。 有很多方法可以做出二元决策。 如果你喜欢最大后验,你可以检查 pZ|𝐗(1|𝐱)class 2class 1pZ|𝐗(2|𝐱), 这将返回一个简单的决策规则。 您给我们𝐱,我们告诉您z{1,2} 解码器:在解码器端,如果我们得到一个潜在代码 z{1,,K},神奇的解码器只需要返回我们一个从 p𝐗|Z(𝐱|k)=𝒩(𝐱|𝝁k,σk2𝐈) 中抽取的样本 𝐱 不同的 z 将为我们提供 K 混合组件之一。 如果我们有足够的样本,总体分布将遵循高斯混合分布。

像你这样聪明的读者肯定会抱怨:“你的例子太不真实了。”不用担心。 我们明白。 当然,生活比具有已知均值和已知方差的高斯混合模型要困难得多。 但我们意识到的一件事是,如果我们想找到神奇的编码器和解码器,我们必须有一种方法来找到两个条件分布。 不过,他们都是高维生物。 因此,为了让我们说一些更有意义的事情,我们需要强加额外的结构,以便我们可以将概念推广到更困难的问题。

在 VAE 的文献中,人们提出了考虑以下两个代理分布的想法:

  • qϕ(𝐳|𝐱)p(𝐳|𝐱) 的代理。 我们将使其成为高斯分布。 为什么是高斯? 没有特别充分的理由。 也许我们只是普通(又名懒惰)人。

  • p𝜽(𝐱|𝐳)p(𝐱|𝐳) 的代理。 不管你信不信,我们也会把它变成高斯分布。 但是,这个高斯分布的作用与高斯分布 qϕ(𝐳|𝐱) 略有不同。 虽然我们需要 估计 高斯分布 qϕ(𝐳|𝐱) 的均值和方差,但我们不需要为高斯分布 p𝜽(𝐱|𝐳) 估计任何东西。 相反,我们需要一个解码器神经网络将 𝐳 转换为 𝐱 高斯分布 p𝜽(𝐱|𝐳) 将用来告知我们生成的图像 𝐱 的好坏程度。

输入 𝐱 与潜在变量 𝐳 之间的关系,以及条件分布,总结在图 1 中。 有两个节点𝐱𝐳 “正向” 关系由 p(𝐳|𝐱) 指定(并由 qϕ(𝐳|𝐱) 近似),而 “反向” 关系由 p(𝐱|𝐳) 指定(并由 p𝜽(𝐱|𝐳) 近似)。

Refer to caption
图1: 在变分自编码器中,变量 𝐱𝐳 由条件分布 p(𝐱|𝐳)p(𝐳|𝐱) 连接。 为了使事情正常运作,我们分别引入了两个代理分布 p𝜽(𝐱|𝐳)qϕ(𝐳|𝐱)
示例 现在是时候考虑另一个简单的例子了。 假设我们有一个随机变量 𝐱 和一个潜在变量 𝐳 ,这样 𝐱 𝒩(𝐱|μ,σ2), 𝐳 𝒩(𝐳| 0,1). 我们的目标是构建一个 VAE。 (什么?! 这个问题有一个微不足道的解决方案,其中 𝐳=(𝐱μ)/σ𝐱=𝝁+σ𝐳 你是绝对正确的。 但请按照我们的推导来看看VAE框架是否有意义。)[Uncaptioned image] 通过构建 VAE,我们的意思是我们想要构建两个映射“编码”和“解码”。 为了简单起见,我们假设这两个映射都是仿射变换: 𝐳 =encode(𝐱)=a𝐱+b,so thatϕ=[a,b], 𝐱 =decode(𝐳)=c𝐳+d,so that𝜽=[c,d]. 我们太懒了,不想找出联合分布 p(𝐱,𝐳),也不想找出条件分布 p(𝐱|𝐳)p(𝐳|𝐱) 但是我们可以构造代理分布 qϕ(𝐳|𝐱)p𝜽(𝐱|𝐳) 既然我们可以自由地选择qϕp𝜽应该是什么样子,我们考虑以下两个高斯怎么样 qϕ(𝐳|𝐱) =𝒩(𝐳|a𝐱+b,1), p𝜽(𝐱|𝐳) =𝒩(𝐱|c𝐳+d,c). 这两个高斯的选择并不神秘。 For qϕ(𝐳|𝐱): if we are given 𝐱, of course we want the encoder to encode the distribution according to the structure we have chosen. Since the encoder structure is a𝐱+b, the natural choice for qϕ(𝐳|𝐱) is to have the mean a𝐱+b. The variance is chosen as 1 because we know that the encoded sample 𝐳 should be unit-variance. Similarly, for p𝜽(𝐱|𝐳): if we are given 𝐳, the decoder must take the form of c𝐳+d because this is how we setup the decoder. The variance is c which is a parameter we need to figure out. 在继续这个例子之前我们将暂停一会儿。 我们想介绍一种数学工具。

1.2 证据下界

我们如何使用这两个代理分布来实现我们确定编码器和解码器的目标? If we treat ϕ and 𝜽 as optimization variables, then we need an objective function (or the loss function) so that we can optimize ϕ and 𝜽 through training samples. To this end, we need to set up a loss function in terms of ϕ and 𝜽. 我们在这里使用的损失函数称为证据下限 (ELBO) [1] ELBO(𝐱)=def𝔼qϕ(𝐳|𝐱)[logp(𝐱,𝐳)qϕ(𝐳|𝐱)]. (2) 你一定很疑惑地球人怎么能想出这个损失函数!? 让我们看看 ELBO 是什么意思以及它是如何衍生的。

In a nutshell, ELBO is a lower bound for the prior distribution logp(𝐱) because we can show that

logp(𝐱)=some magical steps =𝔼qϕ(𝐳|𝐱)[logp(𝐱,𝐳)qϕ(𝐳|𝐱)]+𝔻KL(qϕ(𝐳|𝐱)p(𝐳|𝐱)) (3)
𝔼qϕ(𝐳|𝐱)[logp(𝐱,𝐳)qϕ(𝐳|𝐱)]
=defELBO(𝐱),

其中不等式源于 KL 散度始终为非负这一事实。 Therefore, ELBO is a valid lower bound for logp(𝐱). Since we never have access to logp(𝐱), if we somehow have access to ELBO and if ELBO is a good lower bound, then we can effectively maximize ELBO to achieve the goal of maximizing logp(𝐱) which is the gold standard. 现在的问题是下限有多好。 As you can see from the equation and also Figure 2, the inequality will become an equality when our proxy qϕ(𝐳|𝐱) can match the true distribution p(𝐳|𝐱) exactly. So, part of the game is to ensure qϕ(𝐳|𝐱) is close to p(𝐳|𝐱).

Refer to caption
图2: Visualization of logp(𝐱). The gap between the two is determined by the KL divergence 𝔻KL(qϕ(𝐳|𝐱)p(𝐳|𝐱)).
方程证明 (3) 这里的全部诀窍是利用我们神奇的代理 qϕ(𝐳|𝐱)p(𝐱) 中四处探查,并推导出界限。 logp(𝐱) =logp(𝐱)×qϕ(𝐳|𝐱)𝑑𝐳=1 multiply 1 =logp(𝐱)some constant wrt 𝐳×qϕ(𝐳|𝐱)distribution in 𝐳𝑑𝐳 move logp(𝐱) into integral =𝔼qϕ(𝐳|𝐱)[logp(𝐱)], (4) 其中最后一个等式是一个有趣的结论,即对于任何随机变量 Z 和一个标量 aa×pZ(z)𝑑z=𝔼[a] 成立。 当然,𝔼[a]=a 看,我们已经获得了 𝔼qϕ(𝐳|𝐱)[] 只需再执行几步即可。 让我们使用贝叶斯定理,它指出 p(𝐱,𝐳)=p(𝐳|𝐱)p(𝐱) 𝔼qϕ(𝐳|𝐱)[logp(𝐱)] =𝔼qϕ(𝐳|𝐱)[logp(𝐱,𝐳)p(𝐳|𝐱)] Bayes Theorem =𝔼qϕ(𝐳|𝐱)[logp(𝐱,𝐳)p(𝐳|𝐱)×qϕ(𝐳|𝐱)qϕ(𝐳|𝐱)] Multiply and divide qϕ(𝐳|𝐱) =𝔼qϕ(𝐳|𝐱)[logp(𝐱,𝐳)qϕ(𝐳|𝐱)]ELBO+𝔼qϕ(𝐳|𝐱)[logqϕ(𝐳|𝐱)p(𝐳|𝐱)]𝔻KL(qϕ(𝐳|𝐱)p(𝐳|𝐱)), (5) 我们认识到第一项正是 ELBO,而第二项正是 KL 散度。 将方程 (5) 与方程 (3) 进行比较,我们知道生活是美好的。

我们现在有ELBO。 但这个 ELBO 仍然不太有用,因为它涉及 p(𝐱,𝐳),而我们无法访问它。 所以,我们还需要做一些事情。 让我们仔细看看 ELBO

ELBO(𝐱) =def𝔼qϕ(𝐳|𝐱)[logp(𝐱,𝐳)qϕ(𝐳|𝐱)] definition
=𝔼qϕ(𝐳|𝐱)[logp(𝐱|𝐳)p(𝐳)qϕ(𝐳|𝐱)] p(𝐱,𝐳)=p(𝐱|𝐳)p(𝐳)
=𝔼qϕ(𝐳|𝐱)[logp(𝐱|𝐳)]+𝔼qϕ(𝐳|𝐱)[logp(𝐳)qϕ(𝐳|𝐱)] split expectation
=𝔼qϕ(𝐳|𝐱)[logp𝜽(𝐱|𝐳)]𝔻KL(qϕ(𝐳|𝐱)p(𝐳)), definition of KL

其中我们暗中用其代理 p𝜽(𝐱|𝐳) 替换了不可访问的 p(𝐱|𝐳) 这是一个漂亮的结果。 我们刚刚展示了一些非常容易理解的东西。 ELBO(𝐱)=𝔼qϕ(𝐳|𝐱)[logp𝜽(𝐱|𝐳)a Gaussian]how good your decoder is𝔻KL(qϕ(𝐳|𝐱)a Gaussianp(𝐳)a Gaussian)how good your encoder is. (6) 方程 (6) 中有两项:

  • 重建 第一项是关于解码器的。 如果我们将潜在的𝐳输入解码器(当然!!),我们希望解码器能够生成良好的图像𝐱 所以,我们想要 最大化 logp𝜽(𝐱|𝐳) 它类似于最大似然,我们想要找到模型参数以最大化观察图像的可能性。 这里的期望是针对样本 𝐳 得出的(以 𝐱 为条件)。 这不足为奇,因为样本 𝐳 用于评估解码器的质量。 它不能是任意的噪声向量,而是有意义的潜在向量。 所以,𝐳 需要从 qϕ(𝐳|𝐱) 中采样。

  • 先前匹配 第二项是编码器的 KL 散度。 我们希望编码器将 𝐱 转换为一个潜在向量 𝐳,使得该潜在向量遵循我们选择的(懒惰)分布 𝒩(0,𝐈) 为了更一般化,我们将 p(𝐳) 写为目标分布。 因为KL是一个距离(当两个分布变得更加不相似时,它会增加),所以我们需要在前面加上一个负号,以便当两个分布变得更加相似时,它会增加。

示例 让我们继续我们的简单高斯示例。 从之前的推导我们知道 qϕ(𝐳|𝐱) =𝒩(𝐳|a𝐱+b,1), p𝜽(𝐱|𝐳) =𝒩(𝐱|c𝐳+d,c). 为了确定 𝜽ϕ,我们需要最小化先验匹配误差并最大化重建项。 对于先验匹配,我们知道 𝔻KL(qϕ(𝐳|𝐱)p(𝐳))=𝔻KL(𝒩(𝐳|a𝐱+b,1)𝒩(𝐳| 0,1)). 由于 𝔼[𝐱]=μVar[𝐱]=σ2,当 a=1σb=μσ 时,KL 散度被最小化,以便 a𝐱+b=𝐱μσ 因此,𝔼[a𝐱+b]=0Var[a𝐱+b]=1 对于重建项,我们知道 𝔼qϕ(𝐳|𝐱)[logp𝜽(𝐱|𝐳)]=𝔼qϕ(𝐳|𝐱)[(c𝐳+dμ)22c2]. 由于 𝔼[𝐳]=0Var[𝐳]=1,因此当 c=σd=μ 时,该项被最大化。 总而言之,编码器和解码器参数是 𝐳 =encode(𝐱)=𝐱μσ, 𝐱 =decode(𝐳)=σ𝐳+μ, 这很容易理解。

重建项和先验匹配项在图 3 中进行了说明。 在这两种情况下,以及在训练过程中,我们假设我们都可以访问 𝐳𝐱,其中 𝐳 需要从 qϕ(𝐳|𝐱) 中采样。 然后,为了重建,我们估计 𝜽 以最大化 p𝜽(𝐱|𝐳) 为了先验匹配,我们找到 ϕ 以最小化 KL 散度。 优化可能具有挑战性,因为如果您更新 ϕ,则分布 qϕ(𝐳|𝐱) 将会改变。

Refer to caption
图3: 解释变分自动编码器的 ELBO 中的重建项和先验匹配项。

1.3训练VAE

现在我们了解了 ELBO 的含义,我们可以讨论如何训练 VAE。 为了训练 VAE,我们需要地面实况对 (𝐱,𝐳) 我们知道如何得到𝐱;它只是数据集中的图像。 但相应地 𝐳 应该是什么?

我们来谈谈编码器 我们知道 𝐳 是从分布 qϕ(𝐳|𝐱) 生成的。 我们也知道 qϕ(𝐳|𝐱) 是一个高斯分布。 假设此高斯分布具有均值 𝝁 和协方差矩阵 σ2𝐈(哈哈! 我们的懒惰又来了! 我们不使用一般的协方差矩阵,而是假设方差相等)。

棘手的部分是如何从输入图像𝐱中确定𝝁σ2 好吧,如果你没有线索,别担心。 欢迎来到原力的黑暗面。 我们构建一个深度神经网络,使得

𝝁 =𝝁ϕneural network(𝐱)
σ2 =σϕ2neural network(𝐱),

因此,样本𝐳()(其中表示训练集中的第个训练样本)可以从高斯分布中采样

𝐳()𝒩(𝐳|𝝁ϕ(𝐱()),σϕ2(𝐱())𝐈)qϕ(𝐳|𝐱()),where 𝝁ϕ,σϕ2 are functions of 𝐱. (7)

该想法总结在图 4 中,我们使用神经网络来估计高斯参数,并从高斯分布中抽取样本。 请注意,𝝁ϕ(𝐱())σϕ2(𝐱())𝐱() 的函数。 因此,对于不同的 𝐱() 我们将有不同的高斯。

Refer to caption
图4: VAE 编码器的实现。 我们使用神经网络来获取图像 𝐱 并估计高斯分布的均值 𝝁ϕ 和方差 σϕ2
备注 对于任何高维高斯分布 𝐱𝒩(𝐱|𝝁,𝚺),采样过程可以通过白噪声的变换来完成 𝐱=𝝁+𝚺12𝐰, (8) 其中 𝐰𝒩(0,𝐈) 半矩阵𝚺12可以通过特征分解或Cholesky分解得到。 对于对角矩阵 𝚺=σ2𝐈,以上公式简化为 𝐱=𝝁+σ𝐰,where𝐰𝒩(0,𝐈). (9)

我们来谈谈解码器 解码器是通过神经网络实现的。 为了符号简单起见,我们将其定义为 decode𝜽,其中 𝜽 表示网络参数。 解码器网络的工作是获取潜在变量𝐳并生成图像𝐱^

𝐱^=decode𝜽(𝐳). (10)

现在让我们再做一个(疯狂的)假设,解码图像 𝐱^ 和地面真实图像 𝐱 之间的误差是高斯的。 (等等,又是高斯?!) 我们假设

(𝐱^𝐱)𝒩(0,σdec2),for some σdec2.

然后,可以得出分布 p𝜽(𝐱|𝐳)

logp𝜽(𝐱|𝐳) =log𝒩(𝐱|decode𝜽(𝐳),σdec2𝐈)
=log1(2πσdec2)Dexp{𝐱decode𝜽(𝐳)22σdec2}
=𝐱decode𝜽(𝐳)22σdec2log(2πσdec2)Dyou can ignore this term, (11)

其中 D𝐱 的尺寸。 该方程表明 ELBO 中似然项的最大化实际上就是解码图像和地面实况之间的 2 损失。 该想法如图 5 所示。

Refer to caption
图5: VAE解码器的实现。 我们使用神经网络获取潜在向量 𝐳 并生成图像 𝐱^ 如果我们假设高斯分布,对数似然将为我们提供一个二次方程。

1.4损失函数

一旦理解了编码器和解码器的结构,损失函数就很容易理解了。 我们通过蒙特卡罗模拟来近似期望:

𝔼qϕ(𝐳|𝐱)[logp𝜽(𝐱|𝐳)]1L=1Llogp𝜽(𝐱|𝐳()),𝐳()qϕ(𝐳|𝐱()),

其中 𝐱() 是训练集中第 个样本, 𝐳()𝐳()qϕ(𝐳|𝐱()) 中采样。 分布 q𝜽qϕ(𝐳|𝐱())=𝒩(𝐳|𝝁ϕ(𝐱()),σϕ2(𝐱())𝐈)

VAE 训练损失 argmaxϕ,𝜽{1L=1Llogp𝜽(𝐱()|𝐳())𝔻KL(qϕ(𝐳|𝐱())p(𝐳))}, (12) 其中 {𝐱()}=1L 是训练数据集中真实的图像, 𝐳() 从公式 (7) 中采样。

KL 散度项中的 𝐳 不依赖于 ,因为我们正在测量两个分布之间的 KL 散度。 这里的变量 𝐳 是一个虚拟变量。

我们需要澄清的最后一件事是 KL 散度。 由于 qϕ(𝐳|𝐱())=𝒩(𝐳|𝝁ϕ(𝐱()),σϕ2(𝐱())𝐈)p(𝐳)=𝒩(0,𝐈),我们实际上正在处理两个高斯分布。 如果你访问维基百科,你会发现两个 d 维高斯分布 𝒩(𝝁0,𝚺0)𝒩(𝝁1,𝚺1) 的 KL 散度为

𝔻KL(𝒩(𝝁0,𝚺0),𝒩(𝝁1,𝚺1))=12(Tr(𝚺11𝚺0)d+(𝝁1𝝁0)T𝚺11(𝝁1𝝁0)+logdet𝚺1det𝚺0). (13)

通过考虑 𝝁0=𝝁ϕ(𝐱())𝚺0=σϕ2(𝐱())𝐈𝝁1=0𝚺1=𝐈,将我们的分布代入公式,我们可以证明 KL 散度具有解析表达式

𝔻KL(qϕ(𝐳|𝐱())p(𝐳))=12((σϕ2(𝐱()))d+𝝁ϕ(𝐱())T𝝁ϕ(𝐱())dlog(σϕ2(𝐱()))), (14)

其中 d 是向量 𝐳 的维度。 因此,整体损失函数公式 (12) 是可微的。 因此,我们可以通过反向传播梯度来端到端训练编码器和解码器。

1.5 使用 VAE 进行推理

对于推理,我们可以简单地将一个潜在向量 𝐳(从 p(𝐳)=𝒩(0,𝐈) 中采样)输入解码器 decode𝜽 并获得图像 𝐱。 就是这样;参见图 6

Refer to caption
图6: 使用 VAE 生成图像就像通过解码器发送潜在噪声代码 𝐳 一样简单。

恭喜! 我们完了。 这就是 VAE 的全部内容。

如果您想阅读更多内容,我们强烈推荐 Kingma 和 Welling [1] 编写的教程。 可以在 [2] 找到较短的教程。 如果您在 Google 中输入 VAE 教程 PyTorch,您将能够找到数百甚至数千个编程教程和视频。

2去噪扩散概率模型(DDPM)

在本节中,我们将讨论 Ho 等人[4]的 DDPM。 如果您对网上成千上万的教程感到困惑,请放心,DDPM 并没有那么复杂。 您只需要了解以下摘要即可:

扩散模型是增量更新,其中整体的组装为我们提供了编码器-解码器结构。 从一种状态到另一种状态的转变是通过降噪器实现的。

为什么要增量? 就像巨轮改变方向一样。 你需要慢慢地将船转向你想要的方向,否则你将失去控制。 同样的原则也适用于你的生活、你的公司人力资源、你的大学管理、你的配偶、你的孩子以及你生活中的任何事情。 “一次弯曲一英寸!” (图片来源:Sergio Goma,他在 Electronic Imaging 2023 上发表了此评论。)

扩散模型的结构如下所示。 它称为变分扩散模型[5] 变分扩散模型具有一系列状态𝐱0,𝐱1,,𝐱T

  • 𝐱0:为原始图像,与VAE中的𝐱相同。

  • 𝐱T:是潜在变量,与VAE中的𝐳相同。 由于我们都很懒,所以我们想要 𝐱T𝒩(0,𝐈)

  • 𝐱1,,𝐱T1:它们是中间状态。 它们也是潜在变量,但它们不是白高斯变量。

变分扩散模型的结构如图 7 所示。 前向和反向路径类似于单步变分自动编码器的路径。 不同之处在于编码器和解码器具有相同的输入输出维度。 所有正向构建块的组装将为我们提供编码器,所有反向构建块的组装将为我们提供解码器。

Refer to caption
图7: 变分扩散模型。 在该模型中,输入图像为𝐱0,白噪声为𝐱T 中间变量(或状态)𝐱1,,𝐱T1是潜在变量。 𝐱t1𝐱t的转换类似于VAE中的前向步骤(编码器),而从𝐱t𝐱t1的转换是类似于 VAE 中的反向步骤(解码器)。 但请注意,这里编码器/解码器的输入维度和输出维度是相同的。

2.1构建块

转换块t转换块由三个状态𝐱t1𝐱t𝐱t+1组成。 如图 8 所示,有两种可能路径到达状态 𝐱t

  • 𝐱t1𝐱t 的前向过渡。 相关的转移分布为 p(𝐱t|𝐱t1) 简单的说,如果你告诉我们 𝐱t1,我们可以根据 p(𝐱t|𝐱t1) 告诉你 𝐱t 然而,就像 VAE 一样,转移分布 p(𝐱t|𝐱t1) 永远无法访问。 但这没关系。 像我们这样懒惰的人只会用高斯分布 qϕ(𝐱t|𝐱t1) 来近似它。 我们将在后面讨论 qϕ 的精确形式,但它只是某个高斯分布。

  • 反向转换从 𝐱t+1𝐱t 再说一遍,我们永远无法知道 p(𝐱t+1|𝐱t),但没关系。 我们只是用另一个高斯分布 p𝜽(𝐱t+1|𝐱t) 来近似真实分布,但它的均值需要由神经网络估计。

Refer to caption
图8: 变分扩散模型的过渡块由三个节点组成。 转移分布 p(𝐱t|𝐱t+1)p(𝐱t|𝐱t1) 无法访问,但我们可以用高斯分布来近似它们。

初始块 变分扩散模型的初始块关注状态𝐱0 由于我们研究的所有问题都是从𝐱0开始的,所以只有从𝐱1𝐱0的反向过渡,而没有从𝐱1开始的过程到𝐱0 因此,我们只需要担心 p(𝐱0|𝐱1) 由于 p(𝐱0|𝐱1) 永远无法访问,我们用高斯分布 p𝜽(𝐱0|𝐱1) 来近似它,其中均值通过神经网络计算。 9 说明了这一点。

Refer to caption
图9: 变分扩散模型的初始块集中于节点𝐱0 由于时间 t=0 之前没有状态,因此我们只有从 𝐱1𝐱0 的反向转换。

最后一个区块 最后一个块重点关注状态𝐱T 请记住, 𝐱T 应该是我们的最终潜在变量,它是高斯白噪声向量。 因为它是最后一个块,所以只有从 𝐱T1𝐱T 的前向过渡,没有诸如 𝐱T+1𝐱T 之类的内容。 前向转移由 qϕ(𝐱T|𝐱T1) 近似,它是一个高斯分布。 10 说明了这一点。

Refer to caption
图10: 变分扩散模型的最后一个块集中于节点𝐱T 由于时间 t=T 之后没有状态,因此我们只有从 𝐱T1𝐱T 的前向转换。

了解转换分布 在我们继续之前,我们需要稍微绕道一下,谈谈转移分布 qϕ(𝐱t|𝐱t1) 我们知道它是高斯分布的。 但我们仍然需要知道它的正式定义,以及这个定义的起源。

转移分布 qϕ(𝐱t|𝐱t1). 在去噪扩散概率模型中,转移分布 qϕ(𝐱t|𝐱t1) 定义为 qϕ(𝐱t|𝐱t1)=def𝒩(𝐱t|αt𝐱t1,(1αt)𝐈). (15)

换句话说,均值为 αt𝐱t1,方差为 1αt 缩放因子αt的选择是为了确保方差大小被保留,使其不会在多次迭代后爆炸和消失。

示例 让我们考虑高斯混合模型 𝐱0p0(𝐱)=π1𝒩(𝐱|μ1,σ12)+π2𝒩(𝐱|μ2,σ22). 给定转移概率,我们知道 𝐱t=αt𝐱t1+(1αt)ϵ,whereϵ𝒩(0,𝐈). 对于混合模型,不难看出𝐱t的概率分布可以通过t=1,2,,T的算法递归计算: pt(𝐱)= π1𝒩(𝐱|αtμ1,t1,αtσ1,t12+(1αt)) + π2𝒩(𝐱|αtμ2,t1,αtσ2,t12+(1αt)), (16) 其中 μ1,t1t1 处的平均值,μ1,0=μ1 是初始平均值。 同样,σ1,t12t1 处的方差,σ1,02=σ12 是初始方差。 在下图中,我们显示了示例,其中 π1=0.3π2=0.7μ1=2μ2=2σ1=0.2、和σ2=1 所有 t 的速率定义为 αt=0.97。我们绘制不同 t 的概率分布函数。 [Uncaptioned image] [Uncaptioned image] [Uncaptioned image] [Uncaptioned image] [Uncaptioned image] [Uncaptioned image] [Uncaptioned image] [Uncaptioned image]
备注 对于那些希望了解我们如何推导出公式 (16) 中混合模型的概率密度的人,我们可以展示一个简单的推导过程。 考虑混合模型 p(𝐱)=k=1Kπk𝒩(𝐱|μk,σk2𝐈)p(𝐱|k). 如果我们考虑一个新的变量 𝐲=α𝐱+1αϵ,其中 ϵ𝒩(0,𝐈),那么 𝐲 的分布可以通过使用全概率定律来推导: p(𝐲) =k=1Kp(𝐲|k)p(k)=k=1Kπkp(𝐲|k). 由于 𝐲|k 是高斯随机变量 𝐱 和另一个高斯随机变量 ϵ 的线性组合,因此和 𝐲 将保持为高斯分布。 平均值是 𝔼[𝐲|k] =α𝔼[𝐱|k]+1α𝔼[ϵ]=αμk Var[𝐲|k] =αVar[𝐱|k]+(1α)Var[ϵ]=ασk2+(1α). 所以,p(𝐲|k)=𝒩(𝐲|αμk,ασk2+(1α)) 这样就完成了推导。

2.2 神奇的标量αt1αt

您可能想知道精灵(去噪扩散的作者)是如何为上述转移概率想出神奇的标量 αt(1αt) 的。 为了揭秘这一点,让我们从两个不相关的标量 ab 开始,我们将转换分布定义为

qϕ(𝐱t|𝐱t1)=𝒩(𝐱t|a𝐱t1,b2𝐈). (17)

这是经验法则: 为什么 αt1αt 我们想要选择 ab 使得当 t 足够大时,𝐱t 的分布将变为 𝒩(0,𝐈) 结果发现答案是a=αb=1α 证明 我们想要显示 a=αb=1α 对于公式 (17) 中所示的分布,等效采样步骤为: 𝐱t=a𝐱t1+bϵt1,whereϵt1𝒩(0,𝐈). (18) 思考一下:如果存在一个随机变量 X𝒩(μ,σ2),从该高斯分布中抽取 X 可以通过定义 X=μ+ση 等效地实现,其中 η𝒩(0,1) 我们可以进行递归来证明 𝐱t =a𝐱t1+bϵt1 =a(a𝐱t2+bϵt2)+bϵt1 (substitute 𝐱t1=a𝐱t2+bϵt2) =a2𝐱t2+abϵt2+bϵt1 (regroup terms ) = =at𝐱0+b[ϵt1+aϵt2+a2ϵt3++at1ϵ0]=def𝐰t. (19) 上面的有限和是独立高斯随机变量的和。 平均向量 𝔼[𝐰t] 仍然为零,因为每个人都有零均值。 协方差矩阵(对于零均值向量)是 Cov[𝐰t]=def𝔼[𝐰t𝐰tT] =b2(Cov(ϵt1)+a2Cov(ϵt2)++(at1)2Cov(ϵ0)) =b2(1+a2+a4++a2(t1))𝐈 =b21a2t11a2𝐈. 正如t,at0对于任何0<a<1 因此,在t=时的极限, limtCov[𝐰t]=b21a2𝐈. 所以,如果我们想要 limtCov[𝐰t]=𝐈(以便 𝐱t 的分布将接近 𝒩(0,𝐈)),那么 b=1a2 现在,如果我们让a=α,那么b=1α 这会给我们 𝐱t=α𝐱t1+1αϵt1. (20) 或等效地,qϕ(𝐱|𝐱t1)=𝒩(𝐱t|α𝐱t1,(1α)𝐈) 如果您更喜欢调度程序,可以将 α 替换为 αt

2.3 分布 qϕ(𝐱t|𝐱0)

通过对神奇标量的理解,我们可以讨论 qϕ(𝐱t|𝐱0) 的分布。 也就是说,如果给定 𝐱0,我们想知道 𝐱t 将如何分配。

条件分布 qϕ(𝐱t|𝐱0) 条件分布 qϕ(𝐱t|𝐱0) 由下式给出: qϕ(𝐱t|𝐱0)=𝒩(𝐱t|α¯t𝐱0,(1α¯t)𝐈), (21) 其中α¯t=i=1tαi
证明 为了弄清楚为什么会这样,我们可以重新进行递归,但这次我们使用 αt𝐱t1(1αt)𝐈 作为均值和协方差。 这会给我们 𝐱t =αt𝐱t1+1αtϵt1 =αt(αt1𝐱t2+1αt1ϵt2)+1αtϵt1 =αtαt1𝐱t2+αt1αt1ϵt2+1αtϵt1𝐰1. (22) 因此,我们有两个高斯的和。 但由于两个高斯函数的和仍然是高斯函数,我们可以计算它的新协方差(因为平均值仍然为零)。 新的协方差是 𝔼[𝐰1𝐰1T] =[(αt1αt1)2+(1αt)2]𝐈 =[αt(1αt1)+1αt]𝐈=[1αtαt1]𝐈. 回到等式 (22),我们可以证明递归被更新为 𝐱t2 的线性组合和一个噪声向量 ϵt2 𝐱t =αtαt1𝐱t2+1αtαt1ϵt2 =αtαt1αt2𝐱t3+1αtαt1αt2ϵt3 = =i=1tαi𝐱0+1i=1tαiϵ0. (23) 因此,如果我们定义 α¯t=i=1tαi,我们可以证明 𝐱t=α¯t𝐱0+1α¯tϵ0. (24) 换句话说,分布 qϕ(𝐱t|𝐱0) 𝐱tqϕ(𝐱t|𝐱0)=𝒩(𝐱t|α¯t𝐱0,(1α¯t)𝐈). (25)

新分布 qϕ(𝐱t|𝐱0) 的效用在于它与链 𝐱0𝐱1𝐱T1𝐱T 相比,只有一步正向扩散步骤。 在正向扩散模型的每一步,由于我们已经知道 𝐱0,并且假设所有后续转换都是高斯分布,所以对于任何 t,我们都会立即知道 𝐱t。从图 11 中可以理解这种情况。

Refer to caption
图11: qϕ(𝐱t|𝐱t1)qϕ(𝐱t|𝐱0) 之间的差异。
示例 对于一个高斯混合模型,例如 𝐱p0(𝐱)=k=1Kπk𝒩(𝐱|𝝁k,σk2𝐈),我们可以证明时间 t 处的分布为 pt(𝐱) =k=1Kπk𝒩(𝐱|α¯t𝝁k,(1α¯t)𝐈+α¯tσk2𝐈) (26) =k=1Kπk𝒩(𝐱|αt𝝁k,(1αt)𝐈+αtσk2𝐈),if αt=αso that α¯t=i=1tα=αt.

如果你好奇概率分布 pt 如何随着时间 t 的推移而演变,我们在图 12 中展示了分布的轨迹。 您可以看到,当我们处于 t=0 时,初始分布是两个高斯分布的混合。 当我们按照公式 (26) 中定义的转换进行时,我们可以看到分布逐渐变成单个高斯分布 𝒩(0,1)

Refer to caption
图 12: 高斯混合体的轨迹图,当我们进行转换以将概率分布转换为 𝒩(0,1) 时。

在同一张图中,我们叠加并显示了随机样本 𝐱t 的一些瞬时轨迹作为时间 t 的函数。我们用来生成样本的方程是

𝐱t=αt𝐱t1+1αtϵt1,ϵ𝒩(0,𝐈).

正如你所见,𝐱t 的轨迹或多或少遵循分布 pt(𝐱)

2.4 证据下界

现在我们了解了变分扩散模型的结构,我们可以写下 ELBO 并训练模型。 变分扩散模型的 ELBO 为 ELBOϕ,𝜽(𝐱) =𝔼qϕ(𝐱1|𝐱0)[logp𝜽(𝐱0|𝐱1)how good the initial block is] 𝔼qϕ(𝐱T1|𝐱0)[𝔻KL(qϕ(𝐱T|𝐱T1)p(𝐱T))how good the final block is] t=1T1𝔼qϕ(𝐱t1,𝐱t+1|𝐱0)[𝔻KL(qϕ(𝐱t|𝐱t1)p𝜽(𝐱t|𝐱t+1))how good the transition blocks are]. (27) 我们可以解读一下这个ELBO的含义。 这里的 ELBO 由三个部分组成:

  • 重建 重建项基于初始块。 我们使用对数似然 p𝜽(𝐱0|𝐱1) 来衡量与 p𝜽 相关的深度神经网络从潜变量 𝐱1 中恢复图像 𝐱0 的好坏程度。 期望是针对从 qϕ(𝐱1|𝐱0) 中抽取的样本得出的,它是生成 𝐱1 的分布。 如果你想知道我们为什么要从 qϕ(𝐱1|𝐱0) 中抽取样本,只要想想样本 𝐱1 应该来自哪里。 样本𝐱1并非来自天空。 由于它们是中间潜在变量,因此它们是由前向转换qϕ(𝐱1|𝐱0)创建的。 所以,我们应该从 qϕ(𝐱1|𝐱0) 中生成样本。

  • 先前匹配 先前的匹配项基于最终块。 我们使用 KL 散度来衡量 qϕ(𝐱T|𝐱T1)p(𝐱T) 之间的差异。 第一个分布 qϕ(𝐱T|𝐱T1) 是从 𝐱T1𝐱T 的正向转换。 这就是 𝐱T 的生成方式。 第二个分布是 p(𝐱T) 由于我们的懒惰,p(𝐱T)𝒩(0,𝐈) 我们希望 qϕ(𝐱T|𝐱T1) 尽可能接近 𝒩(0,𝐈) 这里的样本是 𝐱T1,它们是从 qϕ(𝐱T1|𝐱0) 中抽取的,因为 qϕ(𝐱T1|𝐱0) 提供了正向样本生成过程。

  • 一致性 一致性项基于转换块。 有两个方向。 正向转换由分布 qϕ(𝐱t|𝐱t1) 决定,而反向转换由神经网络 p𝜽(𝐱t|𝐱t+1) 决定。 一致性项使用KL散度来衡量偏差。 期望值是相对于从联合分布 qϕ(𝐱t1,𝐱t+1|𝐱0) 中抽取的样本 (𝐱t1,𝐱t+1) 而言的。 哦,qϕ(𝐱t1,𝐱t+1|𝐱0) 是什么呢? 不用担心。 我们很快就会摆脱它。

此时,我们将跳过训练和推理,因为该公式尚未准备好实施。 我们将讨论更多的技巧,然后我们将讨论实现。

方程 (27) 的证明 让我们定义以下符号: 𝐱0:T={𝐱0,,𝐱T} 表示从 t=0t=T 的所有状态变量的集合。 我们还记得先验分布 p(𝐱) 是图像 𝐱0 的分布。 所以它等效于 p(𝐱0) 考虑到这些,我们可以证明 logp(𝐱) =logp(𝐱0) =logp(𝐱0:T)𝑑𝐱1:T Marginalize by integrating over 𝐱1:T =logp(𝐱0:T)qϕ(𝐱1:T|𝐱0)qϕ(𝐱1:T|𝐱0)𝑑𝐱1:T Multiply and divide qϕ(𝐱1:T|𝐱0) =logqϕ(𝐱1:T|𝐱0)[p(𝐱0:T)qϕ(𝐱1:T|𝐱0)]𝑑𝐱1:T Rearrange terms =log𝔼qϕ(𝐱1:T|𝐱0)[p(𝐱0:T)qϕ(𝐱1:T|𝐱0)] Definition of expectation. 现在,我们需要使用 Jensen 不等式,该不等式指出,对于任何随机变量 X 和任何凹函数 f,都有 f(𝔼[X])𝔼[f(X)] 通过识别 f()=log(),我们可以证明 logp(𝐱) =log𝔼qϕ(𝐱1:T|𝐱0)[p(𝐱0:T)qϕ(𝐱1:T|𝐱0)] 𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱0:T)qϕ(𝐱1:T|𝐱0)] (28) 让我们仔细看看 p(𝐱0:T) 检查图 8,我们注意到如果我们想解耦 p(𝐱0:T),我们应该对 𝐱t1|𝐱t 进行条件化。 这导致: p(𝐱0:T)=p(𝐱T)t=1Tp(𝐱t1|𝐱t)=p(𝐱T)p(𝐱0|𝐱1)t=2Tp(𝐱t1|𝐱t). (29) 至于 qϕ(𝐱1:T|𝐱0),图 8 表明我们需要对 𝐱t|𝐱t1 进行条件化。 但是,由于顺序关系,我们可以写 qϕ(𝐱1:T|𝐱0) =t=1Tqϕ(𝐱t|𝐱t1)=qϕ(𝐱T|𝐱T1)t=1T1qϕ(𝐱t|𝐱t1). (30) 将公式 (29) 和公式 (30) 代回公式 (28),我们可以证明 logp(𝐱) 𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱0:T)qϕ(𝐱1:T|𝐱0)] =𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱T)p(𝐱0|𝐱1)t=2Tp(𝐱t1|𝐱t)qϕ(𝐱T|𝐱T1)t=1T1qϕ(𝐱t|𝐱t1)] =𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱T)p(𝐱0|𝐱1)t=1T1p(𝐱t|𝐱t+1)qϕ(𝐱T|𝐱T1)t=1T1qϕ(𝐱t|𝐱t1)] shift t to t+1 =𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱T)p(𝐱0|𝐱1)qϕ(𝐱T|𝐱T1)]+𝔼qϕ(𝐱1:T|𝐱0)[logt=1T1p(𝐱t|𝐱t+1)qϕ(𝐱t|𝐱t1)] split expectation 上面的第一项可以进一步分解为两个期望 𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱T)p(𝐱0|𝐱1)qϕ(𝐱T|𝐱T1)] =𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱0|𝐱1)]Reconstruction+𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱T)qϕ(𝐱T|𝐱T1)]Prior Matching. 重建项可以简化为 𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱0|𝐱1)] =𝔼qϕ(𝐱1|𝐱0)[logp(𝐱0|𝐱1)], 我们使用条件 𝐱1:T|𝐱0 相当于 𝐱1|𝐱0 的事实。 先验匹配项是 𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱T)qϕ(𝐱T|𝐱T1)] =𝔼qϕ(𝐱T,𝐱T1|𝐱0)[logp(𝐱T)qϕ(𝐱T|𝐱T1)] =𝔼qϕ(𝐱T1,𝐱T|𝐱0)[𝔻KL(qϕ(𝐱T|𝐱T1)p(𝐱T))], 其中,我们注意到条件期望可以简化为仅对样本 𝐱T𝐱T1 进行采样,因为 logp(𝐱T)qϕ(𝐱T|𝐱T1) 仅取决于 𝐱T𝐱T1 最后,我们看一下产品术语。 我们可以证明 𝔼qϕ(𝐱1:T|𝐱0)[logt=1T1p(𝐱t|𝐱t+1)qϕ(𝐱t|𝐱t1)] =t=1T1𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱t|𝐱t+1)qϕ(𝐱t|𝐱t1)] =t=1T1𝔼qϕ(𝐱t1,𝐱t,𝐱t+1|𝐱0)[logp(𝐱t|𝐱t+1)qϕ(𝐱t|𝐱t1)] =t=1T1𝔼qϕ(𝐱t1,𝐱t+1|𝐱0)[𝔻KL(qϕ(𝐱t|𝐱t1)p(𝐱t|𝐱t+1))]consistency. 通过将 p(𝐱0|𝐱1) 替换为 p𝜽(𝐱0|𝐱1) 以及 p(𝐱t|𝐱t+1) 替换为 p𝜽(𝐱t|𝐱t+1),我们就完成了。

2.5 重写一致性术语

上述变分扩散模型的噩梦在于我们需要从联合分布 qϕ(𝐱t1,𝐱t+1|𝐱0) 中抽取样本 (𝐱t1,𝐱t+1) 我们不知道 qϕ(𝐱t1,𝐱t+1|𝐱0) 是什么! 嗯,当然,它是高斯分布,但我们仍然需要使用未来的样本𝐱t+1来绘制当前的样本𝐱t 这很奇怪,而且一点也不有趣。

检查一致性项,我们注意到 qϕ(𝐱t|𝐱t1)p𝜽(𝐱t|𝐱t+1) 沿着两个相反的方向移动。 因此,我们不可避免地需要使用𝐱t1𝐱t+1 我们需要问的问题是:我们能否想出一些办法,以便在能够检查一致性的同时不需要处理两个相反的方向?

所以,这是一个称为贝叶斯定理的简单技巧。

q(𝐱t|𝐱t1)=q(𝐱t1|𝐱t)q(𝐱t)q(𝐱t1)condition on 𝐱0q(𝐱t|𝐱t1,𝐱0)=q(𝐱t1|𝐱t,𝐱0)q(𝐱t|𝐱0)q(𝐱t1|𝐱0). (31)

通过改变条件顺序,我们可以通过添加一个额外的条件变量 𝐱0q(𝐱t|𝐱t1,𝐱0) 切换为 q(𝐱t1|𝐱t,𝐱0) 方向 q(𝐱t1|𝐱t,𝐱0) 现在与 p𝜽(𝐱t1|𝐱t) 平行,如图 13 所示。 所以,如果我们想重写一致性项,一个自然的选择是计算 qϕ(𝐱t1|𝐱t,𝐱0)p𝜽(𝐱t1|𝐱t) 之间的 KL 散度。

Refer to caption
图 13: 如果我们考虑公式 (31) 中的贝叶斯定理,我们可以定义一个方向与 p𝜽(𝐱t1|𝐱t) 平行的分布 qϕ(𝐱t1|𝐱t,𝐱0)

如果我们设法进行一些(无聊的)代数推导,我们可以证明 ELBO 现在是: 变分扩散模型的 ELBO 为 ELBOϕ,𝜽(𝐱) =𝔼qϕ(𝐱1|𝐱0)[logp𝜽(𝐱0|𝐱1)same as before]𝔻KL(qϕ(𝐱T|𝐱0)p(𝐱T))new prior matching t=2T𝔼qϕ(𝐱t|𝐱0)[𝔻KL(qϕ(𝐱t1|𝐱t,𝐱0)p𝜽(𝐱t1|𝐱t))new consistency]. (32) 让我们快速做出三种解释:

  • 重建 新的重建期限与之前相同。 我们仍在最大化对数似然。

  • 先前匹配 新的先验匹配简化为 qϕ(𝐱T|𝐱0)p(𝐱T) 之间的 KL 散度。 更改是由于我们现在以 𝐱0 为条件。 因此,无需从 qϕ(𝐱T1|𝐱0) 中抽取样本并进行期望。

  • 一致性 新的一致性术语与之前的一致性术语有两个不同之处。 首先,运行索引tt=2开始,到t=T结束。 以前是从 t=1t=T1 伴随着这一点的是分布匹配,它现在在 qϕ(𝐱t1|𝐱t,𝐱0)p𝜽(𝐱t1|𝐱t) 之间。 因此,与其寻找匹配逆向转换的前向转换,我们使用 qϕ 来构建逆向转换,并用它来匹配 p𝜽

公式 (32) 的证明 我们从公式 (28) 开始,通过证明 logp(𝐱) 𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱0:T)qϕ(𝐱1:T|𝐱0)] By Eqn (28) =𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱T)p(𝐱0|𝐱1)t=2Tp(𝐱t1|𝐱t)qϕ(𝐱1|𝐱0)t=2Tqϕ(𝐱t|𝐱t1,𝐱0)] split the chain =𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱T)p(𝐱0|𝐱1)qϕ(𝐱1|𝐱0)]+𝔼qϕ(𝐱1:T|𝐱0)[logt=2Tp(𝐱t1|𝐱t)qϕ(𝐱t|𝐱t1,𝐱0)] (33) 让我们考虑第二项: t=2Tp(𝐱t1|𝐱t)qϕ(𝐱t|𝐱t1,𝐱0) =t=2Tp(𝐱t1|𝐱t)qϕ(𝐱t1|𝐱t,𝐱0)qϕ(𝐱t|𝐱0)qϕ(𝐱t1|𝐱0) Bayes rule, Eqn (31) =t=2Tp(𝐱t1|𝐱t)qϕ(𝐱t1|𝐱t,𝐱0)×t=2Tqϕ(𝐱t1|𝐱0)qϕ(𝐱t|𝐱0) Rearrange denominator =t=2Tp(𝐱t1|𝐱t)qϕ(𝐱t1|𝐱t,𝐱0)×qϕ(𝐱1|𝐱0)qϕ(𝐱T|𝐱0), Recursion cancels terms 其中最后一个等式使用了对于任何序列 a1,,aT,我们有 t=2Tat1at=a1a2×a2a3××aT1aT=a1aT 的事实。 回到公式 (33),我们可以看到 𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱T)p(𝐱0|𝐱1)qϕ(𝐱1|𝐱0)]+𝔼qϕ(𝐱1:T|𝐱0)[logt=2Tp(𝐱t1|𝐱t)qϕ(𝐱t|𝐱t1,𝐱0)] =𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱T)p(𝐱0|𝐱1)qϕ(𝐱1|𝐱0)+logqϕ(𝐱1|𝐱0)qϕ(𝐱T|𝐱0)]+𝔼qϕ(𝐱1:T|𝐱0)[logt=2Tp(𝐱t1|𝐱t)qϕ(𝐱t1|𝐱t,𝐱0)] =𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱T)p(𝐱0|𝐱1)qϕ(𝐱T|𝐱0)]+𝔼qϕ(𝐱1:T|𝐱0)[logt=2Tp(𝐱t1|𝐱t)qϕ(𝐱t1|𝐱t,𝐱0)], 其中我们消去了分子和分母中的 qϕ(𝐱1|𝐱0),因为对于任何正常数 abclogab+logbc=logac。 这将给我们 𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱T)p(𝐱0|𝐱1)qϕ(𝐱T|𝐱0)] =𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱0|𝐱1)]+𝔼qϕ(𝐱1:T|𝐱0)[logp(𝐱T)qϕ(𝐱T|𝐱0)] =𝔼qϕ(𝐱1|𝐱0)[logp(𝐱0|𝐱1)]reconstruction𝔻KL(qϕ(𝐱T|𝐱0)p(𝐱T))prior matching. 最后一项是 𝔼qϕ(𝐱1:T|𝐱0)[logt=2Tp(𝐱t1|𝐱t)qϕ(𝐱t1|𝐱t,𝐱0)] =t=2𝔼qϕ(𝐱t,𝐱t1|𝐱0)logp(𝐱t1|𝐱t)qϕ(𝐱t1|𝐱t,𝐱0) =t=2𝔼qϕ(𝐱t,𝐱t1|𝐱0)𝔻KL(qϕ(𝐱t1|𝐱t,𝐱0)p(𝐱t1|𝐱t))consistency. 最后,用 p𝜽(𝐱t1|𝐱t) 替换 p(𝐱t1|𝐱t),用 p𝜽(𝐱0|𝐱1) 替换 p(𝐱0|𝐱1) 完毕!

2.6 qϕ(𝐱t1|𝐱t,𝐱0) 的推导

现在我们知道了变分扩散模型的新 ELBO,我们应该花一些时间讨论它的核心组件,即 qϕ(𝐱t1|𝐱t,𝐱0) 简而言之,我们想要展示的是

  • qϕ(𝐱t1|𝐱t,𝐱0) 并不像你想象的那么疯狂。 它仍然是高斯分布。

  • 由于它是高斯分布,因此它完全由均值和协方差来表征。 事实证明

    qϕ(𝐱t1|𝐱t,𝐱0)=𝒩(𝐱t1|𝐱t+𝐱0,𝐈), (34)

    对于下面定义的一些神奇标量

分布 qϕ(𝐱t1|𝐱t,𝐱0) 采用 qϕ(𝐱t1|𝐱t,𝐱0)=𝒩(𝐱t1|𝝁q(𝐱t,𝐱0),𝚺q(t)), (35) 在哪里 𝝁q(𝐱t,𝐱0) =(1α¯t1)αt1α¯t𝐱t+(1αt)α¯t11α¯t𝐱0 (36) 𝚺q(t) =(1αt)(1α¯t1)1α¯t𝐈=defσq2(t)𝐈. (37)

公式 (35) 的有趣之处在于 qϕ(𝐱t1|𝐱t,𝐱0)𝐱t𝐱0 完全刻画 不需要神经网络来估计均值和方差! (您可以将其与需要网络的 VAE 进行比较。) 由于不需要网络,所以实际上没有什么可“学习”的。 如果我们知道 𝐱t𝐱0,则 qϕ(𝐱t1|𝐱t,𝐱0) 会自动确定。 没有猜测,没有估计,什么都没有。

这里的认识很重要。 如果我们看一下一致性项,它是许多 KL 散度项的总和,其中第 t 项是

𝔻KL(qϕ(𝐱t1|𝐱t,𝐱0)nothing to learnp𝜽(𝐱t1|𝐱t)need to do something). (38)

正如我们所说,qϕ(𝐱t1|𝐱t,𝐱0) 与任何事情都无关。 但是我们需要对 p𝜽(𝐱t1|𝐱t) 做些什么,以便我们可以计算 KL 散度。

那么,我们应该做什么呢? 我们知道 qϕ(𝐱t1|𝐱t,𝐱0) 是高斯分布的。 如果我们想快速计算 KL 散度,那么显然我们需要 假设 p𝜽(𝐱t1|𝐱t) 也是高斯分布。 是的,不是开玩笑。 我们没有理由为什么它是高斯分布。 但由于p𝜽是我们可以选择的发行版,我们当然应该选择更容易的发行版。 为此,我们选择

p𝜽(𝐱t1|𝐱t)=𝒩(𝐱t1|𝝁𝜽(𝐱t)neural network,σq2(t)𝐈), (39)

我们假设可以使用神经网络来确定平均向量。 关于方差,我们 选择 方差为 σq2(t) 这与公式 (37) 完全相同 因此,如果我们将公式 (35) 与 p𝜽(𝐱t1|𝐱t) 并排放置,我们会注意到两者之间存在平行关系:

qϕ(𝐱t1|𝐱t,𝐱0) =𝒩(𝐱t1|𝝁q(𝐱t,𝐱0)known,σq2(t)𝐈known), (40)
p𝜽(𝐱t1|𝐱t) =𝒩(𝐱t1|𝝁𝜽(𝐱t)neural network,σq2(t)𝐈known). (41)

因此,KL 散度简化为

𝔻KL(qϕ(𝐱t1|𝐱t,𝐱0)p𝜽(𝐱t1|𝐱t))
=𝔻KL(𝒩(𝐱t1|𝝁q(𝐱t,𝐱0),σq2(t)𝐈)𝒩(𝐱t1|𝝁𝜽(𝐱t),σq2(t)𝐈))
=12σq2(t)𝝁q(𝐱t,𝐱0)𝝁𝜽(𝐱t)2, (42)

其中我们使用了两个同方差高斯函数之间的 KL 散度只是两个均值向量之间的欧几里得距离平方这一事实。

如果我们回到公式 (32) 中 ELBO 的定义,我们可以将其改写为

ELBO𝜽(𝐱) =𝔼q(𝐱1|𝐱0)[logp𝜽(𝐱0|𝐱1)]𝔻KL(q(𝐱T|𝐱0)p(𝐱T))nothing to train
t=2T𝔼q(𝐱t|𝐱0)[12σq2(t)𝝁q(𝐱t,𝐱0)𝝁𝜽(𝐱t)2]. (43)

有一些观察很有趣:

  • 我们删除了所有下标 ϕ,因为只要我们知道 𝐱0q 就完全描述了。 我们只是向每个 𝐱1,,𝐱T 添加(不同级别的)白噪声。 这将为我们提供一个 ELBO,只需要我们对 𝜽 进行优化。

  • 参数 𝜽 是通过网络 𝝁𝜽(𝐱t) 实现的。 它是 𝝁𝜽(𝐱t) 的网络权重。

  • q(𝐱t|𝐱0) 中采样是根据公式 (21) 进行的,该公式指出 q(𝐱t|𝐱0)=𝒩(𝐱t|α¯t𝐱0,(1α¯t)𝐈)

  • 给定 𝐱tq(𝐱t|𝐱0),我们可以计算 logp𝜽(𝐱0|𝐱1),它只是 log𝒩(𝐱0|𝝁𝜽(𝐱1),σq2(1)𝐈) 因此,只要我们知道 𝐱1,我们就可以将其发送到网络 𝝁𝜽(𝐱1),以返回我们的均值估计。 然后,平均估计将用于计算可能性。

在我们继续之前,让我们通过讨论公式 (35) 是如何确定的来完成这个故事。

方程式 (35) 的证明. 使用方程式 (31) 中陈述的贝叶斯定理,q(𝐱t1|𝐱t,𝐱0) 可以通过评估以下高斯函数的乘积来确定 q(𝐱t1|𝐱t,𝐱0)=𝒩(𝐱t|αt𝐱t1,(1αt)𝐈)𝒩(𝐱t1|α¯t1,(1α¯t1𝐈))𝒩(𝐱t|α¯t𝐱0,(1α¯t)𝐈). (44) 为简单起见,我们将向量视为标量。 那么上面的高斯乘积将变成 q(𝐱t1|𝐱t,𝐱0)exp{(𝐱tαt𝐱t1)22(1αt)+(𝐱t1α¯t1z)22(1α¯t1)(𝐱tα¯t𝐱0)22(1α¯t)}. (45) 我们考虑以下映射: x =𝐱t, a=αt y =𝐱t1, b=α¯t1 z =𝐱0, c=α¯t. 考虑二次函数 f(y)=(xay)22(1a)+(ybz)22(1b)(xcz)22(1c). (46) 我们知道,无论我们如何重新排列各项,得到的函数仍然是一个二次方程。 f(y) 的最小化器是所得高斯的均值。 因此,我们可以计算 f 的导数并证明 f(y)=1ab(1a)(1b)y(a1ax+b1bz). 设置 f(y)=0 产生 y=(1b)a1abx+(1a)b1abz. (47) 我们注意到 ab=αtα¯t1=α¯t 所以, 𝝁q(𝐱t,𝐱0)=(1α¯t1)αt1α¯t𝐱t+(1αt)α¯t11α¯t𝐱0. (48) 同样,对于方差,我们可以检查曲率 f′′(y) 我们可以很容易地证明 f′′(y)=1ab(1a)(1b)=1α¯t(1αt)(1α¯t1). 取倒数可以得到 𝚺q(t)=(1αt)(1α¯t1)1α¯t𝐈. (49)

2.7训练和推理

方程式 (43) 中的 ELBO 表明,我们需要找到一个网络 𝝁𝜽,它能够以某种方式最小化这种损失:

12σq2(t)𝝁q(𝐱t,𝐱0)known𝝁𝜽(𝐱t)network2. (50)

但“去噪”的概念从何而来?

为了看到这一点,我们从方程式 (36) 中回忆起

𝝁q(𝐱t,𝐱0)=(1α¯t1)αt1α¯t𝐱t+(1αt)α¯t11α¯t𝐱0. (51)

既然𝝁𝜽是我们的设计,我们没有理由不能将它定义为更方便的东西。 所以这里有一个选择:

𝝁𝜽a network(𝐱t)=def(1α¯t1)αt1α¯t𝐱t+(1αt)α¯t11α¯t𝐱^𝜽(𝐱t)another network. (52)

将方程式 (51) 和方程式 (52) 代入方程式 (50) 将得到

12σq2(t)𝝁q(𝐱t,𝐱0)𝝁𝜽(𝐱t)2 =12σq2(t)(1αt)α¯t11α¯t(𝐱^𝜽(𝐱t)𝐱0)2
=12σq2(t)(1αt)2α¯t1(1α¯t)2𝐱^𝜽(𝐱t)𝐱02

因此ELBO可以简化为

ELBO𝜽 =𝔼q(𝐱1|𝐱0)[logp𝜽(𝐱0|𝐱1)]t=2T𝔼q(𝐱t|𝐱0)[12σq2(t)𝝁q(𝐱t,𝐱0)𝝁𝜽(𝐱t)2]
=𝔼q(𝐱1|𝐱0)[logp𝜽(𝐱0|𝐱1)]t=2T𝔼q(𝐱t|𝐱0)[12σq2(t)(1αt)2α¯t1(1α¯t)2𝐱^𝜽(𝐱t)𝐱02]. (53)

第一项是

logp𝜽(𝐱0|𝐱1) =log𝒩(𝐱0|𝝁𝜽(𝐱1),σq2(1)𝐈)12σq2(1)𝝁𝜽(𝐱1)𝐱02 definition
=12σq2(1)(1α¯0)α11α¯1𝐱1+(1α1)α¯01α¯1𝐱^𝜽(𝐱1)𝐱02 recall α0=1
=12σq2(1)(1α1)1α¯1𝐱^𝜽(𝐱1)𝐱02=12σq2(1)𝐱^𝜽(𝐱1)𝐱02 recall α¯1=α1 (54)

将方程 (54) 代入方程 (53) 将简化 ELBO 为

ELBO𝜽=t=1T𝔼q(𝐱t|𝐱0)[12σq2(t)(1αt)2α¯t1(1α¯t)2𝐱^𝜽(𝐱t)𝐱02].

因此,神经网络的训练可以归结为一个简单的损失函数: 去噪扩散概率模型的损失函数 𝜽=argmin𝜽t=1T12σq2(t)(1αt)2α¯t1(1α¯t)2𝔼q(𝐱t|𝐱0)[𝐱^𝜽(𝐱t)𝐱02]. (55)

方程 (55) 中定义的损失函数非常直观。 忽略常量和期望,对于特定的𝐱t,主要感兴趣的主题是

argmin𝜽𝐱^𝜽(𝐱t)𝐱02.

这不过是一个去噪问题,因为我们需要找到一个网络 𝐱^𝜽,使得去噪后的图像 𝐱^𝜽(𝐱t) 将接近于真实值 𝐱0 它不是典型的降噪器的原因是

  • 𝔼q(𝐱t|𝐱0): 我们不是试图对任何随机噪声图像进行去噪。 相反,我们仔细选择噪声图像

    𝐱tq(𝐱t|𝐱0) =𝒩(𝐱t|α¯t𝐱0,(1α¯t)𝐈)
    =α¯t𝐱0+(1α¯t)𝐳,𝐳𝒩(0,𝐈).

    在这里,“小心”是指仔细控制注入图像的噪声量。

    Refer to caption
    图 14: 正向采样过程。 前向采样过程原本是一个操作链。 然而,如果我们假设高斯分布,那么我们可以将采样过程简化为一步数据生成。
  • 12σq2(t)(1αt)2α¯t1(1α¯t)2: 我们不会对所有步骤的去噪损失进行等权重。 相反,有一个调度程序来控制每个去噪损失的相对重点。 然而,为了简单起见,我们可以放弃这些。 其影响较小。

  • t=1T: 求和可以用均匀分布 tUniform[1,T] 代替。

训练拒绝扩散概率模型 (版本:预测图像)对于训练数据集中的每个图像 𝐱0 重复以下步骤直至收敛。 选择一个随机时间戳 tUniform[1,T] 从样本 𝐱t𝒩(𝐱t|α¯t𝐱0,(1α¯t)𝐈) 中抽取样本,即 𝐱t=α¯t𝐱0+(1α¯t)𝐳,𝐳𝒩(0,𝐈). 采取梯度下降步骤 𝜽𝐱^𝜽(𝐱t)𝐱02 您可以批量执行此操作,就像训练任何其他神经网络一样。 请注意,在这里,您正在为所有噪声条件训练一个去噪网络𝐱^𝜽
Refer to caption
图 15: 去噪扩散概率模型的训练。 对于同一个神经网络 𝐱^𝜽,我们将噪声输入 𝐱t 发送到网络。 损失的梯度被反向传播以更新网络。 请注意,噪声图像不是任意的。 它们是根据前向采样过程生成的。

一旦降噪器 𝐱^𝜽 训练完毕,我们就可以应用它来进行推理。 推理是关于从状态序列 𝐱T,𝐱T1,,𝐱1 上的分布 p𝜽(𝐱t1|𝐱t) 中采样图像。 由于这是反向扩散过程,我们需要通过以下方式递归地进行:

𝐱t1p𝜽(𝐱t1|𝐱t) =𝒩(𝐱t1|𝝁𝜽(𝐱t),σq2(t)𝐈)
=𝝁𝜽(𝐱t)+σq2(t)𝐳,where𝐳𝒩(0,𝐈)
=(1α¯t1)αt1α¯t𝐱t+(1αt)α¯t11α¯t𝐱^𝜽(𝐱t)+σq(t)𝐳.

这导致了以下推理算法。

拒绝扩散概率模型的推断 (版本:预测图像) 您给我们一个白噪声向量 𝐱T𝒩(0,𝐈) t=T,T1,,1 重复以下操作。 我们使用训练过的去噪器计算 𝐱^𝜽(𝐱t) 更新根据 𝐱t1 =(1α¯t1)αt1α¯t𝐱t+(1αt)α¯t11α¯t𝐱^𝜽(𝐱t)+σq(t)𝐳,𝐳𝒩(0,𝐈). (56)
Refer to caption
图 16: 去噪扩散概率模型的推断。

2.8基于噪声向量的推导

如果您熟悉去噪文献,您可能知道预测噪声而不是信号的残差类型算法。 同样的精神也适用于去噪扩散,我们可以学习预测噪声。 为了了解为什么会出现这种情况,我们考虑方程 (24)。 如果我们重新安排条款,我们将获得

𝐱t=α¯t𝐱0+1α¯tϵ0
𝐱0=𝐱t1α¯tϵ0α¯t.

将此代入 𝝁q(𝐱t,𝐱0),我们可以证明

𝝁q(𝐱t,𝐱0) =αt(1α¯t1)𝐱t+α¯t1(1αt)𝐱01α¯t
=αt(1α¯t1)𝐱t+α¯t1(1αt)𝐱t1α¯tϵ0α¯t1α¯t
=a few more algebraic steps
=1αt𝐱t1αt1α¯tαtϵ0. (57)

因此,如果我们可以设计我们的均值估计器𝝁𝜽,我们就可以自由选择它来匹配以下形式:

𝝁𝜽(𝐱t)=1αt𝐱t1αt1α¯tαtϵ^𝜽(𝐱t). (58)

将公式 (57) 和公式 (58) 代入公式 (50) 将得到一个新的 ELBO

ELBO𝜽=t=1T𝔼q(𝐱t|𝐱0)[12σq2(t)(1αt)2α¯t1(1α¯t)2ϵ^𝜽(𝐱t)ϵ02].

因此,如果你给我们 𝐱t,我们会返回一个预测的噪声 ϵ^𝜽(𝐱t) 这将为我们提供替代的训练方案 训练拒绝扩散概率模型(版本预测噪声)。 对于训练数据集中的每个图像 𝐱0 重复以下步骤直至收敛。 随机选择一个时间戳 tUniform[1,T] 抽取一个样本 𝐱t𝒩(𝐱t|α¯t𝐱0,(1α¯t)𝐈),即 𝐱t=α¯t𝐱0+(1α¯t)𝐳,𝐳𝒩(0,𝐈). 采取梯度下降步骤 𝜽ϵ^𝜽(𝐱t)ϵ02 因此,推理步骤可以通过

𝐱t1p𝜽(𝐱t1|𝐱t) =𝒩(𝐱t1|𝝁𝜽(𝐱t),σq2(t)𝐈)
=𝝁𝜽(𝐱t)+σq2(t)𝐳
=1αt𝐱t1αt1α¯tαtϵ^𝜽(𝐱t)+σq(t)𝐳
=1αt(𝐱t1αt1α¯tϵ^𝜽(𝐱t))+σq(t)𝐳

总结到这里,我们有 拒绝扩散概率模型的推断 (版本预测噪声) 你给我们一个白噪声向量 𝐱T𝒩(0,𝐈) t=T,T1,,1 重复以下操作。 我们使用训练好的去噪器计算 𝐱^𝜽(𝐱t) 更新根据 𝐱t1 =1αt(𝐱t1αt1α¯tϵ^𝜽(𝐱t))+σq(t)𝐳,𝐳𝒩(0,𝐈).

2.9直接去噪 (InDI) 反演

如果我们查看 DDPM 公式,我们会看到更新公式 (56) 采用以下形式:

𝐱t1=(something)𝐱t+(something else)denoise(𝐱t)+noise. (59)

换句话说,(t1) 次估计是三项的线性组合:当前估计 𝐱t、去噪版本 denoise(𝐱t) 和噪声项。 当前的估计和噪声项很容易理解。 但什么是“降噪”? Delbracio 和 Milanfar 发表的一篇有趣的论文[6]从纯去噪的角度研究了生成扩散模型。 事实证明,这种令人惊讶的简单观点在某些方面与其他更先进的扩散模型是一致的。

什么是 denoise(𝐱t) 去噪是一种从噪声图像中去除噪声的通用过程。 在统计信号处理的美好时光中,标准教科书问题是导出白噪声的最佳降噪器。 给定观察模型

𝐲=𝐱+ϵ,whereϵ𝒩(0,𝐈),

你能构建一个估计器 g() 使得均方误差最小化吗?

我们将跳过这个经典问题解的推导,因为你可以在任何概率教科书中找到它,例如[7,第8章] 解决办法是

denoise(𝐲) =argmin𝑔𝔼𝐱,𝐲[g(𝐲)𝐱2]
=some magical step
=𝔼[𝐱|𝐲]. (60)

那么,回到我们的问题:如果我们假设

𝐱t=𝐱t1+ϵt1,whereϵt1𝒩(0,𝐈),

那么显然,降噪器是后验分布的条件期望:

denoise(𝐱t)=𝔼[𝐱t1|𝐱t]. (61)

因此,如果给定分布 p𝜽(𝐱t1|𝐱t),则最佳去噪器只是该分布的条件期望。 这种降噪器称为最小均方误差 (MMSE) 降噪器。 MMSE 降噪器不是“最佳”降噪器;它只是相对于均方误差而言的最佳降噪器。 由于均方误差从来都不是衡量图像质量的良好指标,因此最小化 MSE 并不一定会给我们带来更好的图像。 尽管如此,人们还是喜欢 MMSE 降噪器,因为它们很容易推导。

增量去噪步骤 如果您了解 MMSE 降噪器相当于后验分布的条件期望,您就会欣赏增量降噪。 下面是它的工作原理。 假设我们有一个干净的图像 𝐱0 和一个噪声图像 𝐲 我们的目标是通过一个简单的方程形成 𝐱0𝐲 的线性组合

𝐱t=(1t)𝐱0+t𝐲,0t1. (62)

现在,考虑时间 t 之前的一个小步骤 τ[6] 显示的以下结果提供了一些有用的实用程序: 0τ<t1,并假设 𝐱t=(1t)𝐱0+t𝐲,则成立 𝔼[𝐱tτ|𝐱t]=(1τt)𝐱tcurrent estimate+τt𝔼[𝐱0|𝐱t]denoised. (63) 如果我们将 𝐱^tτ 定义为左侧,用 𝐱^t 替换 𝐱t,并将 𝔼[𝐱0|𝐱t] 写成 denoise(𝐱^t),则上面的等式将变为

𝐱^tτ=(1τt)𝐱^t+τtdenoise(𝐱^t), (64)

其中 τ 是时间的一小步。

等式 (64) 给我们一个 推断步骤。 如果你告诉我们去噪器,并假设你从一个噪声图像 𝐲 开始,那么我们可以迭代地应用等式 (64) 来检索图像 𝐱^t1𝐱^t2,…,𝐱^0

训练 迭代方案的训练需要一个生成 denoise(𝐱t) 的去噪器。 为此,我们可以训练一个神经网络denoise𝜽(其中𝜽表示网络权重):

minimize𝜽𝔼𝐱,𝐲𝔼tuniform[denoise𝜽(𝐱t)𝐱2]. (65)

这里,分布“tuniform”指定时间步t是从给定分布中均匀绘制的。 因此,我们为所有时间步t训练一个降噪器。当您使用数据集中的一对有噪声且干净的图像时,通常会满足期望 (𝐱,𝐲) 训练。 训练后,我们可以通过等式 (64) 进行增量更新。

与去噪分数匹配的连接 尽管我们还没有讨论分数匹配(将在下一节中介绍),但关于上述迭代去噪过程的一个有趣的事实是它与去噪分数匹配有关。 在高层,我们可以将迭代重写为

𝐱tτ =(1τt)𝐱t+τtdenoise(𝐱t)
𝐱tτ𝐱t =τt𝐱t+τtdenoise(𝐱t)
𝐱t𝐱tττ =𝐱tdenoise(𝐱t)t
d𝐱tdt=limτ0𝐱t𝐱tττ =𝐱tdenoise(𝐱t)t

这是一个常微分方程 (ODE)。 如果我们让 𝐱t=𝐱+tϵ 使得 𝐱t 中的噪声水平为 σt2=t2σ2,那么我们可以使用文献中的几个结果来证明

d𝐱tdt =12d(σt2)dt𝐱tlogpt(𝐱t) (ODE defined by Song et al. [8])
=tσ2𝐱tlogpt(𝐱t) (σt=tσ)
tσ2𝐱denoise(𝐱t)t2σ2 (Approximation proposed by Vincent [9])
=𝐱tdenoise(𝐱t)t.

因此,增量去噪迭代相当于去噪分数匹配,至少在 ODE 确定的极限情况下是这样。

添加随机步骤 上述增量去噪迭代可以配备随机扰动。 对于推理步骤,我们可以定义一系列噪声级别{σt| 0t1},并定义

𝐱^tτ=(1τt)𝐱^t+τtdenoise(𝐱^t)+(tτ)σtτ2σt2ϵ,ϵ𝒩(0,𝐈). (66)

作为或训练,人们可以通过以下方式训练降噪器

minimize𝜽𝔼(𝐱,𝐲)𝔼tuniform𝔼ϵ[denoise(𝐱t)𝐱2], (67)

其中 𝐱t=(1t)𝐱+t𝐲+tσtϵ

恭喜! 我们完了。 这就是 DDPM 的全部内容。

DDPM 的文献正在迅速爆炸式增长。 Sohl-Dickstein 等人 [10] 和 Ho 等人 [4] 的原始论文是理解该主题的必读文章。 对于更“用户友好”的版本,我们发现 Luo 的教程非常有用[11] 一些后续工作被高度引用,包括宋等人[12]的去噪扩散隐式模型。 在应用方面,人们已经将DDPM用于各种图像合成应用,例如[13, 14]

3分数匹配 Langevin Dynamics (SMLD)

基于分数的生成模型[8]是根据所需分布生成数据的替代方法。 有几个核心要素:朗之万动力学、(Stein) 评分函数和评分匹配损失。 在本节中,我们将一一探讨这些主题。

3.1朗之万动力

我们讨论的一个有趣的起点是朗之万动力学。 这是一个非常物理学的话题,似乎与生成模型无关。 但请不要担心。 事实上,它们以一种很好的方式相关。

我们不以正确的方式告诉您物理原理,而是讨论如何使用朗之万动力学从分布中抽取样本。 想象一下,我们给定一个分布 p(𝐱),并假设我们想要从 p(𝐱) 中抽取样本。 朗之万动力学是一个迭代过程,允许我们根据以下方程抽取样本。 从已知分布 p(𝐱) 中采样的 朗之万动力学 是一个用于 t=1,,T 的迭代过程: 𝐱t+1=𝐱t+τ𝐱logp(𝐱t)+2τ𝐳,𝐳𝒩(0,𝐈), (68) 其中τ是用户可以控制的步长,𝐱0是白噪声。

你可能想知道,这个神秘的方程式到底是关于什么的? 这是简短而快速的答案。 如果你忽略了末尾的噪声项 2τ𝐳,则等式 (68) 中的朗之万动力学方程实际上是 梯度下降 仔细选择下降方向 𝐱logp(𝐱),使得 𝐱t 会收敛到分布 p(𝐱) 如果您观看任何 YouTube 视频,长达 10 分钟地咕哝朗之万动力学方程,但没有解释它是什么,您可以温和地告诉他们以下内容: 如果没有噪声项,朗之万动力学梯度下降

考虑一个分布 p(𝐱) 一旦定义了模型参数,该分布的形状就被定义并固定。 例如,如果您选择高斯分布,则一旦指定均值和方差,高斯分布的形状和位置就会固定。 p(𝐱) 不过是在数据点 𝐱 处评估的概率密度。 因此,从一个 𝐱 到另一个 𝐱,我们只是从一个值 p(𝐱) 移动到一个不同的值 p(𝐱) 高斯的基本形状没有改变。

假设我们从 d 中的某个任意位置开始。 我们希望将其移至分布的(其中一个)峰值。 峰值是一个特殊的地方,因为它是概率最高的地方。 所以,如果我们说样本 𝐱 是从分布 p(𝐱) 中抽取的,那么 𝐱 的“最佳”位置一定是 p(𝐱) 最大化的位置。 如果 p(𝐱) 有多个局部最小值,任何一个都可以。 所以,很自然地,采样的目标就相当于解决优化问题

𝐱=argmax𝐱logp(𝐱).

我们再次强调,这不是最大似然估计。 在最大似然情况下,数据点𝐱是固定的,但模型参数正在变化。 这里,模型参数是固定的,但数据点是变化的。 下表总结了差异。

Problem Sampling Maximum Likelihood
Optimization target A sample 𝐱 Model parameter 𝜽
Formulation 𝐱=argmax𝐱logp(𝐱;𝜽) 𝜽=argmax𝜽logp(𝐱;𝜽)

优化可以通过多种方式解决。 最便宜的方法当然是梯度下降。 对于 logp(𝐱),我们看到梯度下降步长是

𝐱t+1=𝐱t+τ𝐱logp(𝐱t),

其中 𝐱logp(𝐱t) 表示在 𝐱t 处计算的 logp(𝐱) 的梯度,τ 是步长。 这里我们使用“+”而不是典型的“”,因为我们正在解决最大化问题。

示例 考虑一个高斯分布 p(x)=𝒩(x|μ,σ2),我们可以很容易地证明朗之万动力学方程是 xt+1 =xt+τxlog{12πσ2e(xμ)22σ2}+2τz =xtτxtμσ2+2τz, z𝒩(0,1)
示例 考虑一个高斯混合模型 p(x)=π1𝒩(x|μ1,σ12)+π2𝒩(x|μ2,σ22) 我们可以数值计算 xlogp(x) 为了演示,我们选择π1=0.6 μ1=2σ1=0.5π2=0.4μ2=2σ2=0.2 我们初始化x0=0 我们选择τ=0.05 我们对 T=500 次运行上述梯度下降迭代,并绘制 t=1,,T 的值 p(xt) 的轨迹。 如下图所示,序列 {x1,x2,,xT} 简单地遵循高斯形状并爬到其中一个峰值。 更有趣的是当我们添加噪声项时。 序列 xt 不是在峰值处着陆,而是围绕峰值移动并在峰值附近的某个位置结束。 我们越接近峰值,我们停在那里的可能性就越大。 [Uncaptioned image] [Uncaptioned image] 𝐱t+1=𝐱t+τ𝐱logp(𝐱t) 𝐱t+1=𝐱t+τ𝐱logp(𝐱t)+2τ𝐳

17 显示了样本轨迹的一个有趣的描述。 从任意位置开始,数据点𝐱t将根据朗之万动力学方程进行随机游走。 随机游走的方向并不是完全任意的。 存在一定量的预定义漂移,而每一步都存在一定程度的随机性。 漂移由 𝐱logp(𝐱) 决定,而随机性来自 𝐳

Refer to caption
图 17: 使用朗之万动力学的样品演化轨迹。 我们用不同的颜色对高斯混合的两种模式进行了着色,以便更好地可视化。 这里的设置与上例相同,只是步长为τ=0.001

从上面的例子中我们可以看出,噪声项的加入实际上将梯度下降变成了随机梯度下降 随机梯度下降不是追求确定性最优,而是随机爬上山。 由于我们使用一个常数步长 2τ,最终的解将在峰值附近振荡。 因此,我们可以将朗之万动力学方程总结为 朗之万动力学随机梯度下降 但为什么我们要进行随机梯度下降而不是梯度下降呢? 关键是我们对解决优化问题不感兴趣。 相反,我们更感兴趣的是从分布中采样 通过在梯度下降步骤中引入随机噪声,我们随机选择一个遵循目标函数轨迹但不停留在原处的样本。 如果接近山顶,我们会稍微左右移动。 如果我们远离峰值,梯度方向会将我们拉向峰值。 如果峰值周围的曲率很陡,我们将把大部分稳态点 𝐱T 集中在那里。 如果峰周围的曲率是平坦的,我们就会向四周扩散。 因此,通过在均匀分布的位置重复初始化随机梯度下降算法,我们最终将收集遵循我们指定分布的样本。

示例 考虑一个高斯混合模型 p(x)=π1𝒩(x|μ1,σ12)+π2𝒩(x|μ2,σ22) 我们可以数值计算 xlogp(x) 为了演示,我们选择π1=0.6 μ1=2σ1=0.5π2=0.4μ2=2σ2=0.2 假设我们初始化 M=10000 为均匀分布的样本 x0Uniform[3,3] 我们运行 t=100 步骤的 Langevin 更新。 生成样本的直方图如下图所示。 [Uncaptioned image] [Uncaptioned image] [Uncaptioned image] [Uncaptioned image]
备注:Langevin Dynamics 的起源 朗之万动力学这个名字当然不是源于我们的“黑客”观点。 这要从物理学开始。 考虑基本牛顿方程,它将力 𝐅 与质量 m 和速度 𝐯(t) 联系起来。 牛顿第二定律说的是 𝐅force=mmassd𝐯(t)dtacceleration. (69) 给定力 𝐅,我们也知道它与势能 U(𝐱) 之间的关系为 𝐅force=𝐱U(𝐱)energy. (70) 朗之万动力学的随机性来自于布朗运动。 想象一下,我们有一袋分子在移动。 它们的运动可以根据布朗运动模型来描述: d𝐯(t)dt=λm𝐯(t)+1m𝜼,where𝜼𝒩(0,σ2𝐈). (71) 因此,将公式 (71) 代入公式 (69),并将其与公式 (70) 等同,我们有 𝐱U(𝐱)=λ𝐯(t)+𝜼𝐯(t)=1λ𝐱U(𝐱)+1λ𝜼. 这可以等效地写为 d𝐱dt=1λ𝐱U(𝐱)+σλ𝐳,where𝐳𝒩(0,𝐈). (72) 如果我们令 τ=dtλ 并对上述微分方程进行离散化,我们将得到 𝐱t+1=𝐱tτ𝐱U(𝐱t)+στ𝐳t. (73) 因此,仍然需要确定能源潜力。 对于我们的概率分布函数 p(𝐱),一个非常合理(且懒惰)的选择是具有以下形式的玻尔兹曼分布 p(𝐱)=1Zexp{U(𝐱)}. 因此,立即得出结论 𝐱logp(𝐱)=𝐱{U(𝐱)logZ}=𝐱U(𝐱). (74) 将公式 (74) 代入公式 (73) 会得到 𝐱t+1=𝐱t+τ𝐱logp(𝐱)+στ𝐳 最后,如果我们选择 σ=2/τ(没有特殊原因),我们将获得 𝐱t+1=𝐱t+τ𝐱logp(𝐱t)+2τ𝐳t. (75)

3.2 (Stein 的)评分函数

朗之万动力学方程的第二个部分是梯度 𝐱logp(𝐱) 它有一个正式名称为斯坦因评分函数,表示为

𝐬𝜽(𝐱)=def𝐱logp𝜽(𝐱). (76)

我们应该小心,不要将 Stein 的得分函数与普通得分函数混淆,后者定义为

𝐬𝐱(𝜽)=def𝜽logp𝜽(𝐱). (77)

普通的得分函数是对数似然的梯度(wrt 𝜽)。 相反,Stein 的得分函数是数据点 𝐱 的梯度。 最大似然估计使用普通得分函数,而朗之万​​动力学使用斯坦因得分函数。 然而,由于扩散文献中的大多数人将斯坦因的得分函数称为得分函数,因此我们遵循这种文化。 朗之万动力学中的“得分函数”更准确地称为斯坦因得分函数。

理解分数函数的方法是记住它是相对于数据𝐱的梯度。 对于任何高维分布 p(𝐱),梯度将为我们提供矢量场

𝐱logp(𝐱)=a vector field=[logp(𝐱)x,logp(𝐱)y]T (78)

让我们考虑两个例子。 示例 如果 p(x) 是一个均值为 p(x)=12πσ2e(xμ)22σ2 的高斯分布,那么 s(x)=xlogp(x)=(xμ)σ2.

示例 如果 p(x) 是一个均值为 p(x)=i=1Nπi12πσi2e(xμi)22σi2 的高斯混合分布,那么 s(x)=xlogp(x)=j=1Nπj12πσj2e(xμj)22σj2(xμj)σj2i=1Nπi12πσi2e(xμi)22σi2.

上述两个例子的概率密度函数和相应的评分函数如图 18 所示。

Refer to caption Refer to caption
(a) 𝒩(1,1) (b) 0.6𝒩(2,0.52)+0.4𝒩(2,0.22)
图 18: 评分函数示例

得分函数的几何解释

  • 向量的幅度在 logp(𝐱) 变化最大的地方最强。 因此,在 logp(𝐱) 接近峰值的区域,梯度将非常弱。

  • 矢量场表示数据点在等高线图中的移动方式 19 显示了高斯混合(包含两个高斯分布)的等高线图。 我们画箭头来表示向量场。 现在,如果我们考虑存在于空间中的数据点,朗之万动力学方程基本上会将数据点沿着矢量场指向的方向移动到盆地。

  • 在物理学中,得分函数相当于“漂移”。 这个名字表明扩散粒子应该如何流向最低能量状态。

Refer to caption Refer to caption
(a) vector field of 𝐱logp(𝐱) (b) 𝐱t trajectory
图 19: 得分函数的等高线图,以及两个样本对应的轨迹。

3.3 分数匹配技术

朗之万动力学中最困难的问题是如何获得 𝐱p(𝐱),因为我们无法访问 p(𝐱) 让我们回顾一下(斯坦因的)评分函数的定义

𝐬𝜽(𝐱)=def𝐱p(𝐱), (79)

我们在其中添加下标 𝜽 来表示 𝐬𝜽 将通过网络实现。 由于上式的右边未知,我们需要一些廉价而肮脏的方法来近似它。 在本节中,我们简要讨论两种近似。

显式分数匹配 假设我们有一个数据集𝒳={𝐱1,,𝐱M} 人们提出的解决方案是通过定义分布来考虑经典的核密度估计

q(𝐱)=1Mm=1M1hK(𝐱𝐱mh), (80)

其中 h 只是核函数 K() 的某个超参数,而 𝐱m 是训练集中第 m 个样本。 20 说明了核密度估计的概念。 在左侧的卡通图中,我们展示了以不同数据点 𝐱m 为中心的多个核 K() 所有这些单个核的总和为我们提供了总的核密度估计 q(𝐱) 在右侧,我们显示了真实的直方图和相应的核密度估计。 我们注意到,q(𝐱) 充其量只是对真实数据分布 p(𝐱) 的近似,而真实数据分布永远不会被知道。

Refer to caption Refer to caption
图 20: 核密度估计的图示。

由于 q(𝐱) 是对永远无法访问的 p(𝐱) 的近似,我们可以根据 q(𝐱) 学习 𝐬𝜽(𝐱) 这导致了以下可用于训练网络的损失函数的定义。 显式分数匹配损失是 JESM(𝜽)=def𝔼q(𝐱)𝐬𝜽(𝐱)𝐱logq(𝐱)2 (81) 通过代入核密度估计,我们可以证明损失为

JESM(𝜽) =def𝔼q(𝐱)𝐬𝜽(𝐱)𝐱logq(𝐱)2
=𝐬𝜽(𝐱)𝐱logq(𝐱)2[1Mm=1M1hK(𝐱𝐱mh)]𝑑𝐱
=1Mm=1M𝐬𝜽(𝐱)𝐱logq(𝐱)21hK(𝐱𝐱mh)𝑑𝐱. (82)

因此,我们推导出了一个可用于训练网络的损失函数。 一旦我们训练了网络𝐬𝜽,我们就可以将其替换到朗之万动力学方程中以获得递归:

𝐱t+1=𝐱t+τ𝐬𝜽(𝐱t)+2τ𝐳. (83)

显式分数匹配的问题在于,核密度估计是真实分布的相当差的非参数估计。 特别是当我们的样本数量有限并且样本位于高维空间中时,核密度估计性能可能很差。

去噪分数匹配 考虑到显式分数匹配的潜在缺点,我们现在引入一种更流行的分数匹配,称为去噪分数匹配(DSM)。 在DSM中,损失函数定义如下。

JDSM(𝜽)=def𝔼q(𝐱,𝐱)[12𝐬𝜽(𝐱)𝐱q(𝐱|𝐱)2] (84)

这里的关键区别在于我们将分布 q(𝐱) 替换为条件分布 q(𝐱|𝐱) 前者需要近似值,例如通过核密度估计,而后者则不需要。 这是一个例子。

q(𝐱|𝐱)=𝒩(𝐱|𝐱,σ2) 的特殊情况下,我们可以令 𝐱=𝐱+σ𝐳 这会给我们

𝐱logq(𝐱|𝐱) =𝐱log1(2πσ2)dexp{𝐱𝐱22σ2}
=𝐱{𝐱𝐱22σ2log(2πσ2)d}
=𝐱𝐱σ2=𝐳σ2.

因此,去噪分数匹配的损失函数变为

JDSM(𝜽) =def𝔼q(𝐱,𝐱)[12𝐬𝜽(𝐱)𝐱q(𝐱|𝐱)2]
=𝔼q(𝐱)[12𝐬𝜽(𝐱+σ𝐳)+𝐳σ22].

如果我们将虚拟变量 𝐱 替换为 𝐱,并且注意到当给出训练数据集时,从 q(𝐱) 中采样可以替换为从 p(𝐱) 中采样,我们可以得出以下结论。 去噪分数匹配的损失函数定义为 JDSM(𝜽)=𝔼p(𝐱)[12𝐬𝜽(𝐱+σ𝐳)+𝐳σ22] (85)

等式 (85) 的优点在于它非常易于解释。 𝐱+σ𝐳 实际上是在干净图像 𝐱 上添加噪声 σ𝐳 评分函数 𝐬𝜽 应该获取该噪声图像并预测噪声 𝐳σ2 预测噪声相当于去噪,因为任何去噪图像加上预测噪声都会给我们带来噪声观测结果。 因此,等式 (85) 是一个 去噪 步骤。 21 说明了得分函数 𝐬𝜽(𝐱) 的训练过程。

Refer to caption
图 21: 用于去噪分数匹配的 𝐬𝜽 训练。 网络𝐬𝜽经过训练来估计噪声。

训练步骤可以简单地描述如下:你给我们一个训练数据集{𝐱()}=1L,我们训练一个网络𝜽,目标是

𝜽=argmin𝜽1L=1L12𝐬𝜽(𝐱()+σ𝐳())+𝐳()σ22,where𝐳()𝒩(0,𝐈). (86)

这里更大的问题是为什么等式 (84) 从一开始就说得通。 这需要通过显式分数匹配损失和去噪分数匹配损失之间的等价来回答。

定理 [Vincent [9]] 对于直到与变量𝜽无关的常数C,它成立 JDSM(𝜽)=JESM(𝜽)+C. (87)

显式分数匹配和去噪分数匹配之间的等价性是一个重大发现。 下面的证明基于 Vincent 2011 的原作。

等式 (87) 的证明 我们从显式得分匹配损失函数开始,它由 JESM(𝜽) =𝔼q(𝐱)[12𝐬𝜽(𝐱)𝐱logq(𝐱)2] =𝔼q(𝐱)[12𝐬𝜽(𝐱)2𝐬𝜽(𝐱)T𝐱logq(𝐱)+12𝐱logq(𝐱)2=defC1,independent of 𝜽]. 让我们放大到第二项。 我们可以证明 𝔼q(𝐱)[𝐬𝜽(𝐱)T𝐱logq(𝐱)] =(𝐬𝜽(𝐱)T𝐱logq(𝐱))q(𝐱)𝑑𝐱, (expectation) =(𝐬𝜽(𝐱)T𝐱q(𝐱)q(𝐱))q(𝐱)𝑑𝐱, (gradient) =𝐬𝜽(𝐱)T𝐱q(𝐱)𝑑𝐱. 接下来,我们考虑通过回忆 q(𝐱)=q(𝐱)q(𝐱|𝐱)𝑑𝐱 来进行条件化。 这会给我们 𝐬𝜽(𝐱)T𝐱q(𝐱)𝑑𝐱 =𝐬𝜽(𝐱)T𝐱(q(𝐱)q(𝐱|𝐱)𝑑𝐱)=q(𝐱)d𝐱 (conditional) =𝐬𝜽(𝐱)T(q(𝐱)𝐱q(𝐱|𝐱)𝑑𝐱)𝑑𝐱 (move gradient) =𝐬𝜽(𝐱)T(q(𝐱)𝐱q(𝐱|𝐱)×q(𝐱|𝐱)q(𝐱|𝐱)𝑑𝐱)𝑑𝐱 (multiple and divide) =𝐬𝜽(𝐱)Tq(𝐱)(𝐱q(𝐱|𝐱)q(𝐱|𝐱))=𝐱logq(𝐱|𝐱)q(𝐱|𝐱)𝑑𝐱𝑑𝐱 (rearrange terms) =𝐬𝜽(𝐱)T(q(𝐱)(𝐱logq(𝐱|𝐱))q(𝐱|𝐱)𝑑𝐱)𝑑𝐱 =q(𝐱|𝐱)q(𝐱)=q(𝐱,𝐱)(𝐬𝜽(𝐱)T𝐱logq(𝐱|𝐱))𝑑𝐱𝑑𝐱 (move integration) =𝔼q(𝐱,𝐱)[𝐬𝜽(𝐱)T𝐱logq(𝐱|𝐱)]. 因此,如果我们将这个结果代入 ESM 的定义,我们可以证明 JESM(𝜽)=𝔼q(𝐱)[12𝐬𝜽(𝐱)2]𝔼q(𝐱,𝐱)[𝐬𝜽(𝐱)T𝐱logq(𝐱|𝐱)]+C1. 与 DSM 的定义进行比较,我们可以观察到 JDSM(𝜽) =def𝔼q(𝐱,𝐱)[12𝐬𝜽(𝐱)𝐱q(𝐱|𝐱)2] =𝔼q(𝐱,𝐱)[12𝐬𝜽(𝐱)2𝐬𝜽(𝐱)T𝐱logq(𝐱|𝐱)+12𝐱logq(𝐱|𝐱)2=defC2,independent of 𝜽] =𝔼q(𝐱)[12𝐬𝜽(𝐱)2]𝔼q(𝐱,𝐱)[𝐬𝜽(𝐱)T𝐱logq(𝐱|𝐱)]+C2. 因此,我们得出结论: JDSM(𝜽)=JESM(𝜽)C1+C2.

对于推理,我们假设我们已经训练了分数估计器𝐬𝜽 为了生成图像,我们对 t=1,,T 执行以下过程:

𝐱t+1=𝐱t+τ𝐬𝜽(𝐱t)+2τ𝐳t,where𝐳t𝒩(0,𝐈). (88)

恭喜! 我们完了。 这都是关于基于分数的生成模型。

有关分数匹配的其他阅读材料应从 Vincent 的技术报告 [9] 开始。 最近文献中非常流行的论文是 Song 和 Ermon [15],他们的后续工作 [16][8] 在实践中,训练评分函数需要通过考虑一系列噪声水平来制定噪声表。 当我们在下一节解释方差爆炸 SDE 时,我们将简要讨论这一点。

4 随机微分方程 (SDE)

到目前为止,我们已经通过 DDPM 和 SMLD 视角导出了扩散迭代。 在本节中,我们将从微分方程的角度介绍第三种视角。 为什么我们的迭代方案突然变成复杂的微分方程可能并不明显。 因此,在推导任何方程之前,我们应该简要讨论微分方程与我们有何关系。

4.1 激励示例

示例 1。 简单一阶常微分方程. 想象一下,我们有一个离散时间算法,其迭代由递归定义: 𝐱i=(1βΔt2)𝐱i1,fori=1,2,,N, (89) 给出,其中 β 是超参数,Δt 是步长参数。 这个递归并不复杂:您给我们𝐱i1,我们更新并返回您𝐱i 如果我们假设一个连续时间函数 𝐱(t) 的离散化方案,通过令 𝐱i=𝐱(iN)Δt=1Nt{0,1N,,N1N},那么我们可以将递归重写为 𝐱(t+Δt)=(1βΔt2)𝐱(t). 重新排列条款将给我们 𝐱(t+Δt)𝐱(t)Δt=β2𝐱(t), 其中,当 Δt0 趋于极限时,我们可以将离散方程写成常微分方程 (ODE) d𝐱(t)dt=β2𝐱(t). (90) 不仅如此,我们还可以求解 ODE 的解析解,其解由下式给出 𝐱(t)=eβ2t. (91) 如果你不相信我们,只需将等式 (91) 代入等式 (90),你就可以证明等式成立。 ODE 的强大之处在于它为我们提供了解析解决方案。 解析解不采用迭代方案(这将需要数百到数千次迭代),而是准确地告诉我们解在任何时间t的行为。为了说明这一事实,我们在下图中显示了算法定义的解 𝐱1,𝐱2,,𝐱i,,𝐱N 的轨迹。 这里,我们选择 Δt=0.1 在同一个图中,我们直接绘制任意 t 的连续时间解 𝐱(t)=exp{βt/2}。 如你所见,解析解与迭代方案预测的轨迹完全相同。 [Uncaptioned image]

我们在这个激励人心的例子中观察到两个有趣的事实:

  • 离散时间迭代方案可以写成连续时间常微分方程。 事实证明,对于任何有限差分方程,我们都可以将递归转化为 ODE。

  • 对于简单的 ODE,我们可以写出封闭形式的解析解。 更复杂的 ODE 将很难编写解析解。 但我们仍然可以使用 ODE 工具来分析解的行为。 我们还可以推导出极限解t0

示例 2:梯度下降 回想一下,(表现良好的)凸函数 f 的梯度下降算法是以下递归。 对于 i=1,2,,N,执行 𝐱i=𝐱i1βi1f(𝐱i1), (92) 对于步长参数βi 使用与之前示例相同的离散化方法,我们可以证明(通过令 βi1=β(t)Δt): 𝐱i=𝐱i1βi1f(𝐱i1) 𝐱(t+Δt)=𝐱(t)β(t)Δtf(𝐱(t)) 𝐱(t+Δt)𝐱(t)Δt=β(t)f(𝐱(t)) d𝐱(t)dt=β(t)f(𝐱(t)). (93) 右边所示的常微分方程有一个解轨迹 𝐱(t) 这个 𝐱(t) 被称为函数 f梯度流 为简单起见,我们可以使所有 tβ(t)=β 相同。然后关于这个 ODE 有两个简单的结论。 首先,我们可以证明 ddtf(𝐱(t)) =f(𝐱(t))Td𝐱(t)dt (chain rule) =f(𝐱(t))T[βf(𝐱(t))] (Eqn (93)) =βf(𝐱(t))Tf(𝐱(t)) =βf(𝐱(t))20 (norm-squares). 因此,当我们从 𝐱i1 移动到 𝐱i 时,目标值 f(𝐱(t)) 必须下降。 这与我们的预期一致,因为梯度下降算法应该随着迭代的进行而降低成本。 其次,当 t 趋于极限时,我们知道 d𝐱(t)dt0 因此,d𝐱(t)dt=f(𝐱(t)) 将意味着 f(𝐱(t))0,as t. (94) 因此,解轨迹 𝐱(t) 将逼近函数 f 的最小化点。

向前和向后更新

让我们使用梯度下降示例来说明 ODE 的另一个方面。 回到方程 (92),我们认识到递归可以等效地写成(假设 β(t)=β)):

𝐱i𝐱i1Δ𝐱=βi1βΔtf(𝐱i1)d𝐱=βf(𝐱)dt, (95)

其中连续方程在我们将 Δt0Δ𝐱0 设置为时成立。 关于这个等式有趣的点是它通过用 dt 表示来为我们提供更新 Δ𝐱 的摘要。 它表明,如果我们沿着时间轴移动 dt,那么解 𝐱 将更新为 d𝐱

等式 (95) 定义了 变化 之间的关系。 如果我们考虑一系列迭代 i=1,2,,N,并且如果我们被告知迭代的进程遵循等式 (95),那么我们可以写出

(forward)𝐱i=𝐱i1+Δ𝐱i1 𝐱i1+d𝐱
=𝐱i1f(𝐱i1)βdt
𝐱i1βi1f(𝐱i1).

我们称之为 正向 方程,因为我们通过 𝐱+Δ𝐱 更新 𝐱,假设 tt+Δt

现在,考虑一个迭代序列i=N,N1,,2,1 如果我们被告知迭代的进程遵循等式 (95),那么时间反转迭代将是

(reverse)𝐱i1=𝐱iΔ𝐱i 𝐱i+d𝐱
=𝐱i+βf(𝐱i)dt
𝐱i+βif(𝐱i).

注意反转前进方向时符号的变化。 我们称之为逆向方程。

4.2 SDE 中的前向和后向迭代

扩散微分方程的概念与上面的梯度下降算法相差不远。 如果我们在梯度下降算法中引入噪声项 𝐳t𝒩(0,𝐈),那么 ODE 将变为随机微分方程 (SDE)。 为了看到这一点,我们只需遵循相同的离散化方案,将 𝐱(t) 定义为 0t1 的连续函数。 假设区间内有N个步,则区间[0,1]可以分为序列{iN|i=0,,N1} 离散化将给出我们 𝐱i=𝐱(iN)𝐱i1=𝐱(i1N) 区间步长为 Δt=1N,所有 t 的集合为 t{0,1N,,N1N} 使用这些定义,我们可以写

𝐱i =𝐱i1τf(𝐱i1)+𝐳i1
𝐱(t+Δt) =𝐱(t)τf(𝐱(t))+𝐳(t).

现在,让我们定义一个随机过程 𝐰(t),使得对于非常小的 Δt𝐳(t)=𝐰(t+Δt)𝐰(t)d𝐰(t)dtΔt 在计算中,我们可以通过积分 𝐳(t)(这是一个维纳过程)来生成这样的 𝐰(t) 通过定义 𝐰(t),我们可以写出

𝐱(t+Δt) =𝐱(t)τf(𝐱(t))+𝐳(t)
𝐱(t+Δt)𝐱(t) =τf(𝐱(t))+𝐰(t+Δt)𝐰(t)
d𝐱 =τf(𝐱)dt+d𝐰.

上面的等式揭示了 SDE 的通用形式。 我们总结如下。 前向扩散 d𝐱=𝐟(𝐱,t)driftdt+g(t)diffusiond𝐰. (96)

两项 𝐟(𝐱,t)g(t) 具有物理意义。 阻尼系数是一个向量值函数 𝐟(𝐱,t),定义了在没有随机效应的情况下封闭系统中的分子如何移动。 对于梯度下降算法,漂移由目标函数的负梯度定义。 也就是说,我们希望解轨迹遵循目标的梯度。

扩散系数 g(t) 是一个标量函数,描述了分子如何从一个位置随机走到另一个位置。 函数 g(t) 决定了随机运动的强度。

示例 考虑方程 d𝐱=ad𝐰, 其中a=0.05 迭代方案可以写为 𝐱i𝐱i1=a(𝐰i𝐰i1)=def𝐳i1𝒩(0,𝐈)𝐱i=𝐱i1+a𝐳i. 我们可以如下绘制函数𝐱i 初始点𝐱0=0标记为红色,表示该过程在时间上向前推进。[Uncaptioned image]

备注 如你所见,微分 d𝐰=𝐰i𝐰i1 被定义为维纳过程,它是一个高斯白噪声向量。 个体𝐰i不是高斯分布,但差值𝐰i𝐰i1是高斯分布。

示例 考虑方程 d𝐱=α2𝐱dt+βd𝐰, 其中 α=1β=0.1 这个方程可以写成 𝐱i𝐱i1=α2𝐱i1+β(𝐰i𝐰i1)=def𝐳i1𝒩(0,𝐈)𝐱i=(1α2)𝐱i1+β𝐳i1. 我们可以如下绘制函数𝐱i[Uncaptioned image]

扩散方程的反方向是时间向后移动。 根据Anderson[17],逆时SDE 如下所示。 反向SDE d𝐱=[𝐟(𝐱,t)driftg(t)2𝐱logpt(𝐱)score function]dt+g(t)d𝐰¯reverse-time diffusion, (97) 其中 pt(𝐱)𝐱 在时间 t 的概率分布,而 𝐰¯ 是时间反向流动时的维纳过程。

示例 考虑反向扩散方程 d𝐱=ad𝐰¯. (98) 我们可以将离散时间递归写成如下。 对于 i=N,N1,,1,执行 𝐱i1=𝐱i+a(𝐰i1𝐰i)=𝐳i=𝐱i+a𝐳i,𝐳i𝒩(0,𝐈). 下图中我们展示了这个逆时过程的轨迹。 请注意,红色标记的初始点位于𝐱N 该过程向后追踪到𝐱0[Uncaptioned image]

4.3 DDPM 的随机微分方程

为了绘制 DDPM 和 SDE 之间的联系,我们考虑离散时间 DDPM 迭代。 对于i=1,2,,N

𝐱i=1βi𝐱i1+βi𝐳i1,𝐳i1𝒩(0,𝐈). (99)

我们可以证明这个方程可以从下面的正向 SDE 方程导出。 DDPM 的前向采样方程可以写成 SDE: d𝐱=β(t)2𝐱=𝐟(𝐱,t)dt+β(t)=g(t)d𝐰. (100)

为了说明为什么是这样,我们定义一个步长 Δt=1N,并考虑一个辅助噪声级别 {β¯i}i=1N,其中 βi=β¯iN 然后

βi=β(iN)β¯i1N=β(t+Δt)Δt,

其中我们假设在N中,β¯i=β(t)0t1的连续时间函数。 同样,我们定义

𝐱i=𝐱(iN)=𝐱(t+Δt),𝐳i=𝐳(iN)=𝐳(t+Δt).

因此,我们有

𝐱i =1βi𝐱i1+βi𝐳i1
𝐱i =1β¯iN𝐱i1+β¯iN𝐳i1
𝐱(t+Δt) =1β(t+Δt)Δt𝐱(t)+β(t+Δt)Δt𝐳(t)
𝐱(t+Δt) (112β(t+Δt)Δt)𝐱(t)+β(t+Δt)Δt𝐳(t)
𝐱(t+Δt) 𝐱(t)12β(t)Δt𝐱(t)+β(t)Δt𝐳(t).

因此,当 Δt0 时,我们有

d𝐱=12β(t)𝐱dt+β(t)d𝐰. (101)

因此,我们证明了 DDPM 前向更新迭代可以等效地写为 SDE。

能够将 DDPM 前向更新迭代编写为 SDE 意味着 DDPM 估计可以通过求解 SDE 来确定。 换句话说,对于适当定义的 SDE 求解器,我们可以将 SDE 放入求解器中。 适当选择的求解器返回的解将是 DDPM 估计。 当然,我们不需要使用 SDE 求解器,因为 DDPM 迭代本身正在求解 SDE。 它可能不是最好的 SDE 求解器,因为 DDPM 迭代只是一阶方法。 尽管如此,如果我们对使用 SDE 求解器不感兴趣,我们仍然可以使用 DDPM 迭代来获得解。 这是一个例子。

示例 考虑对于所有 i=0,,N1 具有 βi=0.05 的 DDPM 前向方程。 我们通过从高斯混合中提取样本来初始化样本𝐱0,使得 𝐱0k=1Kπk𝒩(𝐱0|𝝁k,σk2𝐈), 其中 π1=π2=0.5σ1=σ2=1𝝁1=3𝝁2=3 然后,使用方程 𝐱i=1βi𝐱i1+βi𝐳i1,𝐳i1𝒩(0,𝐈), 我们可以绘制轨迹和分布如下。[Uncaptioned image]

通过代入适当的量:𝐟(𝐱,t)=β(t)2g(t)=β(t),可从方程(97) 得出反向扩散方程。 这会给我们

d𝐱 =[𝐟(𝐱,t)g(t)2𝐱logpt(𝐱)]dt+g(t)d𝐰¯
=[β(t)2𝐱β(t)𝐱logpt(𝐱)]dt+β(t)d𝐰¯,

这将为我们提供以下等式: DDPM 的逆采样方程可以写成 SDE: d𝐱=β(t)[𝐱2+𝐱logpt(𝐱)]dt+β(t)d𝐰¯. (102)

通过考虑 d𝐱=𝐱(t)𝐱(tΔt)d𝐰¯=𝐰(tΔt)𝐰(t)=𝐳(t) 可以写出迭代更新方案。 然后,令 dt=Δt,我们可以证明

𝐱(t)𝐱(tΔt) =β(t)Δt[𝐱(t)2+𝐱logpt(𝐱(t))]β(t)Δt𝐳(t)
𝐱(tΔt) =𝐱(t)+β(t)Δt[𝐱(t)2+𝐱logpt(𝐱(t))]+β(t)Δt𝐳(t).

通过将这些项分组,并假设 β(t)Δt1,我们认识到

𝐱(tΔt) =𝐱(t)[1+β(t)Δt2]+β(t)Δt𝐱logpt(𝐱(t))+β(t)Δt𝐳(t)
𝐱(t)[1+β(t)Δt2]+β(t)Δt𝐱logpt(𝐱(t))+(β(t)Δt)22𝐱logpt(𝐱(t))+β(t)Δt𝐳(t)
=[1+β(t)Δt2](𝐱(t)+β(t)Δt𝐱logpt(𝐱(t)))+β(t)Δt𝐳(t).

然后,根据离散化方案,令 t{0,,N1N}Δt=1/N𝐱(tΔt)=𝐱i1𝐱(t)=𝐱iβ(t)Δt=βi,我们可以证明

𝐱i1 =(1+βi2)[𝐱i+βi2𝐱logpi(𝐱i)]+βi𝐳i
11βi[𝐱i+βi2𝐱logpi(𝐱i)]+βi𝐳i, (103)

其中 pi(𝐱) 是在时间 i𝐱 的概率密度函数。 为了实际实现,我们可以用估计的分数函数 𝐬𝜽(𝐱i) 替换 𝐱logpi(𝐱i)

因此,我们恢复了与Song和Ermon在[8]中定义的DDPM迭代一致的DDPM迭代。 这是一个有趣的结果,因为它允许我们使用得分函数连接 DDPM 的迭代。 Song 和 Ermon [8] 将 SDE 称为方差保留 (VP) SDE。

示例 根据前面的示例,我们使用以下命令执行反向扩散方程 𝐱i1=11βi[𝐱i+βi2𝐱logpi(𝐱i)]+βi𝐳i, 其中 𝐳i𝒩(0,𝐈) 迭代的轨迹如下所示。[Uncaptioned image]

4.4 SMLD 的随机微分方程

分数匹配 Langevin Dynamics 模型也可以通过 SDE 来描述。 首先,我们注意到在 SMLD 设置中,并不存在真正的“前向扩散步骤”。 然而,我们可以粗略地认为,如果我们将 SMLD 训练中的噪声尺度划分为 N 级别,那么递归应该遵循马尔可夫链

𝐱i=𝐱i1+σi2σi12𝐳i1,i=1,2,,N. (104)

这并不难看出。 如果我们假设 𝐱i1 的方差为 σi12,那么我们可以证明

Var[𝐱i] =Var[𝐱i1]+(σi2σi12)
=σi12+(σi2σi12)=σi2.

因此,给定一系列噪声水平,方程式 (104) 将确实生成估计值 𝐱i,以使噪声统计量满足所需的属性。

如果我们同意方程式 (104),那么很容易推导出与方程式 (104) 相关的 SDE。 假设在极限 {σi}i=1N 成为 0t1 的连续时间 σ(t),并且 {𝐱i}i=1N 成为 𝐱(t),其中 𝐱i=𝐱(iN) 如果我们令 t{0,1N,,N1N} 然后我们有

𝐱(t+Δt) =𝐱(t)+σ(t+Δt)2σ(t)2𝐳(t)
𝐱(t)+d[σ(t)2]dtΔt𝐳(t).

在极限 Δt0 时,方程收敛到

d𝐱=d[σ(t)2]dtd𝐰.

我们将结果总结如下。 SMLD 的前向采样方程可以写成 SDE: d𝐱=d[σ(t)2]dtd𝐰. (105) 将其映射到方程式 (96),我们认识到

𝐟(𝐱,t)=0,andg(t)=d[σ(t)2]dt.

因此,如果我们写出反向方程 Eqn (97),我们应该有

d𝐱 =[𝐟(𝐱,t)g(t)2𝐱logpt(𝐱)]dt+g(t)d𝐰¯
=(d[σ(t)2]dt𝐱logpt(𝐱(t)))dt+d[σ(t)2]dtd𝐰¯.

这将为我们提供以下逆方程: SMLD 的逆采样方程可以写成 SDE: d𝐱=(d[σ(t)2]dt𝐱logpt(𝐱(t)))dt+d[σ(t)2]dtd𝐰¯. (106) 对于离散时间迭代,我们首先定义 α(t)=d[σ(t)2]dt 然后,使用与 DDPM 情况相同的一组离散化设置,我们可以证明

𝐱(t+Δt)𝐱(t) =(α(t)𝐱logpt(𝐱))Δtα(t)Δt𝐳(t)
𝐱(t) =𝐱(t+Δt)+α(t)Δt𝐱logpt(𝐱)+α(t)Δt𝐳(t)
𝐱i1 =𝐱i+αi𝐱logpi(𝐱i)+αi𝐳i (107)
𝐱i1 =𝐱i+(σi2σi12)𝐱logpi(𝐱i)+(σi2σi12)𝐳i,

这与SMLD反向更新方程相同。 Song 和 Ermon [8] 将 SDE 称为方差爆炸 (VE) SDE。

4.5求解SDE

在本小节中,我们简要讨论如何数值求解微分方程。 为了使我们的讨论稍微容易一些,我们将重点关注 ODE。 考虑以下常微分方程

d𝐱(t)dt=𝐟(𝐱(t),t). (108)

如果 ODE 是一个标量 ODE,那么 ODE 是 dx(t)dt=f(x(t),t)

欧拉方法 欧拉方法是求解 ODE 的一阶数值方法。 给定 dx(t)dt=f(x(t),t)x(t0)=x0,欧拉方法通过对 i=0,1,,N1 的迭代方案来解决问题,使得

xi+1=xi+αf(xi,ti),0,1,,N1,

其中 α 是步长。 让我们考虑一个简单的例子。

示例 [18,示例 2.2] 考虑以下 ODE dx(t)dt=x(t)+t22t+1. 如果我们应用步长为 α 的欧拉方法,那么迭代将采用以下形式 xi+1=xi+αf(xi,ti)=xi+α(xi+ti22)ti+1.

龙格-库塔(RK)方法 另一种常用的 ODE 求解器是 Runge-Kutta (RK) 方法。 经典的 RK-4 算法通过迭代求解 ODE

xi+1=xi+α6(k1+2k2+2k3+k4),i=1,2,,N,

其中数量 k1k2k3k4 定义为

k1 =f(xi,ti),
k2 =f(xi+αk12,ti+α2),
k3 =f(xi+αk22,ti+α2),
k4 =f(xi+αk3,ti+α).

详细内容可以查阅[18]等数值方法教材。

预测校正算法 由于不同的数值求解器在近似误差方面有不同的行为,因此将 ODE(或 SDE)放入现成的数值求解器将导致不同程度的误差[19] 然而,如果我们特别试图解决反向扩散方程,我们可以使用数值 ODE/SDE 求解器以外的技术来进行适当的修正,如图 22 所示。

Refer to caption
图 22: 预测和校正算法。

我们以 DDPM 为例。 在 DDPM 中,反向扩散方程由下式给出

𝐱i1=11βi[𝐱i+βi2𝐱logpi(𝐱i)]+βi𝐳i.

我们可以将其视为反向扩散的欧拉方法。 然而,如果我们已经训练了分数函数 𝐬𝜽(𝐱i,i),我们可以运行分数匹配方程,即

𝐱i1=𝐱i+ϵi𝐬𝜽(𝐱i,i)+2ϵi𝐳i,

M 次进行修正。 算法 1 总结了这个想法。 (请注意,我们已将得分函数替换为估计值。)

算法1 DDPM 的预测校正算法。
   𝐱N=𝒩(0,𝐈)
   for i=N1,,0 do
     
(Prediction)𝐱i1=11βi[𝐱i+βi2𝐬𝜽(𝐱i,i)]+βi𝐳i. (109)
      for m=1,,M do
        
(Correction)𝐱i1=𝐱i+ϵi𝐬𝜽(𝐱i,i)+2ϵi𝐳i, (110)
      end for
   end for

对于 SMLD 算法,两个方程为:

𝐱i1 =𝐱i+(σi2σi12)𝐬𝜽(𝐱i,σi)+σi2σi12𝐳 Prediction,
𝐱i1 =𝐱i+ϵi𝐱𝐬𝜽(𝐱i,σi)+ϵi𝐳 Correction.

我们可以像 DDPM 的预测校正算法一样,通过重复校正迭代几次来将它们配对。

加速 SDE 求解器 虽然通用 ODE 求解器可用于求解 ODE,但我们遇到的正向和反向扩散方程非常特殊。 事实上,它们的形式是

d𝐱(t)dt=𝐚(t)𝐱(t)+𝐛(t),𝐱(t0)=𝐱0, (111)

对于某些函数 𝐚(t)𝐛(t) 的选择,初始条件为 𝐱(t0)=𝐱0 这不是一个复杂的 ODE。 它只是一阶 ODE。 [20] 中,Lu 等人观察到,由于 ODE 的特殊结构(他们称之为半线性结构),可以分别处理 𝐚(t)𝐱(t)𝐛(t) 为了理解事情是如何运作的,我们使用如下所示的教科书结果。 定理 [常数的变化]([21,定理 1.2.3])。 考虑 [s,t] 范围内的 ODE: dx(t)dt=a(t)x(t)+b(t),wherex(t0)=x0. (112) 解由下式给出 x(t)=x0eA(t)+eA(t)t0teA(τ)b(τ)𝑑τ. (113) 其中 A(t)=t0ta(τ)𝑑τ 我们可以通过注意到进一步简化上面的第二项

eA(t)A(τ) =et0ta(r)𝑑rt0τa(r)𝑑r=eτta(r)𝑑r.

[20] 中提出的特别有趣的是从 [8] 导出的反向扩散方程:

d𝐱(t)dt=f(t)𝐱(t)+g2(t)2σ(t)ϵ𝜽(𝐱(t),t),𝐱(t)𝒩(0,σ~2𝐈),

其中 f(t)=dlogα(t)dt,以及 g2(t)=dσ(t)2dt2dlogα(t)dtσ(t)2 利用常量变分定理,我们可以通过以下公式精确求解时间 t 的 ODE

𝐱(t)=estf(τ)𝑑τ𝐱(s)+st(eτtf(r)𝑑rg2(τ)2σ(τ)ϵ𝜽(𝐱(τ),τ))𝑑τ.

然后,通过定义 λt=logα(t)/σ(t),并在 [20] 中概述的额外简化下,这个方程可以简化为

𝐱(t)=α(t)α(s)𝐱(s)α(t)st(dλτdτ)σ(τ)α(τ)ϵ𝜽(𝐱(tau))𝑑τ.

要评估该方程,只需运行数值积分器即可进行右侧所示的积分。 当然,还有其他数值加速方法来求解 ODE,为简洁起见,我们将跳过这些方法。

恭喜! 我们完了。 这就是 SDE 的全部内容。

有些人可能想知道:为什么我们要将迭代方案映射到微分方程? 有几个原因,有些是合理的,有些是推测的。

  • 通过将多个扩散模型统一到同一个 SDE 框架,人们可以比较算法。 在某些情况下,可以通过借鉴 SDE 文献以及概率抽样文献的思想来改进数值方案。 例如,[8] 中的预测校正器方案是与马尔可夫链蒙特卡罗结合的混合 SDE 求解器。

  • 根据[22]等一些论文,将扩散迭代映射到 SDE 可以提供更大的设计灵活性。

  • 在上下文扩散算法之外,一般随机梯度下降算法都有相应的 SDE,例如 Fokker-Planck 方程。 人们已经演示了如何以精确的封闭形式从理论上分析估计值的极限分布。 这减轻了通过分析明确定义的极限分布来分析随机算法的难度。

5结论

本教程涵盖了最近文献中支持基于扩散的生成模型的开发的一些基本概念。 考虑到文献数量巨大(并且正在迅速扩大),我们发现描述基本思想而不是重复使用 Python 演示尤为重要。 我们从编写本教程中学到的一些教训是:

  • 同一个扩散思想可以从多个角度独立推导,即VAE、DDPM、SMLD和SDE。 尽管有些人可能有不同的争论,但没有特别的理由说明为什么一个人比另一个人更优越/更差。

  • 去噪扩散起作用的主要原因是其增量很小,这在 GAN 和 VAE 时代是无法实现的。

  • 尽管迭代去噪是当前最先进的技术,但该方法本身似乎并不是最终的解决方案。 人类不会从纯粹的噪声中生成图像。 此外,由于扩散模型的增量性质较小,尽管已经在知识蒸馏方面做出了一些努力来改善这种情况,但速度仍将是一个主要障碍。

  • 关于从非高斯生成噪声的一些问题可能需要论证。 如果引入高斯分布的全部原因是为了使推导变得更容易,那么为什么我们要通过让我们的生活变得更加困难而转向另一种类型的噪声呢?

  • 扩散模型在反问题中的应用是很容易实现的。 对于任何现有的逆解算器,例如即插即用 ADMM 算法,我们可以用显式扩散采样器替换降噪器。 人们已经证明了基于这种方法改进的图像恢复结果。

参考