将 feed_dict 与 Tensorflow Estimator API 一起使用,“您必须为占位符张量‘Placeholder’提供一个值为 dtype float 和 shape [?,784]”

问题描述

我正在尝试将自定义 EstimatorFeed_dict 一起使用。根据几个相关问题,例如this one,我得出了以下代码。请注意,我从 dataset 返回 input_fn,而不是 next_example,next_label错误

tensorflow.python.framework.errors_impl.InvalidArgumentError: You must Feed a value for placeholder tensor 'Placeholder' with dtype float and shape [?,784]
     [[{{node Placeholder}}]]

完整的堆栈跟踪进一步向下。

我遗漏了一些关于数据 Xy 究竟是如何输入到图表中的基本概念。任何人都可以阐明我做错了什么吗?谢谢!

代码

import numpy as np
import tensorflow as tf

class IteratorInitializerHook(tf.compat.v1.train.SessionRunHook):
    def __init__(self):
        super(IteratorInitializerHook,self).__init__()
        self.iterator_initializer_func = None # Will be set in the input_fn

    def after_create_session(self,session,coord):
        # Initialize the iterator with the data Feed_dict
        self.iterator_initializer_func(session)

def get_inputs(X,y):
    iterator_initializer_hook = IteratorInitializerHook()

    def input_fn():
        X_pl = tf.compat.v1.placeholder(X.dtype,[None,X.shape[1]])
        y_pl = tf.compat.v1.placeholder(y.dtype,y.shape[1]])

        dataset = tf.compat.v1.data.Dataset.from_tensor_slices((X_pl,y_pl))
        iterator = dataset.make_initializable_iterator()
        dataset = dataset.batch(32)

        iterator_initializer_hook.iterator_initializer_func = \
            lambda sess: sess.run(iterator.initializer,Feed_dict={X_pl: X,y_pl: y})

        return dataset

    return input_fn,iterator_initializer_hook

class MyMnist:
    def __init__(self,params,**kwargs):
        self.loss = 0
        self.optimizer = tf.compat.v1.train.AdamOptimizer()

        self.W = tf.compat.v1.Variable(tf.zeros([784,10]),trainable=True,name="W")
        self.b = tf.compat.v1.Variable(tf.zeros([10]),name="b")

    def build_model(self,features,labels,mode):
        """
        Build model and return output
        """
        is_training = mode == tf.estimator.ModeKeys.TRAIN

        output = tf.compat.v1.nn.softmax(
            tf.matmul(features,self.W) + self.b
        )
        return output

    def build_total_loss(self,model_outputs,mode):
        """
        Return computed loss
        """
        loss = tf.compat.v1.losses.softmax_cross_entropy(
            labels,model_outputs
        )
        return loss

    def build_optimizer(self):
        """
        Setup the optimizer.

        :returns: The optimizer
        """
        print("build_optimizer")
        lr = 0.01
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=lr,name="Adam"
        )
        return optimizer

    def build_train_ops(self,loss):
        """
        Setup optimizer and build train ops.

        :param Tensor loss: The loss tensor
        :return: Train ops
        """
        print("build_train_ops")
        self.optimizer = self.build_optimizer()
        return self.optimizer.minimize(
            loss,global_step=tf.compat.v1.train.get_global_step()
        )

def model_fn(features,mode,params):
    print('model_fn')

    model = MyMnist(params)
    output = model.build_model(features,mode)
    loss = model.build_total_loss(output,mode)

    if mode == tf.estimator.ModeKeys.TRAIN:
        train_op = model.build_train_ops(loss)

    log_hook = \
        tf.compat.v1.train.LoggingTensorHook(
            {"W is": model.W,"b is": model.b},every_n_iter=1)

    return tf.estimator.EstimatorSpec(
        mode=mode,loss=loss,train_op=train_op,training_hooks=[log_hook]
    )

def train(argv=None):
    print("train")
    params = { 'mode': 'train','model_dir': './model_dir','training': {'steps': 10 },'size': 100
    }

    est = tf.estimator.Estimator(
        model_fn,model_dir=params['model_dir'],params=params,)

    (X_train,l_train),(X_test,l_test) = tf.keras.datasets.mnist.load_data()
    y_train = np.zeros((l_train.shape[0],l_train.max()+1),dtype=np.float32)
    y_train[np.arange(l_train.shape[0]),l_train] = 1
    y_test = np.zeros((l_test.shape[0],l_test.max()+1),dtype=np.float32)
    y_test[np.arange(l_test.shape[0]),l_test] = 1

    X_train = X_train.reshape((X_train.shape[0],-1)).astype(np.float32)
    X_test = X_test.reshape((X_test.shape[0],-1))
    train_input_fn,train_iterator_initializer_hook = \
        get_inputs(X_train,y_train)
    test_input_fn,test_iterator_initializer_hook = get_inputs(X_test,y_test)

    if params['mode'] == 'train':

        est.train(
            input_fn=train_input_fn,hooks=[train_iterator_initializer_hook],steps=params['training']['steps']
        )

