问题描述
我有一个函数 apiVersion: apps/v1
kind: StatefulSet
Metadata:
name: MysqL-statefulset
spec:
serviceName: MysqL-service
replicas: 1
selector:
matchLabels:
app: MysqL-pod
template:
Metadata:
labels:
app: MysqL-pod
spec:
containers:
- name: MysqL
image: MysqL
ports:
- containerPort: 3306
volumeMounts:
- name: pvc-test
mountPath: /var/lib/MysqL
volumeClaimTemplates:
- Metadata:
name: pvc-test
spec:
storageClassName: gp2-retain
accessModes: [ "ReadWriteOnce" ]
resources:
requests:
storage: 1Gi
,其中 compute(x)
是一个 x
。现在,我想使用 jnp.ndarray
将其转换为一个函数,该函数接受一批数组 vmap
,然后使用 x[i]
来加速它。 jit
类似于:
compute(x)
然而,每个数组 def compute(x):
# ... some code
y = very_expensive_function(x)
return y
都有不同的长度。我可以通过用尾随零填充数组来轻松解决这个问题,使它们都具有相同的长度 x[i]
并且 N
可以应用于形状为 vmap(compute)
的批次。
然而,这样做会导致 (batch_size,N)
在每个数组 very_expensive_function()
的尾随零上也被调用。有没有办法修改 x[i]
,使得 compute()
只在 very_expensive_function()
的切片上调用,而不干扰 x
和 vmap
?
解决方法
使用 JAX,当您想要 jit 函数以加快速度时,给定的批处理参数 x
必须是定义良好的 ndarray(即 x[i] 必须具有相同的形状)。无论您是否使用 vmap
,都是如此。
现在,通常的处理方法是填充这些数组。这意味着您在参数中添加了一个掩码,这样填充的值就不会影响您的结果。例如,如果我想计算形状 softmax
的填充值 x
的 (bath_size,max_length)
,我需要“禁用”填充值的效果。下面是一个例子:
import jax.numpy as jnp
import jax
PAD = 0
MINUS_INFINITY = -1e6
x = jnp.array([
[1,2,3,4],[1,PAD,PAD],PAD]
])
mask = jnp.array([
[1,1,1],0],0]
])
masked_sofmax = jax.nn.softmax(x + (1-mask)*MINUS_INFINITY)
它不像填充 x
那样微不足道。您需要在每一步实际更改计算以禁用填充效果。在 softmax 的情况下,您可以通过将填充值设置为接近负无穷大来实现。
最后,您无法真正提前知道在使用或不使用填充 + 掩码的情况下速度性能是否会更好。根据我的经验,它通常会导致 CPU 的良好改进,以及 GPU 的非常大的改进。特别是,批次大小的选择对性能有很大影响,因为较高的 batch_size
会在统计上导致较高的 max_length
,因此会导致更多的“无用”计算在填充的值。