如何使用TensorFlow 1.14从3D张量获取最大2D张量?

问题描述

我正在寻找最佳和优化的方式(无循环),使用TensorFlow 1.14根据最大的一个值从3D张量中获取2D最大张量。假设我们有这个Tensor和这个函数(为了理解,它不起作用):

def get_Max(inputs):
    max_indices = [0,0]
    for i in range(16):
        for j in range(2048):
            for k in range(10):
                if(inputs[max_indices[0],max_indices[1],max_indices[2]]<inputs[i,j,k]):
                   max_indices = [i,k]
    return inputs[:][j]
inputs = tf.random.uniform(shape=[16,2048,10],dtype=tf.dtypes.float32)
output = get_Max(inputs)

因此,输出张量必须具有[16,10]的形状,这是从2048开始的16个最大值。 那么,如何实现一个无需循环即可执行此功能功能

我使用了tf.math.reduce_max,但这并不是我要找的东西,如下图所示:

enter image description here

解决方法

inp = tf.random.uniform(shape=[4,6,2],maxval=20,dtype=tf.int32)
print(inp)

array([[[14,8],[18,10],[ 6,14],[ 8,9],[11,11],[14,13]],[[ 7,18],[ 4,[15,6],[19,[10,4]],[[ 8,1],[ 1,3],17],7],[ 0,0],[[ 5,12],16],[ 3,[ 2,18]]],dtype=int32)>

因此,如果我理解正确,对于每个inp[i,:,:],例如:

    [[14,13]]

您要保留包含最大数量的项目,在本例中为第二行:[18,10]。我要做的是首先沿着最后一个轴找到最大数量:

am = tf.math.reduce_max(inp,axis=2)
am[0,:]
[14,18,14,9,11,14]

,然后找到包含最大数量的行的索引:

am = tf.math.argmax(am,axis=1)

这些将是您想要的j,然后您可以使用tf.gather_nd并枚举以获得这些值:

# [*enumerate(am)] = [(0,am[0]),(1,am[1]),...]
tf.gather_nd(inp,[*enumerate(am)])

<tf.Tensor: shape=(4,2),dtype=int32,numpy=
array([[18,18]],dtype=int32)>