if __name__ == "__main__":
    tf.compat.v1.disable_eager_execution()
    tf.compat.v1.app.run(main=train)

堆栈跟踪:

$ python main.py 
train
INFO:tensorflow:Using default config.
I0128 18:51:12.061203 139970842568512 estimator.py:1822] Using default config.
INFO:tensorflow:Using config: {'_model_dir': './model_dir','_tf_random_seed': None,'_save_summary_steps': 100,'_save_checkpoints_steps': None,'_save_checkpoints_secs': 600,'_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    Meta_optimizer_iterations: ONE
  }
},'_keep_checkpoint_max': 5,'_keep_checkpoint_every_n_hours': 10000,'_log_step_count_steps': 100,'_train_distribute': None,'_device_fn': None,'_protocol': None,'_eval_distribute': None,'_experimental_distribute': None,'_experimental_max_worker_delay_secs': None,'_session_creation_timeout_secs': 7200,'_service': None,'_cluster_spec': ClusterSpec({}),'_task_type': 'worker','_task_id': 0,'_global_id_in_cluster': 0,'_master': '','_evaluation_master': '','_is_chief': True,'_num_ps_replicas': 0,'_num_worker_replicas': 1}
I0128 18:51:12.061676 139970842568512 estimator.py:191] Using config: {'_model_dir': './model_dir','_num_worker_replicas': 1}
WARNING:tensorflow:From /srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/training/training_util.py:235: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
W0128 18:51:12.383132 139970842568512 deprecation.py:317] From /srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/training/training_util.py:235: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From main.py:21: DatasetV1.make_initializable_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
This is a deprecated API that should only be used in TF 1 graph mode and legacy TF 2 graph mode available through `tf.compat.v1`. In all other situations -- namely,eager mode and inside `tf.function` -- you can consume dataset elements using `for elem in dataset: ...` or by explicitly creating iterator via `iterator = iter(dataset)` and fetching its elements via `values = next(iterator)`. Furthermore,this API is not available in TF 2. During the transition from TF 1 to TF 2 you can use `tf.compat.v1.data.make_initializable_iterator(dataset)` to create a TF 1 graph mode style iterator for a dataset created through TF 2 APIs. Note that this should be a transient state of your code base as there are in general no guarantees about the interoperability of TF 1 and TF 2 code.
W0128 18:51:12.397508 139970842568512 deprecation.py:317] From main.py:21: DatasetV1.make_initializable_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
This is a deprecated API that should only be used in TF 1 graph mode and legacy TF 2 graph mode available through `tf.compat.v1`. In all other situations -- namely,this API is not available in TF 2. During the transition from TF 1 to TF 2 you can use `tf.compat.v1.data.make_initializable_iterator(dataset)` to create a TF 1 graph mode style iterator for a dataset created through TF 2 APIs. Note that this should be a transient state of your code base as there are in general no guarantees about the interoperability of TF 1 and TF 2 code.
INFO:tensorflow:Calling model_fn.
I0128 18:51:12.404662 139970842568512 estimator.py:1162] Calling model_fn.
model_fn
build_train_ops
build_optimizer
INFO:tensorflow:Done calling model_fn.
I0128 18:51:12.476285 139970842568512 estimator.py:1164] Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
I0128 18:51:12.477094 139970842568512 basic_session_run_hooks.py:546] Create CheckpointSaverHook.
data size:0.033848 MB
data size:0.033848 MB
INFO:tensorflow:Graph was finalized.
I0128 18:51:12.530334 139970842568512 monitored_session.py:246] Graph was finalized.
2021-01-28 18:51:12.530598: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (onednN)to use the following cpu instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
To enable them in other operations,rebuild TensorFlow with the appropriate compiler flags.
2021-01-28 18:51:12.539575: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] cpu Frequency: 2400000000 Hz
2021-01-28 18:51:12.543713: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x559677c5ee30 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2021-01-28 18:51:12.543747: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host,Default Version
INFO:tensorflow:Running local_init_op.
I0128 18:51:12.571843 139970842568512 session_manager.py:505] Running local_init_op.
INFO:tensorflow:Done running local_init_op.
I0128 18:51:12.574033 139970842568512 session_manager.py:508] Done running local_init_op.
data size:0.052762 MB
data size:0.052762 MB
data size:0.052762 MB
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
I0128 18:51:12.677021 139970842568512 basic_session_run_hooks.py:613] Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./model_dir/model.ckpt.
I0128 18:51:12.677253 139970842568512 basic_session_run_hooks.py:618] Saving checkpoints for 0 into ./model_dir/model.ckpt.
data size:0.052762 MB
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
I0128 18:51:12.704515 139970842568512 basic_session_run_hooks.py:625] Calling checkpoint listeners after saving checkpoint 0...
Traceback (most recent call last):
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/client/session.py",line 1365,in _do_call
    return fn(*args)
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/client/session.py",line 1349,in _run_fn
    return self._call_tf_sessionrun(options,Feed_dict,fetch_list,File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/client/session.py",line 1441,in _call_tf_sessionrun
    return tf_session.TF_SessionRun_wrapper(self._session,options,tensorflow.python.framework.errors_impl.InvalidArgumentError: You must Feed a value for placeholder tensor 'Placeholder' with dtype float and shape [?,784]
     [[{{node Placeholder}}]]

During handling of the above exception,another exception occurred:

Traceback (most recent call last):
  File "main.py",line 144,in <module>
    tf.compat.v1.app.run(main=train)
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/platform/app.py",line 40,in run
    _run(main=main,argv=argv,flags_parser=_parse_flags_tolerate_undef)
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/absl/app.py",line 303,in run
    _run_main(main,args)
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/absl/app.py",line 251,in _run_main
    sys.exit(main(argv))
  File "main.py",line 136,in train
    est.train(
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py",line 349,in train
    loss = self._train_model(input_fn,hooks,saving_listeners)
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py",line 1175,in _train_model
    return self._train_model_default(input_fn,line 1206,in _train_model_default
    return self._train_with_estimator_spec(estimator_spec,worker_hooks,File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py",line 1495,in _train_with_estimator_spec
    with training.MonitoredTrainingSession(
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/training/monitored_session.py",line 601,in MonitoredTrainingSession
    return MonitoredSession(
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/training/monitored_session.py",line 1034,in __init__
    super(MonitoredSession,self).__init__(
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/training/monitored_session.py",line 749,in __init__
    self._sess = _RecoverableSession(self._coordinated_creator)
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/training/monitored_session.py",line 1231,in __init__
    _WrappedSession.__init__(self,self._create_session())
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/training/monitored_session.py",line 1236,in _create_session
    return self._sess_creator.create_session()
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/training/monitored_session.py",line 909,in create_session
    hook.after_create_session(self.tf_sess,self.coord)
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/util.py",line 86,in after_create_session
    session.run(self._initializer)
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/client/session.py",line 957,in run
    result = self._run(None,fetches,options_ptr,line 1180,in _run
    results = self._do_run(handle,final_targets,final_fetches,line 1358,in _do_run
    return self._do_call(_run_fn,Feeds,targets,line 1384,in _do_call
    raise type(e)(node_def,op,message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must Feed a value for placeholder tensor 'Placeholder' with dtype float and shape [?,784]
     [[node Placeholder (defined at main.py:17) ]]

Original stack trace for 'Placeholder':
  File "main.py",line 1201,in _train_model_default
    self._get_features_and_labels_from_input_fn(input_fn,ModeKeys.TRAIN))
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py",line 1037,in _get_features_and_labels_from_input_fn
    self._call_input_fn(input_fn,mode))
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow_estimator/python/estimator/estimator.py",line 1130,in _call_input_fn
    return input_fn(**kwargs)
  File "main.py",line 17,in input_fn
    X_pl = tf.compat.v1.placeholder(X.dtype,X.shape[1]])
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py",line 3100,in placeholder
    return gen_array_ops.placeholder(dtype=dtype,shape=shape,name=name)
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/ops/gen_array_ops.py",line 6808,in placeholder
    _,_,_op,_outputs = _op_def_library._apply_op_helper(
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py",line 742,in _apply_op_helper
    op = g._create_op_internal(op_type_name,inputs,dtypes=None,File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/framework/ops.py",line 3478,in _create_op_internal
    ret = Operation(
  File "/srv/scratch/packages/spack/opt/spack/linux-rhel8-skylake_avx512/gcc-8.3.1/anaconda3-2020.07-weugqkfkxd6zmn2irm7lpmujzczwebiw/envs/graphsaint_env/lib/python3.8/site-packages/tensorflow/python/framework/ops.py",line 1949,in __init__
    self._traceback = tf_stack.extract_stack()

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...