如何优化用于 TensorRT 推理的 grid_sample 的自定义双线性采样替代方案?

问题描述

我正在尝试使用 torch.nn.functional.grid_sample 将模型从 Pytorch (1.6) 转换为 TensorRT (7) 通过 ONNX (opset 11)。 Opset 11 不支持 grid_sample 转换。 我发现的自定义替代方案 (https://github.com/pytorch/pytorch/issues/27212) 在 Pytorch 中运行时非常慢,并且在将主循环转换为 TRT 时存在问题。

我自己实现的双线性采样(不仅仅是 grid_sample,而是整个原始采样,基于 grid_sample)在 Pytorch 中执行得更快,并成功转换为 TRT。 但是我在 TRT 中的自定义双线性采样比 Pytorch 中的要慢(5.6 ms vs 2.0 ms)。事实证明,Pytorch image[:,ind,y0,x0] 索引产生了运行时间约为 0.97 ms 的 Gather 层。而这种双线性采样的TRT版本中有4个这样的层。

所以问题是:

  • 我应该如何优化 Pytorch 代码以获得有效的 TRT 模型?
  • 我应该怎么做才能使 Gather 层执行得更快?
  • 将此函数创建为自定义 TRT 插件是否有助于加快速度?

这里是双线性采样函数代码

def bilinear_sample_noloop(image,grid):
    """
    :param image: sampling source of shape [N,C,H,W]
    :param grid: integer sampling pixel coordinates of shape [N,grid_H,grid_W,2]
    :return: sampling result of shape [N,grid_W]
    """
    Nt,W = image.shape
    grid_H = grid.shape[1]
    grid_W = grid.shape[2]
    xgrid,ygrid = grid.split([1,1],dim=-1)
    mask = ((xgrid >= 0) & (ygrid >= 0) & (xgrid < W - 1) & (ygrid < H - 1)).float()
    x0 = torch.floor(xgrid)
    x1 = x0 + 1
    y0 = torch.floor(ygrid)
    y1 = y0 + 1
    wa = ((x1 - xgrid) * (y1 - ygrid)).permute(3,1,2)
    wb = ((x1 - xgrid) * (ygrid - y0)).permute(3,2)
    wc = ((xgrid - x0) * (y1 - ygrid)).permute(3,2)
    wd = ((xgrid - x0) * (ygrid - y0)).permute(3,2)
    x0 = (x0 * mask).view(Nt,grid_W).long()
    y0 = (y0 * mask).view(Nt,grid_W).long()
    x1 = (x1 * mask).view(Nt,grid_W).long()
    y1 = (y1 * mask).view(Nt,grid_W).long()
    ind = torch.arange(Nt,device=image.device) #torch.linspace(0,Nt - 1,Nt,device=image.device)
    ind = ind.view(Nt,1).expand(-1,grid_H).view(Nt,-1,grid_W).long()
    image = image.permute(1,2,3)
    output_tensor = (image[:,x0] * wa + image[:,y1,x0] * wb + image[:,x1] * wc + \
                 image[:,x1] * wd).permute(1,3)
    output_tensor *= mask.permute(0,3,2).expand(-1,-1)
    image = image.permute(1,3)
    return output_tensor,mask

时间分析参数:

  • 时间分析实验在笔记本电脑 Dell G3 15(Core i7 8750H 2.2 GHz x12、16 Gb RAM (2666MHz)、NVidia GeForce GTX 1050 Ti)上进行。
  • 用于分析的 Pytorch 环境:Python 3.7 Anaconda 3 环境、Pytorch 1.6。 Pytorch 时间分析在每个时间戳之前通过 time.time() 和 torch.synchronize() 执行。
  • 用于分析的 TRT 环境:Docker 容器 http://nvcr.io/nvidia/tensorrt:20.06-py3。使用 trtexec 以及自定义 C++ 和 Python 代码执行分析。所有三个结果都很接近。

使用 trtexec 进行 TRT 模型分析的一部分:

     Layer   Time (ms)   Avg. Time (ms)   Time %
...
   Mul_146        5.82             0.03      0.5
   Add_147        8.50             0.04      0.7
Gather_148      214.39             0.97     17.3
Gather_174      214.25             0.97     17.3
Gather_201      213.88             0.97     17.3
Gather_228      214.48             0.97     17.3
 Add_237))       25.01             0.11      2.0
   Mul_251        7.84             0.04      0.6
     Total     1238.40             5.60    100.0

此外,我尝试将图像视为除 C 之外的所有维度上的线性数组,并创建线性索引以寻址形式 image[:,p0] 中的元素。对于这种情况,Gather 变得更慢(大约 1.07 毫秒)。然后我考虑了 C=1(它总是在原始模型中)并将张量元素处理为 image[p0]。这次 Gather 大约需要 0.92 毫秒(仍然太慢)。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)