保存matplotlib图像抛出尺寸错误

问题描述

首先:我是这一切的新手,请原谅我的无知。

我尝试从TFRecord-File加载图像,并且 am 可以用plt.show()显示它们,但是当我尝试用plt.imsave()保存图像时,我得到了错误

这就是我想要做的:

import tensorflow as tf
import matplotlib.pyplot as plt

reader = tf.data.TFRecordDataset(input_file)

for raw_record in reader.take(1):
    example = tf.train.Example()
    example.ParseFromString(raw_record.numpy())
    raw_record = example.features.feature['image/encoded']
    img = example.features.feature["image/encoded"].bytes_list.value[0]
    decoded = tf.io.decode_jpeg(img)

    plt.figure(figsize = (20,3))
    plt.imshow(decoded)
    plt.show()
    plt.imsave(output_file,decoded)

错误如下:

Traceback (most recent call last):
  File "/home/freddy/PycharmProjects/ocr/visualize_fsns.py",line 30,in <module>
    plt.imsave(flags.output_file,decoded)
  File "/home/freddy/.local/lib/python3.8/site-packages/matplotlib/pyplot.py",line 2235,in        imsave
    return matplotlib.image.imsave(fname,arr,**kwargs)
  File "/home/freddy/.local/lib/python3.8/site-packages/matplotlib/image.py",line 1567,in imsave
    rgba = sm.to_rgba(arr,bytes=True)
  File "/home/freddy/.local/lib/python3.8/site-packages/matplotlib/cm.py",line 305,in to_rgba
    xx = np.empty(shape=(m,n,4),dtype=x.dtype)
TypeError: data type not understood

您能帮我解决这个难题吗?

解决方法

matplotlib可能会感到困惑,因为tf.io.decode_jpeg()返回了张量;可以解释数据类型错误消息。绘图前,尝试使用decoded.numpy()转换为numpy数组。