TensorFlow
Python Save and Restore
▌Introduction
While we have trained a model and would like to apply it to
prediction data, here are two ways to save and load it.
▋Checkpoint
file
The Variables are default saved in checkpoint file with Variable.name.
Notice that it stores
the graph structure separately from the variable values.
Checkpoints are binary
files in a proprietary format which map variable names to tensor values.
The best way to examine
the contents of a checkpoint is to load it by the Saver.
Savers can automatically
number checkpoint filenames with a provided counter. This lets you keep
multiple checkpoints at different steps while training a model. For example
you can number the checkpoint filenames with the training step number. To avoid
filling up disks, savers manage checkpoint files automatically. For example,
they can keep only the N most recent files, or one checkpoint for every N
hours of training.
|
▋Graph proto
file
However,
a GraphDef cannot save the Variables and we will use tf.graph_util.convert_variables_to_constants to
replace the variables in a graph with constants of the same values in the later
sample code.
▋Related articles
▌Environment
▋Python 3.6.2
▋TensorFlow 1.5.0
▋matplotlib 2.1.2
▌Implement
▋Before we started
Here are some of
TensorFlow’s APIs we will use later.
▋Checkpoint file
▋Save
# Set a global step with trainable = false
global_step = tf.Variable(0, name='global_step', trainable=False)
# Call this after declaring all tf.Variables.
saver =
tf.train.Saver()
# This variable won't be stored, since it is declared
after tf.train.Saver()
non_storable_variable = tf.Variable(777)
with tf.Session() as sess:
tf.global_variables_initializer().run()
for i in range(start, 100):
global_step.assign(i).eval()
# set and update global_step with index
saver.save(sess, ckpt_dir + "/model.ckpt", global_step=global_step)
▋Restore
with tf.Session() as sess:
# you
need to initialize all variables
tf.global_variables_initializer().run()
# Load
last train state
ckpt = tf.train.get_checkpoint_state(ckpt_dir) #
Returns CheckpointState from the "checkpoint" file
if ckpt and
ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path) #
restore all variables
▋Demo
While I ran 10 steps,
start = global_step.eval()
for i in range(start, 10):
# …
the output result is like following
TensorFlow generated 3 kinds of files
(.data,.index,.meta) and only the latest 5 checkpoint files would be kept in
default.
You can change the default max_to_keep number
like this,
saver = tf.train.Saver(max_to_keep=6)
The 3 kinds of files as
Checkpoint are
1. meta file:
describes the saved graph structure, includes GraphDef, SaverDef, and so on; then apply tf.train.import_meta_graph('/tmp/model.ckpt.meta'), will restore Saver and Graph.
2. index file:
it is a string-string immutable table(tensorflow::table::Table). Each key is a name of a tensor and its value is a serialized BundleEntryProto. Each BundleEntryProto describes the metadata of a tensor: which of the "data" files contains the content of a tensor, the offset into that file, checksum, some auxiliary data, etc.
3. data file:
it is TensorBundle collection, save the values of all variables. |
After the first 10 steps,
now I update the range as following and run the training again.
for i in range(start, 20):
# …
Now the model started training
from step 10.
▋Graph proto file
# Launch the graph in a session
with tf.Session() as sess:
sess.run(train_op,
feed_dict={…})
tf.train.write_graph(sess.graph_def, '/tmp/tfmodel','train.pbtxt')
Here is a sample code for
saving and restoring constant in Graph proto file.
import tensorflow as tf
from tensorflow.python.framework.graph_util import
convert_variables_to_constants
import os
a = tf.Variable([[1],[2]], dtype=tf.float32,
name='a')
b = tf.Variable(3, dtype=tf.float32,
name='b')
output = tf.add(a, b, name='out') # Tensor
must have a name
graph_dir = "./graph_dir"
if not
os.path.exists(graph_dir):
os.makedirs(graph_dir)
# Save graph file
with tf.Session() as sess:
tf.global_variables_initializer().run()
#
Convert Variable to constant, "out" is the name of the tensor
graph =
convert_variables_to_constants(sess, sess.graph_def, ["out"])
tf.train.write_graph(graph, graph_dir,'graph.pb', as_text=False)
# Restore graph file
with tf.Session() as sess:
with
tf.gfile.FastGFile(os.path.join(graph_dir,'graph.pb'),'rb') as f:
graph_def=tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
output = tf.import_graph_def(graph_def, return_elements=['out:0'])
print(sess.run(output))
Output:
▌Reference
沒有留言:
張貼留言