从检查点创建Estimator并另存为SavedModel,无需进一步培训

问题描述

我已经从TF Slim resnet V2检查点创建了一个估算器,并对其进行了测试以进行预测。我所做的主要工作基本上与普通的Estimator相似,并且与Assign_from_checkpoint_fn一样:

def model_fn(features,labels,mode,params):
  ...
  slim.assign_from_checkpoint_fn(os.path.join(checkpoint_dir,'resnet_v2_50.ckpt'),slim.get_model_variables('resnet_v2')
  ...
  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
      'class_ids': predicted_classes[:,tf.newaxis],'probabilities': tf.nn.softmax(logits),'logits': logits,}
  return tf.estimator.EstimatorSpec(mode,predictions=predictions)

要将估算器导出为SavedModel,我进行了以下serving_input_fn:

def image_preprocess(image_buffer):
    image = tf.image.decode_jpeg(image_buffer,channels=3)
    image_preprocessing_fn = preprocessing_factory.get_preprocessing('inception',is_training=False)
    image = image_preprocessing_fn(image,FLAGS.image_size,FLAGS.image_size)
    return image

def serving_input_fn():
    input_ph = tf.placeholder(tf.string,shape=[None],name='image_binary')
    image_tensors = image_preprocess(input_ph)
    return tf.estimator.export.ServingInputReceiver(image_tensors,input_ph)

在主函数中,我使用export_saved_model尝试将Estimator导出为SavedModel格式:

def main():
    ...
    classifier = tf.estimator.Estimator(model_fn=model_fn)
    classifier.export_saved_model(dir_path,serving_input_fn)

但是,当我尝试运行代码时,它显示“在/ tmp / tmpn3spty2z找不到经过训练的模型”。据我了解,这个export_saved_model试图找到训练有素的Estimator模型以导出到SavedModel。但是,我想知道是否可以通过任何方法将经过预训练的检查点还原到Estimator中,并将Estimator导出到SavedModel而不进行任何进一步的培训?

解决方法

我已经解决了我的问题。要将带有TF 1.14的TF Slim Resnet检查点导出到SavedModel,可以将热启动与export_savedmodel一起使用,如下所示:

config = tf.estimator.RunConfig(save_summary_steps = None,save_checkpoints_secs = None)
warm_start = tf.estimator.WarmStartSettings(checkpoint_dir,checkpoint_name)
classifier = tf.estimator.Estimator(model_fn=model_fn,warm_start_from = warm_start,config = config)
classifier.export_savedmodel(export_dir_base = FLAGS.output_dir,serving_input_receiver_fn =  serving_input_fn)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...