tf-agent,QNetwork => DqnAgent,带有tfa.optimizers.CyclicalLearningRate

问题描述

是否有一种简单的本机方式来实现tfa.optimizers.CyclicalLearningRate w / QNetwork on DqnAgent

尝试避免编写自己的DqnAgent。

我想更好的问题可能是,在DqnAgent上实现回调的正确方法是什么?

解决方法

在您链接的教程中,设置优化程序的部分是

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),train_env.action_spec(),q_network=q_net,optimizer=optimizer,td_errors_loss_fn=common.element_wise_squared_loss,train_step_counter=train_step_counter)

agent.initialize()

因此,您可以使用您愿意使用的任何优化程序替换优化程序。基于documentation之类的

optimizer = tf.keras.optimizers.Adam(learning_rate=tfa.optimizers.CyclicalLearningRate)

应该有效,除非在教程中使用tf 1.0 adam引起任何潜在的兼容性问题。