可以看到sklearn为机器学习提供了以下6大模块:
也可以三大块:
数据预处理,比如:
Imputation(对缺失值进行填充)模型
score对模型进行评估:比如:accuracy recall f1 roc_auc等squared_error r2等fm_score rand_score等sklearn的接口非常人性化,很容易就可以弄清楚各个函数的作用,并且其模型的训练/预测数据一般为(n_sample, n_feature)形状的二维数组, 模型得到的标签/回归数值则是(n_sample,)形状的一维数组:
大部分情况下,训练模型的函数长这样:
只需要提供一个表示样本特征的ndarray(n_samples, n_features),以及表示各样本标签的ndarray(n_samples, )
from sklearn.linear_model import Ridge
print(Ridge.fit.__doc__)
大部分情况下,使用训练好的模型预测新数据的函数长这样:
由于是预测,自然是返回一个预测好的标签数组啦。
from sklearn.svm import SVC
print(SVC.predict.__doc__)
大部分模型都是这么使用的,很多非sklearn内置的模型,也实现了这样的接口,比如xgboost, 可以与sklearn的交叉验证,模型评估等想结合,甚至你自己的模型都可以整合进来。
# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
# License: BSD 3 clause
from IPython.utils import io
import matplotlib.pyplot as plt
from sklearn import datasets, svm, metrics
from sklearn.model_selection import train_test_split
# 数据导入, 绘图展示
digits = datasets.load_digits()
images_and_labels = list(zip(digits.images, digits.target))
with io.capture_output() as captured:
for index, (image, label) in enumerate(images_and_labels[:4]):
plt.subplot(2, 4, index + 1)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.title('Training: %i' % label)
##############################################################################################
# 以上均不重要
# 数据预处理,划分训练集,测试集
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
train_x, test_x, train_y, test_y = train_test_split(data, digits.target)
# 使用SVC模型, 设置参数, 使用训练集训练模型
classifier = svm.SVC(gamma=0.001)
classifier.fit(train_x, train_y)
# 使用模型预测新数据
predicted_y = classifier.predict(test_x)
# 模型评估
print("Classification report for classifier %s:\n%s\n"
% (classifier, metrics.classification_report(test_y, predicted_y)))
print("Confusion matrix:\n%s" % metrics.confusion_matrix(test_y, predicted_y))
# 以下均不重要
###########################################################################################
images_and_predictions = list(zip(test_x, predicted_y, test_y))
with io.capture_output() as captured:
for index, (image, prediction, true_label) in enumerate(images_and_predictions[:4]):
plt.subplot(2, 4, index + 5)
plt.axis('off')
plt.imshow(image.reshape(8, 8), cmap=plt.cm.gray_r, interpolation='nearest')
plt.title('Prediction: {} \nTrue Label: {}'.format(prediction, true_label))
plt.show()
看!大约10行(主体部分)的代码就能使用一个模型解决我们的问题, 并且准确率那么高(有些字我自己都认不清╯▽╰ )。