tensorflow中batch_normalization的正确使用姿势

谁借莪1个温暖的怀抱¢ 2022-11-14 00:47 222阅读 0赞

原理

batch_normalization一般是用在进入网络之前,它的作用是可以将每层网络的输入的数据分布变成正态分布,有利于网络的稳定性,加快收敛。

具体的公式如下: γ ( x − μ ) σ 2 + ϵ + β \frac{\gamma(x-\mu)}{\sqrt{\sigma^2+\epsilon}}+\beta σ2+ϵ​γ(x−μ)​+β

其中 γ \gamma γ和 β \beta β是决定最终的正态分布,分别影响了方差和均值, ϵ \epsilon ϵ是为了避免出现分母为0的情况

tensorflow

在真实的使用中,均值 μ \mu μ和标准差 σ \sigma σ是由历史累计样本和当前批次样本来共同决定的:

μ = m o m e n t u m ∗ μ + ( 1 − m o m e n t u m ) ∗ μ b a t c h \mu=momentum*\mu+(1-momentum)*\mu_{batch} μ=momentum∗μ+(1−momentum)∗μbatch​

σ = m o m e n t u m ∗ σ + ( 1 − m o m e n t u m ) ∗ σ b a t c h \sigma=momentum*\sigma+(1-momentum)*\sigma_{batch} σ=momentum∗σ+(1−momentum)∗σbatch​

μ b a t c h \mu_{batch} μbatch​表示当前批次样本的均值

API

在tensorflow中,推荐的api是

  1. tf.layers.batch_normalization(
  2. inputs, axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True,
  3. beta_initializer=tf.zeros_initializer(),
  4. gamma_initializer=tf.ones_initializer(),
  5. moving_mean_initializer=tf.zeros_initializer(),
  6. moving_variance_initializer=tf.ones_initializer(), beta_regularizer=None,
  7. gamma_regularizer=None, beta_constraint=None, gamma_constraint=None,
  8. training=False, trainable=True, name=None, reuse=None, renorm=False,
  9. renorm_clipping=None, renorm_momentum=0.99, fused=None, virtual_batch_size=None,
  10. adjustment=None
  11. )

看几个关键的参数:

  1. momentum:对应上述公式,决定历史累计样本和当前批次样本的权重;
  2. epsilon: ϵ \epsilon ϵ是为了避免出现分母为0的情况
  3. center:是否加入 β \beta β
  4. scale:是否加入 γ \gamma γ
  5. training:当前是否为训练阶段,决定均值和方差是否固定
  6. trainable:是否将 γ \gamma γ和 β \beta β加到训练变量中

正确使用方式

γ \gamma γ和 β \beta β是可训练变量,存放于tf.GraphKeys.TRAINABLE_VARIABLES

而均值和方差则不是训练变量,只能在tf.GraphKeys.GLOBAL_VARIABLES中,并且更新过程存放于tf.GraphKeys.UPDATE_OPS

所以,最关键的点,也是最容易出问题的,就是:

  1. 训练阶段,要保证均值和方差的正确更新;
  2. 预测阶段,则要保证所有参数与训练阶段的一致,其实主要就4个: γ 、 β 、 μ 、 σ \gamma、\beta、\mu、\sigma γ、β、μ、σ

训练

那么,在训练的时候,需要将更新过程加入到train_op中:

  1. x_norm = tf.layers.batch_normalization(x, training=True)
  2. # ...
  3. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  4. train_op = optimizer.minimize(loss)
  5. train_op = tf.group([train_op, update_ops])

模型保存

由于均值和方差是GLOBAL_VARIABLES,但是tensorflow默认只保存TRAINABLE_VARIABLES,所以,我们需要设置将所有变量保存起来,即GLOBAL_VARIABLES

  1. sess = tf.Session()
  2. saver = tf.train.Saver(tf.global_variables())
  3. saver.save(sess, "your_path")

预测

如果,模型正确保存了全局变量GLOBAL_VARIABLES,那么预测阶段,即可加载已经训练有素的batch_normalzation相关的参数;

但是,除此之外,还要将training设为False,将均值和方差固定住。

  1. x_norm = tf.layers.batch_normalization(x, training=False)
  2. # ...
  3. saver = tf.train.Saver(tf.global_variables())
  4. saver.restore(sess, "your_path")

