在 `CrossShardOptimizer` 中包装优化器的最佳方式

问题描述

假设我有这个代码

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import re
import mesh_tensorflow as mtf
import tensorflow.compat.v1 as tf


def get_optimizer(mesh,loss,params,variable_dtype,inp_var_grads=None):
    """Creates and returns an optimizer training op."""
    global_step = tf.train.get_or_create_global_step()

    learning_rate = tf.constant(value=params["lr"],shape=[],dtype=variable_dtype.slice_dtype)
    clip_value = mtf.constant(mesh,params["gradient_clipping"],dtype=variable_dtype.slice_dtype)

    if inp_var_grads is None:
        var_grads = mtf.gradients([loss],[v.outputs[0] for v in mesh.graph.trainable_variables])
    else:
        var_grads = inp_var_grads

    # Cast to full precision
    var_grads_fp = [mtf.cast(v,variable_dtype.slice_dtype) for v in var_grads]

    # decrease LR to final lr (lr*0.1) by this step - defaults to train_steps
    end_step = params.get("lr_decay_end",params["train_steps"]) 

    if params["lr_decay"] == "linear":
        learning_rate = tf.train.polynomial_decay(
            learning_rate,global_step,end_step,end_learning_rate=params["lr"]*0.1,# Decrease to 10% of initial LR according to GPT-3 paper
            power=1.0,cycle=False)
    elif params["lr_decay"] == "cosine":
        learning_rate = tf.train.cosine_decay(
            learning_rate,alpha=0.1  # Alpha is min lr value as a fraction of init lr.
        )

    if params["warmup_steps"] > 0:
        global_steps_int = tf.cast(global_step,tf.int32)
        warmup_steps_int = tf.constant(params["warmup_steps"],dtype=tf.int32)

        dtype = variable_dtype.slice_dtype

        global_steps_float = tf.cast(global_steps_int,dtype)
        warmup_steps_float = tf.cast(warmup_steps_int,dtype)

        warmup_percent_done = global_steps_float / warmup_steps_float
        warmup_learning_rate = learning_rate * warmup_percent_done

        is_warmup = tf.cast(global_steps_int < warmup_steps_int,dtype)
        learning_rate = ((1.0 - is_warmup) * learning_rate +
                       is_warmup * warmup_learning_rate)

    learning_rate = mtf.import_fully_replicated(mesh,learning_rate,mtf.Shape([]),name="learning_rate")
    mtf.scalar_summary("lr",learning_rate)

    if params["opt_name"].lower() == "adam":
        optimizer = AdamWeightDecayOptimizer(
            learning_rate=learning_rate,weight_decay_rate=params["weight_decay"],beta_1=params["beta1"],beta_2=params["beta2"],epsilon=params["epsilon"],exclude_from_weight_decay=["norm","bias"],variable_dtype=variable_dtype
        )
    else:
        optimizer = mtf.optimize.AdafactorOptimizer(
            learning_rate=params["lr"],decay_rate=params["weight_decay"],beta1=params["beta1"],epsilon1=params["ada_epsilon1"],epsilon2=params["ada_epsilon2"]
        )

    if params["use_tpu"]:
        optimizer = tf.tpu.CrossShardOptimizer(optimizer)


    if params["gradient_clipping"] is not None:
        (var_grads_fp,_) = clip_by_global_norm(var_grads_fp,clip_norm=clip_value)

    update_ops = optimizer.apply_grads(var_grads_fp,mesh.graph.trainable_variables)
    return learning_rate,update_ops,var_grads_fp


class AdamWeightDecayOptimizer(mtf.optimize.Optimizer):
  """A basic Adam optimizer that includes "correct" L2 weight decay."""

  def __init__(self,weight_decay_rate=0.0,beta_1=0.9,beta_2=0.999,epsilon=1e-6,exclude_from_weight_decay=None,variable_dtype=None):
    """Constructs a AdamWeightDecayOptimizer."""

    self.learning_rate = learning_rate
    self.weight_decay_rate = weight_decay_rate
    self.beta_1 = beta_1
    self.beta_2 = beta_2
    self.epsilon = epsilon
    self.exclude_from_weight_decay = exclude_from_weight_decay
    self.variable_dtype = variable_dtype

  def apply_grad(self,grad,var):
    """See base class."""
    if grad is None:
      tf.logging.warning("Gradient is None for variable %s" % var.name)
      return []
    
    grad = mtf.to_float(grad)

    assignments = []

    m = mtf.get_variable(
        var.mesh,var.name + "/adam_m",var.shape,initializer=tf.zeros_initializer(),# master_dtype=self.variable_dtype.master_dtype,# slice_dtype=self.variable_dtype.slice_dtype,# activation_dtype=self.variable_dtype.activation_dtype,trainable=False)

    v = mtf.get_variable(
        var.mesh,var.name + "/adam_v",trainable=False)

    # Standard Adam update.
    next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad
    next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad)

    update = next_m / (mtf.sqrt(next_v) + self.epsilon)

    # Just adding the square of the weights to the loss function is *not*
    # the correct way of using L2 regularization/weight decay with Adam,# since that will interact with the m and v parameters in strange ways.
    #
    # Instead we want to decay the weights in a manner that doesn't interact
    # with the m/v parameters. This is equivalent to adding the square
    # of the weights to the loss with plain (non-momentum) SGD.
    if self._do_use_weight_decay(var.name):
      update += mtf.to_float(var.value) * self.weight_decay_rate 

    update_with_lr = self.learning_rate * update

    var_update = mtf.assign_sub(var,update_with_lr)

    assignments.extend(
        [var_update,mtf.assign(m,next_m),mtf.assign(v,next_v)])
    return assignments

代码在运行时导致以下错误

TypeError:CrossShardOptimizer 仅适用于 tf.training.Optimizer 而不适用于 Optimizer_v2。如果您使用 TPUStrategy,OptimizerV2 将对副本之间的梯度求和。如果你想平均你的梯度,重新调整你的损失: loss /= global_batch_size

所以我想知道处理它的最佳方法是什么?是否有另一个与 v2 优化器兼容的 CrossShardOptimizer 包装器?我应该重写网格张量流优化器吗?也许 tensorflow 的一个子模块已经实现了可以在 TPU 上运行的优化器?

解决方法

使用 CrossShardOptimizer 包装优化器是将 Estimator 模型移植到 TPUEstimator 模型所必需的,因为这可以处理跨 TPU 分片的平均梯度。

对于 Mesh,这有点不同,因为 TPU implementation 采用 SIMD(单指令,多设备)理念。因此,您不会看到任何使用 CrossShardOptimizer 的 MTF 实现,但实际上 TPU 支持 mtf.optimize.Optimizer。它只需要更改 SIMD,而不需要更改优化器级别。

如果您还没有看到这个,here 是一个在 MNIST 上运行的 Mesh TF 示例,应该会有所帮助。