Theme NexT works best with JavaScript enabled

学习飞翔的企鹅

What's the point in living if I have to hide ?

0%

Neural ODEs

因为想学习Neural ODE视角下的diffusion方法,所以先从入门Neural ODE开始。本篇文章从ODE的定义开始,通过Euler数值法求解ODE与ResNet架构之间的关系引出Neural ODE,并推导其训练所需的reverse-time算法。

一、Ordinary Differential Equation (ODE)

常微分方程(ordinary differential equation, ODE)是未知函数只含有一个自变量的微分方程,例如简单的一阶常微分方程有以下形式: y(t)=f(t,y),y(t0)=y0 其中y(t)表示y是以t为自变量的函数,y(t)是其导数,通常会需要给出某点的初值y0,才能解出y(t)。当f(x,y)取某些特定形式的时候,我们可以得到y(t)的解析解,例如f(t,y)=y(t),我们知道y(t)=Aet是满足条件的,其中A为待定系数,通过代入给定的初值即可得到。但是大多数时候,f(t,y)比较复杂, 无法求得解析解。幸运的是,很多场景中我们的目标只是能够在给定t时获得y(t)的取值即可,因此只需要数值解。在给定初值时,任意一点t处的y(t)取值可以由积分得到: y(t)=y(t0)+t0ty(t)dt=y0+t0tf(t,y)dt

二、Euler法与ResNet

ODE的数值解法是已经被广泛探索的领域,本质就是各种求t0tf(t,y)dt的数值方法,但之所以想到用ODE来进行数据建模,还是因为最基本的Euler法。Euler法把积分区间[t0,t]平均分成若干片段,把积分过程变为在切线方向上逐步前进的过程(类似将区间分割为若干足够小的分区间求和的数值积分方法)。

euler

令片段长度为δ,则tn+1=tn+δ,由一阶近似方法可以得到: y(tn+1)=y(tn)+(tn+1tn)y(tn)=y(tn)+δf(t,y(tn)) 显然,h越小,即分割的片段越多,则得到的结果越精确。而如果我们仔细看递推表达式,会发现与ResNet的定义形式非常相似: Euler Method: y(tn+1)=y(tn)+δf(t,y(tn))ResNet: ht+1=ht+f(θt,ht) 其中ht为第t层的隐状态,f是前向网络,θt为第t层的网络权重。如此相似的形式告诉我们ResNet一定程度上可以看作是在用Euler法求解ODE初值问题,其中积分域t是离散的,在ResNet中代表网络的深度。但有一点不同的是,ResNet每层的权重不一样,对应的ODE中的f每层都不同,因此只能说是形式上的相近。

三、Neural ODE

1. 定义

那很自然的,我们会想到用ODE来做更高层次的抽象,用神经网络来表达f(t,y(tn)),或者说f(t,h(tn),θ),则整个过程可以表示为: y=ODESolver(h(t0),f,t0,t1,θ),h(t0)=embed(x) 其中h(t0)为初值条件,是根据输入x变换得到的(e.g. embedding过程),t0t1是积分上下限的超参(但是也可以做成可训练的),θ指模型f的参数。而ODESolver则可以使用各种各样的求解器,前文所述的Euler法就是最基础的一种。使用ODE的视角做建模有两个非常显著的优势:

  • Memory efficient:ODE视角只需要用一个网络来建模导数,因此只需要一组参数,而ResNet这样的则是每层都有独立的参数。另外,在训练中,ODE的reverse算法与前向算法一致,都只需求解一次ODE,因此也不需要存储中间状态(下文再展开),可以大幅减少训练的内存需求。
  • Adaptive computation:很多先进的ODESolver都能对递推步数(一定程度上可以理解为模型深度)根据f的各种性质进行动态调整,因此模型的计算复杂度能够自适应地根据问题的复杂度调整。

2. 训练

如果仍然使用backpropagation算法计算梯度来进行训练的话,内存占用会非常大,因为需要记录每一步的中间输出,而且传导过程可能与具体采用的ODESolver算法有关。而伴随灵敏度方法(adjoint sensitivity method)可以很优雅地解决这个问题,它将梯度计算的流程划归到与前向计算一致,都只需求解一次ODE,因此不仅内存占用小(不用存中间态),同时也与ODESolver的算法选择解藕。

