Pytorch:如何从扁平化的网络中解压/取回网络?

问题描述

我正在使用以下函数来扁平化网络:

#############################################################################
# Flattening the NET
#############################################################################
def flattenNetwork(net):
    flatNet = []
    shapes = []
    for param in net.parameters():
        #if its WEIGHTS
        curr_shape = param.cpu().data.numpy().shape
        shapes.append(curr_shape)
        if len(curr_shape) == 2:
            param = param.cpu().data.numpy().reshape(curr_shape[0]*curr_shape[1])
            flatNet.append(param)
        elif len(curr_shape) == 4:
            param = param.cpu().data.numpy().reshape(curr_shape[0]*curr_shape[1]*curr_shape[2]*curr_shape[3])
            flatNet.append(param)
        else:
            param = param.cpu().data.numpy().reshape(curr_shape[0])
            flatNet.append(param)
    finalNet = []
    for obj in flatNet:
        for x in obj:
            finalNet.append(x)
    finalNet = np.array(finalNet)
    return finalNet,shapes

上述函数将所有权重作为网络的numpy列向量finalNetshapes(列表)返回。我想看看权重修改对预测精度的影响。所以,我改变了权重。如何将此修改后的权重向量复制回原始网络?请帮忙。谢谢。

解决方法

模型定义(它的 forward 函数)和参数配置(所谓的模型状态,使用 state_dict 作为字典很容易访问)之间存在差异。

您可以获得模型的状态,就像您在实现 flattenNetwork 时所做的那样。然而,对于几乎所有模型,恢复此操作(i.e.,如果您只有权重和图层形状)是不可能的。

现在,假设您 - 仍然 - 可以访问 net。我的建议是直接使用 net.state_dict(),修改它,然后用 load_state_dict 加载回权重字典。这样,您就不必自己处理序列化模型的参数。