问题描述
我使用 simpletransformers
库训练了一个 T5 转换器。
pred_values = model.predict(input_values)
但是,它只返回top或greedy的预测,我怎样才能得到10个top结果?
解决方法
必需的参数是num_return_sequences
,表示要生成的样本数。但是,如果要使用波束搜索算法,还应该为波束搜索设置一个数字。
model_args = T5Args()
model_args.num_beams = 5
model_args.num_return_sequences = 2
或者,您可以使用 top_k
或 top_p
来生成和选择顶级样本,在这些情况下,您必须将 do_sample
设置为 True
。关于参数的更多信息参见[1]和[2],其中有详细说明。
model_args = T5Args()
model_args.do_sample = True
model_args.top_p = 0.9
model_args.num_return_sequences = 2