问题描述
torch::stack
接受c10::TensorList
,并且在给出相同形状的张量时可以很好地工作。但是,当您尝试发送先前torch::stack
版本的Tensor的输出时,它会失败并出现内存访问冲突。
更具体地讲,假设我们有3个形状为4的张量:
torch::Tensor x1 = torch::randn({4});
torch::Tensor x2 = torch::randn({4});
torch::Tensor x3 = torch::randn({4});
torch::Tensor y = torch::randn({4});
第一轮堆叠很简单:
torch::Tensor stacked_xs = torch::stack({x1,x2,x3});
但是,尝试这样做:
torch::Tensor stacked_result = torch::stack({y,stacked_xs});
将失败。
我正在寻找与Python中np.vstack
中相同的行为,在此情况下这是允许的并且可以工作。
我应该怎么做?
解决方法
您可以使用y
向torch::unsqueeze
添加尺寸。然后与cat
(不是stack
并置,它与numpy不同,但结果将是您所要求的):
torch::Tensor x1 = torch::randn({4});
torch::Tensor x2 = torch::randn({4});
torch::Tensor x3 = torch::randn({4});
torch::Tensor y = torch::randn({4});
torch::Tensor stacked_xs = torch::stack({x1,x2,x3});
torch::Tensor stacked_result = torch::cat({y.unsqueeze(0),stacked_xs});
还可以根据您的喜好将您的第一个堆栈弄平,然后重塑形状:
torch::Tensor stacked_xs = torch::stack({x1,x3});
torch::Tensor stacked_result = torch::cat({y,stacked_xs.view({-1}}).view({4,4});