2 万 Star!屡屡斩获 Kaggle 各大竞赛冠军宝座的利器
一、简介
XGBoost (eXtreme Gradient Boosting) 是“极端梯度提升”的简称,它是为满足高效、灵活和可移植性的目的而“诞生”,也被称为优化的分布式增强库。支持 Python、Scala、R、Julia、C++ 等语言。
XGBoost 所应用的算法是机器学习算法—GBDT(gradient boosting decision tree)的改进,既可以用于分类也可以用于回归问题中。
XGBoost 提供了一种并行树增强(也被称为GBDT、GBM),可以快速准确地解决许多数据科学问题。相同的代码可以运行在目前主流的分布式环境(Kubernetes、Hadoop、SGE、MPI、Dask),并且可以解决超过数十亿个样例计算的问题。
XGBoost 可以处理回归、分类和排序等多种任务。由于它在预测性能上的强大且训练速度快,XGBoost 已屡屡斩获 Kaggle 各大竞赛的冠军宝座。
二、开源主页
XGBoost 在 GitHub 已获得 20.4 k Star。
https://github.com/dmlc/xgboost
三、安装和使用案例
3.1、下载(windows10 64位,python 3.7.3)
打开cmd命令模式,输入 pip install xgboost
但网络需要架梯子(不知道是不是我的网不好),我不架梯子会因网络问题出现以下报错:
架梯子后,虽然速度慢了点,但能成功安装。
2、如果没有梯子,可以从这个网址( https://www. lfd.uci.edu/~gohlke/pyt honlibs/#xgboost ) ,找到与自己 Python 版本
xgboost
然后重新执行 pip install xgboost-0.90-cp37-cp37m-win_amd64.whl 可以安装成功
3.2、使用示例
在这个例子中,将使用墨尔本房屋数据集进行简单的测试使用——预测房价。
数据集下载方式:请关注微信公号 开源前哨
,然后发送 xgboost
即可获取。
在使用这数据之前,需要对三种常见的数值问题进行处理。
1、缺失值的处理(通常处理方法是直接删除缺少值的列或利用该列均值填充);
2、异常值的处理;
3、重复值的处理等;
在这里我们不会关注数据的处理加载步骤(请自行学习),假设已经拥有(我处理过了)了X_train(样本训练集)、X_valid(样本验证集)、y_train(标签训练集)和y_valid(标签验证集)的训练和验证数据。
下面是用 jupyter notebook 进行的代码执行的代码:
import numpy as np import pandas as pd from sklearn.model_selection import train_test_split # 装载数据 data = pd.read_csv('E:\数据集\melb_data.csv') #选择测试集列和预测目标列 cols_to_use = ['Rooms', 'Distance', 'Landsize', 'BuildingArea', 'YearBuilt'] X = data[cols_to_use] y = data.Price #将数据分为训练集和测试集 X_train, X_valid, y_train, y_valid = train_test_split(X, y) #本例使用XGBoost库 #先导入该库, from xgboost import XGBRegressor xg_model = XGBRegressor() #利用训练集训练模型 语句简单 xg_model.fit(X_train,y_train) #本例使用XGBoost库 #先导入该库, from xgboost import XGBRegressor xg_model = XGBRegressor() #利用训练集训练模型 语句简单 xg_model.fit(X_train,y_train) #Out '''XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1, importance_type='gain', interaction_constraints='', learning_rate=0.300000012, max_delta_step=0, max_depth=6, min_child_weight=1, missing=nan, monotone_constraints='()', n_estimators=100, n_jobs=12, num_parallel_tree=1, objective='reg:squarederror', random_state=0, reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1, tree_method='exact', validate_parameters=1, verbosity=None)''' #先对模型进行预测和用测试集进行评估,得出绝对误差 from sklearn.metrics import mean_absolute_error predictions = xg_model.predict(X_valid) print("Mean Absolute Error: " + str(mean_absolute_error(predictions, y_valid))) #Out '''Mean Absolute Error: 235552.95046713916''' #XGBoost有几个参数对训练结果比较重要,改变参数值,查看评估结果 #learning_rate——学习率(默认是0.1),学习率的出现可以很好的解决过拟合问题,我们改为0.15和0.05看一下不同结果 xg_model = XGBRegressor(learning_rate=0.05) xg_model.fit(X_train, y_train) #out '''XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1, importance_type='gain', interaction_constraints='', learning_rate=0.05, max_delta_step=0, max_depth=6, min_child_weight=1, missing=nan, monotone_constraints='()', n_estimators=100, n_jobs=12, num_parallel_tree=1, objective='reg:squarederror', random_state=0, reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1, tree_method='exact', validate_parameters=1, verbosity=None)''' predictions = xg_model.predict(X_valid) print("Mean Absolute Error: " + str(mean_absolute_error(predictions, y_valid))) #out '''Mean Absolute Error: 248468.6911450663''' xg_model = XGBRegressor(learning_rate=0.15) xg_model.fit(X_train, y_train) predictions = xg_model.predict(X_valid) print("Mean Absolute Error: " + str(mean_absolute_error(predictions, y_valid))) #out '''Mean Absolute Error: 234624.95226435937''' #n_estimators参数,指的是弱估计器的数量,即树的个数,太低容易导致欠拟合(对训练集和测试集的训练结果误差都很大),过高容易导致过拟合(只对训练集效果很好,但对测试集效果差) xg_model = XGBRegressor(n_estimators=500) xg_model.fit(X_train, y_train) predictions = xg_model.predict(X_valid) print("Mean Absolute Error: " + str(mean_absolute_error(predictions, y_valid))) #out '''Mean Absolute Error: 247669.14656434095''' xg_model = XGBRegressor(n_estimators=1000) xg_model.fit(X_train, y_train) predictions = xg_model.predict(X_valid) print("Mean Absolute Error: " + str(mean_absolute_error(predictions, y_valid))) #Out '''Mean Absolute Error: 255654.5402660162''' #early_stopping_rounds参数,该参数的意思是早点让训练结束,因为在迭代次数到了一定数量时,训练误差会在一个值范围内波动,甚至出现下降的 #的现象,这样就会出现过拟合现象,early_stopping_rounds参数一般设置在40左右,意思是当迭代40轮后,训练误差若出现上升的现象,便提前终止训练。 xg_model = XGBRegressor() xg_model.fit(X_train, y_train, early_stopping_rounds=5,eval_set=[(X_valid, y_valid)]) predictions = xg_model.predict(X_valid) print("Mean Absolute Error: " + str(mean_absolute_error(predictions, y_valid))) #Out '''[0] validation_0-rmse:941935.18750 [1] validation_0-rmse:732831.68750 [2] validation_0-rmse:599871.62500 [3] validation_0-rmse:516429.90625 [4] validation_0-rmse:467964.43750 [5] validation_0-rmse:441869.50000 [6] validation_0-rmse:426009.87500 [7] validation_0-rmse:416771.93750 [8] validation_0-rmse:411553.68750 [9] validation_0-rmse:408549.15625 [10] validation_0-rmse:406357.93750 [11] validation_0-rmse:403870.65625 [12] validation_0-rmse:402537.96875 [13] validation_0-rmse:402280.25000 [14] validation_0-rmse:400586.65625 [15] validation_0-rmse:399610.59375 [16] validation_0-rmse:398340.71875 [17] validation_0-rmse:397867.90625 [18] validation_0-rmse:397690.25000 [19] validation_0-rmse:397726.62500 [20] validation_0-rmse:396976.93750 [21] validation_0-rmse:396865.03125 [22] validation_0-rmse:395752.12500 [23] validation_0-rmse:392401.65625 [24] validation_0-rmse:393190.93750 [25] validation_0-rmse:392763.21875 [26] validation_0-rmse:392791.09375 [27] validation_0-rmse:391552.53125 [28] validation_0-rmse:391745.37500 [29] validation_0-rmse:391624.40625 [30] validation_0-rmse:391096.18750 [31] validation_0-rmse:391777.71875 [32] validation_0-rmse:392427.34375 [33] validation_0-rmse:391748.43750 [34] validation_0-rmse:391423.40625 [35] validation_0-rmse:391501.90625 Mean Absolute Error: 243436.07564893225''' xg_model = XGBRegressor() xg_model.fit(X_train, y_train, early_stopping_rounds=30,eval_set=[(X_valid, y_valid)]) predictions = xg_model.predict(X_valid) print("Mean Absolute Error: " + str(mean_absolute_error(predictions, y_valid))) #Out '''[0] validation_0-rmse:941935.18750 [1] validation_0-rmse:732831.68750 [2] validation_0-rmse:599871.62500 [3] validation_0-rmse:516429.90625 [4] validation_0-rmse:467964.43750 [5] validation_0-rmse:441869.50000 [6] validation_0-rmse:426009.87500 [7] validation_0-rmse:416771.93750 [8] validation_0-rmse:411553.68750 [9] validation_0-rmse:408549.15625 [10] validation_0-rmse:406357.93750 [11] validation_0-rmse:403870.65625 [12] validation_0-rmse:402537.96875 [13] validation_0-rmse:402280.25000 [14] validation_0-rmse:400586.65625 [15] validation_0-rmse:399610.59375 [16] validation_0-rmse:398340.71875 [17] validation_0-rmse:397867.90625 [18] validation_0-rmse:397690.25000 [19] validation_0-rmse:397726.62500 [20] validation_0-rmse:396976.93750 [21] validation_0-rmse:396865.03125 [22] validation_0-rmse:395752.12500 [23] validation_0-rmse:392401.65625 [24] validation_0-rmse:393190.93750 [25] validation_0-rmse:392763.21875 [26] validation_0-rmse:392791.09375 [27] validation_0-rmse:391552.53125 [28] validation_0-rmse:391745.37500 [29] validation_0-rmse:391624.40625 [30] validation_0-rmse:391096.18750 [31] validation_0-rmse:391777.71875 [32] validation_0-rmse:392427.34375 [33] validation_0-rmse:391748.43750 [34] validation_0-rmse:391423.40625 [35] validation_0-rmse:391501.90625 [36] validation_0-rmse:390788.90625 [37] validation_0-rmse:390342.50000 [38] validation_0-rmse:388972.56250 [39] validation_0-rmse:388548.37500 [40] validation_0-rmse:389271.59375 [41] validation_0-rmse:388126.15625 [42] validation_0-rmse:387813.06250 [43] validation_0-rmse:387755.25000 [44] validation_0-rmse:388325.31250 [45] validation_0-rmse:388446.96875 [46] validation_0-rmse:388394.09375 [47] validation_0-rmse:388425.00000 [48] validation_0-rmse:388244.31250 [49] validation_0-rmse:387900.65625 [50] validation_0-rmse:387721.56250 [51] validation_0-rmse:387184.31250 [52] validation_0-rmse:386687.31250 [53] validation_0-rmse:386059.87500 [54] validation_0-rmse:386105.56250 [55] validation_0-rmse:386133.31250 [56] validation_0-rmse:385496.62500 [57] validation_0-rmse:385332.71875 [58] validation_0-rmse:385390.18750 [59] validation_0-rmse:385281.46875 [60] validation_0-rmse:385243.71875 [61] validation_0-rmse:385267.50000 [62] validation_0-rmse:385012.65625 [63] validation_0-rmse:385141.46875 [64] validation_0-rmse:384997.34375 [65] validation_0-rmse:385355.09375 [66] validation_0-rmse:385625.09375 [67] validation_0-rmse:385546.90625 [68] validation_0-rmse:385723.62500 [69] validation_0-rmse:385636.68750 [70] validation_0-rmse:385617.34375 [71] validation_0-rmse:385682.62500 [72] validation_0-rmse:385741.12500 [73] validation_0-rmse:385583.09375 [74] validation_0-rmse:385650.28125 [75] validation_0-rmse:385895.87500 [76] validation_0-rmse:385654.53125 [77] validation_0-rmse:385794.40625 [78] validation_0-rmse:385448.96875 [79] validation_0-rmse:385472.40625 [80] validation_0-rmse:385382.06250 [81] validation_0-rmse:385556.25000 [82] validation_0-rmse:385969.62500 [83] validation_0-rmse:385744.06250 [84] validation_0-rmse:385602.81250 [85] validation_0-rmse:385691.06250 [86] validation_0-rmse:385536.59375 [87] validation_0-rmse:385594.53125 [88] validation_0-rmse:385643.12500 [89] validation_0-rmse:385773.93750 [90] validation_0-rmse:385573.40625 [91] validation_0-rmse:385784.81250 [92] validation_0-rmse:385986.00000 [93] validation_0-rmse:386088.37500 Mean Absolute Error: 237037.17186349412'''
以上便是这次 Xgboost 的安装及简单实用,大家可以根据自己训练模型的需要,去学习它更多的 API 方法。