导入数据

tf.data API使你能够从简单的、可重用的碎片构建复杂的输入管道。 例如,图像模型的管道可能会聚合分布式文件系统中的文件中的数据、将随机扰动应用于每个图像、并将随机选择的图像合并为一个批次以进行训练。 文本模型的管道可能从原始文本数据中提取符号、将它们转换为带有查找表的嵌入标识符、以及将不同长度的序列一起进行批处理。 tf.data API可以轻松处理大量数据、不同数据格式和复杂的转换。

tf.data API为TensorFlow引入了两个新的抽象:

基本的机制

本指南的这一部分描述了创建不同类型的DatasetIterator对象的基本原理,以及如何从中提取数据。

要启动一个输入管道,你必须定义一个 对于样本,要从内存中的某些张量构造Dataset,可以使用tf.data.Dataset.from_tensors()tf.data.Dataset.from_tensor_slices() 或者,如果你的输入数据以建议的TFRecord格式存储在磁盘上,你可以构建一个tf.data.TFRecordDataset

在你拥有一个Dataset对象之后,你可以通过在tf.data.Dataset上链接地调用方法,将它转换成新的Dataset 例如,你可以应用每个元素的转换如Dataset.map()(将函数应用于每个元素)和多个元素的转换,如Dataset.batch() 有关转换的完整列表,请参阅tf.data.Dataset的文档。

Dataset中使用值的最常见方法是创建一个迭代器对象,该对象可以一次访问数据集的一个元素(例如,通过调用Dataset.make_one_shot_iterator())。 tf.data.Iterator提供两个操作:Iterator.initializer,它使你能够(重新)初始化迭代器的状态;Iterator.get_next(),它返回下一个符号元素对应的tf.Tensor对象。 根据你的使用情况,你可能会选择不同类型的迭代器,下文有列出具体的选项。

数据集的结构

数据集包含的每个元素具有相同的结构。 一个元素包含一个或多个tf.Tensor对象,称为组成部分 每个组成部分都有一个表示张量中元素类型的tf.DType,以及一个表示每个元素的静态形状的tf.TensorShape(可能部分指定)。 Dataset.output_typesDataset.output_shapes属性允许你检查数据集元素的每个组成部分的推断类型和形状。 这些属性的嵌套结构映​​射到元素的结构,该元素可以是单个张量、张量元组或张量的嵌套元组。 例如:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
print(dataset1.output_types)  # ==> "tf.float32"
print(dataset1.output_shapes)  # ==> "(10,)"

dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random_uniform([4]),
    tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types)  # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes)  # ==> "((), (100,))"

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types)  # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes)  # ==> "(10, ((), (100,)))"

通常给原始的每个组成部分起一个名字会带来方便,例如如果它们表示训练样本的不同特征。 除了元组之外,您还可以使用collections.namedtuple或将字符串映射到张量以表示Dataset的单个元素。

dataset = tf.data.Dataset.from_tensor_slices(
   {"a": tf.random_uniform([4]),
    "b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types)  # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes)  # ==> "{'a': (), 'b': (100,)}"

Dataset的转换支持任何结构的数据集。 在使用Dataset.map()Dataset.flat_map()Dataset.filter()转换时,它们将函数应用于每个元素,元素结构决定函数的参数:

dataset1 = dataset1.map(lambda x: ...)

dataset2 = dataset2.flat_map(lambda x, y: ...)

# Note: Argument destructuring is not available in Python 3.
dataset3 = dataset3.filter(lambda x, (y, z): ...)

创建一个迭代器

一旦你建立了Dataset来表示你的输入数据,下一步就是创建一个迭代器来访问该数据集中的元素。 目前,tf.data API支持以下迭代器,其级别越来越高:

A one-shot iterator is the simplest form of iterator, which only supports iterating once through a dataset, with no need for explicit initialization. 一次迭代器处理几乎所有现有的基于队列的输入流水线支持的情况,但它们不支持参数化。 使用Dataset.range()的样本:

dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(100):
  value = sess.run(next_element)
  assert i == value

注意:目前,一次迭代器是唯一可以轻松用于Estimator的类型。

一个可初始化的迭代器需要你在使用它之前运行一个显式的iterator.initializer操作。 In exchange for this inconvenience, it enables you to parameterize the definition of the dataset, using one or more tf.placeholder() tensors that can be fed when you initialize the iterator. 继续Dataset.range()样本:

max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
  value = sess.run(next_element)
  assert i == value

# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
  value = sess.run(next_element)
  assert i == value

一个可重新初始化的迭代器可以从多个不同的Dataset对象初始化。 对于样本,你可能有一个训练输入管道,它使用输入图像的随机扰动来改进泛化,以及验证输入管道,用于评估对未修改数据的预测。 这些管道通常使用具有相同结构的不同Dataset对象(即每个组件具有相同类型和兼容形状)。

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)

# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
                                           training_dataset.output_shapes)
