自定义仅用于培训过程的图层

问题描述

Env:Tensorflow2.3.0 python3.6

我正在尝试为训练过程自定义图层以进行图像增强。这是我的代码

class RandomLight(layers.Layer):
def __init__(self,factor=0.2):
    super(RandomLight,self).__init__()
    self.factor = factor

def call(self,input,training=None):
    return tf.cond(training,lambda: tf.clip_by_value(tf.image.random_brightness(input,self.factor),1),lambda: input)

以及当我要将其放入网络时:

import tensorflow.keras as keras
import tensorflow.keras.layers as layers
from tensorflow.keras.applications import VGG16

inputs = keras.Input(shape=(224,224,3))
vgg16 = VGG16(include_top=False,weights='imagenet',input_shape=(224,3))
data_augmentation = keras.Sequential(
[
    layers.experimental.preprocessing.Randomrotation(0.25),layers.experimental.preprocessing.RandomFlip(),RandomLight()
])
i1 = data_augmentation(inputs)
bn = layers.Batchnormalization()(i1)
x = vgg16(bn)
flat_out = layers.Flatten()(x)
h1 = layers.Dense(1024,activation='relu',name='fc1')(flat_out)
h2 = layers.Dropout(0.5)(h1)
h3 = layers.Dense(32,name='fc2')(h2)
h4 = layers.Dropout(0.5)(h3)
new_out = layers.Dense(1,activation='sigmoid',name='prediction')(h4)
vgg_ft = keras.Model(inputs,new_out)

错误似乎与“ training = None”有关

ValueError                                Traceback (most recent call last)
<ipython-input-290-966a2fabc71b> in <module>()
----> 1 inputs = data_augmentation(inputs)
  2 inputs = randomLight(inputs)
  3 bn = layers.Batchnormalization()(inputs)
  4 x = vgg16(bn)
  5 flat_out = layers.Flatten()(x)

F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __call__(self,*args,**kwargs)
    924     if _in_functional_construction_mode(self,inputs,args,kwargs,input_list):
    925       return self._functional_construction_call(inputs,--> 926                                                 input_list)
    927 
    928     # Maintains info about the `Layer.call` stack.

F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in _functional_construction_call(self,input_list)
   1115           try:
   1116             with ops.enable_auto_cast_variables(self._compute_dtype_object):
-> 1117               outputs = call_fn(cast_inputs,**kwargs)
   1118 
   1119           except errors.OperatorNotAllowedInGraphError as e:

F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\autograph\impl\api.py in wrapper(*args,**kwargs)
    256       except Exception as e:  # pylint:disable=broad-except
    257         if hasattr(e,'ag_error_Metadata'):
--> 258           raise e.ag_error_Metadata.to_exception(e)
    259         else:
    260           raise

ValueError: in user code:
<ipython-input-278-87ec004f05b3>:11 call  *
    lambda: input)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\util\dispatch.py:201 wrapper  **
    return target(*args,**kwargs)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\ops\control_flow_ops.py:1396 cond_for_tf_v2
    return cond(pred,true_fn=true_fn,false_fn=false_fn,strict=True,name=name)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\util\dispatch.py:201 wrapper
    return target(*args,**kwargs)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\util\deprecation.py:507 new_func
    return func(*args,**kwargs)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\ops\control_flow_ops.py:1180 cond
    return cond_v2.cond_v2(pred,true_fn,false_fn,name)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\ops\cond_v2.py:74 cond_v2
    pred = ops.convert_to_tensor(pred)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\framework\ops.py:1499 convert_to_tensor
    ret = conversion_func(value,dtype=dtype,name=name,as_ref=as_ref)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\framework\constant_op.py:338 _constant_tensor_conversion_function
    return constant(v,name=name)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\framework\constant_op.py:264 constant
    allow_broadcast=True)
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\framework\constant_op.py:282 _constant_impl
    allow_broadcast=allow_broadcast))
F:\Anaconda3\envs\tf\lib\site-packages\tensorflow\python\framework\tensor_util.py:444 make_tensor_proto
    raise ValueError("None values not supported.")

ValueError: None values not supported.

我也尝试过training=False,但这也不起作用。

看来Sequential()在我的自定义图层上效果很好,但是如何以我的格式使用

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)