[pdf][supp]

这篇论文发布在18年的NeurIPS上,内容偏理论,比较晦涩难懂,借助Youtube上一些比较优质的讲解视频,我总算是搞懂了这篇论文的核心思想,至于其他的部分,因为视频中并没有讲解,可以认为不是这篇论文的核心,所以等有需要时再来研究

Method

首先我们来看普通的ResNet,如图

输入一个xx,由网络获得一个输出F(x)\mathcal F(x),最终获得的结果为F(x)+x\mathcal F(x)+x。本文中,为了与论文符号保持一致,我们用ff表示函数F\mathcal F,用zz表示输入xx

这是ResNet其中的一层,而对于一个多层的ResNet,我们用tt来表示层数,其公式可以表示为

zt+1=zt+f(zt,θt)z_{t+1}=z_t+f(z_t,\theta_t)

ztz_t是第tt层的输入,该层网络可以认为是一个黑匣子,由ResNet模拟了一个函数ff,这个函数的输出f(zt,θt)f(z_t,\theta_t),与该层的输入ztz_t和该层的网络参数θt\theta_t有关

从公式我们看到,zz在不同的层具有不同的输出,那我们是否可以将zz看作一个关于参数tt的函数z(t)z(t),此时,以上公式就变成了一个常微分方程

limε0z(t+ε)z(t)=dz(t)dt=f(z(t),t,θ)\lim_{\varepsilon\rightarrow0}z(t+\varepsilon)-z(t)=\frac{dz(t)}{dt}=f(z(t),t,\theta)

该方程的解就是z(t)z(t),此时的tt将不限于整数,而是任意一个实数,获得输出的过程如下

假设初始时刻为t0t_0,则初始的输入为z(t0)z(t_0),而我们需要t1t_1时刻的输出,那么获得输出并计算loss的方法如下

L(z(t1))=L(z(t0)+t0t1f(z(t),t,θ)dt)=L(ODESolve(z(t0),f,t0,t1,θ))L(z(t_1))=L\left(z(t_0)+\int_{t_0}^{t_1}f(z(t),t,\theta)dt\right)=L(ODESolve(z(t_0),f,t0,t1,\theta))

如果需要更多的输出,我们可以增加t2,t3,...t_2,t_3,...

Optimize

方法与之前不同,优化方法也与之前不一样,首先来看传统的ResNet的优化方式,当我们获得了loss LL,我们需要计算的LL关于参数θ\theta的梯度

以从ttt+1t+1为例,已知

zt+1=zt+f(zt,θ)z_{t+1}=z_t+f(z_t,\theta)

则反向传播为

Lzt=Lzt+1zt+1zt=Lzt+1zt(zt+f(zt,θ))=Lzt+1(1+f(zt,θ)zt)=Lzt+1+Lzt+1f(zt,θ)zt\begin{aligned} \frac{\partial L}{\partial z_t} &=\frac{\partial L}{\partial z_{t+1}}\frac{\partial z_{t+1}}{\partial z_t}\\ &=\frac{\partial L}{\partial z_{t+1}}\frac{\partial}{\partial z_t}(z_t+f(z_t,\theta))\\ &=\frac{\partial L}{\partial z_{t+1}}(1+\frac{\partial f(z_t,\theta)}{\partial z_t})\\ &=\frac{\partial L}{\partial z_{t+1}}+\frac{\partial L}{\partial z_{t+1}}\frac{\partial f(z_t,\theta)}{\partial z_t} \end{aligned}

LL关于参数θ\theta的梯度为

Lθ=Lzt+1zt+1θ=Lzt+1θ(zt+f(zt,θ))=Lzt+1f(zt,θ)θ\begin{aligned} \frac{\partial L}{\partial \theta} &=\frac{\partial L}{\partial z_{t+1}}\frac{\partial z_{t+1}}{\partial \theta}\\ &=\frac{\partial L}{\partial z_{t+1}}\frac{\partial}{\partial \theta}(z_t+f(z_t,\theta))\\ &=\frac{\partial L}{\partial z_{t+1}}\frac{\partial f(z_t,\theta)}{\partial \theta} \end{aligned}

而当zz是连续的时,反向传播是无法用链式法则计算的,或者说,使用链式法则的准确度不高,而且需要消耗巨大内存,因此使用了另一个微分方程求解以计算梯度

首先,另a(t)=Lz(t)a(t)=\frac{\partial L}{\partial z(t)},则有结论如下

Forward:

z(t+1)=z(t)+tt+1f(z(t),t,θ)dtz(t+1)=z(t)+\int_{t}^{t+1}f(z(t),t,\theta)dt

