tensorflow中batch_normalization的正确使用姿势
原理
batch_normalization一般是用在进入网络之前,它的作用是可以将每层网络的输入的数据分布变成正态分布,有利于网络的稳定性,加快收敛。
具体的公式如下: γ ( x − μ ) σ 2 + ϵ + β \frac{\gamma(x-\mu)}{\sqrt{\sigma^2+\epsilon}}+\beta σ2+ϵγ(x−μ)+β
其中 γ \gamma γ和 β \beta β是决定最终的正态分布,分别影响了方差和均值, ϵ \epsilon ϵ是为了避免出现分母为0的情况
tensorflow
在真实的使用中,均值 μ \mu μ和标准差 σ \sigma σ是由历史累计样本和当前批次样本来共同决定的:
μ = m o m e n t u m ∗ μ + ( 1 − m o m e n t u m ) ∗ μ b a t c h \mu=momentum*\mu+(1-momentum)*\mu_{batch} μ=momentum∗μ+(1−momentum)∗μbatch
σ = m o m e n t u m ∗ σ + ( 1 − m o m e n t u m ) ∗ σ b a t c h \sigma=momentum*\sigma+(1-momentum)*\sigma_{batch} σ=momentum∗σ+(1−momentum)∗σbatch
μ b a t c h \mu_{batch} μbatch表示当前批次样本的均值
API
在tensorflow中,推荐的api是
tf.layers.batch_normalization(
inputs, axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True,
beta_initializer=tf.zeros_initializer(),
gamma_initializer=tf.ones_initializer(),
moving_mean_initializer=tf.zeros_initializer(),
moving_variance_initializer=tf.ones_initializer(), beta_regularizer=None,
gamma_regularizer=None, beta_constraint=None, gamma_constraint=None,
training=False, trainable=True, name=None, reuse=None, renorm=False,
renorm_clipping=None, renorm_momentum=0.99, fused=None, virtual_batch_size=None,
adjustment=None
)
看几个关键的参数:
- momentum:对应上述公式,决定历史累计样本和当前批次样本的权重;
- epsilon: ϵ \epsilon ϵ是为了避免出现分母为0的情况
- center:是否加入 β \beta β
- scale:是否加入 γ \gamma γ
- training:当前是否为训练阶段,决定均值和方差是否固定
- trainable:是否将 γ \gamma γ和 β \beta β加到训练变量中
正确使用方式
γ \gamma γ和 β \beta β是可训练变量,存放于tf.GraphKeys.TRAINABLE_VARIABLES
而均值和方差则不是训练变量,只能在tf.GraphKeys.GLOBAL_VARIABLES
中,并且更新过程存放于tf.GraphKeys.UPDATE_OPS
所以,最关键的点,也是最容易出问题的,就是:
- 训练阶段,要保证均值和方差的正确更新;
- 预测阶段,则要保证所有参数与训练阶段的一致,其实主要就4个: γ 、 β 、 μ 、 σ \gamma、\beta、\mu、\sigma γ、β、μ、σ
训练
那么,在训练的时候,需要将更新过程加入到train_op中:
x_norm = tf.layers.batch_normalization(x, training=True)
# ...
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = optimizer.minimize(loss)
train_op = tf.group([train_op, update_ops])
模型保存
由于均值和方差是GLOBAL_VARIABLES,但是tensorflow默认只保存TRAINABLE_VARIABLES,所以,我们需要设置将所有变量保存起来,即GLOBAL_VARIABLES
sess = tf.Session()
saver = tf.train.Saver(tf.global_variables())
saver.save(sess, "your_path")
预测
如果,模型正确保存了全局变量GLOBAL_VARIABLES,那么预测阶段,即可加载已经训练有素的batch_normalzation相关的参数;
但是,除此之外,还要将training设为False,将均值和方差固定住。
x_norm = tf.layers.batch_normalization(x, training=False)
# ...
saver = tf.train.Saver(tf.global_variables())
saver.restore(sess, "your_path")
estimator
如果你使用的是高阶API:estimator进行训练的话,那么就比较麻烦,因为它的session没有暴露出来,你没办法直接使用,需要换个方式:
- 幸好的是,estimator默认保存的是所有变量GLOBAL_VARIABLES;
关键在于保证eval、predict阶段要保证加载训练好的参数。在你model_fn函数中,增加一步模型的加载
def model_fn_build(init_checkpoint=None, lr=0.001, model_dir=None):
def _model_fn(features, labels, mode, params):
x = features['inputs']
y = features['labels']
#####################在这里定义你自己的网络模型###################
x_norm = tf.layers.batch_normalization(x, training=mode == tf.estimator.ModeKeys.TRAIN)
pre = tf.layers.dense(x_norm, 1)
loss = tf.reduce_mean(tf.pow(pre - y, 2), name='loss')
######################在这里定义你自己的网络模型###################
lr = params['lr']
######################进入eval和predict之前,都经过这一步加载过程###################
# 加载保存的模型
# 为了加载batch_normalization的参数,需要global_variables
tvars = tf.global_variables()
initialized_variable_names = { }
if params['init_checkpoint'] is not None or tf.train.latest_checkpoint(model_dir) is not None:
checkpoint = params['init_checkpoint'] or tf.train.latest_checkpoint(model_dir)
(assignment_map, initialized_variable_names
) = get_assignment_map_from_checkpoint(tvars, checkpoint)
tf.train.init_from_checkpoint(checkpoint, assignment_map)
# tf.logging.info("**** Trainable Variables ****")
# for var in tvars:
# init_string = ""
# if var.name in initialized_variable_names:
# init_string = ", *INIT_FROM_CKPT*"
# tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
# init_string)
######################进入eval和predict之前,都经过这一步加载过程###################
if mode == tf.estimator.ModeKeys.TRAIN:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = optimizer.minimize(loss)
train_op = tf.group([train_op, update_ops])
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
if mode == tf.estimator.ModeKeys.EVAL:
metrics = { "accuracy": tf.metrics.accuracy(features['label'], pred)}
return tf.estimator.EstimatorSpec(mode, eval_metric_ops=metrics, loss=loss)
predictions = { 'predictions': pred}
predictions.update({ k: v for k, v in features.items()})
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
return tf.estimator.Estimator(_model_fn, model_dir=model_dir, config=config,
params={ "lr": lr, "init_checkpoint": init_checkpoint})
def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
"""Compute the union of the current variables and checkpoint variables."""
assignment_map = { }
initialized_variable_names = { }
name_to_variable = collections.OrderedDict()
for var in tvars:
name = var.name
m = re.match("^(.*):\\d+$", name)
if m is not None:
name = m.group(1)
name_to_variable[name] = var
init_vars = tf.train.list_variables(init_checkpoint)
assignment_map = collections.OrderedDict()
for x in init_vars:
(name, var) = (x[0], x[1])
if name not in name_to_variable:
continue
assignment_map[name] = name
initialized_variable_names[name] = 1
initialized_variable_names[name + ":0"] = 1
return (assignment_map, initialized_variable_names)
还没有评论,来说两句吧...