问题描述
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import mesh_tensorflow as mtf
import tensorflow.compat.v1 as tf
def get_optimizer(mesh,loss,params,variable_dtype,inp_var_grads=None):
"""Creates and returns an optimizer training op."""
global_step = tf.train.get_or_create_global_step()
learning_rate = tf.constant(value=params["lr"],shape=[],dtype=variable_dtype.slice_dtype)
clip_value = mtf.constant(mesh,params["gradient_clipping"],dtype=variable_dtype.slice_dtype)
if inp_var_grads is None:
var_grads = mtf.gradients([loss],[v.outputs[0] for v in mesh.graph.trainable_variables])
else:
var_grads = inp_var_grads
# Cast to full precision
var_grads_fp = [mtf.cast(v,variable_dtype.slice_dtype) for v in var_grads]
# decrease LR to final lr (lr*0.1) by this step - defaults to train_steps
end_step = params.get("lr_decay_end",params["train_steps"])
if params["lr_decay"] == "linear":
learning_rate = tf.train.polynomial_decay(
learning_rate,global_step,end_step,end_learning_rate=params["lr"]*0.1,# Decrease to 10% of initial LR according to GPT-3 paper
power=1.0,cycle=False)
elif params["lr_decay"] == "cosine":
learning_rate = tf.train.cosine_decay(
learning_rate,alpha=0.1 # Alpha is min lr value as a fraction of init lr.
)
if params["warmup_steps"] > 0:
global_steps_int = tf.cast(global_step,tf.int32)
warmup_steps_int = tf.constant(params["warmup_steps"],dtype=tf.int32)
dtype = variable_dtype.slice_dtype
global_steps_float = tf.cast(global_steps_int,dtype)
warmup_steps_float = tf.cast(warmup_steps_int,dtype)
warmup_percent_done = global_steps_float / warmup_steps_float
warmup_learning_rate = learning_rate * warmup_percent_done
is_warmup = tf.cast(global_steps_int < warmup_steps_int,dtype)
learning_rate = ((1.0 - is_warmup) * learning_rate +
is_warmup * warmup_learning_rate)
learning_rate = mtf.import_fully_replicated(mesh,learning_rate,mtf.Shape([]),name="learning_rate")
mtf.scalar_summary("lr",learning_rate)
if params["opt_name"].lower() == "adam":
optimizer = AdamWeightDecayOptimizer(
learning_rate=learning_rate,weight_decay_rate=params["weight_decay"],beta_1=params["beta1"],beta_2=params["beta2"],epsilon=params["epsilon"],exclude_from_weight_decay=["norm","bias"],variable_dtype=variable_dtype
)
else:
optimizer = mtf.optimize.AdafactorOptimizer(
learning_rate=params["lr"],decay_rate=params["weight_decay"],beta1=params["beta1"],epsilon1=params["ada_epsilon1"],epsilon2=params["ada_epsilon2"]
)
if params["use_tpu"]:
optimizer = tf.tpu.CrossShardOptimizer(optimizer)
if params["gradient_clipping"] is not None:
(var_grads_fp,_) = clip_by_global_norm(var_grads_fp,clip_norm=clip_value)
update_ops = optimizer.apply_grads(var_grads_fp,mesh.graph.trainable_variables)
return learning_rate,update_ops,var_grads_fp
class AdamWeightDecayOptimizer(mtf.optimize.Optimizer):
"""A basic Adam optimizer that includes "correct" L2 weight decay."""
def __init__(self,weight_decay_rate=0.0,beta_1=0.9,beta_2=0.999,epsilon=1e-6,exclude_from_weight_decay=None,variable_dtype=None):
"""Constructs a AdamWeightDecayOptimizer."""
self.learning_rate = learning_rate
self.weight_decay_rate = weight_decay_rate
self.beta_1 = beta_1
self.beta_2 = beta_2
self.epsilon = epsilon
self.exclude_from_weight_decay = exclude_from_weight_decay
self.variable_dtype = variable_dtype
def apply_grad(self,grad,var):
"""See base class."""
if grad is None:
tf.logging.warning("Gradient is None for variable %s" % var.name)
return []
grad = mtf.to_float(grad)
assignments = []
m = mtf.get_variable(
var.mesh,var.name + "/adam_m",var.shape,initializer=tf.zeros_initializer(),# master_dtype=self.variable_dtype.master_dtype,# slice_dtype=self.variable_dtype.slice_dtype,# activation_dtype=self.variable_dtype.activation_dtype,trainable=False)
v = mtf.get_variable(
var.mesh,var.name + "/adam_v",trainable=False)
# Standard Adam update.
next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad
next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad)
update = next_m / (mtf.sqrt(next_v) + self.epsilon)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
if self._do_use_weight_decay(var.name):
update += mtf.to_float(var.value) * self.weight_decay_rate
update_with_lr = self.learning_rate * update
var_update = mtf.assign_sub(var,update_with_lr)
assignments.extend(
[var_update,mtf.assign(m,next_m),mtf.assign(v,next_v)])
return assignments
TypeError:CrossShardOptimizer 仅适用于 tf.training.Optimizer 而不适用于 Optimizer_v2。如果您使用 TPUStrategy,OptimizerV2 将对副本之间的梯度求和。如果你想平均你的梯度,重新调整你的损失: loss /= global_batch_size
所以我想知道处理它的最佳方法是什么?是否有另一个与 v2 优化器兼容的 CrossShardOptimizer
包装器?我应该重写网格张量流优化器吗?也许 tensorflow 的一个子模块已经实现了可以在 TPU 上运行的优化器?
解决方法
使用 CrossShardOptimizer
包装优化器是将 Estimator
模型移植到 TPUEstimator
模型所必需的,因为这可以处理跨 TPU 分片的平均梯度。
对于 Mesh,这有点不同,因为 TPU implementation 采用 SIMD(单指令,多设备)理念。因此,您不会看到任何使用 CrossShardOptimizer
的 MTF 实现,但实际上 TPU 支持 mtf.optimize.Optimizer
。它只需要更改 SIMD,而不需要更改优化器级别。
如果您还没有看到这个,here 是一个在 MNIST 上运行的 Mesh TF 示例,应该会有所帮助。