tf.data.TFRecordDataset保持为空

问题描述

我正在尝试读取以FloatList形式存储在tfrecords中的numpy数组。这就是我正在使用的代码

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES)
    dataset = dataset.repeat() 
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) 
    return dataset
def load_dataset(filenames):
    tf_op = tf.data.Options()
    tf_op.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames,num_parallel_reads=AUTO)
    print(dataset)
    dataset = dataset.with_options(tf_op)
    dataset = dataset.map(read_tfrecord,num_parallel_calls=AUTO)
    
    return dataset
def read_tfrecord(example):
    tfrec_format = {
        "x" : tf.io.FixedLenFeature([],tf.float32),"y" : tf.io.FixedLenFeature([],tf.float32)
    }
    print(example)
    example = tf.io.parse_single_example(example,tfrec_format)
    print(example['y'].shape)
    x = tf.reshape(example['x'],[224,224,3])
    y = to_categorical(tf.reshape(example['y'],2]),num_classes=32)
    
    y = {"output1":y[:,:,0],"output2":y[:,1]}
    
    return x,y
print("Training data shapes:")
for image,label in get_training_dataset().take(3):
    print(image.numpy().shape,label.numpy().shape)
print("Validation data shapes:")
for image,label in get_validation_dataset().take(3):
    print(image.numpy().shape,label.numpy().shape)

印刷品的输出

训练数据形状:

张量(“ args_0:0”,shape =(),dtype = string) ()

为什么形状显示为空? tfrecords不为空。我也收到以下错误

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-40-53f01a0d2c13> in <module>
      1 print("Training data shapes:")
----> 2 for image,label in get_training_dataset().take(3):
      3     print(image.numpy().shape,label.numpy().shape)
      4 print("Validation data shapes:")
      5 for image,label in get_validation_dataset().take(3):

<ipython-input-22-f4dcf023fbcc> in get_training_dataset()
      1 def get_training_dataset():
----> 2     dataset = load_dataset(TRAINING_FILENAMES)
      3     dataset = dataset.repeat()
      4     dataset = dataset.shuffle(2048)
      5     dataset = dataset.batch(BATCH_SIZE)

<ipython-input-39-16e1046af942> in load_dataset(filenames)
      5     print(dataset)
      6     dataset = dataset.with_options(tf_op)
----> 7     dataset = dataset.map(read_tfrecord,num_parallel_calls=AUTO)
      8 
      9     return dataset

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in map(self,map_func,num_parallel_calls,deterministic)
   1626           num_parallel_calls,1627           deterministic,-> 1628           preserve_cardinality=True)
   1629 
   1630   def flat_map(self,map_func):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self,input_dataset,deterministic,use_inter_op_parallelism,preserve_cardinality,use_legacy_function)
   4018         self._transformation_name(),4019         dataset=input_dataset,-> 4020         use_legacy_function=use_legacy_function)
   4021     if deterministic is None:
   4022       self._deterministic = "default"

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self,func,transformation_name,dataset,input_classes,input_shapes,input_types,input_structure,add_to_graph,use_legacy_function,defun_kwargs)
   3219       with tracking.resource_tracker_scope(resource_tracker):
   3220         # Todo(b/141462134): Switch to using garbage collection.
-> 3221         self._function = wrapper_fn.get_concrete_function()
   3222 
   3223         if add_to_graph:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in get_concrete_function(self,*args,**kwargs)
   2530     """
   2531     graph_function = self._get_concrete_function_garbage_collected(
-> 2532         *args,**kwargs)
   2533     graph_function._garbage_collector.release()  # pylint: disable=protected-access
   2534     return graph_function

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_garbage_collected(self,**kwargs)
   2494       args,kwargs = None,None
   2495     with self._lock:
-> 2496       graph_function,args,kwargs = self._maybe_define_function(args,kwargs)
   2497       if self.input_signature:
   2498         args = self.input_signature

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self,kwargs)
   2775 
   2776       self._function_cache.missed.add(call_context_key)
-> 2777       graph_function = self._create_graph_function(args,kwargs)
   2778       self._function_cache.primary[cache_key] = graph_function
   2779       return graph_function,kwargs

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self,kwargs,override_flat_arg_shapes)
   2665             arg_names=arg_names,2666             override_flat_arg_shapes=override_flat_arg_shapes,-> 2667             capture_by_value=self._capture_by_value),2668         self._function_attributes,2669         # Tell the ConcreteFunction to clean up its graph once it goes out of

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name,python_func,signature,func_graph,autograph,autograph_options,add_control_dependencies,arg_names,op_return_value,collections,capture_by_value,override_flat_arg_shapes)
    979         _,original_func = tf_decorator.unwrap(python_func)
    980 
--> 981       func_outputs = python_func(*func_args,**func_kwargs)
    982 
    983       # invariant: `func_outputs` contains only Tensors,CompositeTensors,/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in wrapper_fn(*args)
   3212           attributes=defun_kwargs)
   3213       def wrapper_fn(*args):  # pylint: disable=missing-docstring
-> 3214         ret = _wrapper_helper(*args)
   3215         ret = structure.to_tensor_list(self._output_structure,ret)
   3216         return [ops.convert_to_tensor(t) for t in ret]

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in _wrapper_helper(*args)
   3154         nested_args = (nested_args,)
   3155 
-> 3156       ret = autograph.tf_convert(func,ag_ctx)(*nested_args)
   3157       # If `func` returns a list of tensors,`nest.flatten()` and
   3158       # `ops.convert_to_tensor()` would conspire to attempt to stack

/opt/conda/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args,**kwargs)
    263       except Exception as e:  # pylint:disable=broad-except
    264         if hasattr(e,'ag_error_Metadata'):
--> 265           raise e.ag_error_Metadata.to_exception(e)
    266         else:
    267           raise

ValueError: in user code:

    <ipython-input-37-7bad4da2b4f1>:9 read_tfrecord  *
        x = tf.reshape(example['x'],3])
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py:193 reshape  **
        result = gen_array_ops.reshape(tensor,shape,name)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py:8087 reshape
        "Reshape",tensor=tensor,shape=shape,name=name)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:744 _apply_op_helper
        attrs=attr_protos,op_def=op_def)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py:595 _create_op_internal
        compute_device)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3327 _create_op_internal
        op_def=op_def)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:1817 __init__
        control_input_ops,op_def)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:1657 _create_c_op
        raise ValueError(str(e))

    ValueError: Cannot reshape a tensor with 1 elements to shape [224,3] (150528 elements) for '{{node Reshape}} = Reshape[T=DT_FLOAT,Tshape=DT_INT32](ParseSingleExample/ParseExample/ParseExampleV2,Reshape/shape)' with input shapes: [],[3] and with input tensors computed as partial shapes: input[1] = [224,3].

我在做什么错?我是tfrecords的新手,请帮助 这是我正在使用的数据集https://www.kaggle.com/gameatro/colorization-data

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)