创建包含元数据用于对象检测的Tflite模型时出现问题

问题描述

我正在尝试在Android上运行tflite模型进行对象检测。同样,

  1. 我已成功使用以下图像集训练了模型:

(a)培训:

!python3 object_detection/model_main.py \
--pipeline_config_path=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config \
--model_dir=training/

修改配置文件以指向提到我的特定TF记录的位置)

(b)导出推理图

!python /content/drive/'My Drive'/'Detecto Tutorial'/models/research/object_detection/export_inference_graph.py \
--input_type=image_tensor \
--pipeline_config_path=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config \
--output_directory={output_directory} \
--trained_checkpoint_prefix={last_model_path}

(c)创建tflite就绪图

!python /content/drive/'My Drive'/'Detecto Tutorial'/models/research/object_detection/export_tflite_ssd_graph.py \
  --pipeline_config_path=/content/drive/My\ Drive/Detecto\ Tutorial/models/research/object_detection/samples/configs/ssd_mobilenet_v2_coco.config \
  --output_directory={output_directory} \
  --trained_checkpoint_prefix={last_model_path} \
  --add_postprocessing_op=true
  1. 我使用如下图文件中的tflite_convert创建了tflite模型

    !tflite_convert
    --graph_def_file = / content / drive /我的\ Drive / Detecto \教程/模型/研究/fine_tuned_model/tflite_graph.pb
    --output_file = / content / drive /我的\ Drive / Detecto \教程/模型/研究/fine_tuned_model/detect3.tflite
    --output_format = TFLITE
    --input_shapes = 1,300,3
    --input_arrays = normalized_input_image_tensor
    --output_arrays ='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'
    --inference_type =浮标
    --allow_custom_ops

上述tflite模型经过独立验证并且可以正常运行(在Android外部)。

现在需要用元数据填充tflite模型,以便可以按照以下每个链接提供的示例Android代码来处理它(因为我遇到了错误:否则,不是有效的Zip文件,并且没有关联在Android Studio上运行的文件)。

https://github.com/tensorflow/examples/blob/master/lite/examples/object_detection/android/README.md

作为同一链接的一部分提供的示例.TFlite填充了元数据,并且运行良好。

当我尝试使用以下链接时: https://www.tensorflow.org/lite/convert/metadata#deep_dive_into_the_image_classification_example

populator = _Metadata.MetadataPopulator.with_model_file('/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/detect3.tflite')
populator.load_Metadata_buffer(Metadata_buf)
populator.load_associated_files(['/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/labelmap.txt'])
populator.populate()

添加元数据(代码的其余部分实际上与对对象检测(而不是图像分类)并指定labelmap.txt的位置的元描述有所更改相同),它产生以下错误

<ipython-input-6-173fc798ea6e> in <module>()
  1 populator = _Metadata.MetadataPopulator.with_model_file('/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/detect3.tflite')
  ----> 2 populator.load_Metadata_buffer(Metadata_buf)
        3 populator.load_associated_files(['/content/drive/My Drive/Detecto Tutorial/models/research/fine_tuned_model/labelmap.txt'])
        4 populator.populate()

1 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_lite_support/Metadata/Metadata.py in _validate_Metadata(self,Metadata_buf)
    540           "The number of output tensors ({0}) should match the number of "
    541           "output tensor Metadata ({1})".format(num_output_tensors,--> 542                                                 num_output_Meta))
    543 
    544 

ValueError: The number of output tensors (4) should match the number of output tensor Metadata (1)

4个输出张量是第2步中output_arrays中提到的那些(有人可以在那儿纠正我)。我不确定如何相应地更新输出张量元数据。

最近使用自定义模型(然后在Android上应用)进行对象检测的任何人都可以提供帮助吗?或帮助您了解如何将张量元数据更新为4而不是1。

解决方法

这是一个代码段,您可以用来填充用于对象检测模型的元数据,该代码段与TFLite Android应用程序兼容。

model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "SSD_Detector"
model_meta.description = (
    "Identify which of a known set of objects might be present and provide "
    "information about their positions within the given image or a video "
    "stream.")

# Creates input info.
input_meta = _metadata_fb.TensorMetadataT()
input_meta.name = "image"
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = (
    _metadata_fb.ColorSpaceType.RGB)
input_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.ImageProperties)
input_normalization = _metadata_fb.ProcessUnitT()
input_normalization.optionsType = (
    _metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_normalization.options = _metadata_fb.NormalizationOptionsT()
input_normalization.options.mean = [127.5]
input_normalization.options.std = [127.5]
input_meta.processUnits = [input_normalization]
input_stats = _metadata_fb.StatsT()
input_stats.max = [255]
input_stats.min = [0]
input_meta.stats = input_stats

# Creates outputs info.
output_location_meta = _metadata_fb.TensorMetadataT()
output_location_meta.name = "location"
output_location_meta.description = "The locations of the detected boxes."
output_location_meta.content = _metadata_fb.ContentT()
output_location_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.BoundingBoxProperties)
output_location_meta.content.contentProperties = (
    _metadata_fb.BoundingBoxPropertiesT())
output_location_meta.content.contentProperties.index = [1,3,2]
output_location_meta.content.contentProperties.type = (
    _metadata_fb.BoundingBoxType.BOUNDARIES)
output_location_meta.content.contentProperties.coordinateType = (
    _metadata_fb.CoordinateType.RATIO)
output_location_meta.content.range = _metadata_fb.ValueRangeT()
output_location_meta.content.range.min = 2
output_location_meta.content.range.max = 2

output_class_meta = _metadata_fb.TensorMetadataT()
output_class_meta.name = "category"
output_class_meta.description = "The categories of the detected boxes."
output_class_meta.content = _metadata_fb.ContentT()
output_class_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_class_meta.content.contentProperties = (
    _metadata_fb.FeaturePropertiesT())
output_class_meta.content.range = _metadata_fb.ValueRangeT()
output_class_meta.content.range.min = 2
output_class_meta.content.range.max = 2
label_file = _metadata_fb.AssociatedFileT()
label_file.name = os.path.basename("label.txt")
label_file.description = "Label of objects that this model can recognize."
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_VALUE_LABELS
output_class_meta.associatedFiles = [label_file]

output_score_meta = _metadata_fb.TensorMetadataT()
output_score_meta.name = "score"
output_score_meta.description = "The scores of the detected boxes."
output_score_meta.content = _metadata_fb.ContentT()
output_score_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_score_meta.content.contentProperties = (
    _metadata_fb.FeaturePropertiesT())
output_score_meta.content.range = _metadata_fb.ValueRangeT()
output_score_meta.content.range.min = 2
output_score_meta.content.range.max = 2

output_number_meta = _metadata_fb.TensorMetadataT()
output_number_meta.name = "number of detections"
output_number_meta.description = "The number of the detected boxes."
output_number_meta.content = _metadata_fb.ContentT()
output_number_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_number_meta.content.contentProperties = (
    _metadata_fb.FeaturePropertiesT())

# Creates subgraph info.
group = _metadata_fb.TensorGroupT()
group.name = "detection result"
group.tensorNames = [
    output_location_meta.name,output_class_meta.name,output_score_meta.name
]
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
subgraph.outputTensorMetadata = [
    output_location_meta,output_class_meta,output_score_meta,output_number_meta
]
subgraph.outputTensorGroups = [group]
model_meta.subgraphMetadata = [subgraph]

b = flatbuffers.Builder(0)
b.Finish(
    model_meta.Pack(b),_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadta_buf = b.Output()