如何在 XLA 上下文中使用 Tensorarray?

问题描述

TF 版本:2.4.0

我有一些在 XLA 中编译成本很高的代码(约 10 分钟)。我正在寻找一种方法来加速这个编译。我注意到我的代码库在其中一个 for 循环中生成了大量计算图节点。因此我打算用 tf.while_loop 替换这个 for 循环。为此,我需要使用 TensorArray。我创建了一个可重现的小玩具示例来展示我现在正在解决的问题。

我的问题是:如何在 XLA 上下文中正确构建 TensorArray? 我知道由于 XLA 的限制,我无法在 XLA 上下文之外构建这个 TensorArray。 如果有另一种方法可以避免 XLA 中的 for 循环,我很想知道。 对我来说,重要的是我使用的张量的某些轴会有所不同。

import tensorflow as tf

array = [
    tf.random.uniform((19,1,3,2)),tf.random.uniform((19,2,tf.random.uniform((3,tf.random.uniform((18,]

size = len(array)

@tf.function(experimental_compile=True) # The same happens with autograph=False
def add(array):
    
    tensor_array = tf.TensorArray(
        dtype=tf.float32,size=size,infer_shape=False,element_shape=tf.TensorShape([None,None,2]),)
    for i in range(size):
        tensor_array = tensor_array.write(i,array[i])
    # There would be a while_loop here

r = add(array)

当我运行这个简单的代码时,我得到一个 Internal Error(如下)。但是,当张量具有相同的形状时,代码运行流畅。因此我有两个假设:

1)。 infer_shape=False 被忽略

2)。 XLA 不支持尺寸中的 None

感谢任何帮助


InternalError: Invalid TensorList shape: element_type: TUPLE
tuple_shapes {
  element_type: F32
  dimensions: 4
  dimensions: 19
  dimensions: 1
  dimensions: 3
  dimensions: 2
  layout {
    minor_to_major: 4
    minor_to_major: 3
    minor_to_major: 2
    minor_to_major: 1
    minor_to_major: 0
    format: DENSE
  }
  is_dynamic_dimension: false
  is_dynamic_dimension: false
  is_dynamic_dimension: false
  is_dynamic_dimension: false
  is_dynamic_dimension: false
}
tuple_shapes {
  element_type: S32
  layout {
    format: DENSE
  }
},expected: element_type: TUPLE
tuple_shapes {
  element_type: F32
  dimensions: 4
  dimensions: 19
  dimensions: 1
  dimensions: 2
  dimensions: 2
  layout {
    minor_to_major: 4
    minor_to_major: 3
    minor_to_major: 2
    minor_to_major: 1
    minor_to_major: 0
    format: DENSE
  }
  is_dynamic_dimension: false
  is_dynamic_dimension: false
  is_dynamic_dimension: false
  is_dynamic_dimension: false
  is_dynamic_dimension: false
}
tuple_shapes {
  element_type: S32
  layout {
    format: DENSE
  }
}


     [[{{node TensorArrayV2Write_1/TensorListSetItem}}]] [Op:__inference_add_65]

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...