tensorflow rl 代理 ValueError,错误的输入形状但在哪里?

问题描述

我正在尝试在 2048 游戏中训练一个强化代理。我自己设计了 Env。该错误表示输入必须为 (1,16) 形状,但数组以 (1,4) 形状传递。我无法弄清楚在我的代码中哪里传递了具有这种形状的数组:

class CardGameEnv(Env):
    def __init__(self):
        self.action_space = discrete(3)
        
        self.observation_space = Box(low=np.array([0 for _ in range(16)]),high=np.array([np.inf for _ in range(16)]))
        
        self._state = [0 for _ in range(16)]
        
        self._episode_ended = False
        
        self._score = 0

    def action_spec(self):
        return self.action_space

    def observation_spec(self):
        return self.observation_space

    def reset(self):
        self._state = [0 for _ in range(16)]
        
        self._episode_ended = False
        
        restart()
        print(np.array([self._state]))
        return ts.restart(np.array([self._state],dtype=np.int32))

    def step(self,action):

        if self._episode_ended:
          # The last action ended the episode. Ignore the current action and start
          # a new episode.
            return self.reset()
        
        old_score = self._score
        print(action)
        make_move(action)
        
        if check_game_over():
            self._episode_ended = True
        
        self._score = get_score()
        
        self._state = list(get_board().flatten())
        
        score = get_score()
        
        if self._episode_ended:
            reward = - 10
            return self._state,reward,self._episode_ended,{}
        
        elif score == old_score:
            reward = -10
            return self._state,{} 
        
        else:
            reward = self._score - old_score
            return self._state,{}

def build_model(states,actions):
    model = Sequential()
    model.add(Flatten(input_shape=states))
    model.add(Dense(24,activation='relu'))
    model.add(Dense(24,activation='relu'))
    model.add(Dense(actions,activation='softmax'))
    return model

def build_agent(model,actions):
    policy = BoltzmannQPolicy()
    memory = SequentialMemory(limit=50000,window_length=1)
    dqn = DQNAgent(model=model,memory=memory,policy=policy,nb_actions=actions,nb_steps_warmup=10,target_model_update=1e-2)
    return dqn

env = CardGameEnv()
model = build_model((1,16),4) 

dqn = build_agent(model,4)
dqn.compile(Adam(lr=1e-3),metrics=['mae'])
dqn.fit(env,nb_steps=50000,visualize=False,verbose=1)

这是我代码的重要部分。根据错误,我传递了一个形状为 (1,4) 的数组,但是我到处都有一个带有 (1,16) 的数组用于 state

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-65-edd0afcc8057> in <module>
      8 dqn = build_agent(model,4)
      9 dqn.compile(Adam(lr=1e-3),metrics=['mae'])
---> 10 dqn.fit(env,verbose=1)

~\anaconda3\lib\site-packages\rl\core.py in fit(self,env,nb_steps,action_repetition,callbacks,verbose,visualize,nb_max_start_steps,start_step_policy,log_interval,nb_max_episode_steps)
    166                 # This is were all of the work happens. We first perceive and compute the action
    167                 # (forward step) and then use the reward to improve (backward step).
--> 168                 action = self.forward(observation)
    169                 if self.processor is not None:
    170                     action = self.processor.process_action(action)

~\anaconda3\lib\site-packages\rl\agents\dqn.py in forward(self,observation)
    222         # Select an action.
    223         state = self.memory.get_recent_state(observation)
--> 224         q_values = self.compute_q_values(state)
    225         if self.training:
    226             action = self.policy.select_action(q_values=q_values)

~\anaconda3\lib\site-packages\rl\agents\dqn.py in compute_q_values(self,state)
     66 
     67     def compute_q_values(self,state):
---> 68         q_values = self.compute_batch_q_values([state]).flatten()
     69         assert q_values.shape == (self.nb_actions,)
     70         return q_values

~\anaconda3\lib\site-packages\rl\agents\dqn.py in compute_batch_q_values(self,state_batch)
     61     def compute_batch_q_values(self,state_batch):
     62         batch = self.process_state_batch(state_batch)
---> 63         q_values = self.model.predict_on_batch(batch)
     64         assert q_values.shape == (len(state_batch),self.nb_actions)
     65         return q_values

~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training_v1.py in predict_on_batch(self,x)
   1203           ' tf.distribute.Strategy.')
   1204     # Validate and standardize user data.
-> 1205     inputs,_,_ = self._standardize_user_data(
   1206         x,extract_tensors_from_dataset=True)
   1207     # If `self._distribution_strategy` is True,then we are in a replica context

~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training_v1.py in _standardize_user_data(self,x,y,sample_weight,class_weight,batch_size,check_steps,steps_name,steps,validation_split,shuffle,extract_tensors_from_dataset)
   2345       return [],[],None
   2346 
-> 2347     return self._standardize_tensors(
   2348         x,2349         run_eagerly=run_eagerly,~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training_v1.py in _standardize_tensors(self,run_eagerly,dict_inputs,is_dataset,batch_size)
   2373     if not isinstance(x,(dataset_ops.DatasetV1,dataset_ops.DatasetV2)):
   2374       # Todo(fchollet): run static checks with dataset output shape(s).
-> 2375       x = training_utils_v1.standardize_input_data(
   2376           x,2377           Feed_input_names,~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training_utils_v1.py in standardize_input_data(data,names,shapes,check_batch_axis,exception_prefix)
    661         for dim,ref_dim in zip(data_shape,shape):
    662           if ref_dim != dim and ref_dim is not None and dim is not None:
--> 663             raise ValueError('Error when checking ' + exception_prefix +
    664                              ': expected ' + names[i] + ' to have shape ' +
    665                              str(shape) + ' but got array with shape ' +

ValueError: Error when checking input: expected flatten_4_input to have shape (1,16) but got array with shape (1,4)

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)