图神经网络 GNN 的可解释性问题与解释方法最新进展

前言

本文是一篇稍有深度的教程,假设读者具备图神经网络的基础知识和一点计算化学的知识。如果你想为本文做好准备,我在下面列出了一些有用的文章。

译注:计算化学(computational chemistry),是理论化学的一个分支,主要目的是利用有效的数学近似以及电脑进程计算分子的性质,例如总能量、偶极矩、四极矩、振动频率、反应活性等,并用以解释一些具体的化学问题。计算化学这个名词有时也用来表示计算机科学与化学的交叉学科。

本文最初发表于 TowardsDataScience 博客,经原作者 Kacper Kubara 授权,InfoQ 中文站翻译并分享。

什么是图神经网络

图 1:图像卷积与图卷积。(
来源

卷积神经网络(Convolutional Neural Network,CNN)和图神经网络(Graph Neural Network,GNN)的主要区别是什么?

简单来说,就是 输入数据

你可能还记得,CNN 所需的输入是一个固定大小的向量或矩阵。然而,某些类型的数据自然是用图表示的,如分子、引用网络或社交媒体连接网络都可以用图数据来表示。在过去 GNN 还没有普及的时候,图数据往往需要经过这样的转换,才可以直接提供给 CNN 作为输入。例如,分子结构仍被转换为固定大小的指纹,每个位表示是否存在某种分子子结构 [1] 。这是一种将分子数据插入 CNN 的好方法,但它会不会导致信息损失呢?

图 2:通过从 NH 原子对获取不同的分子子结构,将分子映射到位指纹,所述分子子结构由不同的半径范围 (h=0,h=1,…) 定义。(
来源

GNN 利用图数据,省去了数据预处理步骤,充分利用了数据中包含的信息。现在,有许多不同的 GNN 架构,其背后的理论也很快变得复杂起来。然而,GNN 可以分为两类:空间方法(Spatial approach)和谱方法(Spectral approach)。空间方法更为直观,因为它将池化和卷积操作(如 CNN)重新定义到图域。谱方法从一个稍有不同的角度来处理这个问题,因为它侧重于处理被定义为使用傅里叶变换(Fourier transform)的图网络的信号。

如果你想了解更多信息,可以阅读 Thomas Kipf 写的 博文 ,这篇博文深入介绍了 GCN。另一篇值得阅读的有趣论文是《 MoleculeNet:分子机器学习的基准 》[2] ,它很好地介绍了 GNN,并描述了流行的 GNN 架构。

当前 GNN 可解释性问题

在撰写本文时,对 GNN 解释方法有贡献的论文屈指可数。尽管如此,这仍然是一个非常重要的主题,最近变得越来越流行。

GNN 的流行要比标准神经网络晚得多。虽然这一领域有很多有趣的研究,但还不是很成熟。GNN 的库和工具目前仍然处于“实验阶段”,我们现在真正需要的是,让更多的人使用它来发现 Bug 和错误,并转向生产就绪的模型。

创建生产就绪的模型的一种方法是更好地理解它们所做的预测。可以使用不同的解释方法来完成。我们已经看到了许多应用于 CNN 的有趣的可解释性方法,例如梯度归因(Gradient attribution)、显著性映射(Saliency maps)或类激活映射(Class activation mapping)等。那么,为什么不将它们重新用于 GNN 呢?

图 3:神经网络和其他机器学习模型常用的解释方法及其特性。(
来源

事实上,这就是目前正在发生的事情。最初用于 CNN 的解释性方法正在重新设计并应用于 GNN。虽然无需重新发明轮子,但我们仍然必须通过重新定义数学操作来调整这些方法,使它们适用于图数据。这里的主要问题是,关于这一主题的研究还相当新,第一篇论文发表于 2019 年。然而,随着时间的推移,它将会变得越来越流行,目前几乎还没有什么有趣的可解释性方法可用于 GNN 模型。

在本文中,我们将研究新的图解释方法,并了解如何将它们应用于 GNN 模型。

第一次尝试:可视化节点激活

图 4:
左图 :Duvenaud 等人提出的三层叠加神经网络指纹模型的计算图可视化。此处,节点表示原子,边表示原子键。
右图 :更详细的图,包括每个操作中使用的键信息。(
来源

2015 年,Duvenaud 等人 [3] 发表了关于 GNN 解释技术的开创性论文。该论文主要贡献是提出了一种新的 神经图指纹模型 ,同时也为这种结构创造了一种解释方法。该模型背后的主要思想是直接从图数据本身创建可区分的指纹。为实现这一点,作者不得不重新定义图的池化和平滑操作。这些操作然后被用来创建一个单层。

如图 4 所示,这些层叠加 n 次以产生向量输出。层深度也与相邻节点的半径相对应,从这些节点收集和汇集节点特征(在本例中为和函数)。这是因为,对于每一层,池化操作都从邻近节点收集信息,如图 1 所示。对于更深的层,池化操作的传播会延伸到更远的邻近节点。与正常的指纹相反,这种方法是可区分的,这使得反向传播可以以类似 CNN 的方式更新其权值。

除了 GNN 模型之外,他们还创建了一个简单的方法,可以显示节点激活及其邻近的节点。遗憾的是,这种方法在论文中没有很好地解释,因此有必要查看它们的 代码实现 以理解其底层机制。然而,他们在分子数据上运行它来预测溶解度,它突出了部分对溶解度预测有交稿预测能力的分子。

图 5:论文作者在溶解度数据集上测试了他们的模型,以突出部分影响溶解度的分子。通过他们的解释方法,他们能够确定使分子更容易溶解的分子子结构(如 R-OH 基团)和那些使其不易溶解的分子子结构(如非极性重复环结构)。(
来源

到底是怎么工作的呢?为计算节点激活量,我们需要进行以下计算。对于每个分子,让我们通过每一层向前传递数据,类似于在一个典型 CNN 网络中对图像的处理。然后,我们用 softmax() 函数提取每个指纹位各层的贡献。然后,我们就能够将一个节点(原子)与它周围对特定指纹位贡献最大的邻居(取决于层深)关联起来。

这种方法相对简单,但没有留下很好的文档记录。但该论文所做的初步工作是相当有前途的,研究人员随后进行了更为详细的尝试,将 CNN 的解释方法转化为图域。

如果你想进一步了解它,可以参阅他们的 论文代码仓库 或我在这个 Github 问题 底部对方法的详细解释。

卷积神经网络的重用方法

敏感度分析 (Sensitive Analysis)、 类激活映射激发反向传播 (Excitation Backpropagation)都是已经成功应用于 CNN 的解释技术的示例。目前针对可解释 GNN 的研究试图将这种方法转换成图域。这一领域的大部分工作都是在这些论文 [4][5] 中完成的。

本文不再着重于这些方法的数学解释,我将为你提供这些方法的直观解释,并简要讨论这些方法的结果。

归纳 CNN 的解释方法

为了重用 CNN 的解释方法,让我们将 CNN 输入的数据(即图像)视为一个格子形状的图。图 6 说明了这个想法。

图 6:图像可以概括为格子形状的图。红十字架只显示了图像的一小部分,可以表示为一个图。每个像素可以被认为是节点
V ,节点 V 具有 3 个特征(RGB 值),并且连接到具有边
E 的相邻像素。

如果我们牢记图像的图推广,我们可以说 CNN 解释方法不怎么关注边缘(像素之间的连接),而是关注节点(像素值)。这里的问题是,图数据中的边缘包含了很多有用的信息 [4][5] ,而不是以网格状的顺序。

哪些方法已经转换为图域?

到目前为止,文献已经讨论过以下解释方法:

  • 敏感度分析(Sensitivity Analysis)[4]
  • 引导反向传播 (Guided Backpropagation)[4]
  • 层次相关传播 (Layer-wise Relevance Propagation)[4]
  • 基于梯度的热图 (Gradient-based heatmaps)[5]
  • 类激活映射(Class Activation Maps)[5]
  • 梯度加权类激活映射(Gradient-weighted Class Activation Mapping (Grad-CAM))[5]
  • 激发反向传播 (Excitation Backpropagation)[5]

请注意,其中两篇论文 [4][5] 的作者并未提供这些方法的开源实现,因此目前尚无法使用它们。

模型无关方法:GNNExplainer

本文的代码仓库可以在 这里 找到。

不用担心,实际上有一个 GNN 解释工具可以供你使用!

GNNExplainer 是一种模型无关的开源 GNN 解释方法。它也是一个相当通用的,因为它可以应用于节点分类、图分类和边预测。这是创建图特定解释方法的首次尝试,由斯坦福大学的研究人员提出 [6] 。

它是如何工作的?

作者声称,重用之前应用于 CNN 的解释方法是一种糟糕的方法,因为它们未能纳入关系信息,而关系信息是图数据的本质。此外,基于梯度的方法对于离散输入并不是特别有效,而 GNN 通常就是这种情况(例如, 邻接矩阵 是二进制矩阵)。

为了克服这些问题,他们创建了一种与模型无关的方法,该方法可以找到输入数据的子图,这些子图以最重要的方式影响 GNN 的预测。说得更具体一些,子图的选择是为了最大化与模型预测的互信息(Mutual information)。下图显示了 GNNExplainer 如何处理由体育活动组成的图数据的一个示例。

图 7:图数据说明了不同的体育活动。目的是解释下一个体育活动中心节点的预测。对于红色图形,Vi 被预测为篮球。GNNExplainer 能够选择一个子图,该子图可以通过预测使互信息最大化。在这种情况下,它从图中提取了一个有用的信息:当某人过去玩球类运动时,选择篮球作为下一项体育活动的可能性很高。(
来源

作者提出的一个非常重要的假设是 GNN 模型的公式。模型的架构或多或少是任意的,但它需要实现 3 个关键计算:

  • 相邻节点间的神经信息计算。
  • 来自节点邻域的消息聚合。
  • 聚合消息的非线性变换及节点表示。

这些所需的计算在某种程度上是有限制的,但大多数现代 GNN 架构无论如何,都是基于消息传递架构的。[6]

论文给出了 GNNExplainer 的数学定义和优化框架的描述。然而,我会省去一些细节,向你展示一些可以通过 GNNExplainer 获得的有趣结果。

GNNExplainer 在实践中的应用

为比较结果,他们使用了 3 种不同的解释方法: GNNExplainer基于梯度的方法 (Gradient-based method,GRAD)和 图注意力模型 (Graph attention model,GAT)。在本文前面部分提到了 GRAD,但 GAT 模型需要一些解释。这是另一个 GNN 架构,它学习边的注意力权重,这将帮助我们确定图网络中哪些边对节点分类实际上很重要。这可以作为另一种解释技术使用,但它只适用于这个特定的模型,而且它不解释节点特性,这与 GNNExplainer 相反。

首先让我们看一下应用于合成数据集的解释方法的性能。图 8 显示了在两个不同的数据集上运行的实验。

BA-Shapes 数据集基于 Barabasi-Albert (BA) 图,它是一种图网络,我们可以通过改变它的一些参数来自由调整大小。在这个基础图上,我们将附加一些小的房屋结构图(motifs),这些图在图 8 中显示为真相(Ground Truth)。此房屋结构的节点有 3 个不同的标签:顶部、中部和底部。这些节点标签只是表示节点在房子中的位置。因此,对于一个类房屋的节点,我们有 1 个顶部节点、2 个中部节点和 2 个底部节点。还有一个额外的标签,表示该节点不属于类房屋图结构。总体而言,我们有 4 个标签和一个 BA 图,其中包含 300 个节点和 80 个类房屋结构,这些结构被添加到 BA 图的随机节点中。我们还通过添加 0.1N 个随机边来增加一些随机性。

BA-Community是两个 BA-Shapes 数据集的联合,因此总共有 8 个不同的标签(每个 BA-Shapes 数据集有 4 个标签)和两倍的节点。

让我们来看一看结果。

图 8:目的是为红色节点的预测提供解释。解释方法的设置是为了找到一个 5 节点的子图,该子图可以最准确地解释结果(标记为绿色)。(
来源

结果似乎很有希望。GNNExplainer 似乎以最准确的方式解释了结果,因为所选择的子图与真相相同。Grad 和 Att 方法未能提供类似的解释。

如图 9 所示,该实验也是在真实数据集上进行的。这一次的任务是对整个图网络进行分类,而不是对单个节点进行分类。

Mutag是一个由分子组成的数据集,这些分子是根据对某类细菌的诱变效应进行分类的。数据集有很多不同的标签,它包含 4337 个分子图。

Reddit-Binary是一个代表 Reddit 中在线讨论主题的数据集。在这个图网络中,用户用节点表示,边表示对另一个用户评论的响应。根据用户交互的类型,有两种可能的标签,它可以是问答互动,也可以是在线讨论互动。总的来说,它包含 2000 个图。

图 9:GNNExplainer 运行图分类任务。对于 Mutag 数据集,颜色表示节点特征(原子)。对于 Reddit-Binary,任务是对图是否为在线讨论进行分类。(
来源

对于 Mutag 数据集,GNNExplainer 正确识别已知具有致突变性的化学基团(如 NO2、NH2)。GNNExplainer 还解释了归类为 Reddit-Binary 数据集在线讨论分类图。这种类型的交互通常可以用树状的模式表示(查看真相)。

参考文献

作者介绍:

Kacper Kubara,南安普顿大学(University of Southampton)即将毕业的学生,电子工程专业,同时也是 IT 创新中心的研究助理。在业余时间喜欢摆弄数据或调试深度学习模型。喜欢徒步旅行。

原文链接:

https://towardsdatascience.com/towards-explainable-graph-neural-networks-45f5e3912dd0