问题描述
我使用tensorflow2训练变压器,并使用edit_distance进行损耗
def convert2(lang,tensor):
result = ''
for t in tensor:
if t == 1 or t==2:continue
if t!=0:
result += lang.index_word[t]
return result
def distLoss(inp,tar):
inpS,tarS,losses = [],[],[]
for i in inp.numpy():
inpS.append(convert2(lang,i))
for i in tar.numpy():
tarS.append(convert2(lang,i))
for i,t in zip(inpS,tarS):
losses.append(1-Levenshtein.ratio(i,t))
return tf.convert_to_tensor(sum(losses)/len(losses),dtype=tf.float32)
def loss(inp,tar):
return tf.py_function(distLoss,[inp,tar],tf.float32)
convert2 函数将令牌张量转换为字符串, 我通过 softargmax
将转换器logit转换为令牌张量def softargmax(x,beta=1e10):
x_range = tf.range(x.shape.as_list()[-1],dtype=x.dtype)
return tf.reduce_sum(tf.nn.softmax(x*beta) * x_range,axis=-1)
然后训练模型:
@tf.function
def train_step(inp,tar):
tar_inp = tar[:,:-1]
tar_real = tar[:,1:]
enc_padding_mask,combined_mask,dec_padding_mask = create_masks(inp,tar_inp)
with tf.GradientTape() as g_tape:
predictions,_ = gen(inp,tar_inp,True,enc_padding_mask,dec_padding_mask)
predicted_id = softargmax(predictions)
loss_g = Loss(predicted_id,tar_real)
g_gradients = g_tape.gradient(loss_g,gen.trainable_variables)
g_optimizer.apply_gradients(zip(g_gradients,gen.trainable_variables))
return loss_g
然后得到警告和错误:
WARNING:tensorflow:The dtype of the source tensor must be floating (e.g. tf.float32) when calling GradientTape.gradient,got tf.int32
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-101-e6f3bf8cf9d7> in <module>
7
8 for (batch,(inp,targ)) in enumerate(dataset.take(steps_per_epoch)):
----> 9 loss_g,loss_d = train_step(inp,targ)
10 total_g_loss += loss_g
11 total_d_loss += loss_d
~\AppData\Roaming\Python\python37\site-packages\tensorflow\python\eager\def_function.py in __call__(self,*args,**kwds)
778 else:
779 compiler = "nonXla"
--> 780 result = self._call(*args,**kwds)
781
782 new_tracing_count = self._get_tracing_count()
~\AppData\Roaming\Python\python37\site-packages\tensorflow\python\eager\def_function.py in _call(self,**kwds)
844 *args,**kwds)
845 # If we did not create any variables the trace we have is good enough.
--> 846 return self._concrete_stateful_fn._filtered_call(canon_args,canon_kwds) # pylint: disable=protected-access
847
848 def fn_with_cond(*inner_args,**inner_kwds):
~\AppData\Roaming\Python\python37\site-packages\tensorflow\python\eager\function.py in _filtered_call(self,args,kwargs,cancellation_manager)
1846 resource_variable_ops.BaseResourceVariable))],1847 captured_inputs=self.captured_inputs,-> 1848 cancellation_manager=cancellation_manager)
1849
1850 def _call_flat(self,captured_inputs,cancellation_manager=None):
~\AppData\Roaming\Python\python37\site-packages\tensorflow\python\eager\function.py in _call_flat(self,cancellation_manager)
1922 # No tape is watching; skip to running the function.
1923 return self._build_call_outputs(self._inference_function.call(
-> 1924 ctx,cancellation_manager=cancellation_manager))
1925 forward_backward = self._select_forward_and_backward_functions(
1926 args,~\AppData\Roaming\Python\python37\site-packages\tensorflow\python\eager\function.py in call(self,ctx,cancellation_manager)
548 inputs=args,549 attrs=attrs,--> 550 ctx=ctx)
551 else:
552 outputs = execute.execute_with_cancellation(
~\AppData\Roaming\Python\python37\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name,num_outputs,inputs,attrs,name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle,device_name,op_name,---> 60 inputs,num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: TypeError: Cannot convert 0.0 to EagerTensor of dtype int32
Traceback (most recent call last):
File "C:\Users\islab\AppData\Roaming\Python\python37\site-packages\tensorflow\python\ops\script_ops.py",line 242,in __call__
return func(device,token,args)
File "C:\Users\islab\AppData\Roaming\Python\python37\site-packages\tensorflow\python\ops\script_ops.py",line 143,in __call__
for (x,dtype) in zip(ret,self._out_dtypes)
File "C:\Users\islab\AppData\Roaming\Python\python37\site-packages\tensorflow\python\ops\script_ops.py",in <listcomp>
for (x,line 119,in _convert
return constant_op.constant(0.0,dtype=dtype)
File "C:\Users\islab\AppData\Roaming\Python\python37\site-packages\tensorflow\python\framework\constant_op.py",line 264,in constant
allow_broadcast=True)
File "C:\Users\islab\AppData\Roaming\Python\python37\site-packages\tensorflow\python\framework\constant_op.py",line 275,in _constant_impl
return _constant_eager_impl(ctx,value,dtype,shape,verify_shape)
File "C:\Users\islab\AppData\Roaming\Python\python37\site-packages\tensorflow\python\framework\constant_op.py",line 300,in _constant_eager_impl
t = convert_to_eager_tensor(value,dtype)
File "C:\Users\islab\AppData\Roaming\Python\python37\site-packages\tensorflow\python\framework\constant_op.py",line 98,in convert_to_eager_tensor
return ops.EagerTensor(value,ctx.device_name,dtype)
TypeError: Cannot convert 0.0 to EagerTensor of dtype int32
[[node gradient_tape/EagerPyFunc (defined at <ipython-input-96-da6c53e9fb65>:30) ]]
(1) Invalid argument: TypeError: Cannot convert 0.0 to EagerTensor of dtype int32
Traceback (most recent call last):
File "C:\Users\islab\AppData\Roaming\Python\python37\site-packages\tensorflow\python\ops\script_ops.py",dtype)
TypeError: Cannot convert 0.0 to EagerTensor of dtype int32
[[node gradient_tape/EagerPyFunc (defined at <ipython-input-96-da6c53e9fb65>:30) ]]
[[GroupCrossDeviceControlEdges_0/Adam/Adam/Const/_85]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_step_57138]
Errors may have originated from an input operation.
Input Source operations connected to node gradient_tape/EagerPyFunc:
EagerPyFunc (defined at <ipython-input-89-05e412a738e7>:36)
Input Source operations connected to node gradient_tape/EagerPyFunc:
EagerPyFunc (defined at <ipython-input-89-05e412a738e7>:36)
Function call stack:
train_step -> train_step
WARNING:tensorflow:The dtype of the source tensor must be floating (e.g. tf.float32) when calling GradientTape.gradient,got tf.int32
是什么
意思 ?问题来自softargmax?如何将变压器logit转换为令牌?
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)