将一段代码与 jax 跟踪隔离

问题描述

提前为这个问题的含糊程度致歉(不幸的是,我对 jax 跟踪的工作原理知之甚少,无法更准确地表达它),但是:有没有办法将函数代码块与 jax 跟踪完全隔离?

对于上下文,我有以下形式的函数

def f(x,y):
   z = h(y)
   return g(x,z)

本质上,我想调用 g(x,z),并在执行任何 jax 转换时将 z 视为常量。但是,设置参数z非常笨拙,因此使用辅助函数h将更易于指定的输入y转换为g所需的格式.我希望 jax 将 h 视为不可追踪的黑匣子,因此对特定 jit(lambda x: f(x,y0)) 执行 y0 与第一次计算 z0 = h(y0) 相同使用 numpy,然后执行 jit(lambda x: g(x,z0))(与 grad 或任何其他函数转换类似)。

在我的代码中,我已经编写了 h 只使用标准的 numpy(我认为这可能会导致黑盒行为),但是 jit(lambda x: f(x,y0)) 的编译时间是明显长于 jit(lambda x: g(x,z0))z0 = h(y0) 编译时间。我有一种感觉,编译时间可能与 jax 跟踪 h 中的许多循环有关,但我不确定。

一些附加说明:

  • 以对 jax 友好的方式编写 h 会很尴尬(输入格式参差不齐、大量循环/条件、依赖于输入值的输出形状等),并且最终会比函数的价值更麻烦执行起来非常便宜,而且我不需要区分它(输入数据是基于整数的)。

想法?

为清楚起见编辑添加:我知道可能有办法解决这个问题,例如f 是顶级函数在这种情况下,让用户首先调用 h 来“预编译”对 g 的 jax 友好输入,然后自由地执行他们想要的任何 jax 转换并不是什么大问题lambda x: g(x,z0)。但是,我想象的情况是,我们有许多要链接在一起的函数,它们具有与 f 相同的结构,其中存在一些对 jax 不友好的输入/计算,但这些输入将始终被处理作为计算的 jax 部分的常量。原则上,我们总是可以提取这些预先计算来设置 jax 的东西,但是如果我们有一个将相互调用的此类函数的重要集合,这似乎很困难。

是否有某种方法可以控制跟踪 f 的方式,以便在跟踪时它知道只评估 z=h(y)(而不是跟踪 h)然后继续跟踪 {{1} }?

解决方法

f_jitted = jax.jit(f,static_argnums=1)

static_argnums 参数可能有帮助

https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

您可以使用诸如 static_argnums 之类的转换参数来代替 jit 来避免跟踪转换函数的特定参数,但代价是需要更多的重新编译。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...