问题描述
我创建了一个像 global_model = CNNMnist(args=args)
这样的 CNN 模型。然后我将它发送到设备,将其设置为训练。然后我训练我的本地模型,收集 local_weights 和它们的平均值以获得更新的 global_model。
现在我试图从 .parameters()
函数中获取项目,但我得到的只是 None
作为 item.grad
。当我对 local_models 做同样的事情时,我得到了想要的输出。我做错了什么?
global_model.to(device)
global_model.train()
...................
global_weights = average_weights(local_weights)
global_model.load_state_dict(global_weights)
last_update = []
for item in global_model.parameters():
last_update.append(copy.deepcopy(item.grad))
print(item.grad)
Output: None None None None None None None None
任何帮助将不胜感激。
解决方法
您正在查看从 state_dict
加载的值 - 渐变未保存在那里。尝试在调用 .grad
之后和 backward()
之前打印 zero_grad()
。