![机器学习实战:模型构建与应用](https://wfqqreader-1252317822.image.myqcloud.com/cover/359/44389359/b_44389359.jpg)
4.2 在Keras模型中使用TFDS
在第2章中你看到了如何使用TensorFlow和Keras创建一个简单的计算机视觉模型,其中使用了Keras内置的数据集(包括Fashion MNIST),代码如下:
![](https://epubservercos.yuewen.com/A03276/23627497809544006/epubprivate/OEBPS/Images/085-1.jpg?sign=1739684010-AuGLVhzQPOsAn1RobhfetoPv79A7sRia-0-e9c4e63d1dbb1c87f20e34ea2d3de757)
使用TFDS时代码是非常相似的,但需要一些小的改动。Keras数据集提供的是ndarray
类型,可以直接在model.fit
中使用,但是使用TFDS我们需要做一点转换工作:
![](https://epubservercos.yuewen.com/A03276/23627497809544006/epubprivate/OEBPS/Images/086-1.jpg?sign=1739684010-nT1Zs6ozjTSs3TlUADRVueuClzkDg1N4-0-cc24bc6119b1fe78d540ade4392b51da)
在这个例子中我们使用了tfds.load
,把fashion_mnist
传给它作为想要的数据集。我们知道它包含train
和test
的分割,因此把这些以数组的形式传送过去会返回一个数据集适配器数组(其中包含图像和标签)。在调用tfds.load
的命令中使用tfds.as_numpy
会导致它们返回Numpy数组。指定batch_size=1
会给我们提供所有的数据,指定as_supervised=True
确保我们得到返回的(输入,标签)的元组。
做完这些,我们就有了Keras数据集中几乎同样格式的数据,只有一个改动—TFDS中的形状是(28,28,1),而Keras数据集中的形状是(28,28)。
这意味着代码需要做一些改动来指定输入数据的形状是(28,28,1)而不是(28,28):
![](https://epubservercos.yuewen.com/A03276/23627497809544006/epubprivate/OEBPS/Images/086-2.jpg?sign=1739684010-IOHNYAwJjrLXIAWxRKbuB5J5mQhbe4Qk-0-11f7b16be45cb0f40a6a82965a9c5746)
对于更复杂的例子,你可以查看第3章中使用的Horses or Humans数据集。它同样可以在TFDS中找到。下面是用它来训练一个模型的完整代码:
![](https://epubservercos.yuewen.com/A03276/23627497809544006/epubprivate/OEBPS/Images/086-3.jpg?sign=1739684010-7LpfXdpxCi76UPCGK1bJlJ0MmGu3yBXx-0-ac013083bb6dd3e8cb6d6ba4754df6e4)
可以看到,它非常直接:只需要调用tfds.load
,传送给它你想要的分割(在这个例子中是train
),并在模型中使用它。数据被分批处理和重组,以使训练更加有效。
Horses or Humans数据集被分为训练集和测试集,因此如果你在训练过程中想对模型进行验证,可以从TFDS加载一个独立的验证集,代码如下:
![](https://epubservercos.yuewen.com/A03276/23627497809544006/epubprivate/OEBPS/Images/087-1.jpg?sign=1739684010-vzOKIziHK8dDy62678Wega61IGdhmhgC-0-cc6c17c00e67dc2f7cf7025d3f9f7057)
你将需要对它进行分批,就像你对训练集所做的一样。例如:
![](https://epubservercos.yuewen.com/A03276/23627497809544006/epubprivate/OEBPS/Images/087-2.jpg?sign=1739684010-NcIRjz7hPjkUXN5WXza5G2ooaVdSWT5J-0-5af723d1e42251ca1da05ba2affece70)
在训练的时候,你指定训练数据是这些批次。你还需要明确地设置每一个回合使用的验证步数,否则TensorFlow会抛出一个错误。如果你不确定,可以把它设置为1
:
![](https://epubservercos.yuewen.com/A03276/23627497809544006/epubprivate/OEBPS/Images/087-3.jpg?sign=1739684010-H3HeCqN1XG2kMTNifeFIpt6H9RrYzW2e-0-27f4babceb48b40b377507bfaa9fa812)
加载具体的版本
所有存储在TFDS中的数据集都使用MAJOR.MINOR.PATCH编号系统。该系统保证了以下规则。如果PATCH被更新,那么调用返回的数据是相同的,但是底层组织可能已经改变。任何改变对于开发者而言应该是不可见的。如果MINOR被更新,那么数据仍然没有变化,除了在每个记录中有额外的特征(非破坏性改变)。同样,对于任何特定的切片(见4.4节)数据也是相同的,因此记录不会被重新排序。如果MAJOR被更新,那么记录的格式和它们的位置可能会有变化,因此特定的片段可能会返回不同的结果。
当检查数据集时,你会发现有不同的版本可以使用。例如,cnn_dailymail
数据集(https://oreil.ly/673CJ)。如果你不想使用默认版本(3.0.0),而想使用更早的版本(例如1.0.0),可以像这样加载它:
![](https://epubservercos.yuewen.com/A03276/23627497809544006/epubprivate/OEBPS/Images/088-1.jpg?sign=1739684010-LCk3jkrbLm1fpfeuqlbM0zt9yNgfN5dn-0-8397f4775252d678f7ee27872ea51030)
注意,如果你正在使用Colab,那么检查TFDS使用的版本总是一个好主意。在写作本书时,Cload被预先设置为TFDS 2.0,但是TFDS 2.1和之后的版本解决了一些加载数据集的错误(包括cnn_dailymail
),因此确保使用这些版本的其中一个,或者最起码将它们安装到Colab中,而不是依赖默认的版本。