问题描述
我正在学习 Transformer。这是 MultiheadAttention 的 pytorch 文档。在他们的 implementation 中,我看到有一个约束:
assert self.head_dim * num_heads == self.embed_dim,"embed_dim must be divisible by num_heads"
为什么需要约束:embed_dim must be divisible by num_heads?
如果我们回到等式
假设:
Q
、K
、V
是 n x emded_dim
矩阵;所有的权重矩阵 W
都是 emded_dim x head_dim
,
那么,concat [head_i,...,head_h]
将是一个 n x (num_heads*head_dim)
矩阵;
W^O
大小为 (num_heads*head_dim) x embed_dim
[head_i,head_h] * W^O
将成为 n x embed_dim
输出
我不知道为什么我们需要 embed_dim must be divisible by num_heads
。
假设我们有 num_heads=10000
,结果是一样的,因为矩阵-矩阵乘积会吸收这些信息。
解决方法
当您有一个 seq_len x emb_dim
序列(即 20 x 8
)并且您想使用 num_heads=2
时,该序列将沿 emb_dim
维度拆分。因此,您会得到两个 20 x 4
序列。您希望每个头部都具有相同的形状,如果 emb_dim
不能被 num_heads
整除,这将不起作用。以序列 20 x 9
和 num_heads=2
为例。然后你会得到 20 x 4
和 20 x 5
,它们不是同一个维度。