tf.data
模块包含一组类,可让你轻松加载数据、操作数据并将其传送到模型中。 本文档通过两个简单的示例介绍了API:
从数组中读取切片是开始使用tf.data
的最简单方法。
预置的Estimator一章讲述了iris_data.py
中的train_input_fn
来将数据传送到Estimator,如下:
def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
# Return the dataset.
return dataset
我们来仔细看看。
这个函数需要三个参数。 期望“数组”的参数几乎可以接受任何可以使用numpy.array
转换为数组的任何内容。 正如我们将看到的,tuple
是一个例外,它对Datasets
具有特殊含义。
在premade_estimator.py
中,我们使用iris_data.load_data()
函数获取鸢尾花数据。 你可以运行它,然后解压结果,如下:
import iris_data
# Fetch the data
train, test = iris_data.load_data()
features, labels = train
然后,我们将这些数据传递给输入函数,并使用类似如下的代码:
batch_size=100
iris_data.train_input_fn(features, labels, batch_size)
我们来看看train_input_fn()
。
在开始的时候,该函数使用tf.data.Dataset.from_tensor_slices
函数来创建表示数组切片的tf.data.Dataset
。 数组将在第一维上切片。 例如,包含mnist训练数据的数组的形状为(60000, 28, 28)
。 将它传递给from_tensor_slices
将返回一个包含60000个切片的Dataset
对象,每个对象都是一个28x28图像。
返回这个Dataset
的代码如下所示:
train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train
mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print(mnist_ds)
这将打印下面这一行,显示数据集中每个元素的形状和类型。 请注意,Dataset
不知道它包含多少个元素。
<TensorSliceDataset shapes: (28,28), types: tf.uint8>
上面的Dataset
表示一个简单的数组集合,但数据集比这个更强大。 一个Dataset
可以透明地处理字典或元组(或namedtuple
)的任何嵌套组合。
例如,将鸢尾花特征
转换为标准python字典后,你可以将数组字典转换为Dataset
字典,如下所示:
dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print(dataset)
<TensorSliceDataset
shapes: {
SepalLength: (), PetalWidth: (),
PetalLength: (), SepalWidth: ()},
types: {
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64}
>
这里我们看到当Dataset
包含结构元素时,Dataset
的shapes
和types
具有相同的结构。 这个数据集包含标量的字典,所有标量的类型都为tf.float64
。
鸢尾花train_input_fn
的第一行使用相同的函数,但增加了另一层结构。 它创建一个包含(features_dict, label)
对的数据集。
以下代码显示标签是类型为int64
的标量:
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)
<TensorSliceDataset
shapes: (
{
SepalLength: (), PetalWidth: (),
PetalLength: (), SepalWidth: ()},
()),
types: (
{
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64},
tf.int64)>
目前,Dataset
将以固定顺序遍历数据一次,并且一次只产生一个元素。 它需要进一步处理才能用于训练。 幸运的是,tf.data.Dataset
类提供了更好地为训练准备数据的方法。 输入函数的下一行利用了以下几种方法:
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
shuffle
方法使用一个固定大小的缓冲区在元素通过时对其进行洗乱。 在这个例子中,buffer_size
大于Dataset
中的样本数目,确保数据完全洗乱(鸢尾花数据集仅包含150个样本)。
repeat
方法在到达结尾时重新启动数据集
。 要限制周期的数量,请设置count
参数。
batch
方法收集若干个样本并将它们堆叠起来,以创建批次。 这为它们的形状增加了一个维度。 新维度被添加为第一维度。 以下代码使用前面的MNIST Dataset
上的batch
方法。 它产生的Dataset
包含的3D数组表示(28,28)
图像的堆叠:
print(mnist_ds.batch(100))
<BatchDataset
shapes: (?, 28, 28),
types: tf.uint8>
请注意,数据集具有未知的批次大小,因为最后一批的元素数量将少一些。
在train_input_fn
中,在批次处理之后,Dataset
包含元素的一维向量,其中每个标量就是先前的:
print(dataset)
<TensorSliceDataset
shapes: (
{
SepalLength: (?,), PetalWidth: (?,),
PetalLength: (?,), SepalWidth: (?,)},
(?,)),
types: (
{
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64},
tf.int64)>
此时,Dataset
包含(features_dict, labels)
对。 这是train
和evaluate
方法预期的格式,所以input_fn
返回数据集。
使用predict
方法时,可以/应该省略labels
。
对于Dataset
类,最常见的现实世界用例是从磁盘上的文件流式传输数据。 tf.data
模块包含各种文件读取器。 我们来看看如何使用数据集
解析csv文件中的鸢尾花数据集。
如果需要,以下对iris_data.maybe_download
函数的调用会下载数据,并返回结果文件的路径名称:
import iris_data
train_path, test_path = iris_data.maybe_download()
iris_data.csv_input_fn
函数包含一个使用数据集
解析csv文件的替代实现。
我们来看看如何构建一个从本地文件读取的兼容Estimator的输入函数。
数据集
我们首先构建一个TextLineDataset
对象,一次读取一行文件。 然后,我们调用skip
方法跳过包含标题的文件的第一行,而不是一个样本:
ds = tf.data.TextLineDataset(train_path).skip(1)
我们将开始构建一个解析每一行的函数。
下面的iris_data.parse_line
函数使用tf.decode_csv
函数和一些简单的python代码完成这个任务:
我们必须解析数据集中的每一行以生成必要的(features, label)
对。 下面的_parse_line
函数调用tf.decode_csv
来将一行解析为它的特征和标签。 由于Estimators要求特征表示为字典,因此我们依靠Python内置的dict
和zip
函数来构建该字典。 特征名称是该字典的键。 然后我们调用字典的pop
方法从特征字典中删除标签字段:
# Metadata describing the text columns
COLUMNS = ['SepalLength', 'SepalWidth',
'PetalLength', 'PetalWidth',
'label']
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):
# Decode the line into its fields
fields = tf.decode_csv(line, FIELD_DEFAULTS)
# Pack the result into a dictionary
features = dict(zip(COLUMNS,fields))
# Separate the label from the features
label = features.pop('label')
return features, label
数据集有很多方法用于在将数据传送到模型时对数据进行操作。 最常用的方法是map
,它对Dataset
的每个元素应用转换。
map
方法使用map_func
参数来描述应该如何转换Dataset
中的每个元素。
map
方法应用map_func
转换数据集
中的每个元素。
因此,为了解析流出csv文件的行,我们将_parse_line
函数传递给map
方法:
ds = ds.map(_parse_line)
print(ds)
<MapDataset
shapes: (
{SepalLength: (), PetalWidth: (), ...},
()),
types: (
{SepalLength: tf.float32, PetalWidth: tf.float32, ...},
tf.int32)>
现在,数据集包含(features, label)
对,不是简单的标量字符串。
iris_data.csv_input_fn
函数的其余部分与基本输入部分中介绍的iris_data.train_input_fn
完全相同。
该函数可以用来替代iris_data.train_input_fn
。 可以将它提供一个Estimator,如下:
train_path, test_path = iris_data.maybe_download()
# All the inputs are numeric
feature_columns = [
tf.feature_column.numeric_column(name)
for name in iris_data.CSV_COLUMN_NAMES[:-1]]
# Build the estimator
est = tf.estimator.LinearClassifier(feature_columns,
n_classes=3)
# Train the estimator
batch_size = 100
est.train(
steps=1000,
input_fn=lambda : iris_data.csv_input_fn(train_path, batch_size))
Estimator期望input_fn
不接受任何参数。 要解决这个限制,我们使用lambda
来捕获参数并提供预期的接口。
tf.data
模块提供一组用于轻松读取各种来源的数据的类和函数。 此外,tf.data
具有简单强大的方法来应用各种标准和自定义转换。
现在你已经具有如何有效地将数据加载到Estimator中的基本概念。 接下来考虑下列文档:
Estimator
模型。tf.data.Datasets
。Datasets
的其他功能。