具有K折交叉验证的BERT文本分类返回“目标3超出范围”错误

问题描述

我正在努力训练一个句子数据集,这些句子将被分为三个类别之一。我正在尝试根据这篇非常有用的中篇文章https://medium.com/swlh/k-fold-as-cross-validation-with-a-bert-text-classification-example-4017f76a863a)中的代码进行此操作。我的代码几乎相同:

n=5
kf = KFold(n_splits=n,shuffle=True)

results = []

for train_index,val_index in kf.split(data):
  # splitting Dataframe (dataset not included)
    train_df = data.iloc[train_index]
    val_df = data.iloc[val_index]
    # Defining Model
    model = ClassificationModel('bert','bert-base-uncased',use_cuda=False) 
  # train the model
    model.train_model(train_df)
  # validate the model 
    result,model_outputs,wrong_predictions = model.eval_model(val_df,acc=accuracy_score)
    print(result['acc'])
  # append model score
    results.append(result['acc'])

在“ model.train_model(train_df)”行,我不断收到索引错误"IndexError: Target 3 is out of bounds."

有人可以帮助我理解为什么会这样吗?

解决方法

您没有在模型构造函数中指定目标类的数量,因此它默认为默认配置文件中的任何内容(请参见the code where it happens)。当目标类别中出现数字3时,它突然超出了预测范围。

在构造函数中使用num_labels参数(请参见the documentation)。