使用 TPU 在 Tensorflow 中加载 CSV 文件时出现问题

问题描述

我正在尝试在 Tensorflow (V2.4.1) 中加载 CSV 文件。我正在使用 tf.data.experimental.make_csv_dataset,虽然它在执行函数时没有引发任何错误,但在尝试迭代数据集时出现错误

我正在使用 TPU 加速的 kaggle 笔记本中运行它。如果我在 cpu 或 GPU 环境中执行相同的代码,则一切正常。

GCS_PATH = kaggleDatasets().get_gcs_path('mydsname')
fpath = GCS_PATH + '/train.csv'

train_ds = tf.data.experimental.make_csv_dataset(
        fpath,64,select_columns=['sentence','label'],column_defaults=[tf.string,tf.float32],label_name='label',num_epochs=3,shuffle=False)

for item in train_ds.take(1):
    print(item)

我之前还复制/粘贴了激活 Google Cloud SDK 的代码

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)

这是我得到的错误

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/iterator_ops.py in _next_internal(self)
    736         # Fast path for the case `self._structure` is not a nested structure.
--> 737         return self._element_spec._from_compatible_tensor_list(ret)  # pylint: disable=protected-access
    738       except AttributeError:

AttributeError: 'tuple' object has no attribute '_from_compatible_tensor_list'

During handling of the above exception,another exception occurred:

InvalidArgumentError                      Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/context.py in execution_mode(mode)
   2112       ctx.executor = executor_new
-> 2113       yield
   2114     finally:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/iterator_ops.py in _next_internal(self)
    738       except AttributeError:
--> 739         return structure.from_compatible_tensor_list(self._element_spec,ret)
    740 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/util/structure.py in from_compatible_tensor_list(element_spec,tensor_list)
    243       lambda spec,value: spec._from_compatible_tensor_list(value),--> 244       element_spec,tensor_list)
    245 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/util/structure.py in _from_tensor_list_helper(decode_fn,element_spec,tensor_list)
    218     value = tensor_list[i:i + num_flat_values]
--> 219     flat_ret.append(decode_fn(component_spec,value))
    220     i += num_flat_values

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/util/structure.py in <lambda>(spec,value)
    242   return _from_tensor_list_helper(
--> 243       lambda spec,244       element_spec,tensor_list)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/tensor_spec.py in _from_compatible_tensor_list(self,tensor_list)
    176     assert len(tensor_list) == 1
--> 177     tensor_list[0].set_shape(self._shape)
    178     return tensor_list[0]

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in set_shape(self,shape)
   1213   def set_shape(self,shape):
-> 1214     if not self.shape.is_compatible_with(shape):
   1215       raise ValueError(

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in shape(self)
   1174         # `EagerTensor`,in C.
-> 1175         self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple())
   1176       except core._NotOkStatusException as e:

InvalidArgumentError: Can't read header of file

During handling of the above exception,another exception occurred:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-12-935e50497dbb> in <module>
----> 1 for e in train_ds.take(1):
      2     pass

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/iterator_ops.py in __next__(self)
    745   def __next__(self):
    746     try:
--> 747       return self._next_internal()
    748     except errors.OutOfRangeError:
    749       raise stopiteration

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/iterator_ops.py in _next_internal(self)
    737         return self._element_spec._from_compatible_tensor_list(ret)  # pylint: disable=protected-access
    738       except AttributeError:
--> 739         return structure.from_compatible_tensor_list(self._element_spec,ret)
    740 
    741   @property

/opt/conda/lib/python3.7/contextlib.py in __exit__(self,type,value,traceback)
    128                 value = type()
    129             try:
--> 130                 self.gen.throw(type,traceback)
    131             except stopiteration as exc:
    132                 # Suppress stopiteration *unless* it's the same exception that

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/context.py in execution_mode(mode)
   2114     finally:
   2115       ctx.executor = executor_old
-> 2116       executor_new.wait()
   2117 
   2118 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/executor.py in wait(self)
     67   def wait(self):
     68     """Waits for ops dispatched in this executor to finish."""
---> 69     pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
     70 
     71   def clear_error(self):

InvalidArgumentError: Can't read header of file

fpath 似乎是正确的,因为如果我更改它的值,那么 make_csv_dataset 会引发不同的错误

有人知道导致错误的原因是什么吗?

解决方法

我找到了问题的根源。在连接到 TPU 之前,我正在执行代码以激活 Google Cloud SDK。正如 this post 中所述,SDK 必须在连接到 TPU 后激活。