如何在pytorch中为Fashion_MNIST使用MSELoss函数?

问题描述

我想通过 Fashion_Mnist 数据,我想看到输出梯度,它可能是第一层和第二层之间的均方和

我的代码在下面

#import the nescessary libs
import numpy as np
import torch
import time

# Loading the Fashion-MNIST dataset
from torchvision import datasets,transforms

# Get GPU Device

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.get_device_name(0)


# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),transforms.normalize((0.5,),(0.5,))
                                                                   ])
# Download and load the training data
trainset = datasets.FashionMNIST('MNIST_data/',download = True,train = True,transform = transform)
testset = datasets.FashionMNIST('MNIST_data/',train = False,transform = transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size = 128,shuffle = True,num_workers=4)
testloader = torch.utils.data.DataLoader(testset,num_workers=4)

# examine a sample
dataiter = iter(trainloader)
images,labels = dataiter.next()

# Define the network architecture
from torch import nn,optim
import torch.nn.functional as F

model = nn.Sequential(nn.Linear(784,128),nn.ReLU(),nn.Linear(128,10),nn.Logsoftmax(dim = 1)
                     )
model.to(device)

# Define the loss
criterion = nn.MSELoss()

# Define the optimizer
optimizer = optim.Adam(model.parameters(),lr = 0.001)

# Define the epochs
epochs = 5
train_losses,test_losses = [],[]
squared_sum = []
# start = time.time()
for e in range(epochs):
    running_loss = 0
    

    for images,labels in trainloader:
    # Flatten Fashion-MNIST images into a 784 long vector
        images = images.to(device)
        labels = labels.to(device)
        images = images.view(images.shape[0],-1)
        


        optimizer.zero_grad()
    
        output = model[0].forward(images)
        loss = criterion(output[0],labels.float())
        
        loss.backward()
        
        
             
        
        optimizer.step()
        running_loss += loss.item()
    
    else:

        print(running_loss)
        test_loss = 0
        accuracy = 0
        
    
    # Turn off gradients for validation,saves memory and computation
        with torch.no_grad():
      # Set the model to evaluation mode
            model.eval()
      
      # Validation pass
            for images,labels in testloader:
                images = images.to(device)
                labels = labels.to(device)
                images = images.view(images.shape[0],-1)
                ps = model(images[0])
                test_loss += criterion(ps,labels)
                top_p,top_class = ps.topk(1,dim = 1)
                equals = top_class == labels.view(*top_class.shape)
                accuracy += torch.mean(equals.type(torch.FloatTensor))
    
    model.train()
    print("Epoch: {}/{}..".format(e+1,epochs),"Training loss: {:.3f}..".format(running_loss/len(trainloader)),"Test loss: {:.3f}..".format(test_loss/len(testloader)),"Test Accuracy: {:.3f}".format(accuracy/len(testloader)))

我想要的,

for e in range(epochs):
    running_loss = 0
    

    for images,-1)


        optimizer.zero_grad()
    
        output = model[0].forward(images)
        loss = criterion(output[0],labels.float())
        
        loss.backward()
                
        optimizer.step()
        running_loss += loss.item()

在这里,模型[0](这可能是第一层 nn.Linear(784,128)),我很想得到第一层和第二层的均方误差,

如果我运行此代码,我会在下面收到此错误

RuntimeError: The size of tensor a (128) must match the size of tensor b (96) at non-singleton dimension 0

如果我想正确运行此代码以获得 MSELoss,我需要做什么?

解决方法

错误是由数据集中的样本数量和批量大小引起的。

更详细地说,训练 MNIST 数据集包括 60,000 个样本,您当前的 batch_size 为 128,您将需要 60000/128=468.75 次循环才能完成一个时期的训练。所以问题来自这里,对于 468 个循环,您的数据将有 128 个样本,但最后一个循环仅包含 60000 - 468*128 = 96 个样本。

要解决这个问题,我认为您还需要找到合适的 batch_size 以及模型中的神经网络数量。

我认为它应该适用于计算损失

trainloader = torch.utils.data.DataLoader(trainset,batch_size = 96,shuffle = True,num_workers=0)
testloader = torch.utils.data.DataLoader(testset,num_workers=0)
model = nn.Sequential(nn.Linear(784,96),nn.ReLU(),nn.Linear(96,10),nn.LogSoftmax(dim = 1)
                     )

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...