为了理解这个权衡,让我们看看SGDClassifier是如何做出它的分类决策的。对于每个实例,它基于决策函数( decision function)计算出一个分数,如果这个分数大于一个阈值,它将该样本分配到正类中,否则的话把它赋给负类。图3-3显示了从左边的最低分数值到右边的最高分数值的几个数字。假设决策阈值(decision threshold)被定位于中间箭头(即在图中那两个5之间):您将在该阈值的右侧找到4个真正[true positives]的样本(实际的5s),以及一个假正[false positive](实际上它是6)。因此,在该阈值条件下,精度为80%(4/5),但实际一共有6个5,分类器只检测成功4个,所以召回率是67%(4/6)。现在如果你提高阈值(移动到右边的箭头),之前的假正(指的是这个6)变成为一个真负,从而增加了精度(在这种情况下高达100%),但一个原来的真正[ true positive]现在变成为了假负[false negative],召回率降至50%。相反,降低阈值增加了召回率,降低了精度。
图3 - 3。决策阈值和精度/召回权衡
Scikitt -Learn不会让你直接设置阈值,但它允许你访问用来做预测的决策数。您可以调用它的decision_function()方法,而不是调用分类器的predict()方法,它返回每个样本的分数,然后根据您想要的任何阈值,基于这些分数进行预测:
>>> y_scores = sgd_clf.decision_function([some_digit])
>>> y_scores
array([ 161855.74572176])
>>> threshold = 0
>>> y_some_digit_pred = (y_scores > threshold)
array([ True], dtype=bool)
SGDClassifier使用一个等于0的阈值,因此前一个代码返回的结果与 predict()方法相同(即:True)。让我们提高阈值:
>>> threshold = 200000
>>> y_some_digit_pred = (y_scores > threshold)
>>> y_some_digit_pred
array([False], dtype=bool)
这证实了提高阈值降低了召回率。图像实际上表示的是5,当阈值为0时,分类器会检测出它,但是当阈值增加到200,000时,它会忽略它。
那么,如何决定使用哪一个阈值呢?为此,您首先需要使用cross_val_predict()函数来获得训练集中所有样本的得分,但这次指定您希望它返回的决策值,而不是预测:
y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,
method="decision_function")
现在,通过这些分数,您可以使用 precision_recall_curve()函数计算出所有可能的阈值,并对所有可能的阈值进行计算:
from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
最后,您可以使用Matplotlib(图3-4)画出精度和召回与阈值的函数图像(图3-4):
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
plt.plot(thresholds, recalls[:-1], "g-", label="Recall")
plt.xlabel("Threshold")
plt.legend(loc="upper left")
plt.ylim([0, 1])
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.show()
图3 - 4 精度和召回与决策阈值的对应关系
您可能想知道为什么图3-4中的精度曲线比的召回率曲线更大。原因是当你提高阈值时,精度有时会下降(尽管总体上它会上升)。要理解其中的原因,请回顾图3-3,并注意当从中心阈值开始向右仅移动一个数字时的情况:精度从4/5(80%)下降到3/4(75%)。另一方面,当阈值增加时,召回率只能下降,这就解释了为什么它的曲线看起来很平滑。
现在,您可以简单地选择为您的任务提供了最佳的精度/召回权衡的阈值。选择一个好的精度/召回权衡的另一种方法是直接在召回中绘制精度的图,如图3-5所示。
图3 - 5 精度和召回
你可以看到,在80%以上的的召回率范围内,精度开始大幅下降。您可能想要选择一个精度/召回的折衷方案,例如,在大约60%的召回点上。当然,选择取决于你的项目。
假设你的目标是90%的精度。你查找第一个图(放大一点),发现你需要使用大约70000的阈值。要做出预测(目前的训练集),而不是调用分类器的 predict()方法,您只需运行以下代码:
y_train_pred_90 = (y_scores > 70000)
让我们检查一下这些预测的精度和召回率:
>>> precision_score(y_train_5, y_train_pred_90)
0.8998702983138781
>>> recall_score(y_train_5, y_train_pred_90)
0.63991883416343853
很好,你有了一个90%精度的分类器(或足够近)!正如您所看到的,创建一个几乎任意精确的分类器是相当容易的:只要设置足够高的阈值,就可以完成了。嗯,没那么快。一个高精度的分类器,如果它的召回率太低,就不是很有用了!
如果有人说“让我们达到99%的精确度”,你应该问,“那么要多的召回率?”