我在哪里可以找到 PyTorch 的 Tensor.unfold() 用于获取图像补丁的直观解释?

问题描述

最近我遇到了一些从 N x B x H x W 形状的 RGB 图像(或它们的集合)中提取(滑动窗口样式)许多方形块的代码。他们这样做如下:

>
patch_width = 3
patches = image.permute(0,2,3,1).unfold(dim = 1,size = patch_width,stride = patch_width) \
        .unfold(dim = 2,stride = patch_width)

我了解 unfold() 方法“从维度 size 中的自张量返回所有大小为 dim 的切片”,通过阅读文档,但尽我所能,我只是无法很好地理解为什么堆叠两个 .unfold() 调用会产生方形补丁。我明白在张量上使用 unfold() 一次会发生什么。我不明白当你沿着两个不同的维度连续调用它两次时会发生什么。

我已经多次看到这种方法被使用过,但总是没有很好地解释为什么它有效(1、2),这让我发疯。为什么空间维度 HW 排列为dims 1 和2,而通道dim 设置为3?为什么在暗淡 1 上以相同的方式展开,然后在暗淡 2 上导致 patch_width 方格 patch_width 补丁?

任何见解都将不胜感激,即使它只是我错过的文章链接。我已经在谷歌上搜索一个多小时,但收效甚微。谢谢!

[1]PyTorch forum post

[2]Another forum post doing the same thing

解决方法

我想,您的问题有两个不同的部分,第一个是您需要 permute 的原因,其次是两个 unfold 组合如何产生方形图像切片。

第一时刻相当技术性 - unfold 将生成的切片放置在张量的新维度中,“插入到形状的末尾”。此处需要 permute 将其放置在通道或深度维度附近,以便稍后使用 view 以自然方式合并它们。

现在是第二部分。考虑一副虚构的卡片,每张卡片都是一个图片通道。拿一张卡片,把它切成垂直的薄片,然后把薄片叠在一起。拿第二张牌做同样的事情,把结果放在第一张牌上,用所有的牌做。现在重复这个过程,水平切薄片。最后,你有更薄但更高的牌组,以前的牌变成了补丁的子牌组。

,

让我们看一个简单的二维示例,看看为什么组合操作会产生“补丁”。

enter image description here


代码:

x = torch.tensor([[1,2,3,4,5],[6,7,8,9,10],[11,12,13,14,15]])
>>> x.unfold(1,1)
tensor([[[ 1,2],[ 2,3],[ 3,4],[ 4,5]],[[ 6,7],[ 7,8],[ 8,9],[ 9,10]],[[11,12],[12,13],[13,14],[14,15]]])
>>> x.unfold(1,1).unfold(0,1)
tensor([[[[ 1,6],7]],[[ 2,8]],[[ 3,9]],[[ 4,[ 5,10]]],[[[ 6,11],12]],[[ 7,13]],[[ 8,14]],[[ 9,[10,15]]]])