(0) 不可用:{{function_node __inference_train_function_53748}}

问题描述

我正在研究 colab 并使用 TPU,但不幸的是它无法正常工作并且模型在拟合时遇到问题..

这是我的代码

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_address)
tf.config.experimental_connect_to_cluster(resolver)
strategy = tf.distribute.TPUStrategy(resolver)

with strategy.scope():
  model = create_model(input_shape=(HEIGHT,WIDTH,CANAL),n_out=N_CLASSES)

  for layer in model.layers:
    layer.trainable = False

  for i in range(-5,0):
    model.layers[i].trainable = True
  
  es = EarlyStopping(monitor='val_loss',mode='min',patience=ES_PATIENCE,restore_best_weights=True,verbose=1)
  rlrop = ReduceLROnPlateau(monitor='val_loss',patience=RLROP_PATIENCE,factor=DECAY_DROP,min_lr=1e-6,verbose=1)

  callback_list = [es,rlrop]
  optimizer = optimizers.Adam(lr=LEARNING_RATE)


  model.compile(optimizer = optimizers.Adam(lr=WARMUP_LEARNING_RATE),loss = 'categorical_crossentropy',metrics = ['accuracy'])
  
  model.summary()

STEP_SIZE_TRAIN = train_generator.n//train_generator.batch_size
STEP_SIZE_VALID = valid_generator.n//valid_generator.batch_size


history_finetunning = model.fit_generator(generator=train_generator,steps_per_epoch=STEP_SIZE_TRAIN,epochs=EPOCHS,validation_data=valid_generator,validation_steps=STEP_SIZE_VALID,verbose =1)



这是错误..

/usr/local/lib/python3.7/dist-packages/keras/engine/training.py:1915: UserWarning: `Model.fit_generator` is deprecated and will be removed in a future version. Please use `Model.fit`,which supports generators.
  warnings.warn('`Model.fit_generator` is deprecated and '
Epoch 1/40
---------------------------------------------------------------------------
UnavailableError                          Traceback (most recent call last)
<ipython-input-41-1c157bad2449> in <module>()
      4                                           validation_data=valid_generator,5                                           validation_steps=STEP_SIZE_VALID,----> 6                                           verbose =1)

