Pytorch为什么RuntimeError在这里需要.float:期望的标量类型为Float,但发现为Double

问题描述

一个简单的问题,我想尝试使用最简单的网络,但是除非遇到将RuntimeError: expected scalar type Float but found Double强制转换为data的情况,否则我就会一直遇到.float()(请参见下面带有注释的代码

我不明白的是,为什么需要这种铸造? data已经是torch.float64类型。为什么需要在output = model(data.float())行中进行显式重铸?

代码

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
from torch.optim.lr_scheduler import StepLR
from sklearn.datasets import make_classification
from torch.utils.data import TensorDataset,DataLoader

# =============================================================================
# Simplest Example
# =============================================================================
X,y = make_classification()
X,y = torch.tensor(X),torch.tensor(y)
print("X Shape :{}".format(X.shape))
print("y Shape :{}".format(y.shape))

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.fc1 = nn.Linear(X.shape[1],128)
        self.fc2 = nn.Linear(128,10)
        self.fc3 = nn.Linear(10,2)

    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
device = torch.device("cuda") 
lr = 1
batch_size = 32
gamma = 0.7
epochs = 14
args = {'log_interval': 10,'dry_run':False}

kwargs = {'batch_size': batch_size}
kwargs.update({'num_workers': 1,'pin_memory': True,'shuffle': True},)

model = Net().to(device)
optimizer = optim.Adam(model.parameters(),lr=lr)
scheduler = StepLR(optimizer,step_size=1,gamma=gamma)

my_dataset = TensorDataset(X,y) # create dataset
train_loader = DataLoader(my_dataset,**kwargs) #generate DataLoader

cross_entropy_loss = torch.nn.CrossEntropyLoss()

for epoch in range(1,epochs + 1):
    ## Train step ##
    model.train()
    
    for batch_idx,(data,target) in enumerate(train_loader):
        data,target = data.to(device),target.to(device)
        optimizer.zero_grad()
        output = model(data.float()) #HERE: why is .float() needed here?
        loss = cross_entropy_loss(output,target)
        loss.backward()
        optimizer.step()
        if batch_idx % args['log_interval'] == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch,batch_idx * len(data),len(train_loader.dataset),100. * batch_idx / len(train_loader),loss.item()))
            if args['dry_run']:
                break
    
    scheduler.step()

解决方法

在PyTorch中,64位浮点对应于torch.float64torch.double。 而32位浮点对应于torch.float32torch.float

因此

data已经是torch.float64类型

data是64个浮点类型(torch.double)。

通过使用.float()进行转换,可以将其转换为32位浮点数。

a = torch.tensor([[1.,-1.],[1.,-1.]],dtype=torch.double)
print(a.dtype)                                                                                                                                                                                                                              
# torch.float64
print(a.float().dtype)                                                                                                                                
# torch.float32

在PyTorch中检查其他data types

enter image description here