关于tf.estimator.Estimator

问题描述

所以我正在使用

tf.estimator.Estimator(
    model_fn,model_dir=None,config=None,params=None,warm_start_from=None
)

我对参数params感到困惑。

我知道它是dict,根据一些示例代码,我假设params类似于:

params = {"batch_size":128,"hidden_layer": 3
}

但是根据官方页面params是将传递给model_fn的超参数的命令。键是参数的名称,值是基本的python类型(offical page)。所以值应该是像int64,float64这样的python类型?

请给我一个清晰的解释。非常感谢您的帮助

解决方法

进一步docs

params参数包含超参数。它被传递给 model_fn,如果model_fn有一个名为“ params”的参数,则输入 以相同的方式起作用。 Estimator仅传递参数,它确实 不检查它。 params的结构因此完全取决于 开发人员。

换句话说,合适的是您确定合适的 。如果您的模型加载了权重,则它可能是权重文件的字符串路径:weights_path = "model.h5"。在0.1.之间浮动以提供辍学率。像这样:

def model_fn(params):
    ...
    x = Dense(params['units'])(x)
    x = Dropout(params['dropout'])(x)
    ...
    model.load_weights(params['weights_path'])
    return model

TF检查model_fn是否具有params参数here,并相应地将其传递。 model_fn也可以有其他任何参数。