Fast anchor regularized local linear model
This work has been done at RIKEN AIP by Mathis Petrovich and Makoto Yamada.
git clone https://github.com/Mathux/Fall.git
cd fall
This code needs Python 3 and the following packages:
- click
- numpy
- pandas
- optuna
- sklearn
Optional packages:
- matplotlib (for the figures)
- xlrd (for loading Excel format datasets)
All the packages can be installed with:
pip install --user -r requirements.txt
The implementation of our method can be found in src/models/fall.py
. Please import Fall
from this file and use it as a sklearn model (model.fit
and model.predict
)
# Import the model
from src.models.fall import Fall
# Import some dataset
from src.data.concrete import Concrete
# Create the model
model = Fall(k=40, K_anchors=20, K_prediction=5, lam=10)
# Load the data
data = Concrete().datasets()
# Fit the model
model.fit(data["train"]["X"], data["train"]["Y"])
# Predict with the test data
y_pred = model.predict(data["test"]["X"])
# Compute the mean squared error
from sklearn.metrics import mean_squared_error
error = mean_squared_error(y_pred, data["test"]["Y"])
Usage: python main.py [OPTIONS]
Options:
--dataset [fishtoxicity|aquatictoxicity|concrete|superconduct|parkinson_updrs_motor|parkinson_updrs_total]
--model [all|fall|netlasso|ridge|lasso|knn|krr]
--cv INTEGER Number of splits for the cross-validation.
--nepochs INTEGER Number of times to compute everything.
--timeout INTEGER Timeout for computing the best hyper
parameters.
--output TEXT Name of the latex output file to store the
results.
--verbose / --no-verbose Verbosity for optuna.
--help Show this message and exit.
For exemple to test our method on the fishtoxicity dataset for 1 epochs, by doing a 3-cross-validation with 3 seconds to find the hyper parameters:
python main.py --dataset fishtoxicity --model fall --cv 3 --nepochs 1 --timeout 3
You can recreate the figures of the paper by running this.
Usage: python figures.py [OPTIONS]
Options:
--folder TEXT Folder to store the figures.
--show / --no-show Show the plots during the generation.
--help Show this message and exit.
If you find this implementation useful in your work, please acknowledge it appropriately and cite the paper:
@techreport{PetM:2020,
author = "Petrovich, Mathis and Yamada, Makoto",
title = "Fast local linear regression with anchor regularization",
year = "2020",
month = "2",
url = "http://hdl.handle.net/2433/245860",
abstract = "Regression is an important task in machine learning and data mining. It has several applications in various domains, including finance, biomedical, and computer vision. Recently, network Lasso, which estimates local models by making clusters using the network information, was proposed and its superior performance was demonstrated. In this study, we propose a simple yet effective local model training algorithm called the fast anchor regularized local linear method (FALL). More specifically, we train a local model for each sample by regularizing it with precomputed anchor models. The key advantage of the proposed algorithm is that we can obtain a closed-form solution with only matrix multiplication; additionally, the proposed algorithm is easily interpretable, fast to compute and parallelizable. Through experiments on synthetic and real-world datasets, we demonstrate that FALL compares favorably in terms of accuracy with the state-ofthe- art network Lasso algorithm with significantly smaller training time (two orders of magnitude)."
}