在这一章中,我们将使用MNIST数据集,它是由高中生和美国人口普查局的雇员手写的7万余张数字小图像。每个图像都标有它所代表的数字。这个数据集被研究得如此之多,以至于它被称为机器学习的“Hello World”:每当人们提出一个新的分类算法时,他们都很好奇它将如何在MNIST上执行。每当有人学习机器学习的时候,他们迟早都要对付MNIST。

Scikit-Learn提供了许多辅助函数来下载流行的数据集。MNIST是其中之一。以下代码就是获取MNIST数据集:

>>> from sklearn.datasets import fetch_mldata
>>> mnist = fetch_mldata('MNIST original')
>>> mnist
{'COL_NAMES': ['label', 'data'],
    'DESCR': 'mldata.org dataset: mnist-original',
    'data': array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
    'target': array([ 0., 0., 0., ..., 9., 9., 9.])}

由Scikit-Learn加载的数据集通常有一个类似的字典结构,包括:

  • 一个DESCR键:描述数据集的。

  • 一个data键,包含一个数组,每个样本一行,每一列表示一个特征。

  • 一个target键:包含有标签的数组。

让我们看看这些数组:

>>> X, y = mnist["data"], mnist["target"]
>>> X.shape
(70000, 784)
>>> y.shape
(70000,)

有7万张图片,每张图片有784个特征。这是因为每个图像28×28个像素,每个特征仅仅代表一个像素的强度,从0(白色)到255(黑色)。让我们看一下数据集中的一个数字。所有你需要做的是获取一个实例的特征向量,重塑为一个28×28的数组,并使用Matplotlib的imshow()函数将其显示出来:

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
some_digit = X[36000]
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap = matplotlib.cm.binary,
        interpolation="nearest")
plt.axis("off")
plt.show()

这看起来像5,实际上这就是标签告诉我们的:

>>> y[36000]
5.0

图3-1显示了来自MNIST数据集的一些图片,以让您了解分类任务的复杂性。

图3 - 1。来自MNIST数据集的几个数字。

但是等等!在仔细检查数据之前,您应该创建一个测试集并将其设置在一旁。MNIST数据集实际上已经被分割成一个训练集(前6万个图像)和一个测试集(最后10,000个图像):

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

让我们来对训练集重新洗牌;这将保证所有交叉验证折叠都是相似的(您不需要一个折叠来丢失一些数字)。此外,一些学习算法对训练样本的顺序非常敏感,如果它们连续得到许多类似的样本,它们的性能就会很差。对数据集重洗可以确保不会发生这种情况:

import numpy as np
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]

results matching ""

    No results matching ""