-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH Enables array_api for LinearDiscriminantAnalysis #102
Changes from 3 commits
5e0e6ea
5c9b7c2
222fc69
a730f92
8b57acc
9406008
b94f36a
1e0fdc7
02d4d44
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,15 +11,18 @@ | |
|
||
import warnings | ||
import numpy as np | ||
import scipy.linalg | ||
from scipy import linalg | ||
from scipy.special import expit | ||
import math | ||
|
||
from .base import BaseEstimator, TransformerMixin, ClassifierMixin | ||
from .base import _ClassNamePrefixFeaturesOutMixin | ||
from .linear_model._base import LinearClassifierMixin | ||
from .covariance import ledoit_wolf, empirical_covariance, shrunk_covariance | ||
from .utils.multiclass import unique_labels | ||
from .utils.validation import check_is_fitted | ||
from .utils._array_api import get_namespace | ||
from .utils.multiclass import check_classification_targets | ||
from .utils.extmath import softmax | ||
from .preprocessing import StandardScaler | ||
|
@@ -110,11 +113,17 @@ def _class_means(X, y): | |
means : array-like of shape (n_classes, n_features) | ||
Class means. | ||
""" | ||
np, is_array_api = get_namespace(X) | ||
classes, y = np.unique(y, return_inverse=True) | ||
cnt = np.bincount(y) | ||
means = np.zeros(shape=(len(classes), X.shape[1])) | ||
np.add.at(means, y, X) | ||
means /= cnt[:, None] | ||
means = np.zeros(shape=(classes.shape[0], X.shape[1])) | ||
|
||
if is_array_api: | ||
for i in range(classes.shape[0]): | ||
means[i, :] = np.mean(X[y == i], axis=0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be +=? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using mean here avoids needing to divide by the count. (In the end, I think it's the same computation. It's basically a groupby + mean aggregation.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I see, you moved the aggregation from I think you could get rid of the loop with something like
except Also, I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thanks for the information! Some functionality would be hard to reproduce without boolean indexing. Looking into "Data-dependent output shape" more, I see that (I wanted to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeah of course unique wouldn't work in such APIs either. So there's no point in worrying about it here. |
||
else: | ||
cnt = np.bincount(y) | ||
np.add.at(means, y, X) | ||
means /= cnt[:, None] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since |
||
return means | ||
|
||
|
||
|
@@ -464,50 +473,53 @@ def _solve_svd(self, X, y): | |
y : array-like of shape (n_samples,) or (n_samples, n_targets) | ||
Target values. | ||
""" | ||
np, is_array_api = get_namespace(X) | ||
|
||
if is_array_api: | ||
svd = np.linalg.svd | ||
else: | ||
svd = scipy.linalg.svd | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This one is a bit interesting. Since I already have a wrapper for numpy: |
||
|
||
n_samples, n_features = X.shape | ||
n_classes = len(self.classes_) | ||
n_classes = self.classes_.shape[0] | ||
|
||
self.means_ = _class_means(X, y) | ||
if self.store_covariance: | ||
self.covariance_ = _class_cov(X, y, self.priors_) | ||
|
||
Xc = [] | ||
for idx, group in enumerate(self.classes_): | ||
Xg = X[y == group, :] | ||
Xc.append(Xg - self.means_[idx]) | ||
Xg = X[y == group] | ||
Xc.append(Xg - self.means_[idx, :]) | ||
|
||
self.xbar_ = np.dot(self.priors_, self.means_) | ||
self.xbar_ = self.priors_ @ self.means_ | ||
|
||
Xc = np.concatenate(Xc, axis=0) | ||
|
||
# 1) within (univariate) scaling by with classes std-dev | ||
std = Xc.std(axis=0) | ||
std = np.std(Xc, axis=0) | ||
# avoid division by zero in normalization | ||
std[std == 0] = 1.0 | ||
fac = 1.0 / (n_samples - n_classes) | ||
|
||
# 2) Within variance scaling | ||
X = np.sqrt(fac) * (Xc / std) | ||
X = math.sqrt(fac) * (Xc / std) | ||
# SVD of centered (within)scaled data | ||
U, S, Vt = linalg.svd(X, full_matrices=False) | ||
U, S, Vt = svd(X, full_matrices=False) | ||
|
||
rank = np.sum(S > self.tol) | ||
rank = np.sum(np.astype(S > self.tol, np.int32)) | ||
# Scaling of within covariance is: V' 1/S | ||
scalings = (Vt[:rank] / std).T / S[:rank] | ||
scalings = (Vt[:rank, :] / std).T / S[:rank] | ||
|
||
# 3) Between variance scaling | ||
# Scale weighted centers | ||
X = np.dot( | ||
( | ||
(np.sqrt((n_samples * self.priors_) * fac)) | ||
* (self.means_ - self.xbar_).T | ||
).T, | ||
scalings, | ||
) | ||
X = ( | ||
(np.sqrt((n_samples * self.priors_) * fac)) * (self.means_ - self.xbar_).T | ||
).T @ scalings | ||
# Centers are living in a space with n_classes-1 dim (maximum) | ||
# Use SVD to find projection in the space spanned by the | ||
# (n_classes) centers | ||
_, S, Vt = linalg.svd(X, full_matrices=0) | ||
_, S, Vt = svd(X, full_matrices=False) | ||
|
||
if self._max_components == 0: | ||
self.explained_variance_ratio_ = np.empty((0,), dtype=S.dtype) | ||
|
@@ -516,12 +528,12 @@ def _solve_svd(self, X, y): | |
: self._max_components | ||
] | ||
|
||
rank = np.sum(S > self.tol * S[0]) | ||
self.scalings_ = np.dot(scalings, Vt.T[:, :rank]) | ||
coef = np.dot(self.means_ - self.xbar_, self.scalings_) | ||
rank = np.sum(np.astype(S > self.tol * S[0], np.int32)) | ||
self.scalings_ = scalings @ Vt.T[:, :rank] | ||
coef = (self.means_ - self.xbar_) @ self.scalings_ | ||
self.intercept_ = -0.5 * np.sum(coef ** 2, axis=1) + np.log(self.priors_) | ||
self.coef_ = np.dot(coef, self.scalings_.T) | ||
self.intercept_ -= np.dot(self.xbar_, self.coef_.T) | ||
self.coef_ = coef @ self.scalings_.T | ||
self.intercept_ -= self.xbar_ @ self.coef_.T | ||
|
||
def fit(self, X, y): | ||
"""Fit the Linear Discriminant Analysis model. | ||
|
@@ -545,12 +557,13 @@ def fit(self, X, y): | |
self : object | ||
Fitted estimator. | ||
""" | ||
np, _ = get_namespace(X) | ||
X, y = self._validate_data( | ||
X, y, ensure_min_samples=2, dtype=[np.float64, np.float32] | ||
) | ||
self.classes_ = unique_labels(y) | ||
n_samples, _ = X.shape | ||
n_classes = len(self.classes_) | ||
n_classes = self.classes_.shape[0] | ||
|
||
if n_samples == n_classes: | ||
raise ValueError( | ||
|
@@ -559,19 +572,21 @@ def fit(self, X, y): | |
|
||
if self.priors is None: # estimate priors from sample | ||
_, y_t = np.unique(y, return_inverse=True) # non-negative ints | ||
self.priors_ = np.bincount(y_t) / float(len(y)) | ||
self.priors_ = np.astype(np.bincount(y_t), np.float64) / float(y.shape[0]) | ||
else: | ||
self.priors_ = np.asarray(self.priors) | ||
|
||
if (self.priors_ < 0).any(): | ||
if np.any(self.priors_ < 0): | ||
raise ValueError("priors must be non-negative") | ||
if not np.isclose(self.priors_.sum(), 1.0): | ||
warnings.warn("The priors do not sum to 1. Renormalizing", UserWarning) | ||
self.priors_ = self.priors_ / self.priors_.sum() | ||
|
||
# TODO: implement isclose in wrapper? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to implement our own |
||
# if not np.isclose(np.sum(self.priors_), 1.0): | ||
# warnings.warn("The priors do not sum to 1. Renormalizing", UserWarning) | ||
# self.priors_ = self.priors_ / self.priors_.sum() | ||
|
||
# Maximum number of components no matter what n_components is | ||
# specified: | ||
max_components = min(len(self.classes_) - 1, X.shape[1]) | ||
max_components = min(n_classes - 1, X.shape[1]) | ||
|
||
if self.n_components is None: | ||
self._max_components = max_components | ||
|
@@ -612,12 +627,12 @@ def fit(self, X, y): | |
"'lsqr', and 'eigen').".format(self.solver) | ||
) | ||
if self.classes_.size == 2: # treat binary case as a special case | ||
self.coef_ = np.array( | ||
self.coef_[1, :] - self.coef_[0, :], ndmin=2, dtype=X.dtype | ||
) | ||
self.intercept_ = np.array( | ||
self.intercept_[1] - self.intercept_[0], ndmin=1, dtype=X.dtype | ||
coef_ = np.asarray(self.coef_[1, :] - self.coef_[0, :], dtype=X.dtype) | ||
self.coef_ = np.reshape(coef_, (1, -1)) | ||
intercept_ = np.asarray( | ||
self.intercept_[1] - self.intercept_[0], dtype=X.dtype | ||
) | ||
self.intercept_ = np.reshape(intercept_, 1) | ||
self._n_features_out = self._max_components | ||
return self | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
"""Tools to support array_api.""" | ||
import numpy | ||
from .._config import get_config | ||
|
||
from contextlib import nullcontext | ||
|
||
|
||
# There are more clever ways to wrap the API to ignore kwargs, but I am writing them out | ||
# explicitly for demonstration purposes | ||
class _ArrayAPIWrapper: | ||
def __init__(self, array_namespace): | ||
self._array_namespace = array_namespace | ||
|
||
def __getattr__(self, name): | ||
return getattr(self._array_namespace, name) | ||
|
||
def errstate(self, *args, **kwargs): | ||
# errstate not in `array_api` | ||
return nullcontext() | ||
|
||
def astype(self, x, dtype, *, copy=True, **kwargs): | ||
# ignore parameters that is not supported by array-api | ||
f = self._array_namespace.astype | ||
return f(x, dtype, copy=copy) | ||
|
||
def asarray(self, obj, dtype=None, device=None, copy=True, **kwargs): | ||
f = self._array_namespace.asarray | ||
return f(obj, dtype=dtype, device=device, copy=copy) | ||
|
||
def may_share_memory(self, *args, **kwargs): | ||
# The safe choice is to return True all the time | ||
return True | ||
|
||
def asanyarray(self, array, *args, **kwargs): | ||
# noop | ||
return array | ||
|
||
def concatenate(self, arrays, *, axis=0, **kwargs): | ||
# ignore parameters that is not supported by array-api | ||
f = self._array_namespace.concat | ||
return f(arrays, axis=axis) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's either this or what we see in https://github.com/scipy/scipy/pull/15395/files: def _concatenate(arrays, axis):
xp = _get_namespace(*arrays)
if xp is np:
return xp.concatenate(arrays, axis=axis)
else:
return xp.concat(arrays, axis=axis) And importing |
||
|
||
def unique(self, x, return_inverse=False): | ||
if return_inverse: | ||
f = self._array_namespace.unique_inverse | ||
else: | ||
f = self._array_namespace.unique_values | ||
return f(x) | ||
|
||
def bincount(self, x): | ||
f = self._array_namespace.unique_counts | ||
return f(x)[1] | ||
|
||
|
||
class _NumPyApiWrapper: | ||
def __getattr__(self, name): | ||
return getattr(numpy, name) | ||
|
||
def astype(self, x, dtype, *args, **kwargs): | ||
# astype is not defined in the top level numpy namespace | ||
return x.astype(dtype, *args, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wrap |
||
|
||
|
||
def get_namespace(*xs): | ||
# `xs` contains one or more arrays, or possibly Python scalars (accepting | ||
# those is a matter of taste, but doesn't seem unreasonable). | ||
# Returns a tuple: (array_namespace, is_array_api) | ||
|
||
if not get_config()["array_api_dispatch"]: | ||
return _NumPyApiWrapper(), False | ||
|
||
namespaces = { | ||
x.__array_namespace__() if hasattr(x, "__array_namespace__") else None | ||
for x in xs | ||
if not isinstance(x, (bool, int, float, complex)) | ||
} | ||
|
||
if not namespaces: | ||
# one could special-case np.ndarray above or use np.asarray here if | ||
# older numpy versions need to be supported. | ||
raise ValueError("Unrecognized array input") | ||
|
||
if len(namespaces) != 1: | ||
raise ValueError(f"Multiple namespaces for array inputs: {namespaces}") | ||
|
||
(xp,) = namespaces | ||
if xp is None: | ||
# Use numpy as default | ||
return _NumPyApiWrapper(), False | ||
|
||
return _ArrayAPIWrapper(xp), True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unique
is not in the array API (this would beunique_inverse
).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, you are using a wrapper below. IMO it would be better to use the array API APIs in the wrapper and wrap non-compatible NumPy conventions to match them, not the other way around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, going the other direction also works. I'll try it the other way around and see how it compares.
My guess is that it's fine.