关于使用批处理规范化

问题描述

我正在尝试确保将批归一化层正确地合并到模型中。

下面的代码段说明了我在做什么。

  1. 这是否适合使用批量标准化?
  2. 在推断时,如何访问每个批次归一化层中的移动平均值以确保已加载它们?

列表项

import tensorflow.v1.compat as tf
from model import Model

# Sample batch normalization layer in the Model class
x_preBN = ...
x_postBN = tf.layers.batch_normalization(inputs=x_preBN,center=True,scale=True,momentum=0.9,training=(self.mode == 'train'))

# During training:
model = Model(mode='train')
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

with tf.Session() as sess:
  for it in range(max_iterations):
    # Training step + update of BN moving statistics
    sess.run([train_step,extra_update_ops],Feed_dict=...)

    # Store checkpoint
    if ii % num_checkpoint_steps == 0:
        saver.save(sess,os.path.join(model_dir,'checkpoint'),global_step=it)
        

# During inference:
model = Model(mode='eval')
with tf.Session() as sess:
  saver.restore(sess,'checkpoint-???'))
  acc = sess.run(model.accuracy,Feed_dict=...)

解决方法

一旦实例化了模型,就可以获得所有全局变量的列表,

model = Model(mode='eval')
saver = tf.train.Saver()
print(tf.global_variables())

特定层的批次归一化变量如下所示:gamma和beta是可训练的,而移动统计信息则不是(因此在训练过程中需要指定extra_update_ops)。

<tf.Variable 'unit_1_1/residual_only_activation/batch_normalization/gamma:0' shape=(16,) dtype=float32>,<tf.Variable 'unit_1_1/residual_only_activation/batch_normalization/beta:0' shape=(16,<tf.Variable 'unit_1_1/residual_only_activation/batch_normalization/moving_mean:0' shape=(16,<tf.Variable 'unit_1_1/residual_only_activation/batch_normalization/moving_variance:0' shape=(16,) dtype=float32>

可以照常访问它们:

ma = <tf.Variable 'unit_1_1/residual_only_activation/batch_normalization/moving_mean:0' shape=(16,) dtype=float32>
with tf.Session() as sess:
  saver.restore(sess,model_dir)
  print(sess.run(ma))

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...