几次迭代后损失到NaN

问题描述

在我的模型中,输入是以 edge-index 和节点 features 形式的图形数据。经过几次图形数据训练后,损失( EDIT :这是MSELoss函数和负损失函数的组合,即 L1 + -L2 ))变为NaN。在大约40次迭代后, L1 -L2 都变为NaN。

学习率= 0.00001。我还检查了无效的输入数据,但没有找到。

from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import networkx as nx
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

class Model(nn.Module):
    def __init__(self,nin,nhid1,nout,inp_l,hid_l,out_l=1):
        super(Model,self).__init__()

        self.g1 = GCNConv(in_channels= nin,out_channels= nhid1)
        self.g2 = GCNConv(in_channels= nhid1,out_channels= nout)
        self.dropout = 0.5
        self.lay1 = nn.Linear(inp_l,hid_l)
        self.lay2 = nn.Linear(hid_l,out_l)

    def forward(self,x,adj):
        x = F.relu(self.g1(x,adj))
        x = F.dropout(x,self.dropout,training=self.training)
        x = self.g2(x,adj)
        
        x = self.lay1(x)
        x = F.relu(x)
        x = self.lay2(x)
        x = F.relu(x)
        
        return x

模型的输入:

x (张量,可选)–形状为[num_nodes,num_node_features]的节点特征矩阵。

edge_index (LongTensor,可选)–图形连接为COO格式,形状为[2,num_edges]

这里num_nodes = 1000; num_node_features = 1; num_edges = 5000

GCNConv是图嵌入器返回的[num_nodes,dim]矩阵。它需要边缘列表和功能以返回矩阵。

编辑2:添加了损失的计算方式

def train_model(epoch):
    model= Model(nin = 1,nhid1=128,nout=128,inp_l=128,hid_l=64,out_l=1).to(device)
    optimizer = optim.Adam(model.parameters(),lr=0.00001)

    model.train()
    t = time.time()
    optimizer.zero_grad()
    Y = model(features,adjacency_list)

    Y1 = func(Y) #Y1 values are calculated from Y by passing through a function func to obtain a same sized vector as Y

    loss1 = ((Y1-Y)**2).mean()  #MSE Loss function
    
    loss2 = -Y.abs().mean() # This loss is implemented to prevent Y values going to 0. Notice the "-" sign
    
    loss_train = loss1 + loss2
    loss_train.backward(retain_graph=True)
    nn.utils.clip_grad_norm_(model.parameters(),0.5)

    optimizer.step()
    
    if epoch%20==0:
        print("MSE loss = ",loss1,"\t","Mean Loss = ",loss2)
        print('Epoch: {:04d}'.format(epoch+1),'loss_train: {:.4f}'.format(loss_train.item()),'time: {:.4f}s'.format(time.time() - t))
        print("\n\n")

    return Y

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)