问题描述
我想检查模型是否在CUDA上。该怎么做?
import torch
import torchvision
model = torchvision.models.resnet18()
model.to('cuda')
表明model.is_cuda()
无效。
解决方法
此代码应执行以下操作:
import torch
import torchvision
model = torchvision.models.resnet18()
model.to('cuda')
next(model.parameters()).is_cuda
出局:
True
请注意,is_cuda()
中没有nn.Module
方法。
另外请注意,model.to('cuda')
与model.cuda()
相同,并且都在适当位置。
另一方面,移动data.to('cuda')
的位置不正确,您通常会调用:
data = data.to('cuda')
将数据移至CUDA。