Backword:

a(t)=a(t+1)+t+1ta(t)f(z(t),t,θ)z(t)dta(t)=a(t+1)+\int_{t+1}^{t}-a(t)\frac{\partial f(z(t),t,\theta)}{\partial z(t)}dt

Params:

Lθ=tt+1a(t)f(z(t),t,θ)θdt\frac{\partial L}{\partial \theta}=\int_{t}^{t+1}a(t)\frac{\partial f(z(t),t,\theta)}{\partial \theta}dt

以下为证明:

首先需要证明的是

da(t)dt=a(t)f(z(t),t,θ)z(t)\frac{da(t)}{dt}=-a(t)\frac{\partial f(z(t),t,\theta)}{\partial z(t)}

过程如下,首先设一个很小的值ε\varepsilon,且

z(t+ε)=tt+εf(z(t),t,θ)dt+z(t)=Tε(z(t),t)z(t+\varepsilon)=\int_{t}^{t+\varepsilon}f(z(t),t,\theta)dt+z(t)=T_{\varepsilon}(z(t),t)

那么根据链式法则可以得到

a(t)=a(t+ε)Tε(z(t),t)z(t)a(t)=a(t+\varepsilon)\frac{\partial T_{\varepsilon}(z(t),t)}{\partial z(t)}

根据定义计算a(t)a(t)的导数

da(t)dt=limε0+a(t+ε)a(t)ε=limε0+a(t+ε)a(t+ε)z(t)Tε(z(t))ε=limε0+a(t+ε)a(t+ε)z(t)(z(t)+εf(z(t),t,θ)+O(ε2))ε=limε0+a(t+ε)a(t+ε)(I+εf(z(t),t,θ)z(t)+O(ε2))ε=limε0+a(t+ε)f(z(t),t,θ)z(t)+O(ε2)ε=a(t)f(z(t),t,θ)z(t)\begin{aligned} \frac{d a(t)}{dt} &=\lim_{\varepsilon\rightarrow 0^+}\frac{a(t+\varepsilon)-a(t)}{\varepsilon}\\ &=\lim_{\varepsilon\rightarrow 0^+}\frac{a(t+\varepsilon)-a(t+\varepsilon)\frac{\partial}{\partial z(t)}T_{\varepsilon}(z(t))}{\varepsilon}\\ &=\lim_{\varepsilon\rightarrow 0^+}\frac{a(t+\varepsilon)-a(t+\varepsilon)\frac{\partial}{\partial z(t)}(z(t)+\varepsilon f(z(t),t,\theta)+\mathcal O(\varepsilon^2))}{\varepsilon}\\ &=\lim_{\varepsilon\rightarrow 0^+}\frac{a(t+\varepsilon)-a(t+\varepsilon)(I+\varepsilon\frac{\partial f(z(t),t,\theta)}{\partial z(t)}+\mathcal O(\varepsilon^2))}{\varepsilon}\\ &=\lim_{\varepsilon\rightarrow 0^+}-a(t+\varepsilon)\frac{\partial f(z(t),t,\theta)}{\partial z(t)}+\frac{\mathcal O(\varepsilon^2)}{\varepsilon}\\ &=-a(t)\frac{\partial f(z(t),t,\theta)}{\partial z(t)} \end{aligned}

此时可以认为

Lθ=Lz(t)z(t)θ=t1t0a(t)f(z(t),t,θ)θdt\frac{\partial L}{\partial\theta}=\frac{\partial L}{\partial z(t)}\frac{\partial z(t)}{\partial \theta}=\int_{t_1}^{t_0} a(t)\frac{\partial f(z(t),t,\theta)}{\partial \theta}dt

文中的图可以很好地描述反向传播的过程

对于每一个loss都使用ODE求解器往前计算a(t)a(t)的值,具体的操作流程如下

总结

第一次读这篇文章时惊为天人,可以说是一点也没有读懂,但是体会到了具体的思路,后来才借助了一些讲解视频逐渐搞懂。不得不感叹作者十分厉害,不仅提出了新型的网络,还巧妙地使用了另一个ODE求解器解决了反向传播和训练的问题,十分的优雅

因为是为了理解另一篇文章才读的,当时有种感觉,那篇文章可能和talk-to-edit里图像编辑的思路很相近。后来果然发现,上述两篇文章的思想其实如出一辙,都是nonlinear image editing。而NueralODE确实十分适合解决这个问题,并且它适合的场景不限制于人脸编辑,比talk-to-edit中的方法要更强大。目前在做对于两者的人脸图像编辑的对比实验,找一找其中存在的问题