问题描述
我有一个形状为 A
的张量 (batch_size,width,height)
。假设它具有以下值:
A = torch.tensor([[[0,1],[1,0]]])
我还得到了一个数字 K
,它是一个正整数。在这种情况下让 K=2
。我想做一个类似于反向池和复制填充的过程。这是预期的输出:
B = torch.tensor([[[0,1,[0,0],0]]])
说明:对于A
中的每个元素,我们将其展开为形状为(K,K)
的矩阵,并将其放入结果张量中。我们继续对其他元素执行此操作,并让它们之间的步幅等于内核大小(即 K
)。
如何在 PyTorch 中执行此操作?目前,A
是一个二进制掩码,但如果我可以将其扩展为非二进制大小写会更好。
解决方法
平方扩展
您可以通过扩展两次来获得所需的输出:
def dilate(t,k):
x = t.squeeze()
x = x.unsqueeze(-1).expand([*x.shape,k])
x = x.unsqueeze(-1).expand([*x.shape,k])
x = torch.cat([*x],dim=1)
x = torch.cat([*x],dim=1)
x = x.unsqueeze(0)
return x
B = dilate(A,k)
调整大小/插值最近
如果您不介意较大扩展中的角可能会“流血”(因为它在确定要插入的“最近”点时使用欧几里得而不是曼哈顿距离),一个更简单的方法是resize
:
import torchvision.transforms.functional as F
B = F.resize(A,A.shape[-1]*k)
为了完整性:
MaxUnpool2d
将 MaxPool2d
的输出(包括最大值的索引)作为输入,并计算其中所有非最大值都设置为零的部分逆。
你可以试试这些:
注意:以下函数以二维张量作为输入。如果您的张量 A
的形状为 (1,N,N),即具有(冗余)批次/通道维度,请将 A.squeeze()
传递给 func()
。
方法一:
此方法广播乘法,然后进行转置和重塑操作以实现最终结果。
import torch
import torch.nn as nn
A = torch.tensor([[0,1,1],[1,0]])
K = 3
def func(A,K):
ones = torch.ones(K,K)
tmp = ones.unsqueeze(0) * A.view(-1,1)
tmp = tmp.reshape(A.shape[0],A.shape[1],K,K)
res = tmp.transpose(1,2).reshape(K * A.shape[0],K * A.shape[1])
return res
方法 2:
根据@Shai 在评论中的提示,此方法在通道维度中重复 (2D) 张量 K**2
次,然后使用 PixelShuffle() 将行和列放大 K
次。
def pixelshuffle(A,K):
pixel_shuffle = nn.PixelShuffle(K)
return pixel_shuffle(A.unsqueeze(0).repeat(K**2,1).unsqueeze(0)).squeeze(0).squeeze(0)
由于 nn.PixelShuffle()
仅采用 4D 张量作为输入,因此需要在 repeat()
之后解压。另请注意,由于从 nn.PixelShuffle()
返回的张量也是 4D,因此遵循两个 squeeze()
以确保我们获得 2D 张量作为输出。
一些示例输出:
A = torch.tensor([[0,0]])
func(A,2)
# tensor([[0.,0.,1.,1.],# [0.,# [1.,0.],0.]])
pixelshuffle(A,2)
# tensor([[0,# [0,# [1,0],0]])
请随时提出进一步的说明,并告诉我它是否适合您。
基准测试:
我针对上面@iacob 的 func()
函数对我的答案 pixel shuffle()
和 dilate()
进行了基准测试,发现我的略快。
A = torch.randint(3,100,(20,20))
assert (dilate(A,5) == func(A,5)).all()
assert (dilate(A,5) == pixelshuffle(A,5)).all()
%timeit dilate(A,5)
# 142 µs ± 2.54 µs per loop (mean ± std. dev. of 7 runs,10000 loops each)
%timeit func(A,5)
# 57.9 µs ± 1.67 µs per loop (mean ± std. dev. of 7 runs,10000 loops each)
%timeit pixelshuffle(A,5)
# 81.6 µs ± 970 ns per loop (mean ± std. dev. of 7 runs,10000 loops each)