Extension of the awesome XGBoost to linear models at the leaves
Extension of the awesome XGBoost to linear models at the leaves
XGBoost is often presented as the algorithm that wins every ML competition. Surprisingly, this is true even though predictions are piecewise constant. This might be justified in high dimensional input spaces, but when the number of features is low, a piecewise linear model is likely to perform better. XGBoost was extended into LinXGBoost that stores at each leaf a linear model. This extension is equivalent to piecewise regularized least-squares.
You will need python 3 with numpy. If you intend to run the tests, then sklearn, matplotlib, XGBoost and pandas must be installed.
The class linxgb
is, without surprise, defined in linxgb.py
.
It is the counterpart of xgb.XGBRegressor
: XGBoost for regression using a sklearn-like API (see Scikit-Learn API). As such, it implements two methods: fit
and predict
. Consequently, you can use sklearn for cross-validation.
The definition of a tree is in node.py
. Normally, you should not have to instance a tree directly.
Suppose train_X
(a numpy array) and train_Y
(a numpy vector) are the training data sets: Inputs and labels, respectively. Then the following will fit a LinXGBoost model with 3 estimators (or trees):
reg = linxgb(n_estimators=3)
reg.fit(train_X, train_Y)
For the predictions, it is as simple as:
pred_Y = reg.predict(test_X)
Most significant parameters comply with XGBoost parameter definition. They are:
Additionally, we have:
Several tests can be run:
test_heavysine.py
: A one-dimensional problem (see Adapting to Unknown Smoothness via Wavelet Shrinkage by Donoho and Johnstone)test_jakeman1_av.py
: The f1 response surface from Local and Dimension Adaptive Sparse Grid Interpolation and Quadrature.test_jakeman4_av.py
: The f4 response surface from Local and Dimension Adaptive Sparse Grid Interpolation and Quadrature with w1=0.5 and w2=3.test_friedman1_av.py
: The Friedman 1 data set is a synthetic dataset. It has been previously employed in evaluations of MARS (Multivariate Adaptive Regression Splines by Friedman) and bagging (Breiman, 1996). It is particularly suited to examine the ability of methods to uncover interaction effects that are present in the data.test_ccpp_av.py
: This real-world dataset contains 9568 data points collected from a Combined Cycle Power Plant over 6 years (2006-2011), when the power plant was set to work with full load. Features consist of hourly average ambient variables Temperature (T), Ambient Pressure (AP), Relative Humidity (RH) and Exhaust Vacuum (V) to predict the net hourly electrical energy output (EP) of the plant.test_puma8_av.py
: This problem was generated using a robot-arm simulation. The data set is highly non-linear and has very low noise. It contains 8192 data samples with 8 attributes.Laurent de Vito
All third-party libraries are subject to their own license.
This work is licensed under a Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International License.