为什么torch.nn.MultiheadAttention中的W_q矩阵是平方的

问题描述

我正在尝试在我的网络中实现nn.MultiheadAttention。根据{{​​3}},

embed_dim –模型的总尺寸。

但是,根据docs

embed_dim必须可被num_heads整除

self.q_proj_weight =参数(torch.Tensor(embed_dim,embed_dim))

如果我理解正确,这意味着每个头部仅每个查询的一部分功能,因为矩阵是二次的。是实现的错误还是我的理解是错误的?

解决方法

每个头都使用投影查询向量的不同部分。您可以想象一下,好像查询被分解为num_heads个向量一样,这些向量独立地用于计算缩放的点积注意力。因此,每个头都对查询中的功能(以及键和值)的不同线性组合进行操作。线性投影是使用self.q_proj_weight矩阵完成的,并将投影的查询传递到F.multi_head_attention_forward函数。

F.multi_head_attention_forward中,它是通过对查询向量进行重塑和转置来实现的,因此可以efficiently by matrix multiplication来计算各个头部的独立注意力。

注意头的大小是PyTorch的设计决定。从理论上讲,您可以使用不同的头大小,因此投影矩阵的形状为embedding_dim×num_heads * head_dims。转换器的某些实现(例如用于机器翻译的基于C ++的MarianHuggingface's Transformers)允许这样做。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...