香侬读 | Set Transformer: 处理集合型数据的Transformer

文章标题:Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks

文章作者:Juho Lee,Yoonho Lee,Jungtaek Kim,Adam R. Kosiorek,Seungjin Choi, Yee Whye Teh

收录情况:Accepted by ICML 2019

文章链接:http://proceedings.mlr.press/v97/lee19d/lee19d.pdf

导言

本文介绍了一种“顺序无关”的,类似Transformer模型,它可以像集合一样去建模无关元素顺序的数据。总的来说,本文有以下几点贡献:

使用类似self-attention处理数据集中的每个元素,构成了类似Transformer的结构,用于建模集合类型的数据,将计算时间从 变为 ,这里 是一个预定义的参数。

类似矩阵分解,在 最大值回归、计数不同字符、混合高斯、集合异常检测 和 点云分类 五个任务上达到较好结果。

问题定义

自然语言属于一种“顺序相关”的任务,任意交换两个字符会改变整个句子的意思。

但也有一些任务无关数据中元素的位置,比如对给定集合求和。

一个 集合类型的输入 问题应该满足下面两个条件: 输入顺序无关 和 输入大小无关 。

而上述条件在现有的RNN或MLP结构下很难满足,需要设计专用的模型进行处理。

模型设计

首先定义下述操作:

上面三式即为Multihead Attention。

这里,MAB为Multihead Attention Block,同Transformer。

这里SAB为Set Attention Block,而计算复杂度为 。

ISAB为Induced Set Attention Block,这里 而 ,从而复杂度为 。

PMA为Pooling by Multihead Attention且 ,我们用PMA来聚合多个特征。

在pooling之后还可以进一步使用SAB建模 个输出之间的联系。 最后是整个模型部分:

下图是模型及各个部分概览:

实验部分

参与比较的有下述模型:

•rFF+Polling•rFFp-mean/rFFp-max+Pooling•rFF+Dotprod•SAB(ISAB)+Pooling(this)•rFF+PMA(this)•SAB(ISAB)+PMA(this)

任务一:最大值回归问题

在这个任务中,我们将 作为损失。结果表明,rFF+max-pooling效果最好,约0.14,其次是SAB+PMA,约0.21。

这是因为,max-pooling只需去学习一个恒等函数,而自注意力也可以较好地找到最大值。

任务二:不同字符计数

对该任务,SAB+PMA可以达到约0.60的准确率,其次是SAB+pooling的0.57与rFF+PMA的0.46,其他的模型都低于0.45。

下图是ISAB中m的取值对结果的影响:

任务三: 混合高斯的最大似然估计

给定一个数据集 ,其混合高斯(MoG)的最大似然估计为:

我们的任务就是要训练 使得上式最大。我们不用EM算法,相反,我们直接用一个神经网络学习到最大值参数的映射:

对于PMA,我们取 ,即由四个高斯混合。我们在两个数据集上训练测试。一是人工合成的2D混合高斯,二是CIFAR-100。

对于人工合成,每个数据集包含 个2D点,每个点都由4个高斯分布中的一个随机产生。

对于CIFAR-100,每个数据集包含 张随机图片,这些图片都由四个随机选定的类中产生。结果如下表:

而下图是可视化结果。左上角是rFF+Pooling,右上角是SAB+Pooling,左下角是rFF+PMA,右下角是SAB+PMA。显然,Set Transformer能更好地拟合混合高斯分布。

任务四: 集合异常检测

文章使用CelebA进行异常检测。该数据集有202599张图片,共40个属性。

我们随机选定两个属性,对这两个属性随机选定七张图片,它们都含有该属性,再选定另一张图片,不含这两个属性,将这八张图片作为一个集合去找出与众不同的图片。

实验结果表明,SAB+PMA达到0.5941的AUROC和0.4386的AURP值,普遍高于其他模型2~3个点。

任务五: 点云分类

对该任务,文章使用ModelNet40作为数据集,该数据集包含许多三维物体数据,共40类,每个物体用一个点云表示,即一个有 个三维向量的集合,这里 。

结果表明,对于 ,ISAB(16)+PMA取得最好效果。

对 ,ISAB(16)+Pooling最好,ISAB(16)+PMA其次。

而对 ,简单的rFFp-max+Pooling效果与ISAB(16)+Pooling最好,ISAB(16)+PMA反而较差。

这是因为,当n很大的时候,用Pooling已经能很好地进行分类,不再需要进行点点之间关系的建模。

小结

本文介绍了一种用于处理“顺序无关”类型,即集合类型数据的模型——Set Transformer,在五个任务上达到了比较好的效果。

该模型完全使用了当前流行的self-attention,或者说是基本上采用了Transformer的Encoder端,再加上Pooling与降维等操作,简单有效,但是缺乏足够多的创新。

在实际应用上有很多此类“顺序无关”的任务,还有很多“半顺序无关”的任务,比较数学运算。未来的工作可以进一步思考如何更好地处理这类任务,提高运行速度与效果。