14 frames
/usr/local/lib/python3.7/dist-packages/keras/engine/training.py in fit_generator(self,generator,steps_per_epoch,epochs,verbose,callbacks,validation_data,validation_steps,validation_freq,class_weight,max_queue_size,workers,use_multiprocessing,shuffle,initial_epoch)
   1930         use_multiprocessing=use_multiprocessing,1931         shuffle=shuffle,-> 1932         initial_epoch=initial_epoch)
   1933 
   1934   def evaluate_generator(self,/usr/local/lib/python3.7/dist-packages/keras/engine/training.py in fit(self,x,y,batch_size,validation_split,sample_weight,initial_epoch,validation_batch_size,use_multiprocessing)
   1161               logs = tmp_logs  # No error,Now safe to assign to logs.
   1162               end_step = step + data_handler.step_increment
-> 1163               callbacks.on_train_batch_end(end_step,logs)
   1164               if self.stop_training:
   1165                 break

/usr/local/lib/python3.7/dist-packages/keras/callbacks.py in on_train_batch_end(self,batch,logs)
    434     """
    435     if self._should_call_train_batch_hooks:
--> 436       self._call_batch_hook(ModeKeys.TRAIN,'end',logs=logs)
    437 
    438   def on_test_batch_begin(self,logs=None):

/usr/local/lib/python3.7/dist-packages/keras/callbacks.py in _call_batch_hook(self,mode,hook,logs)
    276       self._call_batch_begin_hook(mode,logs)
    277     elif hook == 'end':
--> 278       self._call_batch_end_hook(mode,logs)
    279     else:
    280       raise ValueError('Unrecognized hook: {}'.format(hook))

/usr/local/lib/python3.7/dist-packages/keras/callbacks.py in _call_batch_end_hook(self,logs)
    296       self._batch_times.append(batch_time)
    297 
--> 298     self._call_batch_hook_helper(hook_name,logs)
    299 
    300     if len(self._batch_times) >= self._num_batches_for_timing_check:

/usr/local/lib/python3.7/dist-packages/keras/callbacks.py in _call_batch_hook_helper(self,hook_name,logs)
    336       hook = getattr(callback,hook_name)
    337       if getattr(callback,'_supports_tf_logs',False):
--> 338         hook(batch,logs)
    339       else:
    340         if numpy_logs is None:  # Only convert once.

/usr/local/lib/python3.7/dist-packages/keras/callbacks.py in on_train_batch_end(self,logs)
   1042 
   1043   def on_train_batch_end(self,logs=None):
-> 1044     self._batch_update_progbar(batch,logs)
   1045 
   1046   def on_test_batch_end(self,logs=None):

/usr/local/lib/python3.7/dist-packages/keras/callbacks.py in _batch_update_progbar(self,logs)
   1106     if self.verbose == 1:
   1107       # Only block async when verbose = 1.
-> 1108       logs = tf_utils.sync_to_numpy_or_python_type(logs)
   1109       self.progbar.update(self.seen,list(logs.items()),finalize=False)
   1110 

/usr/local/lib/python3.7/dist-packages/keras/utils/tf_utils.py in sync_to_numpy_or_python_type(tensors)
    505     return t  # Don't turn ragged or sparse tensors to NumPy.
    506 
--> 507   return tf.nest.map_structure(_to_single_numpy_or_python_type,tensors)
    508 
    509 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/nest.py in map_structure(func,*structure,**kwargs)
    865 
    866   return pack_sequence_as(
--> 867       structure[0],[func(*x) for x in entries],868       expand_composites=expand_composites)
    869 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/nest.py in <listcomp>(.0)
    865 
    866   return pack_sequence_as(
--> 867       structure[0],868       expand_composites=expand_composites)
    869 

/usr/local/lib/python3.7/dist-packages/keras/utils/tf_utils.py in _to_single_numpy_or_python_type(t)
    501   def _to_single_numpy_or_python_type(t):
    502     if isinstance(t,tf.Tensor):
--> 503       x = t.numpy()
    504       return x.item() if np.ndim(x) == 0 else x
    505     return t  # Don't turn ragged or sparse tensors to NumPy.

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in numpy(self)
   1092     """
   1093     # Todo(slebedev): Consider avoiding a copy for non-cpu or remote tensors.
-> 1094     maybe_arr = self._numpy()  # pylint: disable=protected-access
   1095     return maybe_arr.copy() if isinstance(maybe_arr,np.ndarray) else maybe_arr
   1096 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in _numpy(self)
   1060       return self._numpy_internal()
   1061     except core._NotOkStatusException as e:  # pylint: disable=protected-access
-> 1062       six.raise_from(core._status_to_exception(e.code,e.message),None)  # pylint: disable=protected-access
   1063 
   1064   @property

/usr/local/lib/python3.7/dist-packages/six.py in raise_from(value,from_value)

UnavailableError: 3 root error(s) found.
  (0) Unavailable: {{function_node __inference_train_function_53748}} Failed to connect to all addresses
Additional GRPC error information from remote target /job:localhost/replica:0/task:0/device:cpu:0:
:{"created":"@1626347736.544045826","description":"Failed to pick subchannel","file":"third_party/grpc/src/core/ext/filters/client_channel/client_channel.cc","file_line":5420,"referenced_errors":[{"created":"@1626347735.785465323","description":"Failed to connect to all addresses","file":"third_party/grpc/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc","file_line":398,"grpc_status":14}]}
     [[{{node MultideviceIteratorGetNextFromShard}}]]
     [[RemoteCall]]
     [[IteratorGetNextAsOptional]]
     [[cond_11/switch_pred/_107/_76]]
  (1) Unavailable: {{function_node __inference_train_function_53748}} Failed to connect to all addresses
Additional GRPC error information from remote target /job:localhost/replica:0/task:0/device:cpu:0:
:{"created":"@1626347736.544045826","grpc_status":14}]}
     [[{{node MultideviceIteratorGetNextFromShard}}]]
     [[RemoteCall]]
     [[IteratorGetNextAsOptional]]
     [[cluster_train_function/_execute_2_0/_333]]
  (2) Unavailable: {{function_node __inference_train_function_53748}} Failed to connect to all addresses
Additional GRPC error information from remote target /job:localhost/replica:0/task:0/device:cpu:0:
:{"created":"@1626347736.544045826","grpc_status":14}]}
     [[{{node MultideviceIteratorGetNextFromShard}}]]
     [[RemoteCall]]
     [[IteratorGetNextAsOptional]]
0 successful operations.
6 derived errors ignored. 

对来自 tensorflow_datasets 的 tfds 使用相同的 TPU 配置代码并使用 model.fit 方法拟合模型不会引发任何错误

解决方法

Model.fit_generator 已弃用且不适用于 TPU。

尝试使用 tf.keras.preprocessing.image_dataset_from_directorytf.data.Dataset 并将其与 Keras 预处理 layers 结合起来