谷歌大脑开源Trax代码库,你的深度学习进阶路径

感觉深度学习建模只不过调库与堆叠层级?你需要谷歌大脑维护的这条路径 Trax,从头实现深度学习模型。

从最开始介绍卷积、循环神经网络原理,到后来展示各种最前沿的算法与论文,机器之心与读者共同探索着机器学习。我们会发现,现在读者对那些著名的深度学习模型已经非常熟悉了,经常也会推导或复现它们。

而对于最前沿的一些实现,包括 Transformer 或其它强化学习,我们通常都需要看原作者开源的代码,或者阅读大厂的复现。出于速度等方面的考虑,这些实现通常会显得比较「隐晦」,理解起来不是那么直白。这个时候,你就需要谷歌大脑维护的 Trax,它是 ML 开发者进阶高级 DL 模型的路径。

Trax 是一个开源项目,它的目的在于帮助我们挖掘并理解高一阶的深度学习模型。谷歌大脑表示,该项目希望 Trax 代码做到非常整洁与直观,并同时令 Reformer 这类高阶深度学习达到最好的效果。

项目地址:https://github.com/google/trax

什么是 Trax

简单来说,Trax 就是一个代码库,它有点类似于一个极简的深度学习框架。只不过 Trax 关注什么样的代码能让读者更好地理解模型,而不只是关注加速与优化。

Trax 代码及其组织方式希望让我们从头理解深度学习,而不只是简单地调库。整个项目从最基础的数学部分开始,然后向上依次构建层级运算、模型运算,以及有监督与强化学习训练任务。

因为是进阶深度学习高级建模,Trax 还囊括了最前沿的研究结果,例如在 ICLR 2020 上做演讲报告的 Reformer。如下展示的是该项目的代码文件结构:

如果要从头理解并进阶深度学习,那么 Trax 代码主要可以分为以下 6 部分:

  • math/:最基本的数学运算,以及通过 JAX 和TensorFlow加速运算性能的方法,尤其是在 GPU/TPU 上;

  • layers/:搭建神经网络的所有层级构建块;

  • models/:包含所有基础模型,例如 MLP、ResNet 和 Transformer,还包含一些前沿 DL 模型;

  • optimizers/:包含深度学习所需要的最优化器;

  • supervised/:包含执行监督学习的各种有用模块,以及整体的训练工具;

  • rl/:包含谷歌大脑在强化学习上的一些研究工作;

每一个文件夹下都有对应的实现,例如在 Layers 中,所有神经网络层级都继承自最基础的 Layer 类,实现这个类花了 700 行代码。而后新的层级在继承它后只要实现以下两个方法就行:

通过 900 行代码(包括 Err 处理),基础的 Layer 类能完成其它所有处理,包括初始化与调用等。

使用 Trax

我们可以将 Trax 作为 Python 脚本库或者JupyterNotebook 的基础,也可以作为命令行工具执行。Trax 包含很多深度学习模型,并且绑定了大量深度学习数据集,包括 Tensor2Tensor 和TensorFlow采用的数据集。同时,如果我们在 CPU、GPU 或 TPU 上运行这些模型,也不需要改变。

如果读者想要了解如何快速将 Trax 作为一个库来使用,那么可以看看如下 Colab 上的入门示例。它介绍了如何生成样本数据,并连接到 Trax 中的 Transformer 模型。在训练或推断时,我们可以选择 GPU,也可以选择 8 核心的免费 TPU。

入门简介地址:

https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb

如果要在命令行中使用 Trax,那么带上参数就可以了,例如模型类型、学习率等超参。谷歌大脑团队建议我们可以看看 gin-config,例如训练一个最简单的 MNIST 分类模型,可以看看 mlp_mnist.gin,然后如下运行就行了:

python -m trax.trainer --config_file=$PWD/trax/configs/mlp_mnist.gin

如果你觉得上面的训练太简单,也可以在 ImageNet64 上训练一下 Reformer:

python -m trax.trainer --config_file=$PWD/trax/configs/reformer_imagenet64.gin

最后,这个项目最重要的还是它的实现代码,我们并不是因为可以直接运行而使用它。相反,我们是因为它的代码直观简洁,能帮助我们一步步更深刻地理解模型而使用它。