Tensorflow Hub-获取模型的输入形状和问题域?

问题描述

我正在使用最新版本的tensorflow集线器,想知道人们如何获取有关模型的预期输入形状以及模型所属的集合类型的信息。 例如,以这种方式在Python中加载模型后,是否可以获取有关预期图像形状的信息?

model = hub.load("https://tfhub.dev/tensorflow/faster_rcnn/inception_resnet_v2_640x640/1")

还是这种方式?

model = hub.KerasLayer("https://tfhub.dev/tensorflow/faster_rcnn/inception_resnet_v2_640x640/1")

在任何情况下,模型对象似乎都不知道预期的形状-在图像高度/宽度和批次大小方面。另一方面,对于较旧的TF模型,可以通过load_module_spec找到此信息...

一个问题:是否有一种方法可以通过编程方式获取模型所属的“问题域”的信息?可以在https://tfhub.dev/上查找它,但是如果需要从模型对象本身或通过tensorflow_hub函数访问该信息怎么办?

谢谢!

解决方法

您可以通过访问模型的第一层并访问该层的input_shape属性来获得模型期望的输入形状

layers = model.layers
first_layer = layers[0] # usually the first layer is the input layer
print(first_layer.input_shape)

输出:

[(None,100,3)] # sample output

无->这指定批处理大小的大小,可以推断出批处理大小可以是您指定的任何大小

(100,100,3)->高度,宽度和通道可能会有所不同,您输入的数据应严格相同。

通过编程发现训练后的模型的域有点棘手,您可以使用tensorflow.keras.util.plot_model绘制模型图,并可以从模型的体系结构推断域。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...