Skip to content

Commit

Permalink
[DOC] Update type hints for similarity search (#1939)
Browse files Browse the repository at this point in the history
* Updated type hints

* Added an additional type hint

* Ignore flake8 errors for long lines
  • Loading branch information
phershbe authored Aug 11, 2024
1 parent 9b68ab6 commit d51ff5c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 34 deletions.
13 changes: 7 additions & 6 deletions aeon/similarity_search/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
__maintainer__ = ["baraline"]

from abc import ABC, abstractmethod
from typing import Union, final
from typing import Optional, final

import numpy as np
from numba import get_num_threads, set_num_threads
from numba.typed import List

Expand Down Expand Up @@ -39,7 +40,7 @@ class BaseSimilaritySearch(BaseCollectionEstimator, ABC):
Attributes
----------
X_ : array, shape (n_cases, n_channels, n_timepoints)
X_ : np.ndarray, 3D array of shape (n_cases, n_channels, n_timepoints)
The input time series stored during the fit method.
Notes
Expand All @@ -59,7 +60,7 @@ class BaseSimilaritySearch(BaseCollectionEstimator, ABC):
def __init__(
self,
distance: str = "euclidean",
distance_args: Union[None, dict] = None,
distance_args: Optional[dict] = None,
inverse_distance: bool = False,
normalize: bool = False,
speed_up: str = "fastest",
Expand All @@ -74,14 +75,14 @@ def __init__(
super().__init__()

@final
def fit(self, X, y=None):
def fit(self, X: np.ndarray, y=None):
"""
Fit method: data preprocessing and storage.
Parameters
----------
X : array, shape (n_cases, n_channels, n_timepoints)
Input array to used as database for the similarity search
X : np.ndarray, 3D array of shape (n_cases, n_channels, n_timepoints)
Input array to be used as database for the similarity search
y : optional
Not used.
Expand Down
62 changes: 34 additions & 28 deletions aeon/similarity_search/query_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import warnings
from collections.abc import Iterable
from typing import Union, final
from typing import Optional, final

import numpy as np
from numba import get_num_threads, set_num_threads
Expand Down Expand Up @@ -74,7 +74,7 @@ class QuerySearch(BaseSimilaritySearch):
Attributes
----------
X_ : array, shape (n_cases, n_channels, n_timepoints)
X_ : np.ndarray, 3D array of shape (n_cases, n_channels, n_timepoints)
The input time series stored during the fit method. This is the
database we search in when given a query.
distance_profile_function : function
Expand All @@ -94,7 +94,7 @@ def __init__(
k: int = 1,
threshold: float = np.inf,
distance: str = "euclidean",
distance_args: Union[None, dict] = None,
distance_args: Optional[dict] = None,
inverse_distance: bool = False,
normalize: bool = False,
speed_up: str = "fastest",
Expand All @@ -116,13 +116,13 @@ def __init__(
n_jobs=n_jobs,
)

def _fit(self, X, y=None):
def _fit(self, X: np.ndarray, y=None):
"""
Check input format and store it to be used as search space during predict.
Parameters
----------
X : array, shape (n_cases, n_channels, n_timepoints)
X : np.ndarray, 3D array of shape (n_cases, n_channels, n_timepoints)
Input array to used as database for the similarity search
y : optional
Not used.
Expand All @@ -145,23 +145,23 @@ def _fit(self, X, y=None):
@final
def predict(
self,
X,
X: np.ndarray,
axis=1,
X_index=None,
exclusion_factor=2.0,
apply_exclusion_to_result=False,
):
) -> np.ndarray:
"""
Predict method: Check the shape of X and call _predict to perform the search.
Predict method : Check the shape of X and call _predict to perform the search.
If the distance profile function is normalized, it stores the mean and stds
from X and X_, with X_ the training data.
Parameters
----------
X : array, shape (n_channels, query_length)
X : np.ndarray, 2D array of shape (n_channels, query_length)
Input query used for similarity search.
axis: int
axis : int
The time point axis of the input series if it is 2D. If ``axis==0``, it is
assumed each column is a time series and each row is a time point. i.e. the
shape of the data is ``(n_timepoints,n_channels)``. ``axis==1`` indicates
Expand All @@ -182,7 +182,7 @@ def predict(
the matching conditions defined by child classes. For example, with
TopKSimilaritySearch, the k best matches are also subject to the exclusion
zone, but with :math:`id_timestamp` the index of one of the k matches.
apply_exclusion_to_result: bool, default=False
apply_exclusion_to_result : bool, default=False
Wheter to apply the exclusion factor to the output of the similarity search.
This means that two matches of the query from the same sample must be at
least spaced by +/- :math:`query_length//exclusion_factor`.
Expand All @@ -200,7 +200,7 @@ def predict(
Returns
-------
array, shape (n_matches, 2)
np.ndarray, 2D array of shape (n_matches, 2)
An array containing the indexes of the matches between X and X_.
The decision of wheter a candidate of size query_length from X_ is matched
with X depends on the subclasses that implent the _predict method
Expand Down Expand Up @@ -240,7 +240,9 @@ def predict(
set_num_threads(prev_threads)
return X_preds

def _predict(self, distance_profiles, exclusion_size=None):
def _predict(
self, distance_profiles: np.ndarray, exclusion_size: Optional[int] = None
) -> np.ndarray:
"""
Private predict method for QuerySearch.
Expand All @@ -249,7 +251,7 @@ def _predict(self, distance_profiles, exclusion_size=None):
Parameters
----------
distance_profiles : array, shape (n_cases, n_timepoints - query_length + 1)
distance_profiles : np.ndarray, 2D array of shape (n_cases, n_timepoints - query_length + 1) # noqa: E501
Precomputed distance profile.
exclusion_size : int, optional
The size of the exclusion zone used to prevent returning as top k candidates
Expand All @@ -262,7 +264,7 @@ def _predict(self, distance_profiles, exclusion_size=None):
Returns
-------
array
np.ndarray
An array containing the indexes of the best k matches between q and _X.
"""
Expand Down Expand Up @@ -347,15 +349,19 @@ def _predict(self, distance_profiles, exclusion_size=None):
return top_k[:n_inserted]

def _init_X_index_mask(
self, X_index, query_dim, query_length, exclusion_factor=2.0
):
self,
X_index: Optional[Iterable[int]],
query_dim: int,
query_length: int,
exclusion_factor: Optional[float] = 2.0,
) -> np.ndarray:
"""
Initiliaze the mask indicating the candidates to be evaluated in the search.
Parameters
----------
X_index : Iterable
An Interable (tuple, list, array) of length two used to specify the index of
An Iterable (tuple, list, array) of length two used to specify the index of
the query X if it was extracted from the input data X given during the fit
method. Given the tuple (id_sample, id_timestamp), the similarity search
will define an exclusion zone around the X_index in order to avoid matching
Expand All @@ -381,7 +387,7 @@ def _init_X_index_mask(
Returns
-------
mask : array, shape=(n_cases, n_timepoints - query_length + 1)
mask : np.ndarray, 2D array of shape (n_cases, n_timepoints - query_length + 1)
Boolean array which indicates the candidates that should be evaluated in the
similarity search.
Expand Down Expand Up @@ -505,21 +511,21 @@ def _get_distance_profile_function(self):
else:
return naive_distance_profile

def _call_distance_profile(self, X, mask):
def _call_distance_profile(self, X: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""
Obtain the distance profile function and call it with the query and the mask.
Parameters
----------
X : array, shape (n_channels, query_length)
X : np.ndarray, 2D array of shape (n_channels, query_length)
Input query used for similarity search.
mask : array, shape=(n_cases, n_timepoints - query_length + 1)
Boolean array which indicates the candidates that should be evaluated in
the similarity search.
mask : np.ndarray, 2D array of shape (n_cases, n_timepoints - query_length + 1)
Boolean array which indicates the candidates that should be evaluated in
the similarity search.
Returns
-------
distance_profiles : array, shape=(n_cases, n_timepoints - query_length + 1)
distance_profiles : np.ndarray, 2D array of shape (n_cases, n_timepoints - query_length + 1) # noqa: E501
The distance profiles between the input time series and the query.
"""
Expand Down Expand Up @@ -566,7 +572,7 @@ def _call_distance_profile(self, X, mask):
distance_profiles = distance_profiles.sum(axis=1)
return distance_profiles

def _store_mean_std_from_inputs(self, query_length):
def _store_mean_std_from_inputs(self, query_length: int) -> None:
"""
Store the mean and std of each subsequence of size query_length in X_.
Expand All @@ -577,7 +583,7 @@ def _store_mean_std_from_inputs(self, query_length):
Returns
-------
None.
None
"""
means = []
Expand All @@ -593,7 +599,7 @@ def _store_mean_std_from_inputs(self, query_length):
self.X_stds_ = List(stds)

@classmethod
def get_speedup_function_names(self):
def get_speedup_function_names(self) -> dict:
"""
Get available speedup for similarity search in aeon.
Expand Down

0 comments on commit d51ff5c

Please sign in to comment.