问题描述
所以我正在使用
tf.estimator.Estimator(
model_fn,model_dir=None,config=None,params=None,warm_start_from=None
)
我对参数params
感到困惑。
我知道它是dict
,根据一些示例代码,我假设params
类似于:
params = {"batch_size":128,"hidden_layer": 3
}
但是根据官方页面,params
是将传递给model_fn的超参数的命令。键是参数的名称,值是基本的python类型(offical page)。所以值应该是像int64,float64这样的python类型?
请给我一个清晰的解释。非常感谢您的帮助
解决方法
进一步docs:
params
参数包含超参数。它被传递给model_fn
,如果model_fn
有一个名为“ params”的参数,则输入 以相同的方式起作用。Estimator
仅传递参数,它确实 不检查它。params
的结构因此完全取决于 开发人员。
换句话说,合适的是您确定合适的 。如果您的模型加载了权重,则它可能是权重文件的字符串路径:weights_path = "model.h5"
。在0.
和1.
之间浮动以提供辍学率。像这样:
def model_fn(params):
...
x = Dense(params['units'])(x)
x = Dropout(params['dropout'])(x)
...
model.load_weights(params['weights_path'])
return model
TF检查model_fn
是否具有params
参数here,并相应地将其传递。 model_fn
也可以有其他任何参数。