问题描述
我编写了一个简单的脚本来尝试使用 JAX 进行梯度累积。这个想法是将大批量(例如 64)分成适合 GPU 内存的小块(例如 4)。对于每个 chunck,存储在 pytree 中的结果梯度将添加到当前批次梯度中。仅当计算出大批量的所有块时才进行更新。在这个特定的例子中,我们只是尝试将随机 512 维向量拟合到具有线性层的随机布尔值。这是脚本:
import jax
import jax.numpy as jnp
from jax import jit,random
from jax.experimental import optimizers
from functools import partial
from jax.nn.initializers import normal,zeros
from typing import Callable
from dataclasses import dataclass
@dataclass
class Jax_model:
init_fun: Callable
apply_fun: Callable
def Dense(input_size: int,output_size: int,init_kernel=normal(),init_bias=zeros):
def init_fun(key):
key,sub_key1,sub_key2 = jax.random.split(key,3)
params = {
'I': init_kernel(sub_key1,(input_size,output_size) ),'I_b': init_bias(sub_key2,(1,}
return params
def apply_fun(params,inputs):
I,I_b,= params['I'],params['I_b']
logits = inputs @ I + I_b
return logits
return Jax_model(init_fun,apply_fun)
def divide_pytree(pytree,div):
for pt in jax.tree_util.tree_leaves(pytree):
pt = pt / div
return pytree
def add_pytrees(pytree1,pytree2):
for pt1,pt2 in zip( jax.tree_util.tree_leaves(pytree1),jax.tree_util.tree_leaves(pytree2) ):
pt1 = pt1 + pt2
return pytree1
rng_key = random.PRNGKey(42)
batch_size = 64
accumulation_size = 4
model_dim = 512
n_iter = 50
model = Dense(model_dim,1)
rng_key,sub_key = random.split(rng_key)
init_params = model.init_fun(sub_key)
opt_init,opt_update,get_params = optimizers.adam(0.001)
opt_state = opt_init(init_params)
@jit
def update(i,current_opt_state,current_batch):
N = current_batch[0].shape[0]
K = accumulation_size
num_gradients = N//K
accumulation_batch = (current_batch[ib][0:K] for ib in range(len(current_batch)))
value,grads = jax.value_and_grad(loss_func)(get_params(current_opt_state),accumulation_batch)
value = value / num_gradients
grads = divide_pytree(grads,num_gradients)
for k in range(K,N,K):
accumulation_batch = (current_batch[ib][k:k+K] for ib in range(len(current_batch)))
new_value,new_grads = jax.value_and_grad(loss_func)(get_params(current_opt_state),accumulation_batch)
value = value + (new_value / num_gradients)
grads = add_pytrees(grads,divide_pytree(new_grads,num_gradients))
return opt_update(i,grads,current_opt_state),value
def loss_func(current_params,current_batch):
inputs,labels = current_batch
predictions = model.apply_fun(current_params,inputs)
loss = jnp.square(labels-predictions).sum()
return loss
for i in range(n_iter):
rng_key,sub_key2 = random.split(rng_key,3)
inputs = jax.random.uniform(sub_key1,(batch_size,model_dim))
labels = jax.random.uniform(sub_key2,1)) > 0.5
batch = inputs,labels
opt_state,batch_loss = update(i,opt_state,batch)
print(i,batch_loss)
我对 divide_pytree
和 add_pytrees
有疑问。它实际上是修改了当前的批处理梯度还是我遗漏了什么?此外,您是否看到此代码有任何速度问题?特别是,我应该使用 jax.lax.fori_loop
代替传统的 python for 循环吗?
相关链接:
- https://github.com/google/jax/issues/1488
- https://github.com/google-research/long-range-arena/issues/4
解决方法
关于 pytree 计算:如所写,您的函数返回未修改的输入。更好的方法是使用 jax.tree_util.tree_map
;例如:
from jax.tree_util import tree_map
def divide_pytree(pytree,div):
return tree_map(lambda pt: pt / div,pytree)
def add_pytrees(pytree1,pytree2):
return tree_map(lambda pt1,pt2: pt1 + pt2,pytree1,pytree2)
关于性能:在 JIT 编译时,for
循环中的任何内容都将被展平,每次循环迭代都会重复一份所有 XLA 指令的副本。如果你有 5 次迭代,那真的不是问题。如果您有 5000 个,那将显着减慢编译时间(因为 XLA 需要分析和优化循环中指令的 5000 个显式副本)。
fori_loop
可以提供帮助,但不会产生最佳代码,尤其是在 CPU 和 GPU 上运行时。
最好使用广播或 vmapped 操作在可能的情况下表达循环的逻辑,而无需显式循环。