pytorch中的DQN算法不收敛

问题描述

我是深度强化学习的新手,自己实现了算法,但价值没有收敛,任何人都可以看看并告诉我我的算法有什么问题,我可以做些什么来做得更好 这是代码

import gym
import torch
import numpy as np
import torch 
import random
from collections import deque
from itertools import count
class ReplayBuffer:
    def __init__(self):
        self.buffer=deque(maxlen=50000)
    def push(self,state,action,reward,next_state,done):
        if(len(self.buffer)<=1000):
            self.buffer.append((state,done))
    def sample(self,batch_size: int,continuous: bool = True):
        if batch_size > len(self.buffer):
            batch_size = len(self.buffer)
        if continuous:
            rand = random.randint(0,len(self.buffer) - batch_size)
            return [self.buffer[i] for i in range(rand,rand + batch_size)]
        else:
            indexes = np.random.choice(np.arange(len(self.buffer)),size=batch_size,replace=False)
            return [self.buffer[i] for i in indexes]
class NNetwork(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l1=torch.nn.Linear(4,128)
        self.l2=torch.nn.Linear(128,128)
        self.l3=torch.nn.Linear(128,2)
        
        self.optimizer=torch.optim.Adam(params=self.parameters(),lr=0.001)
        self.criterion=torch.nn.MSELoss()
    def forward(self,x):
        al1=torch.nn.ReLU()(self.l1(x))
        al2=torch.nn.ReLU()(self.l2(al1))
        al3=self.l3(al2)
        return al3
class Agent():
    def __init__(self):
        
        self.env=gym.make('CartPole-v0')
        self.mem=ReplayBuffer()
        self.q_local=NNetwork()
        self.q_target=NNetwork()
        self.q_target.load_state_dict(self.q_local.state_dict())
        self.epsilon=1.0
        self.e_decay=0.0995
        self.e_min=0.1
        self.update=4
        self.score=0
        self.gamma=0.99

    def predict(self,state):
        if (np.random.randn()<self.epsilon):
            return random.randint(0,1)
        else:
            index=self.q_local.forward(torch.Tensor(state).unsqueeze(0))
            return torch.argmax(index,dim=1).item()
    
    def step(self):
        state=self.env.reset()
        done=False
        i=0
        while not done:
            action=self.predict(state)
            n_state,done,_=self.env.step(action)
            self.mem.push(state,n_state,done)
            self.score+=reward
            self.learn()
            state=n_state
            i+=1
            if(i%10==0):
                if(self.epsilon>self.e_min):
                    self.epsilon=self.epsilon-self.e_decay
                else:
                    self.epsilon=self.e_min
                self.q_target.load_state_dict(self.q_local.state_dict())
          
        print(self.score)
        self.score=0
    def learn(self):
        if(len(self.mem.buffer)%32==0):
            return
        batch =self.mem.sample(32)
        state,done= zip(*batch)
        state=torch.Tensor(state)
        action=torch.Tensor(action).unsqueeze(1)
        n_state=torch.Tensor(n_state)
        reward=torch.Tensor(reward).unsqueeze(1)
        done=torch.Tensor(done).unsqueeze(1)

        self.q_local.optimizer.zero_grad()
        
        q_N=self.q_local.forward(state).gather(1,action.long())
        q_t=self.q_target.forward(n_state)
        y=reward+(1-done)*self.gamma*torch.max(q_t,dim=1,keepdim=True)[0]
        
        loss=self.q_local.criterion(q_N,y)
        loss.backward()
        self.q_local.optimizer.step()
agent=Agent()
for t in count():
    print("EP ",t)
    agent.step()

好吧,我很乐意打出几分,但并没有收敛

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)