如何在pytorch图像处理模型中处理具有多个图像的样本?

问题描述

我的模型训练涉及对同一图像的多个变体进行编码,然后对图像的所有变体所产生的表示求和。

数据加载器产生张量批次[batch_size,num_variants,1,height,width]1对应于图像颜色通道。

如何在pytorch中使用迷你批次训练模型? 我正在寻找通过网络转发所有batch_size×num_variant图像并汇总所有变体组的结果的正确方法

我当前的解决方包括展平前两个维度并进行for循环以汇总表示形式,但是我觉得应该有更好的方法,并且我不确定渐变是否会记住所有内容

解决方法

不确定我是否正确理解了您,但是我想这就是您想要的(比如说批处理图像张量称为image):

Nb,Nv,inC,inH,inW = image.shape

# treat each variant as if it's an ordinary image in the batch
image = image.reshape(Nb*Nv,inW)

output = model(image)
_,outC,outH,outW = output.shape[1]

# reshapes the output such that dim==1 indicates variants
output = output.reshape(Nb,outW)

# summing over the variants and lose the dimension of summation,[Nb,outW]
output = output.sum(dim=1,keepdim=False)

在输入和输出通道/大小不同的情况下,我使用了inCoutCinH等。