问题描述
我们正在尝试实现一个分段函数,基本上是大约 100 个具有不同系数的多项式,具体取决于 x 的值。
这将在带有 JIT 的 TensorFlow 或 jax 中实现,并针对数据数组进行优化。问题是实现这一目标的最佳方法可能是什么?
可以使用一百个 wheres,但这并不是最佳选择。或者将 tf.switch_case
与 tf.vectorize_map
(或类似的)一起使用。
有什么想法吗?
解决方法
如果我理解正确,我认为 jax.lax.switch
提供了您感兴趣的功能。例如:
import jax.numpy as jnp
from jax import vmap,lax
import matplotlib.pyplot as plt
def f1(x):
return 0.0 * x
def f2(x):
return (x - 1.0) ** 2
def f3(x):
return 2 * x - 3
branches = (f1,f2,f3)
bounds = jnp.array([1,2]) # boundaries between branches
x = jnp.linspace(0,3)
index = jnp.searchsorted(bounds,x) # index in branches for each value in x
result = vmap(lambda i,x: lax.switch(i,branches,x))(index,x)
plt.plot(x,result)