在测试 bert 模型期间分配权重

问题描述

我有一个基本的概念性疑问。当我在句子上训练 bert 模型时说:

Train: "went to get loan from bank" 
Test :"received education loan from bank"

测试句子如何为每个标记分配权重,因为我没有通过准确的句子进行测试,并且稍微添加了一些像“教育”这样的词,这会稍微改变上下文

假设在我的模型中没有训练这样的上下文,在我进一步微调之前如何为我的 bert 中的每个标记分配权重

如果我对我的问题感到困惑,简单地说,我试图了解如果未经过训练的上下文发生轻微变化,则在测试期间如何分配权重。

解决方法

token 的向量表示(记住 token != word)存储在嵌入层中。当我们加载 'bert-base-uncased' 模型时,我们可以看到它“知道”了 30522 个标记,并且每个标记的向量表示由 768 个元素组成:

from transformers import BertModel
bert= BertModel.from_pretrained('bert-base-uncased')
print(bert.embeddings.word_embeddings)

输出:

Embedding(30522,768,padding_idx=0)

这个嵌入层不知道任何字符串,但知道 id。例如,id 101 的向量表示为:

print(bert.embeddings.word_embeddings.weight[101])

输出:

tensor([ 1.3630e-02,-2.6490e-02,-2.3503e-02,-7.7876e-03,8.5892e-03,-7.6645e-03,-9.8808e-03,6.0184e-03,4.6921e-03,-3.0984e-02,1.8883e-02,-6.0093e-03,-1.6652e-02,1.1684e-02,-3.6245e-02,8.3482e-03,-1.2112e-03,1.0322e-02,1.6692e-02,-3.0354e-02,...
         5.4162e-03,-3.0037e-02,8.6773e-03,-1.7942e-03,6.6826e-03,-1.1929e-02,-1.4076e-02,1.6709e-02,1.6860e-03,-3.3842e-03,8.6805e-03,7.1340e-03,1.5147e-02],grad_fn=<SelectBackward>)

BERT 无法处理“已知”ID 之外的所有内容。要回答您的问题,我们需要查看将字符串映射到 id 的组件。该组件称为标记器。有不同的标记化 approaches。 BERT 使用 WordPiece 分词器,它是一种子字算法。该算法将所有无法从其词汇表中创建的内容替换为词汇表的一部分未知标记(原始实现中的[UNK],id:100 ).

请查看以下小示例,其中从头开始训练 WordPiece 分词器以确认该行为:

from tokenizers import BertWordPieceTokenizer
path ='file_with_your_trainings_sentence.txt'
tokenizer = BertWordPieceTokenizer()
tokenizer.train(files=path,vocab_size=30000,special_tokens=['[UNK]','[SEP]','[PAD]','[CLS]','[MASK]'])
otrain = tokenizer.encode("went to get loan from bank")
otest =  tokenizer.encode("received education loan from bank")

print('Vocabulary size: {}'.format(tokenizer.get_vocab_size()))
print('Train tokens: {}'.format(otrain.tokens))
print('Test tokens: {}'.format(otest.tokens))

输出:

Vocabulary size: 27
Train tokens: ['w','##e','##n','##t','t','##o','g','l','##an','f','##r','##m','b','##k']
Test tokens: ['[UNK]','[UNK]','##k']

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...