-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #38 from simon-hirsch/seperate_estimator_method
Seperate Estimator Class and EstimationMethod Class
- Loading branch information
Showing
34 changed files
with
2,087 additions
and
1,579 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,14 @@ | ||
# Estimators | ||
|
||
Estimator classes provide an `sklearn`-like API to fit, predict and update models with the accordingly named methods. | ||
|
||
## Online GAMLSS | ||
|
||
::: rolch.OnlineGamlss | ||
|
||
::: rolch.OnlineLasso | ||
## Linear Models | ||
|
||
::: rolch.OnlineLinearModel | ||
|
||
::: rolch.OnlineLasso | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# The `Estimator()` and `EstimationMethod()` classes | ||
|
||
## Overview | ||
|
||
Our package separates `Estimator` classes and `EstimationMethod` classes in the design. An `Estimator` is a python object that provides the user interface to set-up, fit, update and predict models. `EstimationMethod` classes are concerned with the estimation of the model coefficients (or weights). This page briefly explains the separation and options provided by it using the `OnlineLinearModel()` class. | ||
|
||
Estimators are your bread and butter partner for modelling. They provide the methods: | ||
|
||
- `Estimator().fit(X, y)` | ||
- `Estimator().update(X, y)` | ||
- `Estimator().predict(X)` | ||
|
||
which one commonly uses for modelling. | ||
|
||
Each estimator is initialized by choosing an estimation method passed to the `method` parameter, if the method is not explicitly called in the name of the estimator (like in the `OnlineLasso()`). The `method` accepts either a `string`, or an `EstimationMethod()` instance. | ||
|
||
## Example | ||
|
||
Let's return to the aforementioned example: We want to fit a simple linear model. We can estimate the parameters either using ordinary least squares (OLS) or using coordinate descent, minimizing the LASSO penalised loss. | ||
|
||
### Ordinary Least Squares | ||
|
||
First, we start with OLS: | ||
|
||
```python | ||
# Set up packages and | ||
from rolch.estimators.online_linear_model import OnlineLinearModel | ||
from rolch.methods import LassoPathMethod, OrdinaryLeastSquaresMethod | ||
from sklearn.datasets import load_diabetes | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
# Get data | ||
X, y = load_diabetes(return_X_y=True) | ||
|
||
fit_intercept = False | ||
scale_inputs = True | ||
|
||
# This is the Estimator Class | ||
model = OnlineLinearModel( | ||
method="ols", | ||
fit_intercept=fit_intercept, | ||
scale_inputs=scale_inputs, | ||
) | ||
model.fit(X[:-10, :], y[:-10]) | ||
model.update(X[-10:, :], y[-10:]) | ||
|
||
# This is equivalent | ||
model = OnlineLinearModel( | ||
method=OrdinaryLeastSquaresMethod(), | ||
fit_intercept=fit_intercept, | ||
scale_inputs=scale_inputs, | ||
) | ||
model.fit(X[:-10, :], y[:-10]) | ||
model.update(X[-10:, :], y[-10:]) | ||
``` | ||
|
||
Since ordinary least squares is a pretty simple method, it does not have a lot of parameters. However, if we look at LASSO, things change, because now we can actually play with the parameters. | ||
|
||
### LASSO and the `LassoPathMethod()` | ||
|
||
The `LassoPathMethod()` estimates the coefficients using coordinate descent along a path of decreasing regularization strength. In this example, we will change some of the parameters of the estimation. | ||
|
||
The `LassoPathMethod()` has for example the following parameters | ||
|
||
- `lambda_n` which defines the length of the regularization path. | ||
- `beta_lower_bounds` which provides the option to place a lower bound on the coefficients/weights. | ||
|
||
Let's have a look at a basic LASSO-estimated model: | ||
|
||
```python | ||
|
||
model = OnlineLinearModel( | ||
method="lasso", | ||
fit_intercept=fit_intercept, | ||
scale_inputs=scale_inputs, | ||
) | ||
model.fit(X[:-10, :], y[:-10]) | ||
plt.plot(model.beta_path) | ||
plt.show() | ||
print(model.beta) | ||
|
||
# Equivalent, we can do: | ||
|
||
model = OnlineLinearModel( | ||
method=LassoPathMethod(), | ||
fit_intercept=fit_intercept, | ||
scale_inputs=scale_inputs, | ||
) | ||
model.fit(X[:-10, :], y[:-10]) | ||
plt.plot(model.beta_path) | ||
plt.show() | ||
print(model.beta) | ||
|
||
``` | ||
|
||
Now we want to change the parameters: | ||
|
||
```python | ||
|
||
estimation_method = LassoPathMethod( | ||
lambda_n=10, # Only fit ten lambdas | ||
beta_lower_bound=np.zeros( | ||
X.shape[1] + fit_intercept | ||
), # all positive parameters | ||
) | ||
|
||
model = OnlineLinearModel( | ||
method=estimation_method, | ||
fit_intercept=fit_intercept, | ||
scale_inputs=scale_inputs, | ||
) | ||
model.fit(X[:-10, :], y[:-10]) | ||
plt.plot(model.beta_path) | ||
plt.show() | ||
print(model.beta) | ||
``` | ||
|
||
And we see that the coefficient path is both shorter and non-negative. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Estimation Methods | ||
|
||
## Overview | ||
|
||
`EstimationMethod()` classes do the actual hard lifting of fitting coefficients (or weights). They take more technical parameters like the length of the regularization path or upper bounds on certain coefficients. These parameters depend on the individual estimation method. In general, we aim to provide sensible out-of-the-box defaults. This [page](estimators_and_methods.md) explains the difference in detail. `Estimator` classes often take a method parameter, to which either a string or an instance of the `EstimationMethod()` can be passed, e.g. | ||
|
||
```python | ||
from rolch import OnlineLinearModel, LassoPathMethod | ||
|
||
fit_intercept = True | ||
scale_inputs = True | ||
|
||
model = OnlineLinearModel( | ||
method="lasso", # default parameters | ||
fit_intercept=fit_intercept, | ||
scale_inputs=scale_inputs, | ||
) | ||
# or equivalent | ||
model = OnlineLinearModel( | ||
method=LassoPathMethod(), # default parameters | ||
fit_intercept=fit_intercept, | ||
scale_inputs=scale_inputs, | ||
) | ||
# or with user-defined parameters | ||
model = OnlineLinearModel( | ||
method=LassoPathMethod( | ||
lambda_n=10 | ||
), # only 10 different regularization strengths | ||
fit_intercept=fit_intercept, | ||
scale_inputs=scale_inputs, | ||
) | ||
``` | ||
|
||
More information on coordinate descent can also be found on this [page](coordinate_descent.md) and in the API Reference below. | ||
|
||
## API Reference | ||
|
||
!!! note | ||
We don't document the classmethods of the `EstimationMethod` since these are only used internally. | ||
|
||
|
||
::: rolch.OrdinaryLeastSquaresMethod | ||
|
||
::: rolch.LassoPathMethod |
Oops, something went wrong.