[TensorFlow] Save and Restore model

While we have trained a model and would like to apply it to prediction data, here are two ways to save and load it.

Checkpoint file

Generate .ckpt file by tf.train.Saver class, it includes variables and tensors.
The Variables are default saved in checkpoint file with Variable.name.
Notice that it stores the graph structure separately from the variable values.

ttf.train.Saver class adds ops to save and restore variables to and from checkpoints.
Checkpoints are binary files in a proprietary format which map variable names to tensor values.
The best way to examine the contents of a checkpoint is to load it by the Saver.

Savers can automatically number checkpoint filenames with a provided counter. This lets you keep multiple checkpoints at different steps while training a model. For example you can number the checkpoint filenames with the training step number. To avoid filling up disks, savers manage checkpoint files automatically. For example, they can keep only the N most recent files, or one checkpoint for every N hours of training.

Graph proto file

Use tf.train.write_graph to write GraphDef into a *.pd binary file.
We can use tf.train.write_graph to save the GraphDef.
However, a GraphDef cannot save the Variables and we will use tf.graph_util.convert_variables_to_constants to replace the variables in a graph with constants of the same values in the later sample code.


Python 3.6.2
TensorFlow 1.5.0
matplotlib  2.1.2


Before we started

Here are some of TensorFlow’s APIs we will use later.

3.  tf.argmax

Checkpoint file

# Set a global step with trainable = false
global_step = tf.Variable(0, name='global_step', trainable=False)

# Call this after declaring all tf.Variables.
saver = tf.train.Saver()

# This variable won't be stored, since it is declared after tf.train.Saver()
non_storable_variable = tf.Variable(777)

with tf.Session() as sess:
    for i in range(start, 100):
        global_step.assign(i).eval() # set and update global_step with index
        saver.save(sess, ckpt_dir + "/model.ckpt", global_step=global_step)


with tf.Session() as sess:
    # you need to initialize all variables

    # Load last train state
    ckpt = tf.train.get_checkpoint_state(ckpt_dir) # Returns CheckpointState from the "checkpoint" file
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path) # restore all variables


While I ran 10 steps,
start = global_step.eval()
for i in range(start, 10):
    # …

the output result is like following

TensorFlow generated 3 kinds of files (.data,.index,.meta) and only the latest 5 checkpoint files would be kept in default.
You can change the default max_to_keep number like this,
saver = tf.train.Saver(max_to_keep=6)

The 3 kinds of files as Checkpoint are

1.  meta file:
describes the saved graph structure, includes
GraphDef, SaverDef, and so on; then apply tf.train.import_meta_graph('/tmp/model.ckpt.meta'), will restore Saver and Graph.

2.  index file:
it is a string-string immutable table(tensorflow::table::Table). Each key is a name of a tensor and its value is a serialized BundleEntryProto. Each BundleEntryProto describes the metadata of a tensor: which of the "data" files contains the content of a tensor, the offset into that file, checksum, some auxiliary data, etc.

3.  data file:
it is TensorBundle collection, save the values of all variables.

After the first 10 steps, now I update the range as following and run the training again.

for i in range(start, 20):
    # …

Now the model started training from step 10.

Graph proto file

If you want to only save the GraphDef into binary file in the previous example.

# Launch the graph in a session
with tf.Session() as sess:
sess.run(train_op, feed_dict={})
      tf.train.write_graph(sess.graph_def, '/tmp/tfmodel','train.pbtxt')

Here is a sample code for saving and restoring constant in Graph proto file.

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
import os

a = tf.Variable([[1],[2]], dtype=tf.float32, name='a')
b = tf.Variable(3, dtype=tf.float32, name='b')

output = tf.add(a, b, name='out') # Tensor must have a name

graph_dir = "./graph_dir"
if not os.path.exists(graph_dir):

# Save graph file
with tf.Session() as sess:
    # Convert Variable to constant, "out" is the name of the tensor
    graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
    tf.train.write_graph(graph, graph_dir,'graph.pb', as_text=False)

# Restore graph file
with tf.Session() as sess:
    with tf.gfile.FastGFile(os.path.join(graph_dir,'graph.pb'),'rb') as f:
        output = tf.import_graph_def(graph_def, return_elements=['out:0'])



