JAX梯度累积

问题描述

我编写了一个简单的脚本来尝试使用 JAX 进行梯度累积。这个想法是将大批量(例如 64)分成适合 GPU 内存的小块(例如 4)。对于每个 chunck,存储在 pytree 中的结果梯度将添加到当前批次梯度中。仅当计算出大批量的所有块时才进行更新。在这个特定的例子中,我们只是尝试将随机 512 维向量拟合到具有线性层的随机布尔值。这是脚本:

import jax
import jax.numpy as jnp
from jax import jit,random
from jax.experimental import optimizers
from functools import partial
from jax.nn.initializers import normal,zeros
from typing import Callable
from dataclasses import dataclass

@dataclass
class Jax_model:
    init_fun: Callable
    apply_fun: Callable


def Dense(input_size: int,output_size: int,init_kernel=normal(),init_bias=zeros):

    def init_fun(key):
        key,sub_key1,sub_key2 = jax.random.split(key,3)
        params = {
            'I': init_kernel(sub_key1,(input_size,output_size) ),'I_b': init_bias(sub_key2,(1,}
        return params

    def apply_fun(params,inputs):
        I,I_b,= params['I'],params['I_b']
        logits = inputs @ I + I_b
        return logits

    return Jax_model(init_fun,apply_fun)


def divide_pytree(pytree,div):
    for pt in jax.tree_util.tree_leaves(pytree):
        pt = pt / div
    return pytree


def add_pytrees(pytree1,pytree2):
    for pt1,pt2 in zip( jax.tree_util.tree_leaves(pytree1),jax.tree_util.tree_leaves(pytree2) ):
        pt1 = pt1 + pt2
    return pytree1


rng_key = random.PRNGKey(42)
batch_size = 64
accumulation_size = 4
model_dim = 512
n_iter = 50

model = Dense(model_dim,1)
rng_key,sub_key = random.split(rng_key)
init_params = model.init_fun(sub_key)
opt_init,opt_update,get_params = optimizers.adam(0.001)
opt_state = opt_init(init_params)

@jit
def update(i,current_opt_state,current_batch):
    N = current_batch[0].shape[0]
    K = accumulation_size
    num_gradients = N//K
    accumulation_batch = (current_batch[ib][0:K] for ib in range(len(current_batch)))
    value,grads = jax.value_and_grad(loss_func)(get_params(current_opt_state),accumulation_batch)
    value = value / num_gradients
    grads = divide_pytree(grads,num_gradients)
    for k in range(K,N,K):
        accumulation_batch = (current_batch[ib][k:k+K] for ib in range(len(current_batch)))
        new_value,new_grads = jax.value_and_grad(loss_func)(get_params(current_opt_state),accumulation_batch)
        value = value + (new_value / num_gradients)
        grads = add_pytrees(grads,divide_pytree(new_grads,num_gradients))
    return opt_update(i,grads,current_opt_state),value

def loss_func(current_params,current_batch):
    inputs,labels = current_batch
    predictions = model.apply_fun(current_params,inputs)
    loss = jnp.square(labels-predictions).sum()
    return loss

for i in range(n_iter):
    rng_key,sub_key2 = random.split(rng_key,3)
    inputs = jax.random.uniform(sub_key1,(batch_size,model_dim))
    labels = jax.random.uniform(sub_key2,1)) > 0.5
    batch = inputs,labels
    opt_state,batch_loss = update(i,opt_state,batch)
    print(i,batch_loss)

我对 divide_pytreeadd_pytrees 有疑问。它实际上是修改了当前的批处理梯度还是我遗漏了什么?此外,您是否看到此代码有任何速度问题?特别是,我应该使用 jax.lax.fori_loop 代替传统的 python for 循环吗?

相关链接

解决方法

关于 pytree 计算:如所写,您的函数返回未修改的输入。更好的方法是使用 jax.tree_util.tree_map;例如:

from jax.tree_util import tree_map

def divide_pytree(pytree,div):
  return tree_map(lambda pt: pt / div,pytree)

def add_pytrees(pytree1,pytree2):
  return tree_map(lambda pt1,pt2: pt1 + pt2,pytree1,pytree2)

关于性能:在 JIT 编译时,for 循环中的任何内容都将被展平,每次循环迭代都会重复一份所有 XLA 指令的副本。如果你有 5 次迭代,那真的不是问题。如果您有 5000 个,那将显着减慢编译时间(因为 XLA 需要分析和优化循环中指令的 5000 个显式副本)。

fori_loop 可以提供帮助,但不会产生最佳代码,尤其是在 CPU 和 GPU 上运行时。

最好使用广播或 vmapped 操作在可能的情况下表达循环的逻辑,而无需显式循环。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...