本文考察如何保存和恢复使用Estimators构建的TensorFlow模型。 TensorFlow提供了两种模型格式:
本文件重点介绍检查点。 有关SavedModel的详细信息,请参阅TensorFlow程序员指南的保存和恢复一章。
本文档依赖于TensorFlow入门中详细介绍的鸢尾花分类示例。 要下载和访问这个示例,请调用以下两个命令:
git clone https://github.com/tensorflow/models/
cd models/samples/core/get_started
本文档中的大部分代码片段都是premade_estimator.py
上的细微变体。
Estimator自动将以下内容写入磁盘:
要指定Estimator存储其信息的顶级目录,请将值分配给任何Estimator的构造函数的可选model_dir
参数。 例如,以下代码将model_dir
参数设置为models/iris
目录:
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris')
假设你调用Estimator的train
方法。 例如:
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
如下图所示,第一次调用train
将检查点和其他文件添加到model_dir
目录中:
在基于UNIX的系统上,要查看创建的model_dir
目录中的对象,只需按如下方式调用ls
:
$ ls -1 models/iris
checkpoint
events.out.tfevents.timestamp.hostname
graph.pbtxt
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index
model.ckpt-1.meta
model.ckpt-200.data-00000-of-00001
model.ckpt-200.index
model.ckpt-200.meta
前面的ls
命令显示,Estimator在步骤1(训练开始)和200(训练结束)创建了检查点。
If you don't specify model_dir
in an Estimator's constructor, the Estimator writes checkpoint files to a temporary directory chosen by Python's tempfile.mkdtemp function. For example, the following Estimator constructor does not specify the model_dir
argument:
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3)
print(classifier.model_dir)
tempfile.mkdtemp
函数会为您的操作系统选择一个安全的临时目录。 对于样本,macOS上的典型临时目录可能如下所示:
/var/folders/0s/5q9kfzfj3gx2knj0vj8p68yc00dhcr/T/tmpYm1Rwa
默认情况下,Estimator会根据以下计划在model_dir
中保存检查点:
训练
方法开始(第一次迭代)并完成(最终迭代)时,写入一个检查点。你可以通过以下步骤改变默认时间表:
RunConfig
对象,它定义所需的时间表。RunConfig
对象传递给Estimator的config
参数。例如,以下代码将检查点时间表更改为每20分钟并保留最近的10个检查点:
my_checkpointing_config = tf.estimator.RunConfig(
save_checkpoints_secs = 20*60, # Save checkpoints every 20 minutes.
keep_checkpoint_max = 10, # Retain the 10 most recent checkpoints.
)
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris',
config=my_checkpointing_config)
第一次调用Estimator的train
方法时,TensorFlow会将检查点保存到model_dir
。 随后每次调用Estimator的train
、evaluate
或predict
方法会导致以下情况:
model_fn()
来构建模型的图。 (有关model_fn()
的详细信息,请参阅 创建自定义估算器。)换句话说,如下图所示,一旦存在检查点,TensorFlow会在你每次调用train()
、evaluate()
或predict()
时重建模型。
从检查点恢复模型的状态只适用于模型和检查点是兼容的。 例如,假设你训练了一个包含两个隐藏层的DNNClassifier
Estimator,每个隐藏层有10个节点:
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris')
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
训练完之后(因此,在models/iris
中创建检查点后),想象一下,将每个隐藏层中的神经元数量从10更改为20,然后尝试重新训练模型:
classifier2 = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[20, 20], # Change the number of neurons in the model.
n_classes=3,
model_dir='models/iris')
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
由于检查点中的状态与classifier2
中描述的模型不兼容,再次训练失败,并显示以下错误:
...
InvalidArgumentError (see above for traceback): tensor_name =
dnn/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10]
does not match the shape stored in checkpoint: [20]
要运行实验并对比稍微不同的模型版本,可以通过为每个版本创建单独的git分支来保存创建每个model_dir
的代码的副本。 这种分离将保持你的检查点可恢复。
检查点提供了一个简单的自动机制来保存和恢复由Estimators创建的模型。
有关详细信息,请参阅TensorFlow程序员指南的保存和恢复一章。