问题描述
当我尝试从 tensorflow-hub resporitory 获取模型时。 我可以将其视为已保存模型格式,但我无法访问模型架构以及每层的权重存储。
import tensorflow_hub as hub
model = hub.load("https://tfhub.dev/tensorflow/centernet/hourglass_512x512/1")
)
有什么正式的工作方式吗?
对于原始模型中的特定层,我可以通过 model.__dict__
获得的所有属性都不清楚。
{'_self_setattr_tracking': True,'_self_unconditional_checkpoint_dependencies': [TrackableReference(name='_model',ref=<tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject object at 0x7fe4e4914710>),TrackableReference(name='signatures',ref=_SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(input_tensor) at 0x7FE4E601F210>})),TrackableReference(name='_self_saveable_object_factories',ref=DictWrapper({}))],'_self_unconditional_dependency_names': {'_model': <tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject at 0x7fe4e4914710>,'signatures': _SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(input_tensor) at 0x7FE4E601F210>}),'_self_saveable_object_factories': {}},'_self_unconditional_deferred_dependencies': {},'_self_update_uid': 176794,'_self_name_based_restores': set(),'_self_saveable_object_factories': {},'_model': <tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject at 0x7fe4e4914710>,'__call__': <tensorflow.python.saved_model.function_deserialization.RestoredFunction at 0x7fe315a28950>,'graph_debug_info':,'tensorflow_version': '2.4.0','tensorflow_git_version': 'unkNown'}
我也试过 model.signatures['serving_default'].__dict__
,每层的 Tensor 代表不可见
[<tf.Tensor: shape=(),dtype=resource,numpy=<unprintable>>,<tf.Tensor: shape=(),numpy=<unprintable>>],
解决方法
使用包 tensorflow-serving-api 提供的 CLI 工具 saved_model_cli
,可以检查保存的模型。在第一步中,我下载并缓存了模型:
from os import environ
import tensorflow_hub as hub
environ['TFHUB_CACHE_DIR'] = '/Users/you/.cache/tfhub_modules'
hub.load("https://tfhub.dev/tensorflow/centernet/hourglass_512x512/1")
然后我检查了签名和图层:
saved_model_cli show --dir /Users/you/.cache/tfhub_modules/3085eb2fbe2ad0b69801d50844c97b7a7a5ecade --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['__saved_model_init_op']:
The given SavedModel SignatureDef contains the following input(s):
The given SavedModel SignatureDef contains the following output(s):
outputs['__saved_model_init_op'] tensor_info:
dtype: DT_INVALID
shape: unknown_rank
name: NoOp
Method name is:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['input_tensor'] tensor_info:
dtype: DT_UINT8
shape: (1,-1,3)
name: serving_default_input_tensor:0
The given SavedModel SignatureDef contains the following output(s):
outputs['detection_boxes'] tensor_info:
dtype: DT_FLOAT
shape: (1,100,4)
name: StatefulPartitionedCall:0
outputs['detection_classes'] tensor_info:
dtype: DT_FLOAT
shape: (1,100)
name: StatefulPartitionedCall:1
outputs['detection_scores'] tensor_info:
dtype: DT_FLOAT
shape: (1,100)
name: StatefulPartitionedCall:2
outputs['num_detections'] tensor_info:
dtype: DT_FLOAT
shape: (1)
name: StatefulPartitionedCall:3
Method name is: tensorflow/serving/predict
之后,我使用调试器了解保存的模型在内部如何工作,并在 variables
中找到了成员字段 trainable_variables
和 model.signatures['serving_default']
,它们存储数据(权重,...)模型的。在这里您可以看到 model.signatures['serving_default'].variables
的输出:
答案的简短摘要。我们可以通过model.signatures['serving_default'].variables