tf.agent策略可以返回所有操作的概率向量吗?

问题描述

我正在尝试使用TF-Agent TF-Agent DQN Tutorial训练强化学习代理。在我的应用程序中,我有1个动作,其中包含9个可能的离散值(标记为0到8)。以下是env.action_spec()

输出
BoundedTensorSpec(shape=(),dtype=tf.int64,name='action',minimum=array(0,dtype=int64),maximum=array(8,dtype=int64))

我想获得包含经过训练的策略计算出的所有动作的概率向量,并在其他应用程序环境中进行进一步处理。但是,该策略仅返回具有单个值的log_probability而不是所有操作的向量。反正有得到概率矢量吗?

from tf_agents.networks import q_network
from tf_agents.agents.dqn import dqn_agent

q_net = q_network.QNetwork(
            env.observation_spec(),env.action_spec(),fc_layer_params=(32,)
        )

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.001)

my_agent = dqn_agent.DqnAgent(
    env.time_step_spec(),q_network=q_net,epsilon_greedy=epsilon,optimizer=optimizer,emit_log_probability=True,td_errors_loss_fn=common.element_wise_squared_loss,train_step_counter=global_step)

my_agent.initialize()

...  # training

tf_policy_saver = policy_saver.PolicySaver(my_agent.policy)
tf_policy_saver.save('./policy_dir/')

# making decision using the trained policy
action_step = my_agent.policy.action(time_step)

dqn_agent.DqnAgent() DQNAgent中,我设置了emit_log_probability=True,它应该定义Whether policies emit log probabilities or not.

但是,当我运行action_step = my_agent.policy.action(time_step)时,它会返回

PolicyStep(action=<tf.Tensor: shape=(1,),dtype=int64,numpy=array([1],dtype=int64)>,state=(),info=PolicyInfo(log_probability=<tf.Tensor: shape=(1,dtype=float32,numpy=array([0.],dtype=float32)>))

我还尝试运行action_distribution = saved_policy.distribution(time_step),它返回

PolicyStep(action=<tfp.distributions.DeterministicWithLogProbCT 'Deterministic' batch_shape=[1] event_shape=[] dtype=int64>,info=PolicyInfo(log_probability=<tf.Tensor: shape=(),numpy=0.0>))

如果TF.Agent中没有这样的API,有没有办法获得这样的概率矢量?谢谢。


后续问题:

如果我理解正确,则深层Q网络应该获取state的输入并从状态输出每个动作的Q值。我可以将此Q值向量传递给softmax函数,然后计算相应的概率向量。实际上,我已经使用自己的自定义DQN脚本(没有TF-Agent)进行了这种计算。那么问题就变成了:如何从TF-Agent返回Q值向量?

解决方法

在 TF-Agents 框架中执行此操作的唯一方法是调用 Policy.distribution() 方法而不是 action 方法。这将返回从网络的 Q 值计算出来的原始分布。 emit_log_probability=True 只影响 info 返回的 PolicyStep 命名元组的 Policy.action() 属性。请注意,此分布可能会受到您通过的操作约束(如果您这样做)的影响;因此,非法行为将被标记为概率为 0(即使原始 Q 值可能很高)。

此外,如果您想查看实际的 Q 值而不是它们生成的分布,那么如果不直接对您的代理随附的 Q 网络采取行动(和它也附加到代理生成的 Policy 对象)。如果您想了解如何正确调用 Q-network,我建议您查看 QPolicy._distribution() 方法如何here

请注意,使用预先实现的驱动程序无法完成这些操作。您必须显式地构建自己的收集循环或实现自己的 Driver 对象(这基本上是等效的)。