如何使用蒸馏器模型预测测试序列?

问题描述

我尝试使用带有蒸馏器模型的 Ktrain 来预测测试序列,我的代码如下所示:

trn,val,preproc = text.texts_from_array(x_train=x_train,y_train=y_train,x_test=x_test,y_test=y_test,class_names=train_b.target_names,preprocess_mode='distilbert',maxlen=350)
model = text.text_classifier('distilbert',train_data=trn,preproc=preproc,multilabel=True)
learner = ktrain.get_learner(model,val_data=val,batch_size=64)
y_pred = learner.model.predict(val,verbose = 0)

在 Ktrain 的 nbsvm、fasttext、bigru 等模型的其他实现中,它非常简单,因为 texts_from_array 函数返回一个 numpy 数组,但使用 distilbert 返回一个 TransformerDataset,因此无法使用 learner.model.predict 预测序列() 因为它会生成一个 python 索引异常。鉴于我有标签分类问题,我也不可能使用 validate() 方法生成混淆矩阵。我的问题是如何使用 distilbert 使用 Ktrain 对测试序列进行测试,我对此的需求来自这样一个事实,即我的度量函数是基于 sklearn.metric 库实现的,它需要 numpy 格式的测试和验证序列。

解决方法

您可以使用 tutorial 中所示的 Predictor 实例。

Predictor 只是使用 preproc 对象将原始文本转换为模型期望的格式并将其提供给模型。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...