创建Variable
Tensorflow有两种方式创建Variable:tf.Variable()和tf.get_variable(),这两种方式获得的都是tensorflow.python.ops.variables.Variable类型的对象,但是他们的输入参数还有些不一样。
tf.Variable() | tf.get_variable() | |
---|---|---|
name | 不需要,已存在的变量名,会在后面加上递增的数值用来区分 | 必须,已存在的会报错 |
shape | 不需要,或者说已经包含在初值中了 | 需要 |
初值 | 必须 | 不需要 |
复用 | 不可以 | 可以 |
两种方法事实上都可以指定name和初值。而tf.Variable()的初值中已经包含了shape,所以不需要再显示传入shape了。这里的需要和不需要指的是必要不必要,如果没有传入需要的参数,就会报错,不需要的参数则不会影响。
tf.Variable()
一句话介绍
创建一个类操作全局变量。在TensorFlow内部,tf.Variable会存储持久性张量,允许各种op读取和修改它的值。这些修改在多个Session之间是可见的,因此对于一个tf.Variable,多个工作器可以看到相同的值。
和tf.Tensor对比
tf.Variable 表示可通过对其运行op来改变其值的张量。与 tf.Tensor对象不同,tf.Variable 存在于单个session.run调用的上下文之外。tf.Tensor的值是不可以改变的,tf.Tensor没有assign函数。
API
1 | tf.Variable.__init__( |
代码示例
1 | tensor1 = tf.Variable([[1,2], [3,5]]) |
初始化
tf.Variable()生成的变量必须初始化,tf.constant()可以不用初始化。
- 使用全局初始化
sess.run(tf.global_variables_initializer()) - 使用checkpoint
- 使用tf.assign赋值
tf.get_variable()
一句话介绍
获取一个已经存在的变量或者创建一个新的变量。主要目的,变量复用。
API
1 | tf.get_variable( |
代码示例
1 | with tf.variable_scope("model") as scope: |
Variable和collection
点击查看关于collecion的详细介绍
默认情况下,每个tf.Variable()都会添加到以下两个collection中:
- tf.GraphKeys.GLOBAL_VARIABLES - 可以在多台设备间共享的变量,
- tf.GraphKeys.TRAINABLE_VARIABLES - TensorFlow 将计算其梯度的变量。
如果不希望变量是可训练的,可以在创建时指定其collection为 tf.GraphKeys.LOCAL_VARIABLES collection中。
1 | my_local = tf.get_variable("my_local", shape=(), collections=[tf.GraphKeys.LOCAL_VARIABLES]) |
或者可以指定 trainable=False:
1 | my_non_trainable = tf.get_variable("my_non_trainable", |
获取collection
要检索放在某个collection中的所有变量的列表,可以使用:
代码示例
1 | import tensorflow as tf |
自定义collection
添加自定义collection
可以使用自定义的collection。collection名称可为任何字符串,且无需显式创建。创建对象(包括Variable和其他)后调用 tf.add_to_collection将其添加到相应collection中。以下代码将 my_local 变量添加到名为 my_collection_name 的collection中:
1 | tf.add_to_collection("my_collection_name", my_local) |
初始化变量
初始化所有变量
调用 tf.global_variables_initializer()在训练开始前一次性初始化所有可训练变量。此函数会返回一个op,负责初始化 tf.GraphKeys.GLOBAL_VARIABLES collection中的所有变量。运行此op会初始化所有变量。
1 | sess.run(tf.global_variables_initializer()) |
初始化单个变量
运行变量的初始化器op。
1 | sess.run(my_variable.initializer) |
查询未初始化变量
1 | print(sess.run(tf.report_uninitialized_variables())) |
共享变量
TensorFlow 支持两种共享变量的方式:
- 显式传递 tf.Variable 对象。
- 将 tf.Variable 对象隐式封装在 tf.variable_scope 对象内。
variable_scope
代码示例1
使用variable_scope区分weights和biases。
1 | def conv_relu(input, kernel_shape, bias_shape): |
代码示例2
使用variable_scope声明不同作用域
1 | def my_image_filter(input_images): |
共享方式1
设置reuse=True
1 | with tf.variable_scope("model"): |
共享方式2
调用scope.reuse_variables触发重用
1 | with tf.variable_scope("model") as scope: |
参考文献
1.https://blog.csdn.net/MrR1ght/article/details/81228087
2.https://www.tensorflow.org/guide/variables?hl=zh_cn
3.https://www.tensorflow.org/api_docs/python/tf/get_variable?hl=zh_cn