next_element = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
  # Initialize an iterator over the training dataset.
  sess.run(training_init_op)
  for _ in range(100):
    sess.run(next_element)

  # Initialize an iterator over the validation dataset.
  sess.run(validation_init_op)
  for _ in range(50):
    sess.run(next_element)

A feedable iterator can be used together with tf.placeholder to select what Iterator to use in each call to tf.Session.run, via the familiar feed_dict mechanism. 它提供了与可重新初始化的迭代器相同的功能,但当您在迭代器之间切换时,它不需要从数据集的开头初始化迭代器。 对于样本,使用上面的训练和验证样本,你可以使用tf.data.Iterator.from_string_handle来定义一个可馈送迭代器,它允许你在两个数据集之间切换:

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)

# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

# Loop forever, alternating between training and validation.
while True:
  # Run 200 steps using the training dataset. Note that the training dataset is
  # infinite, and we resume from where we left off in the previous `while` loop
  # iteration.
  for _ in range(200):
    sess.run(next_element, feed_dict={handle: training_handle})

  # Run one pass over the validation dataset.
  sess.run(validation_iterator.initializer)
  for _ in range(50):
    sess.run(next_element, feed_dict={handle: validation_handle})

从迭代器中消费值

The Iterator.get_next() method returns one or more tf.Tensor objects that correspond to the symbolic next element of an iterator. 每次评估这些张量时,它们都会获取底层数据集中下一个元素的值。 (请注意,与TensorFlow中的其他有状态对象一样,调用Iterator.get_next()不会立即推进迭代器。 相反,您必须使用TensorFlow表达式中返回的tf.Tensor对象,并将该表达式的结果传递给tf.Session.run()以获取下一个元素并推进迭代器。)

If the iterator reaches the end of the dataset, executing the Iterator.get_next() operation will raise a tf.errors.OutOfRangeError. 在这之后,迭代器将处于不可用状态,如果你想进一步使用它,你必须重新初始化它。

dataset = tf.data.Dataset.range(5)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Typically `result` will be the output of a model, or an optimizer's
# training operation.
result = tf.add(next_element, next_element)

sess.run(iterator.initializer)
print(sess.run(result))  # ==> "0"
print(sess.run(result))  # ==> "2"
print(sess.run(result))  # ==> "4"
print(sess.run(result))  # ==> "6"
print(sess.run(result))  # ==> "8"
try:
  sess.run(result)
except tf.errors.OutOfRangeError:
  print("End of dataset")  # ==> "End of dataset"

一个常见的模式是将“训练循环”封装在try - 块之外:

sess.run(iterator.initializer)
while True:
  try:
    sess.run(result)
  except tf.errors.OutOfRangeError:
    break

If each element of the dataset has a nested structure, the return value of Iterator.get_next() will be one or more tf.Tensor objects in the same nested structure:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

iterator = dataset3.make_initializable_iterator()

sess.run(iterator.initializer)
next1, (next2, next3) = iterator.get_next()

注意next1next2next3是由同一个op /节点产生的张量(由Iterator.get_next / T3>)。 Therefore, evaluating any of these tensors will advance the iterator for all components. 迭代器的典型使用者将在一个表达式中包含所有组件。

保存迭代器状态

tf.contrib.data.make_saveable_from_iterator函数从一个迭代器中创建一个SaveableObject,它可以用来保存和恢复迭代器的当前状态(实际上,整个输入管道)。 这样创建的可保存对象可以添加到tf.train.Saver变量​​列表或tf.GraphKeys.SAVEABLE_OBJECTS集合中,以保存和恢复方式与tf.Variable 有关如何保存和恢复变量的详细信息,请参阅保存和恢复

# Create saveable object from iterator.
saveable = tf.contrib.data.make_saveable_from_iterator(iterator)

# Save the iterator state by adding it to the saveable objects collection.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver()

with tf.Session() as sess:

  if should_checkpoint:
    saver.save(path_to_checkpoint)

# Restore the iterator state.
with tf.Session() as sess:
  saver.restore(sess, path_to_checkpoint)

读取输入数据

消费NumPy数组

If all of your input data fit in memory, the simplest way to create a Dataset from them is to convert them to tf.Tensor objects and use Dataset.from_tensor_slices().

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]

# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]

dataset = tf.data.Dataset.from_tensor_slices((features, labels))

请注意,上面的代码片段会将特征标签数组作为tf.constant()操作嵌入到TensorFlow图中。 这适用于小数据集,但浪费内存---因为数组的内容将被复制多次---并可以运行到tf.GraphDef协议缓冲区的2GB限制内。

