问题描述
我有一个包含 9 个波段的多光谱数据集。由于数据非常大,我将每个波段拆分为 256 x 256 个样本。所以我为每个乐队准备了 16 个这样的样本,并将它们保存到不同的文件夹中。现在如何连接 9 个波段的每个样本?
例如,我想将第 1 个波段数据中的第一个样本与第 2 个波段的第一个样本、第 3 个波段的第一个样本连接起来,直到第 9 个波段。然后是第 1、第 2...第 9 个频段的第 2 个样本。依此类推,直到第 16 个样本。
解决方法
您可以使用 (torch.stack)[https://pytorch.org/docs/stable/generated/torch.stack.html] 或 (torch.cat)[https://pytorch.org/docs /stable/generated/torch.cat.html]
例如,让我们生成一些随机的 256x256 矩阵:
import torch
a = torch.FloatTensor(256,256).uniform_(0,1)
b = torch.FloatTensor(256,1)
第一种方法:您可以使用 torch.stack 连接张量。
c = torch.stack((a,b),axis=2)
第二种方式:或者torch.cat
c = torch.cat((a.reshape(256,256,-1),b.reshape(256,-1)),axis=2)
主要区别在于,torch.stack 沿新轴连接,而 torch.cat 只能连接现有轴(这就是需要 reshape 命令的原因)。