问题描述
我正在使用 Detectron2 来训练用于对象检测的 Faster R-CNN 模型,我想训练模型动物园给出的模型,输入范围为 [0 1] 而不是 [0 255],所以我使用了颜色变换调用我的函数 scale_transform
def scale_transform(img):
return img/255.
此函数正在接收一个 numpy 数组并返回它的缩放比例。但是,在火车时间出现此错误
RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same
有人知道我该如何解决这个问题吗?或另一种缩放detectron2图像的方法?
谢谢
解决方法
我认为这里的相关词是类型。
也许确保输入被定义为浮点数。尽管它在正确的范围 (0-1) 内,但它可能会发现数据类型不正确,因此在那里绊倒了。
以下可能对它 -
def scale_transform(img):
img = img/255
img = img.astype(np.float32)
return img