检查点

本文考察如何保存和恢复使用Estimators构建的TensorFlow模型。 TensorFlow提供了两种模型格式:

本文件重点介绍检查点。 有关SavedModel的详细信息,请参阅TensorFlow程序员指南保存和恢复一章。

示例代码

本文档依赖于TensorFlow入门中详细介绍的鸢尾花分类示例 要下载和访问这个示例,请调用以下两个命令:

git clone https://github.com/tensorflow/models/
cd models/samples/core/get_started

本文档中的大部分代码片段都是premade_estimator.py上的细微变体。

保存部分训练的模型

Estimator自动将以下内容写入磁盘:

要指定Estimator存储其信息的顶级目录,请将值分配给任何Estimator的构造函数的可选model_dir参数。 例如,以下代码将model_dir参数设置为models/iris目录:

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris')

假设你调用Estimator的train方法。 例如:

classifier.train(
        input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
                steps=200)

如下图所示,第一次调用train将检查点和其他文件添加到model_dir目录中:

第一次调用train()。

在基于UNIX的系统上,要查看创建的model_dir目录中的对象,只需按如下方式调用ls

$ ls -1 models/iris
checkpoint
events.out.tfevents.timestamp.hostname
graph.pbtxt
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index
model.ckpt-1.meta
model.ckpt-200.data-00000-of-00001
model.ckpt-200.index
model.ckpt-200.meta

前面的ls命令显示,Estimator在步骤1(训练开始)和200(训练结束)创建了检查点。

默认检查点目录

If you don't specify model_dir in an Estimator's constructor, the Estimator writes checkpoint files to a temporary directory chosen by Python's tempfile.mkdtemp function. For example, the following Estimator constructor does not specify the model_dir argument:

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3)

print(classifier.model_dir)

tempfile.mkdtemp函数会为您的操作系统选择一个安全的临时目录。 对于样本,macOS上的典型临时目录可能如下所示:

/var/folders/0s/5q9kfzfj3gx2knj0vj8p68yc00dhcr/T/tmpYm1Rwa

检查点频率

默认情况下,Estimator会根据以下计划在model_dir中保存检查点

你可以通过以下步骤改变默认时间表:

  1. 创建一个RunConfig对象,它定义所需的时间表。
  2. 在实例化Estimator时,将该RunConfig对象传递给Estimator的config参数。

例如,以下代码将检查点时间表更改为每20分钟并保留最近的10个检查点:

my_checkpointing_config = tf.estimator.RunConfig(
    save_checkpoints_secs = 20*60,  # Save checkpoints every 20 minutes.
    keep_checkpoint_max = 10,       # Retain the 10 most recent checkpoints.
)

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris',
    config=my_checkpointing_config)

恢复你的模型

第一次调用Estimator的train方法时,TensorFlow会将检查点保存到model_dir 随后每次调用Estimator的trainevaluatepredict方法会导致以下情况:

  1. Estimator通过运行​​model_fn()来构建模型的 (有关model_fn()的详细信息,请参阅 创建自定义估算器。)
  2. Estimator根据最近检查点中存储的数据初始化新模型的权重。

换句话说,如下图所示,一旦存在检查点,TensorFlow会在你每次调用train()evaluate()predict()时重建模型。

接下来train()、evaluate()或predict()的调用

避免不良的恢复

从检查点恢复模型的状态只适用于模型和检查点是兼容的。 例如,假设你训练了一个包含两个隐藏层的DNNClassifier Estimator,每个隐藏层有10个节点:

classifier = tf.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris')

classifier.train(
    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
        steps=200)

训练完之后(因此,在models/iris中创建检查点后),想象一下,将每个隐藏层中的神经元数量从10更改为20,然后尝试重新训练模型:

classifier2 = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[20, 20],  # Change the number of neurons in the model.
    n_classes=3,
    model_dir='models/iris')

classifier.train(
    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
        steps=200)

由于检查点中的状态与classifier2中描述的模型不兼容,再次训练失败,并显示以下错误:

...
InvalidArgumentError (see above for traceback): tensor_name =
dnn/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10]
does not match the shape stored in checkpoint: [20]

要运行实验并对比稍微不同的模型版本,可以通过为每个版本创建单独的git分支来保存创建每个model_dir的代码的副本。 这种分离将保持你的检查点可恢复。

总结

检查点提供了一个简单的自动机制来保存和恢复由Estimators创建的模型。

有关详细信息,请参阅TensorFlow程序员指南保存和恢复一章。