想要获得 T5 变压器的顶级生成文本?

问题描述

我使用 simpletransformers 库训练了一个 T5 转换器。

这是获取预测的代码

pred_values = model.predict(input_values)

但是,它只返回top或greedy的预测,我怎样才能得到10个top结果?

解决方法

必需的参数是num_return_sequences,表示要生成的样本数。但是,如果要使用波束搜索算法,还应该为波束搜索设置一个数字。

model_args = T5Args()
model_args.num_beams = 5
model_args.num_return_sequences = 2

或者,您可以使用 top_ktop_p 来生成和选择顶级样本,在这些情况下,您必须将 do_sample 设置为 True。关于参数的更多信息参见[1]和[2],其中有详细说明。

model_args = T5Args()
model_args.do_sample = True
model_args.top_p = 0.9
model_args.num_return_sequences = 2

[1] https://simpletransformers.ai/docs/t5-model/

[2] https://huggingface.co/blog/how-to-generate