TPU 不支持动态空间卷积

问题描述

我正在尝试使用 Keras (TF 2.3.1) 模型进行图像分类,并将多个二进制标签作为输出。该模型由 Xception CNN + 注意力层 + 密集分类器组成,并且仅在某些 TPU 上出现错误UnimplementedError: {{function_node __inference_train_function_644557}} Compilation failure: Dynamic Spatial Convolution is not supported。这kaggle TPU 上失败,但在 Colab 上没有 - 在 TF 版本 2.3.1 上测试。

我正在寻找 here,但建议的解决方案暗示未设置图像尺寸,此处并非如此。 train_df 属于 <PrefetchDataset shapes: ((None,750,3),(None,11)),types: (tf.float32,tf.int64)> 类型,因此每个图像的大小为 750x750x3。根据下面的模型摘要,每一层都有一个定义的输出形状,因此跟随它们的层应该正确推断它们的输入形状。

错误来看,问题似乎出在attn_layer = LocallyConnected2D(...定义的层上。传递 implementation = 2 是一种解决方法,可以让训练完成,但这不适用于大型模型(请参阅 LocallyConnected2D documentation

建模代码

import tensorflow as tf
from tensorflow.keras import models,layers
from tensorflow.keras.callbacks import ModelCheckpoint,EarlyStopping,ReduceLROnPlateau
from tensorflow.keras.applications import Xception
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import GlobalAveragePooling2D,Dense,Dropout,Flatten,Input,Conv2D,multiply,LocallyConnected2D,Lambda,Batchnormalization
from tensorflow.keras.models import Model
from tensorflow.keras.metrics import mean_absolute_error

def create_model():
    input_shape = (TARGET_SIZE,TARGET_SIZE,3)
    in_lay = Input(input_shape)
    conv_base = Xception(include_top = False,weights = 'imagenet',input_shape = input_shape)
    pt_features = conv_base(in_lay)
    bn_features = Batchnormalization()(pt_features)

    # here we do an attention mechanism to turn pixels in the GAP on an off
    attn_layer = Conv2D(64,kernel_size = (1,1),padding = 'same',activation = 'relu')(bn_features)
    attn_layer = Conv2D(16,activation = 'relu')(attn_layer)
    attn_layer = LocallyConnected2D(1,padding = 'valid',activation = 'sigmoid')(attn_layer)
    # fan it out to all of the channels
    pt_depth = conv_base.get_output_shape_at(0)[-1]
    up_c2_w = np.ones((1,1,pt_depth))
    up_c2 = Conv2D(pt_depth,activation = 'linear',use_bias = False,weights = [up_c2_w])
    up_c2.trainable = False
    attn_layer = up_c2(attn_layer)

    mask_features = multiply([attn_layer,bn_features])
    gap_features = GlobalAveragePooling2D()(mask_features)
    gap_mask = GlobalAveragePooling2D()(attn_layer)
    # to account for missing values from the attention model
    gap = Lambda(lambda x: x[0]/x[1],name = 'RescaleGAP')([gap_features,gap_mask])
    gap_dr = Dropout(0.5)(gap)
    dr_steps = Dropout(0.25)(Dense(1024,activation = 'elu')(gap_dr))
    out_layer = Dense(11,activation = 'sigmoid')(dr_steps)
    model = Model(inputs = [in_lay],outputs = [out_layer])
    model.compile(optimizer = Adam(lr = 0.002),loss = 'binary_crossentropy',metrics = ["AUC"])
    return model


with tpu_strategy.scope():
    model = create_model()
model.summary()

history = model.fit(
    train_df,epochs = EPOCHS,steps_per_epoch = STEPS_PER_EPOCH,validation_data = valid_df,validation_steps = VALIDATION_STEPS
)

生成的模型摘要

Model: "model_8"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_19 (InputLayer)           [(None,3) 0                                            
__________________________________________________________________________________________________
xception (Model)                (None,24,2048) 20861480    input_19[0][0]                   
__________________________________________________________________________________________________
batch_normalization_49 (BatchNo (None,2048) 8192        xception[1][0]                   
__________________________________________________________________________________________________
conv2d_67 (Conv2D)              (None,64)   131136      batch_normalization_49[0][0]     
__________________________________________________________________________________________________
conv2d_68 (Conv2D)              (None,16)   1040        conv2d_67[0][0]                  
__________________________________________________________________________________________________
locally_connected2d_9 (LocallyC (None,1)    9792        conv2d_68[0][0]                  
__________________________________________________________________________________________________
conv2d_69 (Conv2D)              (None,2048) 2048        locally_connected2d_9[0][0]      
__________________________________________________________________________________________________
multiply_9 (Multiply)           (None,2048) 0           conv2d_69[0][0]                  
                                                                 batch_normalization_49[0][0]     
__________________________________________________________________________________________________
global_average_pooling2d_23 (Gl (None,2048)         0           multiply_9[0][0]                 
__________________________________________________________________________________________________
global_average_pooling2d_24 (Gl (None,2048)         0           conv2d_69[0][0]                  
__________________________________________________________________________________________________
RescaleGAP (Lambda)             (None,2048)         0           global_average_pooling2d_23[0][0]
                                                                 global_average_pooling2d_24[0][0]
__________________________________________________________________________________________________
dropout_18 (Dropout)            (None,2048)         0           RescaleGAP[0][0]                 
__________________________________________________________________________________________________
dense_17 (Dense)                (None,1024)         2098176     dropout_18[0][0]                 
__________________________________________________________________________________________________
dropout_19 (Dropout)            (None,1024)         0           dense_17[0][0]                   
__________________________________________________________________________________________________
dense_18 (Dense)                (None,11)           11275       dropout_19[0][0]                 
==================================================================================================
Total params: 23,123,139
Trainable params: 23,062,467
Non-trainable params: 60,672
__________________________________________________________________________________________________

完整的堆栈跟踪 + 错误消息:

---------------------------------------------------------------------------
UnimplementedError                        Traceback (most recent call last)
<ipython-input-53-5130a0bcf331> in <module>
     19     validation_data = valid_df,20     validation_steps = VALIDATION_STEPS,---> 21     callbacks = [model_save,early_stop,reduce_lr]
     22 )

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self,*args,**kwargs)
     64   def _method_wrapper(self,**kwargs):
     65     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
---> 66       return method(self,**kwargs)
     67 
     68     # Running inside `run_distribute_coordinator` already.

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self,x,y,batch_size,epochs,verbose,callbacks,validation_split,validation_data,shuffle,class_weight,sample_weight,initial_epoch,steps_per_epoch,validation_steps,validation_batch_size,validation_freq,max_queue_size,workers,use_multiprocessing)
    853                 context.async_wait()
    854               logs = tmp_logs  # No error,Now safe to assign to logs.
--> 855               callbacks.on_train_batch_end(step,logs)
    856         epoch_logs = copy.copy(logs)
    857 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in on_train_batch_end(self,batch,logs)
    387     """
    388     if self._should_call_train_batch_hooks:
--> 389       logs = self._process_logs(logs)
    390       self._call_batch_hook(ModeKeys.TRAIN,'end',logs=logs)
    391 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in _process_logs(self,logs)
    263     """Turns tensors into numpy arrays or Python scalars."""
    264     if logs:
--> 265       return tf_utils.to_numpy_or_python_type(logs)
    266     return {}
    267 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/tf_utils.py in to_numpy_or_python_type(tensors)
    521     return t  # Don't turn ragged or sparse tensors to NumPy.
    522 
--> 523   return nest.map_structure(_to_single_numpy_or_python_type,tensors)
    524 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/nest.py in map_structure(func,*structure,**kwargs)
    615 
    616   return pack_sequence_as(
--> 617       structure[0],[func(*x) for x in entries],618       expand_composites=expand_composites)
    619 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/nest.py in <listcomp>(.0)
    615 
    616   return pack_sequence_as(
--> 617       structure[0],618       expand_composites=expand_composites)
    619 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/tf_utils.py in _to_single_numpy_or_python_type(t)
    517   def _to_single_numpy_or_python_type(t):
    518     if isinstance(t,ops.Tensor):
--> 519       x = t.numpy()
    520       return x.item() if np.ndim(x) == 0 else x
    521     return t  # Don't turn ragged or sparse tensors to NumPy.

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in numpy(self)
    959     """
    960     # Todo(slebedev): Consider avoiding a copy for non-cpu or remote tensors.
--> 961     maybe_arr = self._numpy()  # pylint: disable=protected-access
    962     return maybe_arr.copy() if isinstance(maybe_arr,np.ndarray) else maybe_arr
    963 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _numpy(self)
    927       return self._numpy_internal()
    928     except core._NotOkStatusException as e:
--> 929       six.raise_from(core._status_to_exception(e.code,e.message),None)
    930 
    931   @property

/opt/conda/lib/python3.7/site-packages/six.py in raise_from(value,from_value)

UnimplementedError: {{function_node __inference_train_function_644557}} Compilation failure: Dynamic Spatial Convolution is not supported: %convolution.30660 = f32[<=8,2048]{3,2,0} convolution(f32[<=8,1]{3,0} %add.30633,f32[1,0} %get-tuple-element.354),window={size=1x1},dim_labels=b01f_01io->b01f,Metadata={op_type="Conv2D" op_name="model_8/conv2d_69/Conv2D"}
    TPU compilation Failed
     [[{{node tpu_compile_succeeded_assert/_17367812259898276239/_5}}]]

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)