tf.train.Saver保存和恢复模型
1 | saver = tf.train.Saver() |
调用上述代码之后会存存储以下几个文件:
1 | checkpoint |
其中checkpoint文件存储的是最近保存的文件的名字,meta文件存放的是计算图的定义,index和data文件存放的是权重文件。
下面介绍一下上述代码中出现的两个API,tf.train.Saver()和tf.train.Saver().save()。
tf.train.Saver()
Saver是类,不是函数。可以用来保存,恢复variable和model,Saver对象提供save()和restore()等函数,save()保存模型,restore()加载模型。
1 | __init__( |
tf.train.Saver.save()
1 | save( |
这里说一下save_path,如果不指定的话,文件名默认是空的,在linux下是以.开头的(即当前目录),所以会显示成隐藏文件。通常情况下我们指定checkpoint要保存的路径,以及名字,比如叫model.ckpt,在load的时候还使用这个名字就行。指定了global_step之后,tf会自动在路径后面加上step进行区分。
读取graph
读取图的定义
meta文件中存放了计算图的定义,可以直接使用API tf.train.import_meta_graph()函数调用:
1 | import tensorflow as tf |
这时计算图就已经定义在当前sess中了。上述代码会保留原始的device信息,如果迁移到其他设备时,可能由于没有指定设备出错,这个问题可以通过指定一个特殊的参数clear_devices解决:
1 | import tensorflow as tf |
这样子就和device无关了。
访问graph中的参数
通过collection访问计算图中collection的键
这里的键指的是graph中都有哪些collections。
-
1
print(sess.graph.get_all_collection_keys())
-
1
print(sess.graph.collections)
-
1
tf.get_default_graph().get_all_collection_keys()
访问collection
-
1
sess.graph.get_collection("summaries")
-
1
tf.get_collection("")
示例
1 | import tensorflow as tf |
通过operation访问
-
1
sess.graph.get_opeartions()
-
1
2for op in sess.graph.get_opeartions():
print(op.name, op.values()) -
1
2
3
4
5
6
7
8
9
10
11
12
13sess.graph.get_operation_by_name("op_name").node_def
```
## 保存和恢复variables
### 保存和恢复全部variables
- 恢复variable时,无需初始化。
- 恢复variable时,使用的是variable的name,不是op的name。只要知道variable的name即可。save和restore的op name不需要相同,只要variable name相同即可。
- 对于使用tf.Variable()创建的variable,如果没有指定variable名字的话,系统会为其生成默认名字,在恢复的时候,需要使用tf.get_variable()恢复variable,同时传variable name和shape。
#### 保存全部variables
``` python
saver = tf.train.Saver()
saver.save(sess, save_path) # 需要指定的是checkpoint的名字而不是目录
恢复全部variables
1 | saver = tf.train.Saver() |
保存和恢复部分variables
保存全部variable
1 | saver = tf.train.Saver({"variable_name1": op_name1,..., "variable_namen": op_namen}) |
恢复全部variable
1 | saver = tf.train.Saver({"variable_name1": op_name1,..., "variable_namen": op_namen}) |
保存和恢复模型
其实和保存恢复变量没有什么区别。只是把整个模型的variables都save和restore了。
代码示例
1 | import tensorflow as tf |
获取最新的checkpoint文件
tf.train.get_checkpoint_state()
给出checkpoint文件所在目录,可以使用get_checkpoint_state()获得最新的checkpoint文件:
1 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) |
使用inspect_checkpoint库
1 | # import the inspect_checkpoint library |
模型的冻结
模型的冻结是不在训练模型,只用于正向推导,所以把变量转换成常量后,和计算图一起保存在协议缓冲区文件(.pb)文件中,因此需要在计算图中预先定义输出节点的名称,示例如下:
1 | import tensorflow as tf |
模型的执行
从协议缓冲区文件(.pb)文件中读取模型,导入计算图
1 | # 读取模型并保存到序列化模型对象中 |
获取输入和输出的张量,然后将测试数据feed给输入张量,得到结果。
1 | x_tensor = graph.get_tensor_by_name("Test/input/image-input:0") |
参考文献
1.https://www.jarvis73.cn/2018/04/25/Tensorflow-Model-Save-Read/
2.https://www.tensorflow.org/guide/saved_model
3.https://www.tensorflow.org/api_docs/python/tf/train/Saver
4.https://www.bilibili.com/read/cv681031/
5.https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/