As an alternative, you can define the Dataset in terms of tf.placeholder() tensors, and feed the NumPy arrays when you initialize an Iterator over the dataset.

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]

# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()

sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                          labels_placeholder: labels})

消费TFRecord数据

tf.data API支持多种文件格式,以便您可以处理不适合内存的大型数据集。 对于样本,TFRecord文件格式是一种简单的面向记录的二进制格式,许多TensorFlow应用程序用于训练数据。 使用tf.data.TFRecordDataset类可以将一个或多个TFRecord文件的内容作为输入管道的一部分进行流式处理。

# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)

TFRecordDataset初始值设定项的文件名参数既可以是字符串,也可以是字符串列表或tf.Tensor字符串。 因此,如果您有两组文件用于训练和验证目的,您可以使用tf.placeholder(tf.string)来表示文件名,并从适当的文件名初始化一个迭代器:

filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)  # Parse the record into tensors.
dataset = dataset.repeat()  # Repeat the input indefinitely.
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()

# You can feed the initializer with the appropriate filenames for the current
# phase of execution, e.g. training vs. validation.

# Initialize `iterator` with training data.
training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})

# Initialize `iterator` with validation data.
validation_filenames = ["/var/data/validation1.tfrecord", ...]
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})

消费文本数据

许多数据集都是作为一个或多个文本文件分发的。 tf.data.TextLineDataset提供了一种从一个或多个文本文件中提取线条的简单方法。 给定一个或多个文件名,一个TextLineDataset将为这些文件的每行生成一个字符串值元素。 Like a TFRecordDataset, TextLineDataset accepts filenames as a tf.Tensor, so you can parameterize it by passing a tf.placeholder(tf.string).

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)

By default, a TextLineDataset yields every line of each file, which may not be desirable, for example if the file starts with a header line, or contains comments. 这些行可以使用Dataset.skip()Dataset.filter()转换来删除。 要将这些转换分别应用于每个文件,我们使用Dataset.flat_map()为每个文件创建一个嵌套的数据集

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]

dataset = tf.data.Dataset.from_tensor_slices(filenames)

# Use `Dataset.flat_map()` to transform each file as a separate nested dataset,
# and then concatenate their contents sequentially into a single "flat" dataset.
# * Skip the first line (header row).
# * Filter out lines beginning with "#" (comments).
dataset = dataset.flat_map(
    lambda filename: (
        tf.data.TextLineDataset(filename)
        .skip(1)
        .filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))

使用Dataset.map()预处理数据

通过将给定函数f应用于输入数据集的每个元素,Dataset.map(f)变换产生一个新数据集。 它基于通常应用于函数式编程语言中的列表(和其他结构)的map()函数 函数f接受表示输入中单个元素的tf.Tensor对象,并返回tf.Tensor对象,该对象将表示新数据集中的单个元素。 其实现使用标准的TensorFlow操作将一个元素转换为另一个元素。

本节介绍如何使用Dataset.map()的常见样本。

解析tf.Example协议缓冲区消息

许多输入管道从TFRecord格式文件中提取tf.train.Example协议缓冲区消息(使用tf.python_io.TFRecordWriter写入)。 每个tf.train.Example记录包含一个或多个“特征”,输入管道通常将这些特征转换为张量。

# 将一个标量字符串`example_proto`转换为一对标量字符串和
# 一个标量整数,分别表示一个图片和它的标签。
def _parse_function(example_proto):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int32, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]

# 创建一个数据集,从两个文件中读取所有的样本,并抽取
# 图片和标签特征。
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)

解码图像数据并调整其大小

当在真实世界的图像数据上训练神经网络时,经常需要将不同大小的图像转换成通用大小,以便它们可以批量化为固定大小。

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_image(image_string)
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

# A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])

# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...])

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)

使用tf.py_func()应用任意Python逻辑

出于性能原因,我们鼓励您尽可能使用TensorFlow操作预处理数据。 但是,在解析输入数据时,调用外部Python库有时很有用。 为此,请在Dataset.map()转换中调用tf.py_func()操作。

import cv2

# Use a custom OpenCV function to read the image, instead of the standard
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename, label):
  image_decoded = cv2.imread(filename.decode(), cv2.IMREAD_GRAYSCALE)
  return image_decoded, label

# Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, label):
  image_decoded.set_shape([None, None, None])
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...]

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(
    lambda filename, label: tuple(tf.py_func(
        _read_py_function, [filename, label], [tf.uint8, label.dtype])))
dataset = dataset.map(_resize_function)

批处理数据集元素

简单的批处理

