【TensorFlow学习笔记(一)】变量作用域

阳光穿透心脏的1/2处 2022-01-23 04:57 339阅读 0赞
  • 更新时间:2019-06-02
    TensorFlow中有两个作用域,一个是name_scope,一个是variable_scopename_scope主要是给op_name加前缀,variable_scope主要是给variable_name加前缀。

variable_scope

variable_scope变量作用域机制主要由两部分组成:

  1. v = tf.get_variable_scope(name, shape, dtype, initializer) # 根据名字返回变量
  2. tf.variable_scope(<scope_name>) # 为变量指定命名空间

共享变量:

  1. # 创建变量作用域
  2. with tf.variable_scope("foo") as scope:
  3. v = tf.get_variable("v", [1])
  4. # 设置reuse参数为True时,共享变量。reuse的默认值为False。
  5. with tf.variable_scope("foo", reuse=True):
  6. v1 = tf.get_variable("v", [1])
  7. assert v1 == v

获取变量作用域

通过tf.variable_scope()获取变量作用域:

  1. with tf.variable_scope("foo") as foo_scope:
  2. v = tf.get_variable("v", [1])
  3. with tf.variable_scope(foo_scope):
  4. w = tf.get_variable("w", [1])

变量作用域的初始化

变量作用域默认携带一个初始化器,在这个作用域中的子作用域或变量都可以继承或重写父作用域初始化器中的值。

  1. with tf.variable_scope("foo", initializer=tf.constant_initializer(0.4)):
  2. v = tf.get_variable("v", [1])
  3. assert v.eval() == 0.4 # 被作用域初始化
  4. w = tf.get_variable("w", [1], initializer=tf.constant_initializer(0.3)):
  5. assert w.eval() == 0.3 # 重写初始化器的值
  6. with tf.variable_scope("bar"):
  7. v = tf.get_variable("v", [1])
  8. assert v.eval() == 0.4 # 继承默认的初始化器
  9. with tf.variable_scope("baz", initializer=tf.tf.constant_initializer(0.2)):
  10. v = tf.get_variable("v", [1])
  11. assert v.eval() == 0.2 # 重写父作用域的初始化器的值

name_scope

name_scope为变量划分范围,在可视化中,表示在计算图中的一个层级。name_scope会影响op_name,不会影响get_variable()创建的变量,而会影响通过Variable()创建的变量。

  1. with tf.variable_scope("foo"):
  2. with tf.name_scope("bar"):
  3. v = tf.get_variable("v", [1])
  4. b = tf.Variable(tf.zeros([1]), name='b')
  5. x = 1.0 +v
  6. assert v.name == "foo/v:0"
  7. assert b.name == "foo/bar/b:0"
  8. assert x.op.name == "foo/bar/add"

tf.name_scope()返回一个字符串。

发表评论

表情:
评论列表 (有 0 条评论,339人围观)

还没有评论,来说两句吧...

相关阅读