正确的debug你的TensorFlow代码(不用这么痛苦)
会话加载并通过预先训练的模型进行预测。这就是瓶颈,我花了几周的时间来理解、调试和修复它。我想高度关注这个问题,并描述两种重新加载预训练模型(图和会话)并使用它的可能技术。
首先,当我们谈论加载模型时,我们真正的意思是什么?当然,为了做到这一点,我们需要事先训练并保存它。后者通常是通过 tf.train.Saver.save
完成的。因此,我们有3个二进制文件 .index
、 .meta
和 .data-00000-of-00001
,其中包含恢复会话和图所需的所有数据。
要加载以这种方式保存的模型,需要通过 tf.train.import_meta_graph()
(参数是 .meta
的文件)来恢复。按照前一段描述的步骤之后,所有变量(包括所谓的“隐藏”变量,稍后将讨论)都将被移植到当前图中。检索某个有自己名字的张量(记住,它可能不同于你初始化它时使用的张量,这取决于创建张量的范围和操作的结果)应该执行 graph.get_tensor_by_name()
。这是第一种方法。
第二种方法更显式,也更难实现(对于我一直在使用的模型的架构,我还没有成功用起来),它的主要思想是将图的边(张量)显式地保存到 .npy
或 .npz
文件中,然后将它们加载回图中(并根据创建它们的范围分配适当的名称)。这种方法的问题在于它有两个巨大的缺点:首先,当模型架构变得非常复杂时,它也变得很难控制和保存所有的权重矩阵。其次,有一种“隐藏的”张量,它是在没有显式初始化的情况下创建的。例如,当你创建 tf.nn.rnn_cell.BasicLSTMCell
时。它创建了所有需要的权值和偏差来实现LSTM cell。变量名也是自动分配的。
这种行为看起来还可以,但实际上,在很多情况下,并不是很好用。这种方法的主要问题是,当你查看图的集合时,看到一堆变量,你不知道它们的来源,你实际上不知道应该保存什么以及在哪里加载它们。坦率地说,很难将隐藏变量放到图中正确的位置并适当地操作它们。