调整图像大小并为Tensorflow 2数据集创建tfexample错误

问题描述

我正在使用Tensorflow 2.2,并尝试将模型转换为TensorRT。我以一个示例为例,该示例成功地适用于接受图像作为输入的模型。不幸的是,我冻结了一个接受TF Example作为输入而不是图像的模型。现在,尝试创建tf数据集管道已成为噩梦。

我的代码是:

def get_dataset(images_dir,annotation_path,batch_size,input_size,dtype=tf.float32):
    image_ids = None
    coco = COCO(annotation_file=annotation_path)
    image_ids = coco.getImgIds()
    image_paths = []
    for image_id in image_ids:
      coco_img = coco.imgs[image_id]
      image_paths.append(os.path.join(images_dir,coco_img['file_name']))
    dataset = tf.data.Dataset.from_tensor_slices(image_paths)
    def conv_jpeg_to_tfexample_tensor(input_img_):
      feature_dict = {
          'image/encoded': dataset_util.bytes_feature(input_img_)
      }
      temp_var = tf.train.Features(feature=feature_dict)
      file_ex = tf.train.Example(features=temp_var).SerializetoString()
      return tf.convert_to_tensor(file_ex)
    def preprocess_fn(path):
      image = tf.io.read_file(path)
      if input_size is not None:
        image = tf.image.decode_jpeg(image,channels=3)
        image = tf.image.convert_image_dtype(image,tf.float32)
        image = tf.image.resize(image,size=(input_size,input_size))
        image = tf.cast(image,tf.uint8)
        image = tf.image.encode_jpeg(image) #.numpy()
      return image
    dataset = dataset.map(map_func=preprocess_fn,num_parallel_calls=3)
    dataset = dataset.map(map_func=conv_jpeg_to_tfexample_tensor,num_parallel_calls=3)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(count=1)
  return dataset,image_ids

用法导致错误

dataset,image_ids = get_dataset(
  images_dir=args.data_dir,annotation_path=args.annotation_path,batch_size=args.batch_size,input_size=args.input_size)

错误

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-152-193739a79a8a> in <module>
      5   batch_size=args.batch_size,----> 6   input_size=args.input_size)

<ipython-input-151-1d1f15019758> in get_dataset(images_dir,dtype)
     76     dataset = dataset.map(map_func=preprocess_fn,num_parallel_calls=3)
---> 77     dataset = dataset.map(map_func=conv_jpeg_to_tfexample_tensor,num_parallel_calls=3)
     78     dataset = dataset.batch(batch_size)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in map(self,map_func,num_parallel_calls,deterministic)
   1626           num_parallel_calls,1627           deterministic,-> 1628           preserve_cardinality=True)
   1629 
   1630   def flat_map(self,map_func):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self,input_dataset,deterministic,use_inter_op_parallelism,preserve_cardinality,use_legacy_function)
   4018         self._transformation_name(),4019         dataset=input_dataset,-> 4020         use_legacy_function=use_legacy_function)
   4021     if deterministic is None:
   4022       self._deterministic = "default"

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self,func,transformation_name,dataset,input_classes,input_shapes,input_types,input_structure,add_to_graph,use_legacy_function,defun_kwargs)
   3219       with tracking.resource_tracker_scope(resource_tracker):
   3220         # Todo(b/141462134): Switch to using garbage collection.
