当然,如果这是一个真实的项目,您将遵循机器学习项目清单中的步骤(见附录B):探索数据准备选项,尝试多个模型,挑选最好的模型,并使用使用GridSearchCV微调它们的超参数来微调超参数 ,尽可能的自动化,正如你在上一张所做的那样。在这里,我们假设您已经找到了一个有效的模型,并且您希望找到改进它的方法。一种方法是分析它所犯的误差类型。
首先,你可以使用混淆矩阵来查看。您需要使用cross_val_predict()函数进行预测,然后调用 confusion_matrix()函数,就像您之前做的那样:
>>> y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
>>> conf_mx = confusion_matrix(y_train, y_train_pred)
>>> conf_mx
array([[5725, 3, 24, 9, 10, 49, 50, 10, 39, 4],
[ 2, 6493, 43, 25, 7, 40, 5, 10, 109, 8],
[ 51, 41, 5321, 104, 89, 26, 87, 60, 166, 13],
[ 47, 46, 141, 5342, 1, 231, 40, 50, 141, 92],
[ 19, 29, 41, 10, 5366, 9, 56, 37, 86, 189],
[ 73, 45, 36, 193, 64, 4582, 111, 30, 193, 94],
[ 29, 34, 44, 2, 42, 85, 5627, 10, 45, 0],
[ 25, 24, 74, 32, 54, 12, 6, 5787, 15, 236],
[ 52, 161, 73, 156, 10, 163, 61, 25, 5027, 123],
[ 43, 35, 26, 92, 178, 28, 2, 223, 82, 5240]])
这是很多数字。使用Matplotlib的matshow()函数查看混淆矩阵的图像表示。
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()
这个混淆矩阵看起来相当不错,因为大多数图像都在主对角线上,这意味着它们被正确分类了。类别5看起来比其他数字稍暗,这可能意味着数据集中的5的图像更少,或者分类器在类别5上的表现不如其他数字。事实上,你可以验证两者都是正确的。
让我们把该图像表示集中在误差上。首先,您需要将混淆矩阵中的每个值除以相应类别的图像数量,这样您就可以比较误差率而不是误差的绝对数量(这会使大量的类别看起来不公平):
row_sums = conf_mx.sum(axis=1, keepdims=True)
norm_conf_mx = conf_mx / row_sums
现在让我们用0来填充对角线,只保留误差,我们来画出结果。
np.fill_diagonal(norm_conf_mx, 0)
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
plt.show()
现在您可以清楚地看到分类器所犯的误差。记住,行表示实际的类别,而列表示预期的类别。第8和第9类的列是相当亮的,这说明许多图像被错误分类为8或9。类似地,第8和第9类的行也很亮,告诉您8和9常常与其他数字混淆。相反,有些行是相当暗的,如第1行:这意味着大多数的类别1的图片是被正确分类的(有一些和8混淆了,但仅此而已)。注意,误差不是完全对称的;例如,5被错误的分类为8的误差要远远大于8被错误分类为5的误差。
分析混淆矩阵可以经常给你一些关于如何改进分类器的见解。我们来看看这个图,你的努力似乎应该花在改进8s和9s的分类上,以及解决特定的3/5混淆。例如,您可以尝试为这些数字收集更多的训练数据。或者你可以设计一些新的特性来帮助分类器——例如,编写一个算法来计算闭环的数量(例如,8有2个,6个有1个,5个没有)。或者你可以对图像进行预处理(例如,使用Scikit-Image,Pillow,或者OpenCV),使一些模式更加突出,比如闭环。
分析单独的误差也可以很好地了解你的分类器应该怎么做以及它为什么会失败,但是它更加困难和耗时。举个例子,我们来举3和5的例子:
cl_a, cl_b = 3, 5
X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]
X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)]
X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)]
X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]
plt.figure(figsize=(8,8))
plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)
plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)
plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)
plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)
plt.show()
左边的两个5×5块显示被分类为3的图片,右侧的两个5×5块显示被分类为5的图像。分类器出错的一些数字(例如:左下角和右上的块)写得很糟糕,甚至人类也很难分类(例如,第8行第1列的5看起来像一个3)。然而,大多数错误分类的图像对我们来说似乎是明显的错误,很难理解为什么分类器会犯这样的错误。原因是我们使用了一个简单的SGDClassifier,它是一个线性模型。它所做的只是给每个像素分配一个权重,当它看到一个新的图像时,它只是对加权像素强度求和,得到它关于每个类别的分数。因此,由于3和5图片的差别只有几个像素,所以这个模型很容易混淆它们。
3和5之间的主要区别是连接顶部和底部弧线的小线的位置。如果你画一个3,这个连接点稍微向左平移,分类器可以把它归为5,反之亦然。换句话说,这个分类器对图像的移动和旋转非常敏感。所以减少3/5混淆的一种方法是对图像进行预处理,以确保他们很集中,不太旋转。这可能有助于减少其他错误。