从 ray.tune 中提取代理

问题描述

我一直在使用 azure 机器学习来训练使用 ray.tune 的强化学习代理。

我的训练函数如下:

    tune.run(
        run_or_experiment="PPO",config={
            "env": "Battery","num_gpus" : 1,"num_workers": 13,"num_cpus_per_worker": 1,"train_batch_size": 1024,"num_sgd_iter": 20,'explore': True,'exploration_config': {'type': 'stochasticSampling'},},stop={'episode_reward_mean': 0.15},checkpoint_freq = 200,local_dir = 'second_checkpoints'
        
    )

如何从检查点提取代理,以便我可以将我的健身房环境中的操作可视化,如下所示:

while not done:
    action,state,logits = agent.compute_action(obs,state)
    obs,reward,done,info = env.step(action)
    episode_reward += reward
    print('action: ' + str(action) + 'reward: ' + str(reward))


我知道我可以使用这样的东西:

analysis = tune.run('PPO",config={"max_iter": 10},restore=last_ckpt)

但我不确定如何从存在于 tune.run 中的代理中提取计算操作(和奖励)。

解决方法

tune run 用于训练模型。培训结束后,您应该有一些检查点文件。可以加载这些文件,然后在您的环境中播放。

agent = ppo.PPOTrainer(config=config,env=env_name)
agent.restore(checkpoint_file)
obs = env.reset()
action = agent.compute_action(obs)
obs,reward,done,info = env.step(action)