pytorch:如何堆叠2张量

问题描述

我正在尝试将两个张量A.shape=(64,16,16)B.shape=(64,16)堆叠成形状为C.shape=(1,128,16)的张量

以及非我尝试过的功能所在的地方
torch.stack => C.shape=(2,64,16)torch.cat => C.shape=(128,16)

依诺能帮助我

解决方法

首先合并,然后使用unsqueeze0th位置添加单例尺寸

torch.cat([A,B]).unsqueeze(0)