数据集快速入门

tf.data模块包含一组类,可让你轻松加载数据、操作数据并将其传送到模型中。 本文档通过两个简单的示例介绍了API:

基本输入

从数组中读取切片是开始使用tf.data的最简单方法。

预置的Estimator一章讲述了iris_data.py中的train_input_fn来将数据传送到Estimator,如下:

def train_input_fn(features, labels, batch_size):
    """An input function for training"""
    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # Shuffle, repeat, and batch the examples.
    dataset = dataset.shuffle(1000).repeat().batch(batch_size)

    # Return the dataset.
    return dataset

我们来仔细看看。

参数

这个函数需要三个参数。 期望“数组”的参数几乎可以接受任何可以使用numpy.array转换为数组的任何内容。 正如我们将看到的,tuple是一个例外,它对Datasets具有特殊含义。

premade_estimator.py中,我们使用iris_data.load_data()函数获取鸢尾花数据。 你可以运行它,然后解压结果,如下:

import iris_data

# Fetch the data
train, test = iris_data.load_data()
features, labels = train

然后,我们将这些数据传递给输入函数,并使用类似如下的代码:

batch_size=100
iris_data.train_input_fn(features, labels, batch_size)

我们来看看train_input_fn()

切片

在开始的时候,该函数使用tf.data.Dataset.from_tensor_slices函数来创建表示数组切片的tf.data.Dataset 数组将在第一维上切片。 例如,包含mnist训练数据的数组的形状为(60000, 28, 28) 将它传递给from_tensor_slices将返回一个包含60000个切片的Dataset对象,每个对象都是一个28x28图像。

返回这个Dataset的代码如下所示:

train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train

mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print(mnist_ds)

这将打印下面这一行,显示数据集中每个元素的形状类型 请注意,Dataset不知道它包含多少个元素。

<TensorSliceDataset shapes: (28,28), types: tf.uint8>

上面的Dataset表示一个简单的数组集合,但数据集比这个更强大。 一个Dataset可以透明地处理字典或元组(或namedtuple)的任何嵌套组合。

例如,将鸢尾花特征转换为标准python字典后,你可以将数组字典转换为Dataset字典,如下所示:

dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print(dataset)
<TensorSliceDataset

  shapes: {
    SepalLength: (), PetalWidth: (),
    PetalLength: (), SepalWidth: ()},

  types: {
      SepalLength: tf.float64, PetalWidth: tf.float64,
      PetalLength: tf.float64, SepalWidth: tf.float64}
>

这里我们看到当Dataset包含结构元素时,Datasetshapestypes具有相同的结构。 这个数据集包含标量的字典,所有标量的类型都为tf.float64

鸢尾花train_input_fn的第一行使用相同的函数,但增加了另一层结构。 它创建一个包含(features_dict, label)对的数据集。

以下代码显示标签是类型为int64的标量:

# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)
<TensorSliceDataset
    shapes: (
        {
          SepalLength: (), PetalWidth: (),
          PetalLength: (), SepalWidth: ()},
        ()),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>

操作

目前,Dataset将以固定顺序遍历数据一次,并且一次只产生一个元素。 它需要进一步处理才能用于训练。 幸运的是,tf.data.Dataset类提供了更好地为训练准备数据的方法。 输入函数的下一行利用了以下几种方法:

# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)

shuffle方法使用一个固定大小的缓冲区在元素通过时对其进行洗乱。 在这个例子中,buffer_size大于Dataset中的样本数目,确保数据完全洗乱(鸢尾花数据集仅包含150个样本)。

repeat方法在到达结尾时重新启动数据集 要限制周期的数量,请设置count参数。

batch方法收集若干个样本并将它们堆叠起来,以创建批次。 这为它们的形状增加了一个维度。 新维度被添加为第一维度。 以下代码使用前面的MNIST Dataset上的batch方法。 它产生的Dataset包含的3D数组表示(28,28)图像的堆叠:

print(mnist_ds.batch(100))
<BatchDataset
  shapes: (?, 28, 28),
  types: tf.uint8>

请注意,数据集具有未知的批次大小,因为最后一批的元素数量将少一些。

