如何在libtorch中堆叠形状n,k的张量和形状k的张量?

问题描述

torch::stack接受c10::TensorList,并且在给出相同形状的张量时可以很好地工作。但是,当您尝试发送先前torch::stack版本的Tensor的输出时,它会失败并出现内存访问冲突。

更具体地讲,假设我们有3个形状为4的张量:

torch::Tensor x1 = torch::randn({4});
torch::Tensor x2 = torch::randn({4});
torch::Tensor x3 = torch::randn({4});
torch::Tensor y = torch::randn({4});

第一轮堆叠很简单:

torch::Tensor stacked_xs = torch::stack({x1,x2,x3});

但是,尝试这样做:

torch::Tensor stacked_result = torch::stack({y,stacked_xs});

将失败。 我正在寻找与Python中np.vstack中相同的行为,在此情况下这是允许的并且可以工作。 我应该怎么做?

解决方法

您可以使用ytorch::unsqueeze添加尺寸。然后与cat(不是stack并置,它与numpy不同,但结果将是您所要求的):

torch::Tensor x1 = torch::randn({4});
torch::Tensor x2 = torch::randn({4});
torch::Tensor x3 = torch::randn({4});
torch::Tensor y = torch::randn({4});

torch::Tensor stacked_xs = torch::stack({x1,x2,x3});
torch::Tensor stacked_result = torch::cat({y.unsqueeze(0),stacked_xs});

还可以根据您的喜好将您的第一个堆栈弄平,然后重塑形状:

torch::Tensor stacked_xs = torch::stack({x1,x3});
torch::Tensor stacked_result = torch::cat({y,stacked_xs.view({-1}}).view({4,4});