tf.data
API使你能够从简单的、可重用的碎片构建复杂的输入管道。 例如,图像模型的管道可能会聚合分布式文件系统中的文件中的数据、将随机扰动应用于每个图像、并将随机选择的图像合并为一个批次以进行训练。 文本模型的管道可能从原始文本数据中提取符号、将它们转换为带有查找表的嵌入标识符、以及将不同长度的序列一起进行批处理。 tf.data
API可以轻松处理大量数据、不同数据格式和复杂的转换。
tf.data
API为TensorFlow引入了两个新的抽象:
一个tf.data.Dataset
表示一系列元素,其中每个元素包含一个或多个张量
对象。 例如,在图像管道中,一个元素可能是一个单一的训练样本,它包含一对张量表示图像数据和一个标签。 有两种不同的方法来创建数据集:
从一个或多个tf.Tensor
对象构建数据集创建一个源(例如Dataset.from_tensor_slices()
)。
对一个或多个tf.data.Dataset
对象应用变换(例如Dataset.batch()
)构造一个数据集。
一个tf.data.Iterator
提供从数据集中提取元素的主要方法。 由Iterator.get_next()
返回的操作在执行时产生Dataset
的下一个元素,并且通常充当输入管道代码和模型之间的接口。 最简单的迭代器是一个“单次迭代器”,它与一个特定的数据集
相关联并遍历一次。 对于更复杂的用途,使用Iterator.initializer
操作可以使用不同的数据集重新初始化和参数化迭代器,这样你就可以在同一个程序中多次遍历训练和验证数据。
本指南的这一部分描述了创建不同类型的Dataset
和Iterator
对象的基本原理,以及如何从中提取数据。
要启动一个输入管道,你必须定义一个源。 对于样本,要从内存中的某些张量构造Dataset
,可以使用tf.data.Dataset.from_tensors()
或tf.data.Dataset.from_tensor_slices()
。 或者,如果你的输入数据以建议的TFRecord格式存储在磁盘上,你可以构建一个tf.data.TFRecordDataset
。
在你拥有一个Dataset
对象之后,你可以通过在tf.data.Dataset
上链接地调用方法,将它转换成新的Dataset
。 例如,你可以应用每个元素的转换如Dataset.map()
(将函数应用于每个元素)和多个元素的转换,如Dataset.batch()
。 有关转换的完整列表,请参阅tf.data.Dataset
的文档。
从Dataset
中使用值的最常见方法是创建一个迭代器对象,该对象可以一次访问数据集的一个元素(例如,通过调用Dataset.make_one_shot_iterator()
)。 tf.data.Iterator
提供两个操作:Iterator.initializer
,它使你能够(重新)初始化迭代器的状态;Iterator.get_next()
,它返回下一个符号元素对应的tf.Tensor
对象。 根据你的使用情况,你可能会选择不同类型的迭代器,下文有列出具体的选项。
数据集包含的每个元素具有相同的结构。 一个元素包含一个或多个tf.Tensor
对象,称为组成部分。 每个组成部分都有一个表示张量中元素类型的tf.DType
,以及一个表示每个元素的静态形状的tf.TensorShape
(可能部分指定)。 Dataset.output_types
和Dataset.output_shapes
属性允许你检查数据集元素的每个组成部分的推断类型和形状。 这些属性的嵌套结构映射到元素的结构,该元素可以是单个张量、张量元组或张量的嵌套元组。 例如:
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
print(dataset1.output_types) # ==> "tf.float32"
print(dataset1.output_shapes) # ==> "(10,)"
dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random_uniform([4]),
tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types) # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes) # ==> "((), (100,))"
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types) # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes) # ==> "(10, ((), (100,)))"
通常给原始的每个组成部分起一个名字会带来方便,例如如果它们表示训练样本的不同特征。 除了元组之外,您还可以使用collections.namedtuple
或将字符串映射到张量以表示Dataset
的单个元素。
dataset = tf.data.Dataset.from_tensor_slices(
{"a": tf.random_uniform([4]),
"b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types) # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes) # ==> "{'a': (), 'b': (100,)}"
Dataset
的转换支持任何结构的数据集。 在使用Dataset.map()
,Dataset.flat_map()
和Dataset.filter()
转换时,它们将函数应用于每个元素,元素结构决定函数的参数:
dataset1 = dataset1.map(lambda x: ...)
dataset2 = dataset2.flat_map(lambda x, y: ...)
# Note: Argument destructuring is not available in Python 3.
dataset3 = dataset3.filter(lambda x, (y, z): ...)
一旦你建立了Dataset
来表示你的输入数据,下一步就是创建一个迭代器
来访问该数据集中的元素。 目前,tf.data
API支持以下迭代器,其级别越来越高:
A one-shot iterator is the simplest form of iterator, which only supports iterating once through a dataset, with no need for explicit initialization. 一次迭代器处理几乎所有现有的基于队列的输入流水线支持的情况,但它们不支持参数化。 使用Dataset.range()
的样本:
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
for i in range(100):
value = sess.run(next_element)
assert i == value
注意:目前,一次迭代器是唯一可以轻松用于Estimator
的类型。
一个可初始化的迭代器需要你在使用它之前运行一个显式的iterator.initializer
操作。 In exchange for this inconvenience, it enables you to parameterize the definition of the dataset, using one or more tf.placeholder()
tensors that can be fed when you initialize the iterator. 继续Dataset.range()
样本:
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value
# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
value = sess.run(next_element)
assert i == value
一个可重新初始化的迭代器可以从多个不同的Dataset
对象初始化。 对于样本,你可能有一个训练输入管道,它使用输入图像的随机扰动来改进泛化,以及验证输入管道,用于评估对未修改数据的预测。 这些管道通常使用具有相同结构的不同Dataset
对象(即每个组件具有相同类型和兼容形状)。
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)
# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
# Initialize an iterator over the training dataset.
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element)
# Initialize an iterator over the validation dataset.
sess.run(validation_init_op)
for _ in range(50):
sess.run(next_element)
A feedable iterator can be used together with tf.placeholder
to select what Iterator
to use in each call to tf.Session.run
, via the familiar feed_dict
mechanism. 它提供了与可重新初始化的迭代器相同的功能,但当您在迭代器之间切换时,它不需要从数据集的开头初始化迭代器。 对于样本,使用上面的训练和验证样本,你可以使用tf.data.Iterator.from_string_handle
来定义一个可馈送迭代器,它允许你在两个数据集之间切换:
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)
# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
# Loop forever, alternating between training and validation.
while True:
# Run 200 steps using the training dataset. Note that the training dataset is
# infinite, and we resume from where we left off in the previous `while` loop
# iteration.
for _ in range(200):
sess.run(next_element, feed_dict={handle: training_handle})
# Run one pass over the validation dataset.
sess.run(validation_iterator.initializer)
for _ in range(50):
sess.run(next_element, feed_dict={handle: validation_handle})
The Iterator.get_next()
method returns one or more tf.Tensor
objects that correspond to the symbolic next element of an iterator. 每次评估这些张量时,它们都会获取底层数据集中下一个元素的值。 (请注意,与TensorFlow中的其他有状态对象一样,调用Iterator.get_next()
不会立即推进迭代器。 相反,您必须使用TensorFlow表达式中返回的tf.Tensor
对象,并将该表达式的结果传递给tf.Session.run()
以获取下一个元素并推进迭代器。)
If the iterator reaches the end of the dataset, executing the Iterator.get_next()
operation will raise a tf.errors.OutOfRangeError
. 在这之后,迭代器将处于不可用状态,如果你想进一步使用它,你必须重新初始化它。
dataset = tf.data.Dataset.range(5)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Typically `result` will be the output of a model, or an optimizer's
# training operation.
result = tf.add(next_element, next_element)
sess.run(iterator.initializer)
print(sess.run(result)) # ==> "0"
print(sess.run(result)) # ==> "2"
print(sess.run(result)) # ==> "4"
print(sess.run(result)) # ==> "6"
print(sess.run(result)) # ==> "8"
try:
sess.run(result)
except tf.errors.OutOfRangeError:
print("End of dataset") # ==> "End of dataset"
一个常见的模式是将“训练循环”封装在try
- 除
块之外:
sess.run(iterator.initializer)
while True:
try:
sess.run(result)
except tf.errors.OutOfRangeError:
break
If each element of the dataset has a nested structure, the return value of Iterator.get_next()
will be one or more tf.Tensor
objects in the same nested structure:
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
iterator = dataset3.make_initializable_iterator()
sess.run(iterator.initializer)
next1, (next2, next3) = iterator.get_next()
注意next1
,next2
和next3
是由同一个op /节点产生的张量(由Iterator.get_next / T3>)。
Therefore, evaluating any of these tensors will advance the iterator for all components. 迭代器的典型使用者将在一个表达式中包含所有组件。
tf.contrib.data.make_saveable_from_iterator
函数从一个迭代器中创建一个SaveableObject
,它可以用来保存和恢复迭代器的当前状态(实际上,整个输入管道)。 这样创建的可保存对象可以添加到tf.train.Saver
变量列表或tf.GraphKeys.SAVEABLE_OBJECTS
集合中,以保存和恢复方式与tf.Variable
有关如何保存和恢复变量的详细信息,请参阅保存和恢复。
# Create saveable object from iterator.
saveable = tf.contrib.data.make_saveable_from_iterator(iterator)
# Save the iterator state by adding it to the saveable objects collection.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver()
with tf.Session() as sess:
if should_checkpoint:
saver.save(path_to_checkpoint)
# Restore the iterator state.
with tf.Session() as sess:
saver.restore(sess, path_to_checkpoint)
If all of your input data fit in memory, the simplest way to create a Dataset
from them is to convert them to tf.Tensor
objects and use Dataset.from_tensor_slices()
.
# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
请注意,上面的代码片段会将特征
和标签
数组作为tf.constant()
操作嵌入到TensorFlow图中。 这适用于小数据集,但浪费内存---因为数组的内容将被复制多次---并可以运行到tf.GraphDef
协议缓冲区的2GB限制内。
As an alternative, you can define the Dataset
in terms of tf.placeholder()
tensors, and feed the NumPy arrays when you initialize an Iterator
over the dataset.
# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer, feed_dict={features_placeholder: features,
labels_placeholder: labels})
tf.data
API支持多种文件格式,以便您可以处理不适合内存的大型数据集。 对于样本,TFRecord文件格式是一种简单的面向记录的二进制格式,许多TensorFlow应用程序用于训练数据。 使用tf.data.TFRecordDataset
类可以将一个或多个TFRecord文件的内容作为输入管道的一部分进行流式处理。
# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
TFRecordDataset
初始值设定项的文件名
参数既可以是字符串,也可以是字符串列表或tf.Tensor
字符串。 因此,如果您有两组文件用于训练和验证目的,您可以使用tf.placeholder(tf.string)
来表示文件名,并从适当的文件名初始化一个迭代器:
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...) # Parse the record into tensors.
dataset = dataset.repeat() # Repeat the input indefinitely.
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
# You can feed the initializer with the appropriate filenames for the current
# phase of execution, e.g. training vs. validation.
# Initialize `iterator` with training data.
training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
# Initialize `iterator` with validation data.
validation_filenames = ["/var/data/validation1.tfrecord", ...]
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
许多数据集都是作为一个或多个文本文件分发的。 tf.data.TextLineDataset
提供了一种从一个或多个文本文件中提取线条的简单方法。 给定一个或多个文件名,一个TextLineDataset
将为这些文件的每行生成一个字符串值元素。 Like a TFRecordDataset
, TextLineDataset
accepts filenames
as a tf.Tensor
, so you can parameterize it by passing a tf.placeholder(tf.string)
.
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)
By default, a TextLineDataset
yields every line of each file, which may not be desirable, for example if the file starts with a header line, or contains comments. 这些行可以使用Dataset.skip()
和Dataset.filter()
转换来删除。 要将这些转换分别应用于每个文件,我们使用Dataset.flat_map()
为每个文件创建一个嵌套的数据集
。
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.Dataset.from_tensor_slices(filenames)
# Use `Dataset.flat_map()` to transform each file as a separate nested dataset,
# and then concatenate their contents sequentially into a single "flat" dataset.
# * Skip the first line (header row).
# * Filter out lines beginning with "#" (comments).
dataset = dataset.flat_map(
lambda filename: (
tf.data.TextLineDataset(filename)
.skip(1)
.filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))
Dataset.map()
预处理数据通过将给定函数f
应用于输入数据集的每个元素,Dataset.map(f)
变换产生一个新数据集。 它基于通常应用于函数式编程语言中的列表(和其他结构)的map()
函数。 函数f
接受表示输入中单个元素的tf.Tensor
对象,并返回tf.Tensor
对象,该对象将表示新数据集中的单个元素。 其实现使用标准的TensorFlow操作将一个元素转换为另一个元素。
本节介绍如何使用Dataset.map()
的常见样本。
tf.Example
协议缓冲区消息许多输入管道从TFRecord格式文件中提取tf.train.Example
协议缓冲区消息(使用tf.python_io.TFRecordWriter
写入)。 每个tf.train.Example
记录包含一个或多个“特征”,输入管道通常将这些特征转换为张量。
# 将一个标量字符串`example_proto`转换为一对标量字符串和
# 一个标量整数,分别表示一个图片和它的标签。
def _parse_function(example_proto):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int32, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["image"], parsed_features["label"]
# 创建一个数据集,从两个文件中读取所有的样本,并抽取
# 图片和标签特征。
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)
当在真实世界的图像数据上训练神经网络时,经常需要将不同大小的图像转换成通用大小,以便它们可以批量化为固定大小。
# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_image(image_string)
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label
# A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])
# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...])
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
tf.py_func()
应用任意Python逻辑出于性能原因,我们鼓励您尽可能使用TensorFlow操作预处理数据。 但是,在解析输入数据时,调用外部Python库有时很有用。 为此,请在Dataset.map()
转换中调用tf.py_func()
操作。
import cv2
# Use a custom OpenCV function to read the image, instead of the standard
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename, label):
image_decoded = cv2.imread(filename.decode(), cv2.IMREAD_GRAYSCALE)
return image_decoded, label
# Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, label):
image_decoded.set_shape([None, None, None])
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label
filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...]
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(
lambda filename, label: tuple(tf.py_func(
_read_py_function, [filename, label], [tf.uint8, label.dtype])))
dataset = dataset.map(_resize_function)
The simplest form of batching stacks n
consecutive elements of a dataset into a single element. The Dataset.batch()
transformation does exactly this, with the same constraints as the tf.stack()
operator, applied to each component of the elements: i.e. for each component i, all elements must have a tensor of the exact same shape.
inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)
iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
print(sess.run(next_element)) # ==> ([0, 1, 2, 3], [ 0, -1, -2, -3])
print(sess.run(next_element)) # ==> ([4, 5, 6, 7], [-4, -5, -6, -7])
print(sess.run(next_element)) # ==> ([8, 9, 10, 11], [-8, -9, -10, -11])
上述配方适用于所有尺寸相同的张量。 然而,许多模型(例如序列模型)与可能具有不同大小的输入数据(例如,不同长度的序列)一起工作。 为了处理这种情况,通过Dataset.padded_batch()
转换,您可以通过指定填充它们的一个或多个尺寸来批量处理不同形状的张量。
dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=[None])
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
print(sess.run(next_element)) # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]]
print(sess.run(next_element)) # ==> [[4, 4, 4, 4, 0, 0, 0],
# [5, 5, 5, 5, 5, 0, 0],
# [6, 6, 6, 6, 6, 6, 0],
# [7, 7, 7, 7, 7, 7, 7]]
在 Dataset.padded_batch() T0>转换允许你设置不同的填充每个组件的每个维度,并且它可以是可变长度(由
也可以重写填充值,该值默认为0。无 T1>在样本上面的表示)或恒定长度。
tf.data
API提供了两种处理同一数据的多个时期的主要方式。
在多个时期迭代数据集的最简单方法是使用Dataset.repeat()
转换。 对于样本,要创建一个重复10个时期输入的数据集:
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.repeat(10)
dataset = dataset.batch(32)
应用不带参数的Dataset.repeat()
转换将无限期地重复输入。 Dataset.repeat()
转换将其参数连接起来,而不用指示一个周期的结束和下一个周期的开始。
如果你想在每个周期结束时收到一个信号,你可以编写一个训练循环,捕捉数据集末尾的tf.errors.OutOfRangeError
。 那时你可能会收集周期的一些统计数据(例如验证错误)。
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# 计算100个周期。
for _ in range(100):
sess.run(iterator.initializer)
while True:
try:
sess.run(next_element)
except tf.errors.OutOfRangeError:
break
# [Perform end-of-epoch calculations here.]
在 Dataset.shuffle() T0>变换随机洗牌使用类似的算法应用于所述输入数据集
tf.RandomShuffleQueue
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
tf.train.MonitoredTrainingSession
API简化了在分布式设置中运行TensorFlow的许多方面。 MonitoredTrainingSession
uses the tf.errors.OutOfRangeError
to signal that training has completed, so to use it with the tf.data
API, we recommend using Dataset.make_one_shot_iterator()
. 例如:
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
next_example, next_label = iterator.get_next()
loss = model_function(next_example, next_label)
training_op = tf.train.AdagradOptimizer(...).minimize(loss)
with tf.train.MonitoredTrainingSession(...) as sess:
while not sess.should_stop():
sess.run(training_op)
要在tf.estimator.Estimator
的input_fn
中使用Dataset
,我们还建议使用Dataset.make_one_shot_iterator() T4>。
例如:
def dataset_input_fn():
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
# Use `tf.parse_single_example()` to extract data from a `tf.Example`
# protocol buffer, and perform any additional per-record preprocessing.
def parser(record):
keys_to_features = {
"image_data": tf.FixedLenFeature((), tf.string, default_value=""),
"date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
"label": tf.FixedLenFeature((), tf.int64,
default_value=tf.zeros([], dtype=tf.int64)),
}
parsed = tf.parse_single_example(record, keys_to_features)
# Perform additional preprocessing on the parsed data.
image = tf.image.decode_jpeg(parsed["image_data"])
image = tf.reshape(image, [299, 299, 1])
label = tf.cast(parsed["label"], tf.int32)
return {"image_data": image, "date_time": parsed["date_time"]}, label
# Use `Dataset.map()` to build a pair of a feature dictionary and a label
# tensor for each example.
dataset = dataset.map(parser)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
# `features` is a dictionary in which each value is a batch of values for
# that feature; `labels` is a batch of labels.
features, labels = iterator.get_next()
return features, labels