train_input_fn中,在批次处理之后,Dataset包含元素的一维向量,其中每个标量就是先前的:

print(dataset)
<TensorSliceDataset
    shapes: (
        {
          SepalLength: (?,), PetalWidth: (?,),
          PetalLength: (?,), SepalWidth: (?,)},
        (?,)),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>

返回

此时,Dataset包含(features_dict, labels)对。 这是trainevaluate方法预期的格式,所以input_fn返回数据集。

使用predict方法时,可以/应该省略labels

读取CSV文件

对于Dataset类,最常见的现实世界用例是从磁盘上的文件流式传输数据。 tf.data模块包含各种文件读取器。 我们来看看如何使用数据集解析csv文件中的鸢尾花数据集。

如果需要,以下对iris_data.maybe_download函数的调用会下载数据,并返回结果文件的路径名称:

import iris_data
train_path, test_path = iris_data.maybe_download()

iris_data.csv_input_fn函数包含一个使用数据集解析csv文件的替代实现。

我们来看看如何构建一个从本地文件读取的兼容Estimator的输入函数。

建立数据集

我们首先构建一个TextLineDataset对象,一次读取一行文件。 然后,我们调用skip方法跳过包含标题的文件的第一行,而不是一个样本:

ds = tf.data.TextLineDataset(train_path).skip(1)

构建一个csv行解析器

我们将开始构建一个解析每一行的函数。

下面的iris_data.parse_line函数使用tf.decode_csv函数和一些简单的python代码完成这个任务:

我们必须解析数据集中的每一行以生成必要的(features, label)对。 下面的_parse_line函数调用tf.decode_csv来将一行解析为它的特征和标签。 由于Estimators要求特征表示为字典,因此我们依靠Python内置的dictzip函数来构建该字典。 特征名称是该字典的键。 然后我们调用字典的pop方法从特征字典中删除标签字段:

# Metadata describing the text columns
COLUMNS = ['SepalLength', 'SepalWidth',
           'PetalLength', 'PetalWidth',
           'label']
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):
    # Decode the line into its fields
    fields = tf.decode_csv(line, FIELD_DEFAULTS)

    # Pack the result into a dictionary
    features = dict(zip(COLUMNS,fields))

    # Separate the label from the features
    label = features.pop('label')

    return features, label

解析行

数据集有很多方法用于在将数据传送到模型时对数据进行操作。 最常用的方法是map,它对Dataset的每个元素应用转换。

map方法使用map_func参数来描述应该如何转换Dataset中的每个元素。

map方法应用map_func转换数据集中的每个元素。

因此,为了解析流出csv文件的行,我们将_parse_line函数传递给map方法:

ds = ds.map(_parse_line)
print(ds)
<MapDataset
shapes: (
    {SepalLength: (), PetalWidth: (), ...},
    ()),
types: (
    {SepalLength: tf.float32, PetalWidth: tf.float32, ...},
    tf.int32)>

现在,数据集包含(features, label) 对,不是简单的标量字符串。

iris_data.csv_input_fn函数的其余部分与基本输入部分中介绍的iris_data.train_input_fn完全相同。

试一试

该函数可以用来替代iris_data.train_input_fn 可以将它提供一个Estimator,如下:

train_path, test_path = iris_data.maybe_download()

# All the inputs are numeric
feature_columns = [
    tf.feature_column.numeric_column(name)
    for name in iris_data.CSV_COLUMN_NAMES[:-1]]

# Build the estimator
est = tf.estimator.LinearClassifier(feature_columns,
                                    n_classes=3)
# Train the estimator
batch_size = 100
est.train(
    steps=1000,
    input_fn=lambda : iris_data.csv_input_fn(train_path, batch_size))

Estimator期望input_fn不接受任何参数。 要解决这个限制,我们使用lambda来捕获参数并提供预期的接口。

总结

tf.data模块提供一组用于轻松读取各种来源的数据的类和函数。 此外,tf.data具有简单强大的方法来应用各种标准和自定义转换。

现在你已经具有如何有效地将数据加载到Estimator中的基本概念。 接下来考虑下列文档: