最近看的论文需要用到DDPM和DDIM,没把两个都看完确实不太好理解,上周看懂了DDPM,所以打算先写一下,然后看完DDIM再写论文的汇报,原本昨天能写完来着,但是过程中发现了几个想不通的点想了很久,不得已拖到今天
前向过程
Diffusion具有前向和逆向两个过程,前向过程中,给定真实图像x0∼q(x),经过T步添加噪声的过程,变成标准的高斯噪声。在这个过程中,每个时刻t只与它的前一时刻t−1有关,所以可以看作马尔可夫过程,也就是说,对于每一时刻都有
q(xt∣xt−1)=q(xt∣x0:t−1)
那么整个过程根据贝叶斯公式可以写成:
q(xt∣xt−1)=N(xt;1−βtxt−1,βtI),q(x1:T∣x0)=t=1∏Tq(xt∣xt−1)
其中βt是设定好的超参数,那么整个加噪声的过程是确定的,根据公式
xt=1−βtxt−1+βtϵ,ϵ∼N(0,I)
另αt=1−βt,且αˉt=∏s=1tαs,经过递推可以获得
xt=αˉtx0+1−αˉtϵˉ,ϵˉ∼N(0,I)
根据这个公式,因为α都是确定的值,所以只要我们确定了ϵˉ,那我们就能根据x0获得任意xt,并且x0,xt,,ϵˉ三者只要知道其中两个就可推出第三个。而因为αt<0,所以这个过程中t越来越大,xt越接近标准高斯噪声
关于为什么每一步中均值要乘上1−βt,知乎专栏的作者是这样解释的
“一开始笔者一直不清楚为什么Eq(1)中diffusion的均值每次要乘上1−βt。明明βt只是方差系数,怎么会影响均值呢?替换为任何一个新的超参数,保证它<1,也能够保证值域并且使得最后均值收敛到0(但是方差并不为1). 然而通过Eq(3)(4),可以发现当T→∞,xt∼N(0,I)。所以1−βt的均值系数能够稳定保证xT最后收敛到方差为1的标准高斯分布,且在Eq(4)的推导中也更为简洁优雅。(注:很遗憾,笔者并没有系统地学习过随机过程,也许1−βt就是diffusion model前向过程收敛到标准高斯分布的唯一解,读者有了解也欢迎评论)”
关于这个解释我有个地方不太理解,为什么更换超参数后,均值收敛到0之后,方差并不为1,我的理解是,应该是方差不一定为1。我一开始也有所疑惑,为什么要让q(xt∣xt−1)的分布的均值和方差分别为1−βtxt−1和βt,这个设定对于最终的目的而言并不是必须的,但是经过思考,我认为更好的解释是,它能够保证均值收敛为0和方差收敛为1,是同时达到的,如果更换其他超参数,两者必不可能同时达成,因为在实际计算过程中,T是可能无限大的,所以这样设置会使计算更加稳定,避免出现意料之外的问题。
逆向过程
前向过程是一个添加噪声的过程,它作为一个前提,认为图片已经经过了前向过程成为了标准高斯噪声,那么我们的目的就是从标准高斯分布中恢复出图片,即我们的逆向过程
逆向过程的推导十分复杂,罗列了各种乱七八糟的公式,一开始看的确实眼花缭乱,但是经过这两天的思考,发现它的本质并不复杂
首先我们先从标准高斯分布采样一个xT,我们希望获得q(xt−1∣xt)的分布,这样就可以一路回推回到x0。但是这个分布仅通过当前条件是不可知的,即只有xt,不知道xt−1推到xt时添加的噪声,是无法知道xt−1是多少的。所以我们用网络来模拟这个分布
pθ(xt−1∣xt)=N(xt;μθ(xt,t),Σθ(xt,t))
训练
网络的目标有了,我们要思考该怎么优化,在训练过程中,我们除了xt,还可以获得x0,因为x0推到xt的过程中所有参数都是已知的,那么通过两者是可以推出xt−1的,贝叶斯公式也可以证明并解出这个分布
q(xt−1∣xt,x0)=N(xt−1;μ~t(xt,x0),β~tI)
其中
μ~t(xt,x0)=1−αˉtαˉt−1βtx0+1−αˉtαˉt(1−αˉt−1)xt,β~t=1−αˉt1−αˉt−1βt
我们发现,在采样过程也就是模型的使用时,该分布的均值和方差中,只有x0是未知的,也就是说我们只要使用网络去模拟x0,就可以获得与这个分布pθ(xt−1∣xt,x0)=N(xt−1;μ~t(xt,x0(xt,t)),β~tI),从而一步步推导出x0。这个思路只能用逆天来形容,为了使其看上去不那么逆天,我们使用之前获得的公式xt=αˉtx0+1−αˉtϵˉ来将x0替换成xt,ϵˉ
μ~t(xt,ϵˉ)=αt1(xt−1−αˉtβtϵˉ)
这样我们就可以用网络只去模拟ϵˉ了,即
μθ(xt,t)=αt1(xt−1−αˉtβtϵθ(xt,t))
到此为止整个过程已经很清晰了,从xt推出xt−1只需要预测出ϵˉ就行了,那么loss也就很好设计,计算ϵˉ和ϵθ(xt,t)的MSE就行了
∣∣ϵˉ−ϵθ(xt,t)∣∣2
此时我认为训练方案有两个,一个是从标准高斯分布任取xT,然后一步一步推到x0,每一步有xt可以计算出对应的ϵˉ,即可计算出对应的loss。但是这种训练策略感觉会有训练时间长,难以收敛的问题,而且必须先规定好变成标准高斯分布的时间步T,更进一步说可能根本没有一个变成标准高斯噪声的时间步,或者每张图片都有不同的T,所以这种训练策略是不合适的。
另一种就是文章中使用的训练策略,公式如下
Ex0,ϵ∣∣ϵˉ−ϵθ(αˉtx0+1−αˉtϵˉ,t)∣∣2
任取ϵˉ∼N(0,I),计算出对应的xt,然后和t一起输入到网络中预测出来的噪声要与ϵˉ相同
训练结束后,我们获得分布N(αt1(xt−1−αˉtβtϵθ(xt,t)),σt),使用以下公式即可推出xt−1
xt−1=αt1(xt−1−αˉtβtϵθ(xt,t))+σtz
z也是符合标准高斯分布的噪声,σt是方差,文中用的是q(xt−1∣xt,x0)的方差
总结
以上所写其实并不是DDPM的准确推导过程,而是我经过思考,想出的不借助复杂的公式,该如何理解Diffusion。总结来说,我给图片加上了一个噪声,Diffusion就要使用加噪声之后的图片,知道我加了什么噪声,但是通篇看来,我还有两个问题没有解决。
首先,为什么不直接从xt推到x0,通过loss可知,我们模型预测的噪声ϵθ(xt,t)可以认为就是从x0推到xt所加上的噪声,那么为什么不直接使用公式xt=αˉtx0+1−αˉtϵˉ,直接将xt还原回x0,而是一步一步的往前推,我猜测可能是效果不好之类的原因,需要我尝试后才能知道结果
其次,在逆推的过程中,明明也添加了噪声z,为什么噪声却越来越少呢,而且正向过程中的噪声是被网络预测出来的,是已经确定的了,而逆推过程中反而添加了不确定的部分z,这不是使整个逆推变成了一个不确定的过程了,这是为什么我还是没有想通