网格搜索
网格搜索 (Grid Search) 是一种通过遍历给定的超参数组合来优化模型表现的方法。
机器学习算法中有两类参数:一类是从训练集中学习到的参数,比如逻辑回归中的权重参数和偏差参数,另一类是超参数 (Hyperparameter),也就是需要人工设定的参数,比如正则项系数或者决策树的深度。
网格搜索通过寻找最佳的超参数组合,可以进一步帮助提高模型的性能。
算法思想
网格搜索的思想非常简单,属于暴力 (Brute Force) 算法:穷尽每个超参数的组合来评估对应的模型性能,然后挑选模型性能的超参数。
数据导入和预处理
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
df = pd.read_csv(
"https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/wdbc.data",
header=None,
)
X = df.iloc[:, 2:]
y = df.iloc[:, 1]
le = LabelEncoder()
y = le.fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, stratify=y, random_state=1
)
基于 Scikit-learn 的网格搜索
创建管道对象:
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
pipe_svc = make_pipeline(StandardScaler(), SVC(random_state=1))
通过网格搜索所有超参数组合 (下面是 8 + 64 = 72 个) 中的准确率最高的一个:
from sklearn.model_selection import GridSearchCV
param_range = [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0]
param_grid = [
{"svc__C": param_range, "svc__kernel": ["linear"]},
{"svc__C": param_range, "svc__gamma": param_range, "svc__kernel": ["rbf"]},
]
gs = GridSearchCV(estimator=pipe_svc, param_grid=param_grid, scoring="accuracy", cv=10)
gs.fit(X_train, y_train)
网格搜索得到的最优模型的准确率、参数、以及模型调用如下:
print(gs.best_score_)
# Output: 0.984615384615
print(gs.best_params_)
# Output: {'svc__C': 100.0, 'svc__gamma': 0.001, 'svc__kernel': 'rbf'}
clf = gs.best_estimator_
clf.fit(X_train, y_train)
print("Test accuracy: {:.3f}".format(clf.score(X_test, y_test)))
# Output: Test accuracy: 0.974