问题描述
最近我遇到了一些从 N x B x H x W 形状的 RGB 图像(或它们的集合)中提取(滑动窗口样式)许多方形块的代码。他们这样做如下:
>patch_width = 3
patches = image.permute(0,2,3,1).unfold(dim = 1,size = patch_width,stride = patch_width) \
.unfold(dim = 2,stride = patch_width)
我了解 unfold()
方法“从维度 size
中的自张量返回所有大小为 dim
的切片”,通过阅读文档,但尽我所能,我只是无法很好地理解为什么堆叠两个 .unfold()
调用会产生方形补丁。我明白在张量上使用 unfold()
一次会发生什么。我不明白当你沿着两个不同的维度连续调用它两次时会发生什么。
我已经多次看到这种方法被使用过,但总是没有很好地解释为什么它有效(1、2),这让我发疯。为什么空间维度 H
和 W
排列为dims 1 和2,而通道dim 设置为3?为什么在暗淡 1 上以相同的方式展开,然后在暗淡 2 上导致 patch_width
方格 patch_width
补丁?
任何见解都将不胜感激,即使它只是我错过的文章的链接。我已经在谷歌上搜索了一个多小时,但收效甚微。谢谢!
[2]Another forum post doing the same thing
解决方法
我想,您的问题有两个不同的部分,第一个是您需要 permute
的原因,其次是两个 unfold
组合如何产生方形图像切片。
第一时刻相当技术性 - unfold
将生成的切片放置在张量的新维度中,“插入到形状的末尾”。此处需要 permute
将其放置在通道或深度维度附近,以便稍后使用 view
以自然方式合并它们。
现在是第二部分。考虑一副虚构的卡片,每张卡片都是一个图片通道。拿一张卡片,把它切成垂直的薄片,然后把薄片叠在一起。拿第二张牌做同样的事情,把结果放在第一张牌上,用所有的牌做。现在重复这个过程,水平切薄片。最后,你有更薄但更高的牌组,以前的牌变成了补丁的子牌组。
,让我们看一个简单的二维示例,看看为什么组合操作会产生“补丁”。
代码:
x = torch.tensor([[1,2,3,4,5],[6,7,8,9,10],[11,12,13,14,15]])
>>> x.unfold(1,1)
tensor([[[ 1,2],[ 2,3],[ 3,4],[ 4,5]],[[ 6,7],[ 7,8],[ 8,9],[ 9,10]],[[11,12],[12,13],[13,14],[14,15]]])
>>> x.unfold(1,1).unfold(0,1)
tensor([[[[ 1,6],7]],[[ 2,8]],[[ 3,9]],[[ 4,[ 5,10]]],[[[ 6,11],12]],[[ 7,13]],[[ 8,14]],[[ 9,[10,15]]]])