如何通过 Pytorch 中的跟踪值有效地成对池化张量?

问题描述

我有一个形状为 T 的 pytorch 张量 (batch_size,window_size,filters,3,3),我想通过跟踪来汇集张量。具体来说,我想通过比较成对帧的轨迹来获得大小为 T_pooled 的张量 (batch_size,window_size//2,3)。例如,如果window_size=4,那么我们将比较T[i,k,3]T[i,1,3]的轨迹,并选择轨迹较小的子张量作为T_pooled[i,3]。同样,比较T[i,2,3]得到T_pooled[i,3]

这可以通过循环 ik 来完成,但这非常缓慢且效率低下。有没有办法对这个池化操作进行矢量化以加快速度?

编辑: 这是我迄今为止尝试过的。它使用列表理解和 for 循环。在大小为 (128,120,22,3) 的张量上运行大约需要 2.5 秒。

def TPL_Pairwise(x):
    x_pooled=torch.zeros(x.shape[0],x.shape[1]//2,x.shape[2],x.shape[3],x.shape[4])
    #compute tensorized trace
    trace=torch.einsum('ijkll->ijkl',x).sum(-1)  
    for i in range(x.shape[0]):  
        for j in range(x.shape[2]):
            keep=[ x[i,j] if trace[i,j] <= trace[i,k+1,j] else x[i,j] for k in range(0,x.shape[1],2)]
            x_pooled[i,:,j]=torch.stack(keep)
    return x_pooled

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)