Hugging facepytorch变压器上的GPT2 RuntimeError:只能为标量输出隐式创建grad

问题描述

我正在尝试使用我的自定义数据集微调gpt2。我使用拥抱式变压器的文档创建了一个基本示例。我收到提到的错误。我知道这是什么意思:(基本上是在非标量张量上向后调用),但是由于我几乎只使用API​​调用,所以我不知道如何解决此问题。有什么建议吗?

from pathlib import Path
from absl import flags,app
import IPython
import torch
from transformers import GPT2LMHeadModel,Trainer,TrainingArguments
from data_reader import GetDataAsPython

# this is my custom data,but i get the same error for the basic case below
# data = GetDataAsPython('data.json')
# data = [data_point.GetText2Text() for data_point in data]
# print("Number of data samples is",len(data))

data = ["this is a trial text","this is another trial text"]

train_texts = data

from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

special_tokens_dict = {'pad_token': '<PAD>'}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
train_encodigs = tokenizer(train_texts,truncation=True,padding=True)


class BugFixDataset(torch.utils.data.Dataset):
    def __init__(self,encodings):
        self.encodings = encodings
    
    def __getitem__(self,index):
        item = {key: torch.tensor(val[index]) for key,val in self.encodings.items()}
        return item

    def __len__(self):
        return len(self.encodings['input_ids'])

train_dataset = BugFixDataset(train_encodigs)

training_args = TrainingArguments(
    output_dir='./results',num_train_epochs=3,per_device_train_batch_size=1,per_device_eval_batch_size=1,warmup_steps=500,weight_decay=0.01,logging_dir='./logs',logging_steps=10,)

model = GPT2LMHeadModel.from_pretrained('gpt2',return_dict=True)
model.resize_token_embeddings(len(tokenizer))

trainer = Trainer(
    model=model,args=training_args,train_dataset=train_dataset,)

trainer.train()

解决方法

我终于明白了。问题在于数据样本不包含目标输出。即使艰难的gpt也是自我监督的,也必须明确告知模型。

您必须添加以下行:

item['labels'] = torch.tensor(self.encodings['input_ids'][index])

访问数据集类的 getitem 函数,然后运行正常!

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...