【译】理解 LSTM

一直以来感觉自己对 LSTM 的理解缺了点什么。这次看到一篇不错的博客 Understanding LSTM Networks,对 LSTM 的机制和公式有逐条的解释。因此决定翻译一下,帮助自己理解和记忆公式。主要靠意译,省略了一点无关紧要的内容,有条件可以自己看看原文。

背景

人类不是每分每秒都从头开始思考,而是基于所学所想进一步思考,有一定的持续性。RNN 可以通过前一个单元将消息传递给后一个单元,一定程度上模拟这个过程。

使用 RNN 面临着一个问题,有时候我们需要非常长的上下文,例如通过非常靠前的相关信息来预测下一个词。在实践中,RNN 难以处理这种“长期依赖”,原因和梯度爆炸(exploding gradient)/梯度消失(vanishing gradient)有关。不过,LSTM 可以解决这个问题,它会默认地长时间记忆信息。

上图展示了包含了单个层的标准 RNN 单元组成的链式结构。LSTM 有着同样的链式结构,但是重复的组件结构不同,它有着四个交互的层:

核心思想

LSTM 的核心是 cell state(单元状态),即贯穿图顶部的水平线。cell state 有点像传送带,沿着整个链向后,只有一些小的线形相互作用。因此,信息很容易不加改动地沿着它流动。

而向 cell state 删除或者添加的信息通过门(gates)结构精心控制。门结构能够选择性地让信息通过,它有一个 sigmoid 层和一个 pointwise 乘积操作组成:

sigmoid 层输出一个 0 到 1 之间的实数,来描述允许多少信息通过,0 代表不允许任何信息通过,而 1 代表让全部信息通过。LSTM 有三个类似的门机构,用来保护和控制 cell state。

一步一步看公式

LSTM 的第一步是决定从 cell state 中丢弃什么信息,这个决定通过称为“忘记门(forget gate)”的 sigmoid 层作出。输入是 $h_{t-1}$ 和 $x_t$,并为 cell state $C_{t-1}$ 中的每个数字输出一个 0 到 1 之间的值,1 表示完全保留,而 0 表示完全遗忘。

举一个基于前面所有单词来预测下一个单词的例子,在这个问题中,cell state 可能包括当前主语的性别,以便使用正确的代词。当我们看到一个新的主语时,我们想要忘记旧主语的性别。

第二步我们决定将在 cell state 中存储什么新信息。首先,一个称为输入门(input gate)的 sigmoid 层决定我们将要更新的值。然后,一个 tanh 层创建一个候选的 $\tilde C_i$,准备加到 cell state 中。在下一步,我们将这两个值合并,来为 cell state 做一个更新。

继续之前的例子,我们想要将新主语的性别加入到 cell state,以取代我们遗忘的内容。

接下来我们将旧的 cell state $C_{t-1}$ 更新为新的 cell state $C_t$。在例子中,我们在这一步中正式地丢弃关于旧主语性别的信息并添加新信息。

最后,我们决定要输出的内容。输出取决于我们的 cell state,不过要经过一层过滤。首先,我们还是通过一个 sigmoid 层来决定 cell state 中哪些部分要被输出。之后,我们将 state cell 通过 tanh 来将值缩放到 -1 与 1 之间,并用 sigmoid 门的输出做乘法,决定输出的部分。

对于语言模型示例,由于它只看到了一个主语,因此可能希望输出与谓词相关的信息,以防接下来会出现与谓词相关的信息。例如,它可以输出主语是单数还是复数,这样我们就知道如果接下来是动词,动词应该用什么形式。

LSTM 的变体

以上所述就是一个标准的 LSTM。但是,几乎所有的论文使用的 LSTM 都有或多或少的修改。虽然差异不大,但其中有些版本值得一提。

Gers & Schmidhuber (2000) 中的版本增加了窥视孔连接(peephole connections),使得门层可以观察 cell state。

上图在所有的门上都加了窥视孔。也有些论文只给一些门增加窥视孔。

另一种变体是将遗忘门和输入门耦合。我们不再单独决定要忘记什么或者应该添加什么新信息,而是共同做出这些决定,并且只在准备遗忘旧信息时添加新信息作为替代。

一个变化稍大的变体是 Cho, et al. (2014) 提出的门控循环单元(Gated Recurrent Unit, GRU)。它将遗忘和输入合并到一个单独的“更新门”,并将 cell state 和 hidden state 合并到一块。GRU 比标准的 LSTM 更简单,并且也得到广泛的应用。

除了几个最值得注意的变体,还有很多版本。Jozefowicz, et al. (2015) 测试了超过一万种 RNN 架构,发现一部分会在特定任务上比 LSTM 表现更好。

推荐资料