关于每次调用jax.jit的函数重新编译

问题描述

我是 jax 的新手。当我阅读文档时,我对 jit 的缓存行为感到困惑。

caching section 中,它说“避免在循环内调用 jax.jit。这样做有效地在每次调用时创建一个新的 f,每次都会编译它而不是重用相同的缓存函数”。但是,运行以下代码只会产生一种打印副作用:

import jax
def unjitted_loop_body(prev_i):
  print("tracing...")
  return prev_i + 1

def g_inner_jitted_poorly(x,n):
  i = 0
  while i < n:
    # Don't do this!
    i = jax.jit(unjitted_loop_body)(i)
  return x + i

g_inner_jitted_poorly(10,20)
# output:
WARNING:absl:No GPU/TPU found,falling back to cpu. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
tracing...
Out[1]: DeviceArray(30,dtype=int32)

字符串“tracing...”只打印一次,看来jit不会再次跟踪函数

这是故意的吗?感谢您的帮助!

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...