The simplest form of batching stacks n consecutive elements of a dataset into a single element. The Dataset.batch() transformation does exactly this, with the same constraints as the tf.stack() operator, applied to each component of the elements: i.e. for each component i, all elements must have a tensor of the exact same shape.

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)

iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()

print(sess.run(next_element))  # ==> ([0, 1, 2,   3],   [ 0, -1,  -2,  -3])
print(sess.run(next_element))  # ==> ([4, 5, 6,   7],   [-4, -5,  -6,  -7])
print(sess.run(next_element))  # ==> ([8, 9, 10, 11],   [-8, -9, -10, -11])

带有填充的批处理张量

上述配方适用于所有尺寸相同的张量。 然而,许多模型(例如序列模型)与可能具有不同大小的输入数据(例如,不同长度的序列)一起工作。 为了处理这种情况,通过Dataset.padded_batch()转换,您可以通过指定填充它们的一个或多个尺寸来批量处理不同形状的张量。

dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=[None])

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

print(sess.run(next_element))  # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]]
print(sess.run(next_element))  # ==> [[4, 4, 4, 4, 0, 0, 0],
                               #      [5, 5, 5, 5, 5, 0, 0],
                               #      [6, 6, 6, 6, 6, 6, 0],
                               #      [7, 7, 7, 7, 7, 7, 7]]

Dataset.padded_batch() T0>转换允许你设置不同的填充每个组件的每个维度,并且它可以是可变长度(由无 T1>在样本上面的表示)或恒定长度。 也可以重写填充值,该值默认为0。

训练工作流程

处理多个周期

tf.data API提供了两种处理同一数据的多个时期的主要方式。

在多个时期迭代数据集的最简单方法是使用Dataset.repeat()转换。 对于样本,要创建一个重复10个时期输入的数据集:

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.repeat(10)
dataset = dataset.batch(32)

应用不带参数的Dataset.repeat()转换将无限期地重复输入。 Dataset.repeat()转换将其参数连接起来,而不用指示一个周期的结束和下一个周期的开始。

如果你想在每个周期结束时收到一个信号,你可以编写一个训练循环,捕捉数据集末尾的tf.errors.OutOfRangeError 那时你可能会收集周期的一些统计数据(例如验证错误)。

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# 计算100个周期。
for _ in range(100):
  sess.run(iterator.initializer)
  while True:
    try:
      sess.run(next_element)
    except tf.errors.OutOfRangeError:
      break

  # [Perform end-of-epoch calculations here.]

随机洗乱输入数据

Dataset.shuffle() T0>变换随机洗牌使用类似的算法应用于所述输入数据集tf.RandomShuffleQueue

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat()

使用高级API

tf.train.MonitoredTrainingSession API简化了在分布式设置中运行TensorFlow的许多方面。 MonitoredTrainingSession uses the tf.errors.OutOfRangeError to signal that training has completed, so to use it with the tf.data API, we recommend using Dataset.make_one_shot_iterator(). 例如:

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()

next_example, next_label = iterator.get_next()
loss = model_function(next_example, next_label)

training_op = tf.train.AdagradOptimizer(...).minimize(loss)

with tf.train.MonitoredTrainingSession(...) as sess:
  while not sess.should_stop():
    sess.run(training_op)

要在tf.estimator.Estimatorinput_fn中使用Dataset,我们还建议使用Dataset.make_one_shot_iterator() T4>。 例如:

def dataset_input_fn():
  filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
  dataset = tf.data.TFRecordDataset(filenames)

  # Use `tf.parse_single_example()` to extract data from a `tf.Example`
  # protocol buffer, and perform any additional per-record preprocessing.
  def parser(record):
    keys_to_features = {
        "image_data": tf.FixedLenFeature((), tf.string, default_value=""),
        "date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
        "label": tf.FixedLenFeature((), tf.int64,
                                    default_value=tf.zeros([], dtype=tf.int64)),
    }
    parsed = tf.parse_single_example(record, keys_to_features)

    # Perform additional preprocessing on the parsed data.
    image = tf.image.decode_jpeg(parsed["image_data"])
    image = tf.reshape(image, [299, 299, 1])
    label = tf.cast(parsed["label"], tf.int32)

    return {"image_data": image, "date_time": parsed["date_time"]}, label

  # Use `Dataset.map()` to build a pair of a feature dictionary and a label
  # tensor for each example.
  dataset = dataset.map(parser)
  dataset = dataset.shuffle(buffer_size=10000)
  dataset = dataset.batch(32)
  dataset = dataset.repeat(num_epochs)
  iterator = dataset.make_one_shot_iterator()

  # `features` is a dictionary in which each value is a batch of values for
  # that feature; `labels` is a batch of labels.
  features, labels = iterator.get_next()
  return features, labels