读取一个TFRecord文件,其中用于编码的功能未知

问题描述

我是TensorFlow的新手,这可能是一个非常初学者的问题。我看到了一些示例,其中使用一个人想要使用的功能(例如“图像”,“标签”)将自定义数据集转换为TFRecord文件。而且,在解析此TFRecord文件时,必须事先了解功能(即“图像”,“标签”)才能使用此数据集。

我的问题是-我们如何在事先不知道功能的情况下解析TFRecord文件?假设有人给了我一个TFRecord文件,我想以此解码所有相关的功能

我要引用的一些示例是:Link 1Link 2

解决方法

这可能会有所帮助。该功能会遍历记录文件并保存有关功能的可用信息。您可以对其进行修改,以仅查看第一条记录并返回该信息,尽管在某些情况下,如果仅某些可选特征或大小可变的特征存在可选特征,则查看所有记录可能会很有用。 >

import tensorflow as tf

def list_record_features(tfrecords_path):
    # Dict of extracted feature information
    features = {}
    # Iterate records
    for rec in tf.data.TFRecordDataset([str(tfrecords_path)]):
        # Get record bytes
        example_bytes = rec.numpy()
        # Parse example protobuf message
        example = tf.train.Example()
        example.ParseFromString(example_bytes)
        # Iterate example features
        for key,value in example.features.feature.items():
            # Kind of data in the feature
            kind = value.WhichOneof('kind')
            # Size of data in the feature
            size = len(getattr(value,kind).value)
            # Check if feature was seen before
            if key in features:
                # Check if values match,use None otherwise
                kind2,size2 = features[key]
                if kind != kind2:
                    kind = None
                if size != size2:
                    size = None
            # Save feature data
            features[key] = (kind,size)
    return features

您可以这样使用它

import tensorflow as tf

tfrecords_path = 'data.tfrecord'
# Make some test records
with tf.io.TFRecordWriter(tfrecords_path) as writer:
    for i in range(10):
        example = tf.train.Example(
            features=tf.train.Features(
                feature={
                    # Fixed length
                    'id': tf.train.Feature(
                        int64_list=tf.train.Int64List(value=[i])),# Variable length
                    'data': tf.train.Feature(
                        float_list=tf.train.FloatList(value=range(i))),}))
        writer.write(example.SerializeToString())
# Print extracted feature information
features = list_record_features(tfrecords_path)
print(*features.items(),sep='\n')
# ('id',('int64_list',1))
# ('data',('float_list',None))