问题描述
这是tensor2tensor中的解码部分。 您还可以在 https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
中从1155行到1298行中找到它def fast_decode(encoder_output,encoder_decoder_attention_bias,symbols_to_logits_fn,hparams,decode_length,vocab_size,init_cache_fn=_init_transformer_cache,beam_size=1,top_beams=1,alpha=1.0,sos_id=0,eos_id=beam_search.EOS_ID,batch_size=None,force_decode_length=False,scope_prefix="body/",cache=None):
if encoder_output is not None:
batch_size = common_layers.shape_list(encoder_output)[0]
cache = init_cache_fn(
cache=cache,hparams=hparams,batch_size=batch_size,attention_init_length=0,encoder_output=encoder_output,encoder_decoder_attention_bias=encoder_decoder_attention_bias,scope_prefix=scope_prefix)
def inner_loop(i,hit_eos,next_id,decoded_ids,cache,log_prob):
"""One step of greedy decoding."""
logits,cache = symbols_to_logits_fn(next_id,i,cache)
log_probs = common_layers.log_prob_from_logits(logits)
temperature = getattr(hparams,"sampling_temp",0.0)
keep_top = getattr(hparams,"sampling_keep_top_k",-1)
if hparams.sampling_method == "argmax":
temperature = 0.0
next_id = common_layers.sample_with_temperature(
logits,temperature,keep_top)
hit_eos |= tf.equal(next_id,eos_id)
log_prob_indices = tf.stack([tf.range(tf.to_int64(batch_size)),next_id],axis=1)
log_prob += tf.gather_nd(log_probs,log_prob_indices)
next_id = tf.expand_dims(next_id,axis=1)
decoded_ids = tf.concat([decoded_ids,axis=1)
return i + 1,log_prob
def is_not_finished(i,*_):
finished = i >= decode_length
if not force_decode_length:
finished |= tf.reduce_all(hit_eos)
return tf.logical_not(finished)
decoded_ids = tf.zeros([batch_size,0],dtype=tf.int64)
hit_eos = tf.fill([batch_size],False)
next_id = sos_id * tf.ones([batch_size,1],dtype=tf.int64)
initial_log_prob = tf.zeros([batch_size],dtype=tf.float32)
_,_,log_prob = tf.while_loop(
is_not_finished,inner_loop,[
tf.constant(0),initial_log_prob
],shape_invariants=[
tf.TensorShape([]),tf.TensorShape([None]),tf.TensorShape([None,None]),nest.map_structure(beam_search.get_state_shape_invariants,cache),])
scores = log_prob
return {"outputs": decoded_ids,"scores": scores,"cache": cache}
有什么方法可以避免产生重复的单词? 例如,该模型可能会生成“我吃吃吃吃吃苹果”。
我认为获得登录后
logits,cache)
也许有一种方法可以减少logits中先前id的值,但是我不知道该怎么做。
任何想法?
谢谢
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)