在Keras Transformer官方示例中解释关注

问题描述

我已经实现了一个模型,如(使用Transformer进行文本分类https://keras.io/examples/nlp/text_classification_with_transformer/

中所述

我想访问特定示例的注意力值。

我知道注意力是在这一点上计算出来的:

class TransformerBlock(layers.Layer):
    [...]

def call(self,inputs,training):
    attn_output = self.att(inputs)
    attn_output = self.dropout1(attn_output,training=training)
    out1 = self.layernorm1(inputs + attn_output)
    ffn_output = self.ffn(out1)
    ffn_output = self.dropout2(ffn_output,training=training)
    return self.layernorm2(out1 + ffn_output)

[...]

embed_dim = 32  # Embedding size for each token

num_heads = 2  # Number of attention heads
ff_dim = 32  # Hidden layer size in Feed forward network inside transformer

inputs = layers.Input(shape=(maxlen,))
embedding_layer = TokenAndPositionEmbedding(maxlen,vocab_size,embed_dim)
x = embedding_layer(inputs)
transformer_block = TransformerBlock(embed_dim,num_heads,ff_dim)
x = transformer_block(x)
x = layers.GlobalAveragePooling1D()(x)
x = layers.Dropout(0.1)(x)
x = layers.Dense(20,activation="relu")(x)
x = layers.Dropout(0.1)(x)
outputs = layers.Dense(2,activation="softmax")(x)

如果我这样做:

A=(model.layers[2].att(model.layers[1](model.layers[0]((X_train[0,:])))))

我可以检索大小为maxlen x num_heads的矩阵。

我应该如何解释这些系数?

解决方法

编辑:如果您想使用注意力来解释分类输出

据我所知,不可能完全解释Transformer在分类中的作用。 Transformer所做的只是看每个输入如何相互关联,而不是每个单词对标签的贡献。如果希望找到可解释的模型,请尝试着眼LSTM进行分类。

好的,所以我读了您的代码,并在您致电model.layers[1]时发现了一些错误。首先,您需要了解该模型正在批量处理数据。因此,您的输入应为(batch_size,seq_len)格式。但是,您的输入形状似乎掉落了第一个维度(即批处理),这使您的模型认为您给模型提供了200个句子,序列长度为1。因此,从图中可以看出,输出形状看起来很奇怪。 / p>

Code Demo

正确的方法是在第一个维度上添加一个额外的维度(使用tf.expand_dims)。

现在,用于解释结果。您需要知道,Transformer块会进行自我注意(它会找到句子中每个单词与其他单词的分数)并对其进行加权求和。因此,输出将与嵌入层相同,并且您将无法解释(因为它是网络生成的隐藏矢量)。

但是,您可以使用以下代码查看每个头部的注意力得分:

import seaborn as sns
import matplotlib.pyplot as plt

head_num=1
inp = tf.expand_dims(x_train[0,:],axis=0)
emb = model.layers[1](model.layers[0]((inp)))

self_attn = model.layers[2].att
# compute Q,K,V
query = self_attn.query_dense(emb)
key = self_attn.key_dense(emb)
value = self_attn.value_dense(emb)
# separate heads
query = self_attn.separate_heads(query,1) # batch_size = 1
key = self_attn.separate_heads(key,1) # batch_size = 1
value = self_attn.separate_heads(value,1) # batch_size = 1
# compute attention scores (QK^T)
attention,weights = self_attn.attention(query,key,value)

idx_word = {v: k for k,v in keras.datasets.imdb.get_word_index().items()}
plt.figure(figsize=(30,30))
sns.heatmap(
    weights.numpy()[0][head_num],xticklabels=[idx_word[idx] for idx in inp[0].numpy()],yticklabels=[idx_word[idx] for idx in inp[0].numpy()]
)

这是示例输出: AttentionScore