Adjoint sensitivity method

整个过程的大致想法是通过伴随状态(adjoint state)将计算梯度的过程也变为求解ODE的过程,再将伴随状态的初值与原问题的初值一起输入到ODESolver中,即可在前向求解过程中将梯度与结果一同输出。

对初值的梯度

令最终的loss为L,我们先考虑如何求解Lh(t0)(该梯度会用于对输入变换g的训练)。首先定义伴随状态(adjoint state): a(t)=Lh(t) 接下来我们需要得到关于a(t)的ODE。当t产生δ的变化时,我们有: a(t)=Lh(t+δ)h(t+δ)h(t)=a(t+δ)h(t+δ)h(t) 根据h(t+δ)=h(t)+tt+δf(s,h(s),θ)ds,我们可以进一步得到: a(t)=a(t+δ)h(t)(h(t)+tt+δf(s,h(s),θ)ds)=a(t+δ)[1+h(t)(tt+δf(s,h(s),θ)ds)] 因此: a(t)=limδ0+a(t+δ)a(t)δ=limδ0+a(t+δ)h(t)(tt+δf(s,h(s),θ)ds)δ=a(t)f(t,h(t),θ)h(t) 由于a(t1)即为L关于网络最终输出h(t1)的梯度,是很容易求的,因此可以作为ODE问题的初值,从而得到: a(t)=a(t)f(t,h(t),θ)h(t),a(t1)=Lh(t1) 则对于h(t0)的梯度为: a(t0)=Lh(t0)=a(t1)+t1t0a(t)f(t,h(t),θ)h(t)dt θ以及t的梯度

类似的,我们定义对于权重θ以及t的伴随状态: aθ(t)=Lθ(t)at(t)=Lt 我们认为权重不随时间而变化,因此有: dθ(t)dt=0,dtdt=1 接下来的想法非常巧妙,我们之前已经分析了如何通过a(t)求解Lh(t),而f(t,h(t),θ)输出的是h(t),那如果我们对f的输出进行扩增,使其同时也能输出θ(t)(t),就可以直接把求解Lh(t)的过程中一元微积分的内容替换成多元微积分,从而得到结果。即: [dh(t)dt,dθ(t)dt,dtdt]=faug(t,h(t),θ)=[f(t,h(t),θ),0,1] 类似的,对伴随状态进行扩增: aaug(t)=[a(t),aθ(t),at(t)] 扩增后的向量可以看成只是对之前的向量增加了若干维度,因此仍然可以用之前推导的结果: aaug(t)=aaug(t)faug[h(t),θ(t),t]=[a(t),aθ(t),at(t)][fh(t)fθ(t)ft0h(t)0θ(t)0t1h(t)1θ(t)1t]=[a(t),aθ(t),at(t)][fh(t)fθ(t)ft000000]=[a(t)fh(t),a(t)fθ(t),a(t)ft] 而初值是: aaug(t1)=[a(t1),aθ(t1),at(t1)]=[Lh(t1),Lθ(t1),Lt1]=[Lh(t1),0,a(t1)f(t1,h(t1),θ)] 其中θ项的初值为0是因为θ(t1)并未参与计算(θ(t0)用于计算h(t0+δ))。这里t的初值在论文中是带负号的,但感觉似乎不应该带负号,目前还没搞明白带负号的原因,github上也有一个issue讨论这个问题,目前没有讨论结果。

综合

将以上综合起来,我们只需要求解augmented后的ODE即可得到梯度: [Lh(t0),Lθ(t0),Lt0]=aaug(t0)=ODESolver(aaug(t1),faug,t1,t0,θ) 之后就可以使用正常的训练过程了。

参考文献

[1] Neural ODEs

[2] Understanding Neural ODEs

[3] Neural Ordinary Differential Equations (NeurIPS 2018)

[4] Understanding Adjoint Method of Neural ODE

Powered By Valine
v1.5.2