问题描述
一个简单的问题,我想尝试使用最简单的网络,但是除非遇到将RuntimeError: expected scalar type Float but found Double
强制转换为data
的情况,否则我就会一直遇到.float()
(请参见下面带有注释的代码)
我不明白的是,为什么需要这种铸造? data
已经是torch.float64
类型。为什么需要在output = model(data.float())
行中进行显式重铸?
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
from torch.optim.lr_scheduler import StepLR
from sklearn.datasets import make_classification
from torch.utils.data import TensorDataset,DataLoader
# =============================================================================
# Simplest Example
# =============================================================================
X,y = make_classification()
X,y = torch.tensor(X),torch.tensor(y)
print("X Shape :{}".format(X.shape))
print("y Shape :{}".format(y.shape))
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.fc1 = nn.Linear(X.shape[1],128)
self.fc2 = nn.Linear(128,10)
self.fc3 = nn.Linear(10,2)
def forward(self,x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
device = torch.device("cuda")
lr = 1
batch_size = 32
gamma = 0.7
epochs = 14
args = {'log_interval': 10,'dry_run':False}
kwargs = {'batch_size': batch_size}
kwargs.update({'num_workers': 1,'pin_memory': True,'shuffle': True},)
model = Net().to(device)
optimizer = optim.Adam(model.parameters(),lr=lr)
scheduler = StepLR(optimizer,step_size=1,gamma=gamma)
my_dataset = TensorDataset(X,y) # create dataset
train_loader = DataLoader(my_dataset,**kwargs) #generate DataLoader
cross_entropy_loss = torch.nn.CrossEntropyLoss()
for epoch in range(1,epochs + 1):
## Train step ##
model.train()
for batch_idx,(data,target) in enumerate(train_loader):
data,target = data.to(device),target.to(device)
optimizer.zero_grad()
output = model(data.float()) #HERE: why is .float() needed here?
loss = cross_entropy_loss(output,target)
loss.backward()
optimizer.step()
if batch_idx % args['log_interval'] == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch,batch_idx * len(data),len(train_loader.dataset),100. * batch_idx / len(train_loader),loss.item()))
if args['dry_run']:
break
scheduler.step()
解决方法
在PyTorch中,64位浮点对应于torch.float64
或torch.double
。
而32位浮点对应于torch.float32
或torch.float
。
因此
data
已经是torch.float64
类型
即data
是64个浮点类型(torch.double
)。
通过使用.float()
进行转换,可以将其转换为32位浮点数。
a = torch.tensor([[1.,-1.],[1.,-1.]],dtype=torch.double)
print(a.dtype)
# torch.float64
print(a.float().dtype)
# torch.float32
在PyTorch中检查其他data types。