-> 3221         self._function = wrapper_fn.get_concrete_function()
   3222 
   3223         if add_to_graph:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in get_concrete_function(self,*args,**kwargs)
   2530     """
   2531     graph_function = self._get_concrete_function_garbage_collected(
-> 2532         *args,**kwargs)
   2533     graph_function._garbage_collector.release()  # pylint: disable=protected-access
   2534     return graph_function

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_garbage_collected(self,**kwargs)
   2494       args,kwargs = None,None
   2495     with self._lock:
-> 2496       graph_function,args,kwargs = self._maybe_define_function(args,kwargs)
   2497       if self.input_signature:
   2498         args = self.input_signature

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self,kwargs)
   2775 
   2776       self._function_cache.missed.add(call_context_key)
-> 2777       graph_function = self._create_graph_function(args,kwargs)
   2778       self._function_cache.primary[cache_key] = graph_function
   2779       return graph_function,kwargs

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self,kwargs,override_flat_arg_shapes)
   2665             arg_names=arg_names,2666             override_flat_arg_shapes=override_flat_arg_shapes,-> 2667             capture_by_value=self._capture_by_value),2668         self._function_attributes,2669         # Tell the ConcreteFunction to clean up its graph once it goes out of

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name,python_func,signature,func_graph,autograph,autograph_options,add_control_dependencies,arg_names,op_return_value,collections,capture_by_value,override_flat_arg_shapes)
    979         _,original_func = tf_decorator.unwrap(python_func)
    980 
--> 981       func_outputs = python_func(*func_args,**func_kwargs)
    982 
    983       # invariant: `func_outputs` contains only Tensors,CompositeTensors,/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in wrapper_fn(*args)
   3212           attributes=defun_kwargs)
   3213       def wrapper_fn(*args):  # pylint: disable=missing-docstring
-> 3214         ret = _wrapper_helper(*args)
   3215         ret = structure.to_tensor_list(self._output_structure,ret)
   3216         return [ops.convert_to_tensor(t) for t in ret]

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in _wrapper_helper(*args)
   3154         nested_args = (nested_args,)
   3155 
-> 3156       ret = autograph.tf_convert(func,ag_ctx)(*nested_args)
   3157       # If `func` returns a list of tensors,`nest.flatten()` and
   3158       # `ops.convert_to_tensor()` would conspire to attempt to stack

/opt/conda/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args,**kwargs)
    263       except Exception as e:  # pylint:disable=broad-except
    264         if hasattr(e,'ag_error_Metadata'):
--> 265           raise e.ag_error_Metadata.to_exception(e)
    266         else:
    267           raise

TypeError: in user code:

    <ipython-input-143-1d1f15019758>:53 conv_jpeg_to_tfexample_tensor  *
        feature_dict = {
    /opt/conda/lib/python3.7/site-packages/object_detection/utils/dataset_util.py:34 bytes_feature  *
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    TypeError: <tf.Tensor 'args_0:0' shape=() dtype=string> has type Tensor,but expected one of: bytes

解决方法

我本可以使用简单的鸟图像来重现您面临的错误。

重新创建错误的代码-

%tensorflow_version 2.x
import tensorflow as tf
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array,array_to_img
from matplotlib import pyplot as plt
import numpy as np
from object_detection.utils import dataset_util

def load_file_and_process(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image,channels=3)
    image = tf.image.central_crop(image,np.random.uniform(0.50,1.00))
    image = tf.cast(image,tf.uint8)
    image = tf.image.encode_jpeg(image)
    return image

train_dataset = tf.data.Dataset.list_files('/content/bird.jpg')
train_dataset = train_dataset.map(load_file_and_process)

def conv_jpeg_to_tfexample_tensor(input_img_):
  feature_dict = {
          'image/encoded': dataset_util.bytes_feature(input_img_)
      }
  temp_var = tf.train.Features(feature=feature_dict)
  file_ex = tf.train.Example(features=temp_var).SerializeToString()
  return tf.convert_to_tensor(file_ex)

train_dataset = train_dataset.map(map_func=conv_jpeg_to_tfexample_tensor,num_parallel_calls=3)

输出-

<MapDataset shapes: (),types: tf.string>
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-44-89ae6292ad21> in <module>()
     28   return tf.convert_to_tensor(file_ex)
     29 
---> 30 train_dataset = train_dataset.map(map_func=conv_jpeg_to_tfexample_tensor,num_parallel_calls=3)

10 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args,**kwargs)
    256       except Exception as e:  # pylint:disable=broad-except
    257         if hasattr(e,'ag_error_metadata'):
--> 258           raise e.ag_error_metadata.to_exception(e)
    259         else:
    260           raise

TypeError: in user code:

    <ipython-input-44-89ae6292ad21>:23 conv_jpeg_to_tfexample_tensor  *
        feature_dict = {
    /usr/local/lib/python3.6/dist-packages/object_detection/utils/dataset_util.py:30 bytes_feature  *
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    TypeError: <tf.Tensor 'args_0:0' shape=() dtype=string> has type Tensor,but expected one of: bytes

建议您参考此tutorial,它解释了如何使用TFRecords读取和写入图像数据的端到端示例。

参考本教程,我为单个图像编写了TFRecord

代码-

# This is an example,just using the bird image.
image_string = open('/content/bird.jpg','rb').read()

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value,type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# Create a dictionary with features that may be relevant.
def image_example(image_string):
  image_shape = tf.image.decode_jpeg(image_string).shape

  feature = {
      'height': _int64_feature(image_shape[0]),'width': _int64_feature(image_shape[1]),'depth': _int64_feature(image_shape[2]),'image_raw': _bytes_feature(image_string),}

  return tf.train.Example(features=tf.train.Features(feature=feature))

for line in str(image_example(image_string)).split('\n')[:15]:
  print(line)

record_file = 'images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
    image_string = open('/content/bird.jpg','rb').read()
    tf_example = image_example(image_string)
    writer.write(tf_example.SerializeToString())

输出-

features {
  feature {
    key: "depth"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "height"
    value {
      int64_list {
        value: 426
      }

enter image description here