访问 Tensorflow Hub 中的权重和层

问题描述

当我尝试从 tensorflow-hub resporitory 获取模型时。 我可以将其视为已保存模型格式,但我无法访问模型架构以及每层的权重存储。

import tensorflow_hub as hub
model = hub.load("https://tfhub.dev/tensorflow/centernet/hourglass_512x512/1")
)

有什么正式的工作方式吗?

对于原始模型中的特定层,我可以通过 model.__dict__ 获得的所有属性都不清楚。

{'_self_setattr_tracking': True,'_self_unconditional_checkpoint_dependencies': [TrackableReference(name='_model',ref=<tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject object at 0x7fe4e4914710>),TrackableReference(name='signatures',ref=_SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(input_tensor) at 0x7FE4E601F210>})),TrackableReference(name='_self_saveable_object_factories',ref=DictWrapper({}))],'_self_unconditional_dependency_names': {'_model': <tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject at 0x7fe4e4914710>,'signatures': _SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(input_tensor) at 0x7FE4E601F210>}),'_self_saveable_object_factories': {}},'_self_unconditional_deferred_dependencies': {},'_self_update_uid': 176794,'_self_name_based_restores': set(),'_self_saveable_object_factories': {},'_model': <tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject at 0x7fe4e4914710>,'__call__': <tensorflow.python.saved_model.function_deserialization.RestoredFunction at 0x7fe315a28950>,'graph_debug_info':,'tensorflow_version': '2.4.0','tensorflow_git_version': 'unkNown'}

我也试过 model.signatures['serving_default'].__dict__,每层的 Tensor 代表不可见

  [<tf.Tensor: shape=(),dtype=resource,numpy=<unprintable>>,<tf.Tensor: shape=(),numpy=<unprintable>>],

解决方法

使用包 tensorflow-serving-api 提供的 CLI 工具 saved_model_cli,可以检查保存的模型。在第一步中,我下载并缓存了模型:

from os import environ
import tensorflow_hub as hub

environ['TFHUB_CACHE_DIR'] = '/Users/you/.cache/tfhub_modules'
hub.load("https://tfhub.dev/tensorflow/centernet/hourglass_512x512/1")

然后我检查了签名和图层:

saved_model_cli show --dir /Users/you/.cache/tfhub_modules/3085eb2fbe2ad0b69801d50844c97b7a7a5ecade --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['input_tensor'] tensor_info:
        dtype: DT_UINT8
        shape: (1,-1,3)
        name: serving_default_input_tensor:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['detection_boxes'] tensor_info:
        dtype: DT_FLOAT
        shape: (1,100,4)
        name: StatefulPartitionedCall:0
    outputs['detection_classes'] tensor_info:
        dtype: DT_FLOAT
        shape: (1,100)
        name: StatefulPartitionedCall:1
    outputs['detection_scores'] tensor_info:
        dtype: DT_FLOAT
        shape: (1,100)
        name: StatefulPartitionedCall:2
    outputs['num_detections'] tensor_info:
        dtype: DT_FLOAT
        shape: (1)
        name: StatefulPartitionedCall:3
  Method name is: tensorflow/serving/predict

之后,我使用调试器了解保存的模型在内部如何工作,并在 variables 中找到了成员字段 trainable_variablesmodel.signatures['serving_default'],它们存储数据(权重,...)模型的。在这里您可以看到 model.signatures['serving_default'].variables 的输出:

Inspection with the debugger

,

答案的简短摘要。我们可以通过model.signatures['serving_default'].variables

访问一个层的变量

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...