可解释性论文阅读笔记1-Tree Regularization
作者: HardenHuang
学校: 清华大学
研究方向: 自然语言处理
知乎专栏: 模型可解释性论文专栏
AAAI2018的一篇关于深度学习模型可解释性的文章,文章的主要亮点是引入tree regularization在训练深度学习网络的同时给出一个具有可解释性能力的决策树
Beyond Sparsity: Tree Regularization of Deep Models for Interpretability arxiv.org
link: https://arxiv.org/pdf/1711.06178.pdf
1. 模型可解释性的重要性
(1)深度模型的进一步应用需要模型可解释性,例如金融、医疗等领域
(2)仿生性模型更容易被人理解和应用,例如决策树模型
2. Model Interpretion Introduction
做模型可解释性有两套思路:一是为已经训练好的模型寻找可解释性;二是训练出更具有解释性的模型
(1)为已经训练好的模型寻找可解释性: 寻找神经网络模型的决策树表示 [1] ,对输入和输入的梯度进行敏感性分析 [2] ,寻找模型的编程化表示 [3] ,寻找模型的规则集合表示 [4]
(2) 训练出更具解释性的模型:惩罚更不相关的特征获得稀疏特征 [5] ,从文本输入中找到高亮部分 [6]
3. Related work
利用了各类正则化:L1正则化方法 [7] ,binary network [8] ,Edge and node regularization [9]
4. 方法的具体描述
key idea :训练深度学习模型同时,获取准确率高与复杂度小的决策树,决策树的复杂度作为正则项;
决策树的复杂度 :利用APL (average path length),训练集中样例做出决策平均经过的节点数,APL即为Tree Regularization。
问题 :L1正则化 L2正则化等正则化项可以由一个函数 计算出来, 为深度模型 的权重,Tree Regularization却不能简单的由 表示
解决方法 :引入一个替代网络 (多层感知机)去近似表示决策树 , 的输入为 训练过程中的权重 , 输出APL的预测值 , 为模型的参数,标签为的APL输出 ,
与 的权重产生关联,因此可以用梯度下降方法进行优化
具体训练过程 :
step1:训练深度学习模型,输入特征 及真实标签 ,输出标签预测值 ,损失函数如下,其中 来自于的输出
step2: 训练决策树模型,输入为特征 和标签预测值 ,输出为 ,训练过程与一般的决策树分裂过程相同
step3:训练替代网络, 输入为各步训练过程中得到的权重,输出为APL的预测值 ,标签来自于决策树模型的,损失函数如下,

如此重复训练,可以同时得到深度学习模型与具有可解释性的模型 ,如下是一个决策树的示例

5. 实验及结果
语音识别任务

脓毒症重症监护(Sepsis Critical Care)

艾滋病治疗结果(HIV Therapy Outcome)

6. 代码解读
tree-regularization-public github.com
link: https://github.com/dtak/tree-regularization-public
训练过程:可以看到是 的依次训练过程
def train(self, X_train, F_train, y_train, iters_retrain=25, num_iters=1000, batch_size=32, lr=1e-3, param_scale=0.01, log_every=10): npr.seed(42) num_retrains = num_iters // iters_retrain for i in xrange(num_retrains): self.gru.objective = self.objective # carry over weights from last training init_weights = self.gru.weights if i > 0 else None print('training deep net... [%d/%d], learning rate: %.4f' % (i + 1, num_retrains, lr)) self.gru.train(X_train, F_train, y_train, num_iters=iters_retrain, batch_size=batch_size, lr=lr, param_scale=param_scale, log_every=log_every, init_weights=init_weights) print('building surrogate dataset...') W_train = deepcopy(self.gru.saved_weights.T) APL_train = self.average_path_length_batch(W_train, X_train, F_train, y_train) print('training surrogate net... [%d/%d]' % (i + 1, num_retrains)) self.mlp.train(W_train[:self.gru.num_weights, :], APL_train, num_iters=3000, lr=1e-3, param_scale=0.1, log_every=250) self.pred_fun = self.gru.pred_fun self.weights = self.gru.weights # save final decision tree self.tree = self.gru.fit_tree(self.weights, X_train, F_train, y_train) return self.weights
参考
-
^train decision trees for pretrained neural network (Craven and Shavlik (1996))
-
^Adler, P.; Falk, C.; Friedler, S. A.; Rybeck, G.; Scheidegger, C.; Smith, B.; and Venkatasubramanian, S. 2016. Auditing black-box models for indirect influence. In ICDM
-
^Ribeiro, M. T.; Singh, S.; and Guestrin, C. 2016. Why should I trust you?: Explaining the predictions of any classifier. In KDD
-
^Lakkaraju, H.; Bach, S. H.; and Leskovec, J. 2016. Interpretable decision sets: A joint framework for description and prediction. In KDD.
-
^Ross, A.; Hughes, M. C.; and Doshi-Velez, F. 2017. Right for the right reasons: Training differentiable models by constraining their explanations. In IJCAI
-
^Ross, A.; Hughes, M. C.; and Doshi-Velez, F. 2017. Right for the right reasons: Training differentiable models by constraining their explanations. In IJCAI
-
^Zhang, Y.; Lee, J. D.; and Jordan, M. I. 2016. l1-regularized neural networks are improperly learnable in polynomial time. In ICML
-
^Tang, W.; Hua, G.; and Wang, L. 2017. How to train a compact binary neural network with high accuracy? In AAAI.
-
^Ochiai, T.; Matsuda, S.; Watanabe, H.; and Katagiri, S. 2017. Automatic node selection for deep neural networks using group lasso regularization. In ICASSP
本文由作者授权AINLP原创发布于公众号平台,点击’阅读原文’直达原文链接,欢迎投稿,AI、NLP均可。
原文链接:
https://zhuanlan.zhihu.com/p/99384386
推荐阅读
AINLP-DBC GPU 云服务器租用平台建立,价格足够便宜
我们建了一个免费的知识星球:AINLP芝麻街,欢迎来玩,期待一个高质量的NLP问答社区
关于AINLP
AINLP 是一个有趣有AI的自然语言处理社区,专注于 AI、NLP、机器学习、深度学习、推荐算法等相关技术的分享,主题包括文本摘要、智能问答、聊天机器人、机器翻译、自动生成、知识图谱、预训练模型、推荐系统、计算广告、招聘信息、求职经验分享等,欢迎关注!加技术交流群请添加AINLP君微信(id:AINLP2),备注工作/研究方向+加群目的。