estimator

如果你使用的是高阶API:estimator进行训练的话,那么就比较麻烦,因为它的session没有暴露出来,你没办法直接使用,需要换个方式:

  1. 幸好的是,estimator默认保存的是所有变量GLOBAL_VARIABLES;
  2. 关键在于保证eval、predict阶段要保证加载训练好的参数。在你model_fn函数中,增加一步模型的加载

    def model_fn_build(init_checkpoint=None, lr=0.001, model_dir=None):

    1. def _model_fn(features, labels, mode, params):
    2. x = features['inputs']
    3. y = features['labels']
    4. #####################在这里定义你自己的网络模型###################
    5. x_norm = tf.layers.batch_normalization(x, training=mode == tf.estimator.ModeKeys.TRAIN)
    6. pre = tf.layers.dense(x_norm, 1)
    7. loss = tf.reduce_mean(tf.pow(pre - y, 2), name='loss')
    8. ######################在这里定义你自己的网络模型###################
    9. lr = params['lr']
    10. ######################进入eval和predict之前,都经过这一步加载过程###################
    11. # 加载保存的模型
    12. # 为了加载batch_normalization的参数,需要global_variables
    13. tvars = tf.global_variables()
    14. initialized_variable_names = { }
    15. if params['init_checkpoint'] is not None or tf.train.latest_checkpoint(model_dir) is not None:
    16. checkpoint = params['init_checkpoint'] or tf.train.latest_checkpoint(model_dir)
    17. (assignment_map, initialized_variable_names
    18. ) = get_assignment_map_from_checkpoint(tvars, checkpoint)
    19. tf.train.init_from_checkpoint(checkpoint, assignment_map)
    20. # tf.logging.info("**** Trainable Variables ****")
    21. # for var in tvars:
    22. # init_string = ""
    23. # if var.name in initialized_variable_names:
    24. # init_string = ", *INIT_FROM_CKPT*"
    25. # tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
    26. # init_string)
    27. ######################进入eval和predict之前,都经过这一步加载过程###################
    28. if mode == tf.estimator.ModeKeys.TRAIN:
    29. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    30. train_op = optimizer.minimize(loss)
    31. train_op = tf.group([train_op, update_ops])
    32. return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
    33. if mode == tf.estimator.ModeKeys.EVAL:
    34. metrics = { "accuracy": tf.metrics.accuracy(features['label'], pred)}
    35. return tf.estimator.EstimatorSpec(mode, eval_metric_ops=metrics, loss=loss)
    36. predictions = { 'predictions': pred}
    37. predictions.update({ k: v for k, v in features.items()})
    38. return tf.estimator.EstimatorSpec(mode, predictions=predictions)
    39. return tf.estimator.Estimator(_model_fn, model_dir=model_dir, config=config,
    40. params={ "lr": lr, "init_checkpoint": init_checkpoint})
  1. def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
  2. """Compute the union of the current variables and checkpoint variables."""
  3. assignment_map = { }
  4. initialized_variable_names = { }
  5. name_to_variable = collections.OrderedDict()
  6. for var in tvars:
  7. name = var.name
  8. m = re.match("^(.*):\\d+$", name)
  9. if m is not None:
  10. name = m.group(1)
  11. name_to_variable[name] = var
  12. init_vars = tf.train.list_variables(init_checkpoint)
  13. assignment_map = collections.OrderedDict()
  14. for x in init_vars:
  15. (name, var) = (x[0], x[1])
  16. if name not in name_to_variable:
  17. continue
  18. assignment_map[name] = name
  19. initialized_variable_names[name] = 1
  20. initialized_variable_names[name + ":0"] = 1
  21. return (assignment_map, initialized_variable_names)

发表评论

表情:
评论列表 (有 0 条评论,222人围观)

还没有评论,来说两句吧...

相关阅读

    相关 webpack 正确使用姿势

    我想大部分人都用过 webpack 。作为现代前端开发中最火的模块打包工具,它只需要通过简单的配置,就能轻松完成模块的加载和打包,实属神器。 不过我发现身边很多朋友都觉得它不