tf.collection
Tensorflow用graph collection来管理不同类型的对象。tf.GraphKeys中定义了默认的collection,tf通过调用各种各样的collection操作graph中的变量。比如tf.Optimizer只优化tf.GraphKeys.TRAINABLE_VARIABLES collection中的变量。常见的collection如下,它们其实都是字符串:
- GLOBAL_VARIABLES: 所有的Variable对象在创建的时候自动加入该colllection,且在分布式环境中共享(model variables是它的子集)。一般来说,TRAINABLE_VARIABLES包含在MODEL_VARIABLES中,MODEL_VARIABLES包含在GLOBAL_VARIABLES中。也就是说TRAINABLE_VARIABLES$\le$MODEL_VARIABLES$\le$GLOBAL_VARIABLES。一般tf.train.Saver()对应的是GLOBAL_VARIABLES的变量。
- LOCAL_VARIABLES: 它是GLOBAL_VARIABLES不同的是在本机器上的Variable子集。使用tf.contrib.framework.local_variable将变量添加到这个collection.
- MODEL_VARIABLES: 模型变量,在构建模型中,所有用于前向传播的Variable都将添加到这里。使用 tf.contrib.framework.model_variable向这个collection添加变量。
- TRAINALBEL_VARIABLES: 所有用于反向传播的Variable,可以被optimizer训练,进行参数更新的变量。tf.Variable对象同样会自动加入这个collection。
- SUMMARIES: graph创建的所有summary Tensor都会记录在这里面。
- QUEUE_RUNNERS:
- MOVING_AVERAGE_VARIABLES: 保持Movering average的变量子集。
- REGULARIZATION_LOSSES: 创建graph的regularization loss。
这里主要介绍三类collection,一种是GLOBAL_VARIABLES,一种是SUMMARIES,一种是自定义的collections。
下面的一些collection也被定义了,但是并不会自动添加
The following standard keys are defined, but their collections are not automatically populated as many of the others are:
- WEIGHTS
- BIASES
- ACTIVATIONS
GLOBAL_Variable collection
tf.Variable()对象在生成时会被默认添加到tf.GraphKeys中的GLOBAL_VARIABLES和TRAINABLE_VARIABLES collection中。
代码示例
1 | import tensorflow as tf |
Summary collection
Summary op产生的变量会被添加到tf.GraphKeys.SUMMARIES collection中。
点击查看关于tf.summary的详细介绍
代码示例
1 | import tensorflow as tf |
自定义collection
通过tf.add_collection()和tf.get_collection()可以添加和访问custom collection。
示例代码
1 | import tensorflow as tf |
疑问
collection是和graph绑定在一起的,那么如果定义了很多个图,如何获得非默认图的tf.GraphKeys中定义的collection??
参考文献
1.https://blog.csdn.net/shenxiaolu1984/article/details/52815641
2.https://blog.csdn.net/hustqb/article/details/80398934
3.https://www.tensorflow.org/api_docs/python/tf/GraphKeys?hl=zh_cn