使用tensorflow 2.x

问题描述

给出类A[A() for _ in range(5)]的实例的列表,我想随机选择其中一个(请参见以下代码作为示例)

class A:
    def __init__(self,a):
        self.a = a
    def __call__(self):
        return self.a
def f():
    a_list = [A(i) for i in range(5)]
    a = a_list[random.randint(0,5)]()
    return a

f()

是否有一种方法可以用f装饰@tf.function而又不更改f的功能,也不需要调用a_list中的所有项目?

请注意,用f直接修饰@tf.function而不更改上面的代码是不可行的,因为它将始终返回相同的结果。另外,我知道可以通过首先调用a_list中的所有元素,然后使用tf.gather_nd对其进行索引来实现。但是,如果调用类型为A的对象涉及到深度神经网络,则会产生大量开销。

解决方法

此刻我正在做同样的事情。这是到目前为止我得到的。如果有人知道更好的方法,我也会有兴趣听的。当我在昂贵的电话上运行它时,它比计算并返回所有值的速度要适当地快。

@tf.function
def f2():
    a_list = [A(i) for i in range(5)]
    idx = tf.cast(tf.random.uniform(shape=[],maxval=4),tf.int32)
    return tf.switch_case(idx,a_list)

为了进行速度比较,我制作了A昂贵矩阵代数的调用方法。然后考虑一个替代函数,它调用每个函数:

@tf.function
def f3():
    a_list = [A(i) for i in range(40)]
    results = [a() for a in a_list]
    return results

以40个元素运行f2:0.42643秒

使用40个元素运行f3:14.9153秒

因此,对于仅选择一个分支的预期40倍加速,这似乎是正确的。

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...