具有不同系数的“基于图”多项式评估的有效方法

问题描述

我们正在尝试实现一个分段函数,基本上是大约 100 个具有不同系数的多项式,具体取决于 x 的值。

这将在带有 JIT 的 TensorFlow 或 jax 中实现,并针对数据数组进行优化。问题是实现这一目标的最佳方法可能是什么?

可以使用一百个 wheres,但这并不是最佳选择。或者将 tf.switch_casetf.vectorize_map(或类似的)一起使用。

有什么想法吗?

解决方法

如果我理解正确,我认为 jax.lax.switch 提供了您感兴趣的功能。例如:

import jax.numpy as jnp
from jax import vmap,lax
import matplotlib.pyplot as plt

def f1(x):
  return 0.0 * x

def f2(x):
  return (x - 1.0) ** 2

def f3(x):
  return 2 * x - 3

branches = (f1,f2,f3)
bounds = jnp.array([1,2])  # boundaries between branches

x = jnp.linspace(0,3)
index = jnp.searchsorted(bounds,x)  # index in branches for each value in x

result = vmap(lambda i,x: lax.switch(i,branches,x))(index,x)
plt.plot(x,result)

enter image description here

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...