为什么Bahdanau中的编码器隐藏状态形状与编码器输出形状不同

问题描述

此问题与此处显示的神经机器翻译有关: Neural Machine Translation

这里:

批量大小 = 64

输入长度(示例输入语句中的单词数,也称为不同的时间步长)= 16

RNN单元数(这也是隐藏状态向量的长度或每个时间步的隐藏状态向量的维数)= 1024

这被解释为:

在每批中(总共64个),对于每个输入词(总计16个),每个时间步都有一个1024维矢量。这个1024维向量代表编码过程中特定时间步的输入单词。 这个1024维向量称为每个单词的隐藏状态。

我的问题是:

为什么(64,1024)的隐藏状态维度与(64,16,1024)的编码器输出维度不同?两者应该不一样,因为对于每批,在输入句子中有16个单词,对于输入句子中的每个单词,我们具有1024维隐藏状态矢量。因此,在编码步骤结束时,我们获得了形状为(64,16,1024)的累积隐藏状态向量,这也是编码器的输出。两者相同。

进一步提供尺寸为(64,1024)的编码器隐藏输出作为解码器的第一个隐藏状态输入。

一个相关问题:

如果输入长度为16个字,而不是使用16个单位,那么在编码器中使用1024个单位的原因是什么?

解决方法

“为什么隐藏状态维数为(64,1024)”。

在您的RNN模型中,每个单词的输出是形状的矢量(GRU单元数= 1024),如果批处理为64,则我们给模型中的64个单词批处理中每个示例的一个单词,这为我们的每个输入提供了一个形状为(64,1024)的输出向量。

现在,要消耗所有序列,我们将下一个单词输入最多16个,以获取RNN层的正常3d输出(64、16、1024)。

对于第二个问题,在RNN模型中,GRU单元的数量(例如1024)不取决于序列的长度,我们在RNN层中添加更多的单元以捕获序列的复杂性

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...