Transformer 是 CNN 是 GNN 是 RNN,Attention is all you need!

现在回头看 17 年那句 Attention is all you need,真是神预言,Transformer 模型从自然语言处理机器翻译开始,先是慢慢的影响(18 年中毕业论文我都还 LSTM),然后席卷整个 NLP 领域,特别是 BERT 出来后,其他啥的都丢一边去了,等 NLP 影响一圈,又开始辐射其他领域,比如 CV 领域也用得多起来,比如之前目标检测那篇。

无疑 Transformer 能传播开来,与它模型本身出色的 Flexibility 和 Scalability 都密切相关。Flexibility 让他能处理各种任务,而 Scalability 又能让它能充分撬动起不断增长的计算资源。因为受欢迎程度,大家也开始慢慢研究它本身,从各种视角看它了。

正如之前这篇2020-0713里面说的一句,研究范式就是

1)首先,大佬发明新方法或架构;
2)之后,有新锤子了,大家都换上新锤子,开始到处找钉子,这也就是所谓的低垂果实;
3)接着,好找的钉子都没了,大家就开始分析锤子本身,你看这锤又大又硬,但还是有一些缺点,我们还能怎么改改;
4)发布新的数据集,更好测试。

现在是大家 纷纷研究分析锤子本身的时候 ,你看这锤又大又硬,这一看,怎么还跟那啥有点像。

诶,跟那 CNN 还有点像

这个说法主要来源这篇 ICLR 666 论文: On the Relationship between Self-Attention and Convolutional Layers

该论文主要是证明了,当一个多头自注意力机制层在加上一个适当的相对位置编码的时候,可以表示任何卷积层。

其中多头是为了能关注类似 CNN 感知域中的每个像素位置,所以一个头数为 N 的多头自注意力可以模拟感知域为的 CNN 卷积核。

然后相对位置编码主要是为了保证 CNN 里的 Translation Invariance 性质。

虽然这篇明确指出,但最早察觉到 CNN 和 Transformer 关系时候,还并不是这篇,是在读 Dynamic CNN 那篇, Pay Less Attention with Lightweight and Dynamic Convolutions . 当时复现了 Dynamic CNN 后发现里面有些性质和 Transformer 很像。

比如说它里面用到的卷积核不是固定的,而是一个动态的,每个时序卷积时用的卷积核都是该时序词向量的函数,每个时序会算出一个类似局部 attention 的卷积核,来对局部词向量进行加权求和。

诶,跟那 GNN 也有点像

主要说法来源这篇博客:Transformers are Graph Neural Networks(https://graphdeeplearning.github.io/post/transformers-are-gnns/)

先来看看 GNN 的主要结构,对于一张有节点和边的图

算某个节点的特征表示时,是通过 neighbourhood aggregation 搜集相邻节点特征来更新自身表示,从而能学习到图上的局部结构。而和 CNN 类似,只要叠个几层,就能慢慢将学习范围扩大,传播至整张图。

所以最基本的计算形式是下面这个公式,和 都是学习参数,然后 是计算节点,是临近节点。式子里面的 就算是这里面的一个 Aggregation 函数,也可以用各种其他方式来做。最后还有个 非线性.

于是表示成图片是

这时将 Transformer 的自注意力机制用类似图表示出来

已经很类似了。不同之处在于,对原来最基础 GNN 里简单的加和 aggregation 函数进行了复杂化, 变成了基于注意力机制的,带权重的加和 ,计算过程也不是通过单纯学习一个矩阵,而是通过和临近节点的点乘(或其他一些方式)。

于是这样看来,在 NLP 里面自注意力机制就相当于把句子看成是一个全连接图,每个词都和句子中其他词相邻,因此计算时,将所有词都考虑进来。正因为这种全连接性质,才导致 Transformer 一个被诟病的性质,一旦增加句子长度,计算复杂度是成 O(N^2) 增长,当句长很长时这个代价是非常大的。

当然最近也有很多论文,比如 Reformer, Linformer 等等尝试解决这个问题。

而对于 Transformer 中的有些研究,想做 Local Attention,其实就相当于加入了先验认为,句子中相近的词算是相邻节点。

诶,怎么还和 RNN 有点像

最后这个来自最近这篇论文, Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention .

直接了当,Transformer 是 RNN, Q.E.D.

开玩笑,翻开论文会发现前面一大半都没说哪里像 RNN,反而一大片写如何优化 Transformer,给计算复杂度降到线性。包括用核函数来简化 Attention 计算过程,还有替换掉 Softmax. 但读到后面才发现,正因为这些优化特别是核函数优化,才能在最后处理 Causal Masking 时推出,与 RNN 计算过程的相似。

从头捋一遍就是,对于一般通用的 Attention 形式它的公式如下

熟悉的 QKV,Q 和 K 利用相似(sim)函数计算出一个分数,之后除以分数总和获得 Attention 权重,之后去 V 里取值。

接着用核函数,对 sim 函数进行处理,再利用结合律,将 Q 运算提出来

之后引入两个简化表示

整个方程就变成了

于是在 Causal Masking 时,对于 S 和 Z 每个时间步,之前时间步计算结果都能复用,所以可将 S 和 Z 看做是两个隐状态,能在每个时间步根据当前输入来更新自己的状态,而这恰恰就是 RNN 的主要思想。