Skip to content

Commit

Permalink
Merge pull request #38 from simon-hirsch/seperate_estimator_method
Browse files Browse the repository at this point in the history
Seperate Estimator Class and EstimationMethod Class
  • Loading branch information
simon-hirsch authored Jan 8, 2025
2 parents 443a249 + c9fdc8b commit ac689bf
Show file tree
Hide file tree
Showing 34 changed files with 2,087 additions and 1,579 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ online_gamlss_lasso = rolch.OnlineGamlss(
method="lasso",
equation=equation,
fit_intercept=True,
estimation_kwargs={"ic": {i: "bic" for i in range(dist.n_params)}},
ic="bic",
)

# Initial Fit
Expand All @@ -58,15 +58,15 @@ online_gamlss_lasso.fit(
y=y[:-11],
)
print("Coefficients for the first N-11 observations \n")
print(online_gamlss_lasso.betas)
print(online_gamlss_lasso.beta)

# Update call
online_gamlss_lasso.update(
X=X[[-11], :],
y=y[[-11]]
)
print("\nCoefficients after update call \n")
print(online_gamlss_lasso.betas)
print(online_gamlss_lasso.beta)

# Prediction for the last 10 observations
prediction = online_gamlss_lasso.predict(
Expand Down
2 changes: 1 addition & 1 deletion docs/distributions.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ All distributions are based on `scipy.stats` distributions. We implement the pro

## Base Class

::: rolch.abc.Distribution
::: rolch.base.Distribution

## API Reference

Expand Down
9 changes: 8 additions & 1 deletion docs/estimators.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# Estimators

Estimator classes provide an `sklearn`-like API to fit, predict and update models with the accordingly named methods.

## Online GAMLSS

::: rolch.OnlineGamlss

::: rolch.OnlineLasso
## Linear Models

::: rolch.OnlineLinearModel

::: rolch.OnlineLasso

120 changes: 120 additions & 0 deletions docs/estimators_and_methods.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# The `Estimator()` and `EstimationMethod()` classes

## Overview

Our package separates `Estimator` classes and `EstimationMethod` classes in the design. An `Estimator` is a python object that provides the user interface to set-up, fit, update and predict models. `EstimationMethod` classes are concerned with the estimation of the model coefficients (or weights). This page briefly explains the separation and options provided by it using the `OnlineLinearModel()` class.

Estimators are your bread and butter partner for modelling. They provide the methods:

- `Estimator().fit(X, y)`
- `Estimator().update(X, y)`
- `Estimator().predict(X)`

which one commonly uses for modelling.

Each estimator is initialized by choosing an estimation method passed to the `method` parameter, if the method is not explicitly called in the name of the estimator (like in the `OnlineLasso()`). The `method` accepts either a `string`, or an `EstimationMethod()` instance.

## Example

Let's return to the aforementioned example: We want to fit a simple linear model. We can estimate the parameters either using ordinary least squares (OLS) or using coordinate descent, minimizing the LASSO penalised loss.

### Ordinary Least Squares

First, we start with OLS:

```python
# Set up packages and
from rolch.estimators.online_linear_model import OnlineLinearModel
from rolch.methods import LassoPathMethod, OrdinaryLeastSquaresMethod
from sklearn.datasets import load_diabetes

import matplotlib.pyplot as plt
import numpy as np

# Get data
X, y = load_diabetes(return_X_y=True)

fit_intercept = False
scale_inputs = True

# This is the Estimator Class
model = OnlineLinearModel(
method="ols",
fit_intercept=fit_intercept,
scale_inputs=scale_inputs,
)
model.fit(X[:-10, :], y[:-10])
model.update(X[-10:, :], y[-10:])

# This is equivalent
model = OnlineLinearModel(
method=OrdinaryLeastSquaresMethod(),
fit_intercept=fit_intercept,
scale_inputs=scale_inputs,
)
model.fit(X[:-10, :], y[:-10])
model.update(X[-10:, :], y[-10:])
```

Since ordinary least squares is a pretty simple method, it does not have a lot of parameters. However, if we look at LASSO, things change, because now we can actually play with the parameters.

### LASSO and the `LassoPathMethod()`

The `LassoPathMethod()` estimates the coefficients using coordinate descent along a path of decreasing regularization strength. In this example, we will change some of the parameters of the estimation.

The `LassoPathMethod()` has for example the following parameters

- `lambda_n` which defines the length of the regularization path.
- `beta_lower_bounds` which provides the option to place a lower bound on the coefficients/weights.

Let's have a look at a basic LASSO-estimated model:

```python

model = OnlineLinearModel(
method="lasso",
fit_intercept=fit_intercept,
scale_inputs=scale_inputs,
)
model.fit(X[:-10, :], y[:-10])
plt.plot(model.beta_path)
plt.show()
print(model.beta)

# Equivalent, we can do:

model = OnlineLinearModel(
method=LassoPathMethod(),
fit_intercept=fit_intercept,
scale_inputs=scale_inputs,
)
model.fit(X[:-10, :], y[:-10])
plt.plot(model.beta_path)
plt.show()
print(model.beta)

```

Now we want to change the parameters:

```python

estimation_method = LassoPathMethod(
lambda_n=10, # Only fit ten lambdas
beta_lower_bound=np.zeros(
X.shape[1] + fit_intercept
), # all positive parameters
)

model = OnlineLinearModel(
method=estimation_method,
fit_intercept=fit_intercept,
scale_inputs=scale_inputs,
)
model.fit(X[:-10, :], y[:-10])
plt.plot(model.beta_path)
plt.show()
print(model.beta)
```

And we see that the coefficient path is both shorter and non-negative.
7 changes: 2 additions & 5 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ where $g_k(\cdot)$ is a link function, which ensures that the predicted distribu

This allows us to specify very flexible models that consider the conditional behaviour of the variable's volatility, skewness and tail behaviour. A simple example for electricity markets is wind forecasts, which are skewed depending on the production level - intuitively, there is a higher risk of having lower production if the production level is already high since it cannot go much higher than "full load" and if, the turbines might cut-off. Modelling these conditional probabilistic behaviours is the key strength of distributional regression models.



## Installation

`ROLCH` is available on the [Python Package Index](https://pypi.org/project/rolch/) and can be installed via `pip`:
Expand All @@ -35,8 +33,7 @@ pip install rolch

## Example

The following few lines give an introduction. We use the `diabetes` data set and model the response variable \(Y\) as Student-\(t\) distributed, where all distribution parameters (location, scale and tail) are modelled conditional on the explanatory variables in \(X\).

The following few lines give an introduction. We use the `diabetes` data set and model the response variable \(Y\) as Student-\(t\) distributed, where all distribution parameters (location, scale and tail) are modelled conditional on the explanatory variables in \(X\). We use LASSO to estimate the coefficients and the Bayesian information criterion to select the best model along a grid of regularization strengths.

```python
import rolch
Expand All @@ -58,7 +55,7 @@ online_gamlss_lasso = rolch.OnlineGamlss(
method="lasso",
equation=equation,
fit_intercept=True,
estimation_kwargs={"ic": {i: "bic" for i in range(dist.n_params)}},
ic="bic",
)

# Initial Fit
Expand Down
2 changes: 1 addition & 1 deletion docs/links.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Some link functions implement _shifted_ versions. The shifted link function is i

## Base Class

::: rolch.abc.LinkFunction
::: rolch.base.LinkFunction

## API Reference

Expand Down
44 changes: 44 additions & 0 deletions docs/methods.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Estimation Methods

## Overview

`EstimationMethod()` classes do the actual hard lifting of fitting coefficients (or weights). They take more technical parameters like the length of the regularization path or upper bounds on certain coefficients. These parameters depend on the individual estimation method. In general, we aim to provide sensible out-of-the-box defaults. This [page](estimators_and_methods.md) explains the difference in detail. `Estimator` classes often take a method parameter, to which either a string or an instance of the `EstimationMethod()` can be passed, e.g.

```python
from rolch import OnlineLinearModel, LassoPathMethod

fit_intercept = True
scale_inputs = True

model = OnlineLinearModel(
method="lasso", # default parameters
fit_intercept=fit_intercept,
scale_inputs=scale_inputs,
)
# or equivalent
model = OnlineLinearModel(
method=LassoPathMethod(), # default parameters
fit_intercept=fit_intercept,
scale_inputs=scale_inputs,
)
# or with user-defined parameters
model = OnlineLinearModel(
method=LassoPathMethod(
lambda_n=10
), # only 10 different regularization strengths
fit_intercept=fit_intercept,
scale_inputs=scale_inputs,
)
```

More information on coordinate descent can also be found on this [page](coordinate_descent.md) and in the API Reference below.

## API Reference

!!! note
We don't document the classmethods of the `EstimationMethod` since these are only used internally.


::: rolch.OrdinaryLeastSquaresMethod

::: rolch.LassoPathMethod
Loading

0 comments on commit ac689bf

Please sign in to comment.