因为想学习Neural ODE视角下的diffusion方法,所以先从入门Neural
ODE开始。本篇文章从ODE的定义开始,通过Euler数值法求解ODE与ResNet架构之间的关系引出Neural
ODE,并推导其训练所需的reverse-time算法。
一、Ordinary Differential
Equation (ODE)
常微分方程(ordinary differential equation,
ODE)是未知函数只含有一个自变量的微分方程,例如简单的一阶常微分方程有以下形式:
其中表示是以为自变量的函数,是其导数,通常会需要给出某点的初值,才能解出。当取某些特定形式的时候,我们可以得到的解析解,例如,我们知道是满足条件的,其中为待定系数,通过代入给定的初值即可得到。但是大多数时候,比较复杂,
无法求得解析解。幸运的是,很多场景中我们的目标只是能够在给定时获得的取值即可,因此只需要数值解。在给定初值时,任意一点处的取值可以由积分得到:
二、Euler法与ResNet
ODE的数值解法是已经被广泛探索的领域,本质就是各种求的数值方法,但之所以想到用ODE来进行数据建模,还是因为最基本的Euler法。Euler法把积分区间平均分成若干片段,把积分过程变为在切线方向上逐步前进的过程(类似将区间分割为若干足够小的分区间求和的数值积分方法)。

令片段长度为,则,由一阶近似方法可以得到: 显然,越小,即分割的片段越多,则得到的结果越精确。而如果我们仔细看递推表达式,会发现与ResNet的定义形式非常相似:
其中为第层的隐状态,是前向网络,为第层的网络权重。如此相似的形式告诉我们ResNet一定程度上可以看作是在用Euler法求解ODE初值问题,其中积分域是离散的,在ResNet中代表网络的深度。但有一点不同的是,ResNet每层的权重不一样,对应的ODE中的每层都不同,因此只能说是形式上的相近。
三、Neural ODE
1. 定义
那很自然的,我们会想到用ODE来做更高层次的抽象,用神经网络来表达,或者说,则整个过程可以表示为: 其中为初值条件,是根据输入变换得到的(e.g. embedding过程),与是积分上下限的超参(但是也可以做成可训练的),指模型的参数。而则可以使用各种各样的求解器,前文所述的Euler法就是最基础的一种。使用ODE的视角做建模有两个非常显著的优势:
- Memory
efficient:ODE视角只需要用一个网络来建模导数,因此只需要一组参数,而ResNet这样的则是每层都有独立的参数。另外,在训练中,ODE的reverse算法与前向算法一致,都只需求解一次ODE,因此也不需要存储中间状态(下文再展开),可以大幅减少训练的内存需求。
- Adaptive computation:很多先进的都能对递推步数(一定程度上可以理解为模型深度)根据的各种性质进行动态调整,因此模型的计算复杂度能够自适应地根据问题的复杂度调整。
2. 训练
如果仍然使用backpropagation算法计算梯度来进行训练的话,内存占用会非常大,因为需要记录每一步的中间输出,而且传导过程可能与具体采用的算法有关。而伴随灵敏度方法(adjoint
sensitivity
method)可以很优雅地解决这个问题,它将梯度计算的流程划归到与前向计算一致,都只需求解一次ODE,因此不仅内存占用小(不用存中间态),同时也与的算法选择解藕。
Adjoint sensitivity method
整个过程的大致想法是通过伴随状态(adjoint
state)将计算梯度的过程也变为求解ODE的过程,再将伴随状态的初值与原问题的初值一起输入到中,即可在前向求解过程中将梯度与结果一同输出。
对初值的梯度
令最终的loss为,我们先考虑如何求解(该梯度会用于对输入变换的训练)。首先定义伴随状态(adjoint
state): 接下来我们需要得到关于的ODE。当产生的变化时,我们有: 根据,我们可以进一步得到:
因此: 由于即为关于网络最终输出的梯度,是很容易求的,因此可以作为ODE问题的初值,从而得到:
则对于的梯度为:
对以及的梯度
类似的,我们定义对于权重以及的伴随状态: 我们认为权重不随时间而变化,因此有: 接下来的想法非常巧妙,我们之前已经分析了如何通过求解,而输出的是,那如果我们对的输出进行扩增,使其同时也能输出和,就可以直接把求解的过程中一元微积分的内容替换成多元微积分,从而得到结果。即:
类似的,对伴随状态进行扩增:
扩增后的向量可以看成只是对之前的向量增加了若干维度,因此仍然可以用之前推导的结果:
而初值是: 其中项的初值为0是因为并未参与计算(用于计算)。这里的初值在论文中是带负号的,但感觉似乎不应该带负号,目前还没搞明白带负号的原因,github上也有一个issue讨论这个问题,目前没有讨论结果。
综合
将以上综合起来,我们只需要求解augmented后的ODE即可得到梯度: 之后就可以使用正常的训练过程了。
参考文献
[1] Neural
ODEs
[2] Understanding
Neural ODEs
[3] Neural Ordinary
Differential Equations (NeurIPS 2018)
[4] Understanding
Adjoint Method of Neural ODE
v1.5.2