为什么嵌入维度必须可以被 MultiheadAttention 中的头数整除?

问题描述

我正在学习 Transformer。这是 MultiheadAttention 的 pytorch 文档。在他们的 implementation 中,我看到有一个约束:

 assert self.head_dim * num_heads == self.embed_dim,"embed_dim must be divisible by num_heads"

为什么需要约束:embed_dim must be divisible by num_heads? 如果我们回到等式

MultiHead(Q,K,V)=Concat(head1​,…,headh​)WOwhereheadi​=Attention(QWiQ​,KWiK​,VWiV​)

假设: QKVn x emded_dim 矩阵;所有的权重矩阵 W 都是 emded_dim x head_dim,

那么,concat [head_i,...,head_h] 将是一个 n x (num_heads*head_dim) 矩阵;

W^O 大小为 (num_heads*head_dim) x embed_dim

[head_i,head_h] * W^O 将成为 n x embed_dim 输出

我不知道为什么我们需要 embed_dim must be divisible by num_heads

假设我们有 num_heads=10000,结果是一样的,因为矩阵-矩阵乘积会吸收这些信息。

解决方法

当您有一个 seq_len x emb_dim 序列(即 20 x 8)并且您想使用 num_heads=2 时,该序列将沿 emb_dim 维度拆分。因此,您会得到两个 20 x 4 序列。您希望每个头部都具有相同的形状,如果 emb_dim 不能被 num_heads 整除,这将不起作用。以序列 20 x 9num_heads=2 为例。然后你会得到 20 x 420 x 5,它们不是同一个维度。